From d1cd004d715b498d4f1bad142be3dd1e953464ab Mon Sep 17 00:00:00 2001 From: Tpt Date: Sat, 15 Oct 2022 19:03:31 +0200 Subject: [PATCH] Adds magic functions to stubs --- python/generate_stubs.py | 76 +++++++++++++++++++++++++++------------- 1 file changed, 51 insertions(+), 25 deletions(-) diff --git a/python/generate_stubs.py b/python/generate_stubs.py index 4a28875f..0be189f5 100644 --- a/python/generate_stubs.py +++ b/python/generate_stubs.py @@ -6,7 +6,7 @@ import logging import re import subprocess from functools import reduce -from typing import Set +from typing import Set, List, Mapping, Any AST_LOAD = ast.Load() AST_ELLIPSIS = ast.Ellipsis() @@ -22,6 +22,18 @@ GENERICS = { value=ast.Name(id="typing", ctx=AST_LOAD), attr="List", ctx=AST_LOAD ), } +OBJECT_MEMBERS = dict(inspect.getmembers(object)) + + +ATTRIBUTES_BLACKLIST = { + "__class__", + "__dir__", + "__doc__", + "__init_subclass__", + "__module__", + "__new__", + "__subclasshook__", +} def module_stubs(module) -> ast.Module: @@ -34,9 +46,7 @@ def module_stubs(module) -> ast.Module: elif inspect.isclass(member_value): classes.append(class_stubs(member_name, member_value, types_to_import)) elif inspect.isbuiltin(member_value): - functions.append( - function_stub(member_name, member_value, types_to_import, is_init=False) - ) + functions.append(function_stub(member_name, member_value, types_to_import)) else: logging.warning(f"Unsupported root construction {member_name}") return ast.Module( @@ -48,29 +58,32 @@ def module_stubs(module) -> ast.Module: def class_stubs(cls_name: str, cls_def, types_to_import: Set[str]) -> ast.ClassDef: - attributes = [] - methods = [] + attributes: List[ast.AST] = [] + methods: List[ast.AST] = [] + magic_methods: List[ast.AST] = [] for (member_name, member_value) in inspect.getmembers(cls_def): if member_name == "__init__": try: inspect.signature(cls_def) # we check it actually exists methods = [ - function_stub(member_name, cls_def, types_to_import, is_init=True) + function_stub(member_name, cls_def, types_to_import) ] + methods except ValueError as e: if "no signature found" not in str(e): raise ValueError( f"Error while parsing signature of {cls_name}.__init__: {e}" ) - elif member_name.startswith("__"): + elif member_name in ATTRIBUTES_BLACKLIST or member_value == OBJECT_MEMBERS.get( + member_name + ): pass elif inspect.isdatadescriptor(member_value): attributes.extend( data_descriptor_stub(member_name, member_value, types_to_import) ) elif inspect.isroutine(member_value): - methods.append( - function_stub(member_name, member_value, types_to_import, is_init=False) + (magic_methods if member_name.startswith("__") else methods).append( + function_stub(member_name, member_value, types_to_import) ) else: logging.warning(f"Unsupported member {member_name} of class {cls_name}") @@ -80,7 +93,12 @@ def class_stubs(cls_name: str, cls_def, types_to_import: Set[str]) -> ast.ClassD cls_name, bases=[], keywords=[], - body=(([build_doc_comment(doc)] if doc else []) + attributes + methods) + body=( + ([build_doc_comment(doc)] if doc else []) + + attributes + + methods + + magic_methods + ) or [AST_ELLIPSIS], decorator_list=[ ast.Attribute( @@ -113,17 +131,15 @@ def data_descriptor_stub( return (assign, build_doc_comment(doc_comment)) if doc_comment else (assign,) -def function_stub( - fn_name: str, fn_def, types_to_import: Set[str], is_init: bool -) -> ast.FunctionDef: +def function_stub(fn_name: str, fn_def, types_to_import: Set[str]) -> ast.FunctionDef: body = [] doc = inspect.getdoc(fn_def) - if doc is not None and not is_init: + if doc is not None and not fn_name.startswith("__"): body.append(build_doc_comment(doc)) return ast.FunctionDef( fn_name, - arguments_stub(fn_def, doc or "", types_to_import, is_init), + arguments_stub(fn_name, fn_def, doc or "", types_to_import), body or [AST_ELLIPSIS], decorator_list=[], returns=returns_stub(doc, types_to_import) if doc else None, @@ -131,9 +147,11 @@ def function_stub( ) -def arguments_stub(callable, doc: str, types_to_import: Set[str], is_init: bool): - real_parameters = inspect.signature(callable).parameters - if is_init: +def arguments_stub(callable_name, callable_def, doc: str, types_to_import: Set[str]): + real_parameters: Mapping[str, inspect.Parameter] = inspect.signature( + callable_def + ).parameters + if callable_name == "__init__": real_parameters = { "self": inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY), **real_parameters, @@ -161,9 +179,13 @@ def arguments_stub(callable, doc: str, types_to_import: Set[str], is_init: bool) kwarg = None defaults = [] for param in real_parameters.values(): - if param.name != "self" and param.name not in parsed_param_types: + if ( + param.name != "self" + and param.name not in parsed_param_types + and (callable_name == "__init__" or not callable_name.startswith("__")) + ): raise ValueError( - f"The parameter {param.name} has no type definition in the function documentation" + f"The parameter {param.name} of {callable_name} has no type definition in the function documentation" ) param_ast = ast.arg( arg=param.name, annotation=parsed_param_types.get(param.name) @@ -173,9 +195,13 @@ def arguments_stub(callable, doc: str, types_to_import: Set[str], is_init: bool) if param.default != param.empty: default_ast = ast.Constant(param.default) if param.name not in optional_params: - raise ValueError(f"Parameter {param.name} is optional according to the type but not flagged as such in the doc") + raise ValueError( + f"Parameter {param.name} is optional according to the type but not flagged as such in the doc" + ) elif param.name in optional_params: - raise ValueError(f"Parameter {param.name} is optional according to the documentation but has no default value") + raise ValueError( + f"Parameter {param.name} is optional according to the documentation but has no default value" + ) if param.kind == param.POSITIONAL_ONLY: posonlyargs.append(param_ast) @@ -234,10 +260,10 @@ def parse_type_to_ast(type_str: str, types_to_import: Set[str]): tokens.append(current_token) # let's first parse nested parenthesis - stack = [[]] + stack: List[List[Any]] = [[]] for token in tokens: if token == "(": - l = [] + l: List[str] = [] stack[-1].append(l) stack.append(l) elif token == ")":