Adds magic functions to stubs

pull/262/head
Tpt 2 years ago committed by Thomas Tanon
parent 5e13aee5be
commit d1cd004d71
  1. 76
      python/generate_stubs.py

@ -6,7 +6,7 @@ import logging
import re import re
import subprocess import subprocess
from functools import reduce from functools import reduce
from typing import Set from typing import Set, List, Mapping, Any
AST_LOAD = ast.Load() AST_LOAD = ast.Load()
AST_ELLIPSIS = ast.Ellipsis() AST_ELLIPSIS = ast.Ellipsis()
@ -22,6 +22,18 @@ GENERICS = {
value=ast.Name(id="typing", ctx=AST_LOAD), attr="List", ctx=AST_LOAD value=ast.Name(id="typing", ctx=AST_LOAD), attr="List", ctx=AST_LOAD
), ),
} }
OBJECT_MEMBERS = dict(inspect.getmembers(object))
ATTRIBUTES_BLACKLIST = {
"__class__",
"__dir__",
"__doc__",
"__init_subclass__",
"__module__",
"__new__",
"__subclasshook__",
}
def module_stubs(module) -> ast.Module: def module_stubs(module) -> ast.Module:
@ -34,9 +46,7 @@ def module_stubs(module) -> ast.Module:
elif inspect.isclass(member_value): elif inspect.isclass(member_value):
classes.append(class_stubs(member_name, member_value, types_to_import)) classes.append(class_stubs(member_name, member_value, types_to_import))
elif inspect.isbuiltin(member_value): elif inspect.isbuiltin(member_value):
functions.append( functions.append(function_stub(member_name, member_value, types_to_import))
function_stub(member_name, member_value, types_to_import, is_init=False)
)
else: else:
logging.warning(f"Unsupported root construction {member_name}") logging.warning(f"Unsupported root construction {member_name}")
return ast.Module( return ast.Module(
@ -48,29 +58,32 @@ def module_stubs(module) -> ast.Module:
def class_stubs(cls_name: str, cls_def, types_to_import: Set[str]) -> ast.ClassDef: def class_stubs(cls_name: str, cls_def, types_to_import: Set[str]) -> ast.ClassDef:
attributes = [] attributes: List[ast.AST] = []
methods = [] methods: List[ast.AST] = []
magic_methods: List[ast.AST] = []
for (member_name, member_value) in inspect.getmembers(cls_def): for (member_name, member_value) in inspect.getmembers(cls_def):
if member_name == "__init__": if member_name == "__init__":
try: try:
inspect.signature(cls_def) # we check it actually exists inspect.signature(cls_def) # we check it actually exists
methods = [ methods = [
function_stub(member_name, cls_def, types_to_import, is_init=True) function_stub(member_name, cls_def, types_to_import)
] + methods ] + methods
except ValueError as e: except ValueError as e:
if "no signature found" not in str(e): if "no signature found" not in str(e):
raise ValueError( raise ValueError(
f"Error while parsing signature of {cls_name}.__init__: {e}" f"Error while parsing signature of {cls_name}.__init__: {e}"
) )
elif member_name.startswith("__"): elif member_name in ATTRIBUTES_BLACKLIST or member_value == OBJECT_MEMBERS.get(
member_name
):
pass pass
elif inspect.isdatadescriptor(member_value): elif inspect.isdatadescriptor(member_value):
attributes.extend( attributes.extend(
data_descriptor_stub(member_name, member_value, types_to_import) data_descriptor_stub(member_name, member_value, types_to_import)
) )
elif inspect.isroutine(member_value): elif inspect.isroutine(member_value):
methods.append( (magic_methods if member_name.startswith("__") else methods).append(
function_stub(member_name, member_value, types_to_import, is_init=False) function_stub(member_name, member_value, types_to_import)
) )
else: else:
logging.warning(f"Unsupported member {member_name} of class {cls_name}") logging.warning(f"Unsupported member {member_name} of class {cls_name}")
@ -80,7 +93,12 @@ def class_stubs(cls_name: str, cls_def, types_to_import: Set[str]) -> ast.ClassD
cls_name, cls_name,
bases=[], bases=[],
keywords=[], keywords=[],
body=(([build_doc_comment(doc)] if doc else []) + attributes + methods) body=(
([build_doc_comment(doc)] if doc else [])
+ attributes
+ methods
+ magic_methods
)
or [AST_ELLIPSIS], or [AST_ELLIPSIS],
decorator_list=[ decorator_list=[
ast.Attribute( ast.Attribute(
@ -113,17 +131,15 @@ def data_descriptor_stub(
return (assign, build_doc_comment(doc_comment)) if doc_comment else (assign,) return (assign, build_doc_comment(doc_comment)) if doc_comment else (assign,)
def function_stub( def function_stub(fn_name: str, fn_def, types_to_import: Set[str]) -> ast.FunctionDef:
fn_name: str, fn_def, types_to_import: Set[str], is_init: bool
) -> ast.FunctionDef:
body = [] body = []
doc = inspect.getdoc(fn_def) doc = inspect.getdoc(fn_def)
if doc is not None and not is_init: if doc is not None and not fn_name.startswith("__"):
body.append(build_doc_comment(doc)) body.append(build_doc_comment(doc))
return ast.FunctionDef( return ast.FunctionDef(
fn_name, fn_name,
arguments_stub(fn_def, doc or "", types_to_import, is_init), arguments_stub(fn_name, fn_def, doc or "", types_to_import),
body or [AST_ELLIPSIS], body or [AST_ELLIPSIS],
decorator_list=[], decorator_list=[],
returns=returns_stub(doc, types_to_import) if doc else None, returns=returns_stub(doc, types_to_import) if doc else None,
@ -131,9 +147,11 @@ def function_stub(
) )
def arguments_stub(callable, doc: str, types_to_import: Set[str], is_init: bool): def arguments_stub(callable_name, callable_def, doc: str, types_to_import: Set[str]):
real_parameters = inspect.signature(callable).parameters real_parameters: Mapping[str, inspect.Parameter] = inspect.signature(
if is_init: callable_def
).parameters
if callable_name == "__init__":
real_parameters = { real_parameters = {
"self": inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY), "self": inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY),
**real_parameters, **real_parameters,
@ -161,9 +179,13 @@ def arguments_stub(callable, doc: str, types_to_import: Set[str], is_init: bool)
kwarg = None kwarg = None
defaults = [] defaults = []
for param in real_parameters.values(): for param in real_parameters.values():
if param.name != "self" and param.name not in parsed_param_types: if (
param.name != "self"
and param.name not in parsed_param_types
and (callable_name == "__init__" or not callable_name.startswith("__"))
):
raise ValueError( raise ValueError(
f"The parameter {param.name} has no type definition in the function documentation" f"The parameter {param.name} of {callable_name} has no type definition in the function documentation"
) )
param_ast = ast.arg( param_ast = ast.arg(
arg=param.name, annotation=parsed_param_types.get(param.name) arg=param.name, annotation=parsed_param_types.get(param.name)
@ -173,9 +195,13 @@ def arguments_stub(callable, doc: str, types_to_import: Set[str], is_init: bool)
if param.default != param.empty: if param.default != param.empty:
default_ast = ast.Constant(param.default) default_ast = ast.Constant(param.default)
if param.name not in optional_params: if param.name not in optional_params:
raise ValueError(f"Parameter {param.name} is optional according to the type but not flagged as such in the doc") raise ValueError(
f"Parameter {param.name} is optional according to the type but not flagged as such in the doc"
)
elif param.name in optional_params: elif param.name in optional_params:
raise ValueError(f"Parameter {param.name} is optional according to the documentation but has no default value") raise ValueError(
f"Parameter {param.name} is optional according to the documentation but has no default value"
)
if param.kind == param.POSITIONAL_ONLY: if param.kind == param.POSITIONAL_ONLY:
posonlyargs.append(param_ast) posonlyargs.append(param_ast)
@ -234,10 +260,10 @@ def parse_type_to_ast(type_str: str, types_to_import: Set[str]):
tokens.append(current_token) tokens.append(current_token)
# let's first parse nested parenthesis # let's first parse nested parenthesis
stack = [[]] stack: List[List[Any]] = [[]]
for token in tokens: for token in tokens:
if token == "(": if token == "(":
l = [] l: List[str] = []
stack[-1].append(l) stack[-1].append(l)
stack.append(l) stack.append(l)
elif token == ")": elif token == ")":

Loading…
Cancel
Save