|
|
|
import argparse
|
|
|
|
import ast
|
|
|
|
import importlib
|
|
|
|
import inspect
|
|
|
|
import logging
|
|
|
|
import re
|
|
|
|
import subprocess
|
|
|
|
from functools import reduce
|
|
|
|
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union
|
|
|
|
|
|
|
|
|
|
|
|
def path_to_type(*elements: str) -> ast.AST:
|
|
|
|
base: ast.AST = ast.Name(id=elements[0], ctx=ast.Load())
|
|
|
|
for e in elements[1:]:
|
|
|
|
base = ast.Attribute(value=base, attr=e, ctx=ast.Load())
|
|
|
|
return base
|
|
|
|
|
|
|
|
|
|
|
|
OBJECT_MEMBERS = dict(inspect.getmembers(object))
|
|
|
|
BUILTINS: Dict[str, Union[None, Tuple[List[ast.AST], ast.AST]]] = {
|
|
|
|
"__annotations__": None,
|
|
|
|
"__bool__": ([], path_to_type("bool")),
|
|
|
|
"__bytes__": ([], path_to_type("bytes")),
|
|
|
|
"__class__": None,
|
|
|
|
"__contains__": ([path_to_type("typing", "Any")], path_to_type("bool")),
|
|
|
|
"__del__": None,
|
|
|
|
"__delattr__": ([path_to_type("str")], path_to_type("None")),
|
|
|
|
"__delitem__": ([path_to_type("typing", "Any")], path_to_type("typing", "Any")),
|
|
|
|
"__dict__": None,
|
|
|
|
"__dir__": None,
|
|
|
|
"__doc__": None,
|
|
|
|
"__eq__": ([path_to_type("typing", "Any")], path_to_type("bool")),
|
|
|
|
"__format__": ([path_to_type("str")], path_to_type("str")),
|
|
|
|
"__ge__": ([path_to_type("typing", "Any")], path_to_type("bool")),
|
|
|
|
"__getattribute__": ([path_to_type("str")], path_to_type("typing", "Any")),
|
|
|
|
"__getitem__": ([path_to_type("typing", "Any")], path_to_type("typing", "Any")),
|
|
|
|
"__gt__": ([path_to_type("typing", "Any")], path_to_type("bool")),
|
|
|
|
"__hash__": ([], path_to_type("int")),
|
|
|
|
"__init__": ([], path_to_type("None")),
|
|
|
|
"__init_subclass__": None,
|
|
|
|
"__iter__": ([], path_to_type("typing", "Any")),
|
|
|
|
"__le__": ([path_to_type("typing", "Any")], path_to_type("bool")),
|
|
|
|
"__len__": ([], path_to_type("int")),
|
|
|
|
"__lt__": ([path_to_type("typing", "Any")], path_to_type("bool")),
|
|
|
|
"__module__": None,
|
|
|
|
"__ne__": ([path_to_type("typing", "Any")], path_to_type("bool")),
|
|
|
|
"__new__": None,
|
|
|
|
"__next__": ([], path_to_type("typing", "Any")),
|
|
|
|
"__reduce__": None,
|
|
|
|
"__reduce_ex__": None,
|
|
|
|
"__repr__": ([], path_to_type("str")),
|
|
|
|
"__setattr__": (
|
|
|
|
[path_to_type("str"), path_to_type("typing", "Any")],
|
|
|
|
path_to_type("None"),
|
|
|
|
),
|
|
|
|
"__setitem__": (
|
|
|
|
[path_to_type("typing", "Any"), path_to_type("typing", "Any")],
|
|
|
|
path_to_type("typing", "Any"),
|
|
|
|
),
|
|
|
|
"__sizeof__": None,
|
|
|
|
"__str__": ([], path_to_type("str")),
|
|
|
|
"__subclasshook__": None,
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def module_stubs(module: Any) -> ast.Module:
|
|
|
|
types_to_import = {"typing"}
|
|
|
|
classes = []
|
|
|
|
functions = []
|
|
|
|
for member_name, member_value in inspect.getmembers(module):
|
|
|
|
element_path = [module.__name__, member_name]
|
|
|
|
if member_name.startswith("__"):
|
|
|
|
pass
|
|
|
|
elif inspect.isclass(member_value):
|
|
|
|
classes.append(class_stubs(member_name, member_value, element_path, types_to_import))
|
|
|
|
elif inspect.isbuiltin(member_value):
|
|
|
|
functions.append(
|
|
|
|
function_stub(
|
|
|
|
member_name,
|
|
|
|
member_value,
|
|
|
|
element_path,
|
|
|
|
types_to_import,
|
|
|
|
in_class=False,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
logging.warning(f"Unsupported root construction {member_name}")
|
|
|
|
return ast.Module(
|
|
|
|
body=[ast.Import(names=[ast.alias(name=t)]) for t in sorted(types_to_import)] + classes + functions,
|
|
|
|
type_ignores=[],
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def class_stubs(cls_name: str, cls_def: Any, element_path: List[str], types_to_import: Set[str]) -> ast.ClassDef:
|
|
|
|
attributes: List[ast.AST] = []
|
|
|
|
methods: List[ast.AST] = []
|
|
|
|
magic_methods: List[ast.AST] = []
|
|
|
|
constants: List[ast.AST] = []
|
|
|
|
for member_name, member_value in inspect.getmembers(cls_def):
|
|
|
|
current_element_path = [*element_path, member_name]
|
|
|
|
if member_name == "__init__":
|
|
|
|
try:
|
|
|
|
inspect.signature(cls_def) # we check it actually exists
|
|
|
|
methods = [
|
|
|
|
function_stub(
|
|
|
|
member_name,
|
|
|
|
cls_def,
|
|
|
|
current_element_path,
|
|
|
|
types_to_import,
|
|
|
|
in_class=True,
|
|
|
|
),
|
|
|
|
*methods,
|
|
|
|
]
|
|
|
|
except ValueError as e:
|
|
|
|
if "no signature found" not in str(e):
|
|
|
|
raise ValueError(f"Error while parsing signature of {cls_name}.__init_") from e
|
|
|
|
elif member_value == OBJECT_MEMBERS.get(member_name) or BUILTINS.get(member_name, ()) is None:
|
|
|
|
pass
|
|
|
|
elif inspect.isdatadescriptor(member_value):
|
|
|
|
attributes.extend(data_descriptor_stub(member_name, member_value, current_element_path, types_to_import))
|
|
|
|
elif inspect.isroutine(member_value):
|
|
|
|
(magic_methods if member_name.startswith("__") else methods).append(
|
|
|
|
function_stub(
|
|
|
|
member_name,
|
|
|
|
member_value,
|
|
|
|
current_element_path,
|
|
|
|
types_to_import,
|
|
|
|
in_class=True,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
elif member_name == "__match_args__":
|
|
|
|
constants.append(
|
|
|
|
ast.AnnAssign(
|
|
|
|
target=ast.Name(id=member_name, ctx=ast.Store()),
|
|
|
|
annotation=ast.Subscript(
|
|
|
|
value=path_to_type("tuple"),
|
|
|
|
slice=ast.Tuple(elts=[path_to_type("str"), ast.Ellipsis()], ctx=ast.Load()),
|
|
|
|
ctx=ast.Load(),
|
|
|
|
),
|
|
|
|
value=ast.Constant(member_value),
|
|
|
|
simple=1,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
elif member_value is not None:
|
|
|
|
constants.append(
|
|
|
|
ast.AnnAssign(
|
|
|
|
target=ast.Name(id=member_name, ctx=ast.Store()),
|
|
|
|
annotation=concatenated_path_to_type(
|
|
|
|
member_value.__class__.__name__, element_path, types_to_import
|
|
|
|
),
|
|
|
|
value=ast.Ellipsis(),
|
|
|
|
simple=1,
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
logging.warning(f"Unsupported member {member_name} of class {'.'.join(element_path)}")
|
|
|
|
|
|
|
|
doc = inspect.getdoc(cls_def)
|
|
|
|
doc_comment = build_doc_comment(doc) if doc else None
|
|
|
|
return ast.ClassDef(
|
|
|
|
cls_name,
|
|
|
|
bases=[],
|
|
|
|
keywords=[],
|
|
|
|
body=(([doc_comment] if doc_comment else []) + attributes + methods + magic_methods + constants)
|
|
|
|
or [ast.Ellipsis()],
|
|
|
|
decorator_list=[path_to_type("typing", "final")],
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def data_descriptor_stub(
|
|
|
|
data_desc_name: str,
|
|
|
|
data_desc_def: Any,
|
|
|
|
element_path: List[str],
|
|
|
|
types_to_import: Set[str],
|
|
|
|
) -> Union[Tuple[ast.AnnAssign, ast.Expr], Tuple[ast.AnnAssign]]:
|
|
|
|
annotation = None
|
|
|
|
doc_comment = None
|
|
|
|
|
|
|
|
doc = inspect.getdoc(data_desc_def)
|
|
|
|
if doc is not None:
|
|
|
|
annotation = returns_stub(data_desc_name, doc, element_path, types_to_import)
|
|
|
|
m = re.findall(r"^ *:return: *(.*) *$", doc, re.MULTILINE)
|
|
|
|
if len(m) == 1:
|
|
|
|
doc_comment = m[0]
|
|
|
|
elif len(m) > 1:
|
|
|
|
raise ValueError(
|
|
|
|
f"Multiple return annotations found with :return: in {'.'.join(element_path)} documentation"
|
|
|
|
)
|
|
|
|
|
|
|
|
assign = ast.AnnAssign(
|
|
|
|
target=ast.Name(id=data_desc_name, ctx=ast.Store()),
|
|
|
|
annotation=annotation or path_to_type("typing", "Any"),
|
|
|
|
simple=1,
|
|
|
|
)
|
|
|
|
doc_comment = build_doc_comment(doc_comment) if doc_comment else None
|
|
|
|
return (assign, doc_comment) if doc_comment else (assign,)
|
|
|
|
|
|
|
|
|
|
|
|
def function_stub(
|
|
|
|
fn_name: str,
|
|
|
|
fn_def: Any,
|
|
|
|
element_path: List[str],
|
|
|
|
types_to_import: Set[str],
|
|
|
|
*,
|
|
|
|
in_class: bool,
|
|
|
|
) -> ast.FunctionDef:
|
|
|
|
body: List[ast.AST] = []
|
|
|
|
doc = inspect.getdoc(fn_def)
|
|
|
|
if doc is not None:
|
|
|
|
doc_comment = build_doc_comment(doc)
|
|
|
|
if doc_comment is not None:
|
|
|
|
body.append(doc_comment)
|
|
|
|
|
|
|
|
decorator_list = []
|
|
|
|
if in_class and hasattr(fn_def, "__self__"):
|
|
|
|
decorator_list.append(ast.Name("staticmethod"))
|
|
|
|
|
|
|
|
return ast.FunctionDef(
|
|
|
|
fn_name,
|
|
|
|
arguments_stub(fn_name, fn_def, doc or "", element_path, types_to_import),
|
|
|
|
body or [ast.Ellipsis()],
|
|
|
|
decorator_list=decorator_list,
|
|
|
|
returns=returns_stub(fn_name, doc, element_path, types_to_import) if doc else None,
|
|
|
|
lineno=0,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def arguments_stub(
|
|
|
|
callable_name: str,
|
|
|
|
callable_def: Any,
|
|
|
|
doc: str,
|
|
|
|
element_path: List[str],
|
|
|
|
types_to_import: Set[str],
|
|
|
|
) -> ast.arguments:
|
|
|
|
real_parameters: Mapping[str, inspect.Parameter] = inspect.signature(callable_def).parameters
|
|
|
|
if callable_name == "__init__":
|
|
|
|
real_parameters = {
|
|
|
|
"self": inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY),
|
|
|
|
**real_parameters,
|
|
|
|
}
|
|
|
|
|
|
|
|
parsed_param_types = {}
|
|
|
|
optional_params = set()
|
|
|
|
|
|
|
|
# Types for magic functions types
|
|
|
|
builtin = BUILTINS.get(callable_name)
|
|
|
|
if isinstance(builtin, tuple):
|
|
|
|
param_names = list(real_parameters.keys())
|
|
|
|
if param_names and param_names[0] == "self":
|
|
|
|
del param_names[0]
|
|
|
|
for name, t in zip(param_names, builtin[0]):
|
|
|
|
parsed_param_types[name] = t
|
|
|
|
|
|
|
|
# Types from comment
|
|
|
|
for match in re.findall(r"^ *:type *([a-z_]+): ([^\n]*) *$", doc, re.MULTILINE):
|
|
|
|
if match[0] not in real_parameters:
|
|
|
|
raise ValueError(
|
|
|
|
f"The parameter {match[0]} of {'.'.join(element_path)} "
|
|
|
|
"is defined in the documentation but not in the function signature"
|
|
|
|
)
|
|
|
|
type = match[1]
|
|
|
|
if type.endswith(", optional"):
|
|
|
|
optional_params.add(match[0])
|
|
|
|
type = type[:-10]
|
|
|
|
parsed_param_types[match[0]] = convert_type_from_doc(type, element_path, types_to_import)
|
|
|
|
|
|
|
|
# we parse the parameters
|
|
|
|
posonlyargs = []
|
|
|
|
args = []
|
|
|
|
vararg = None
|
|
|
|
kwonlyargs = []
|
|
|
|
kw_defaults = []
|
|
|
|
kwarg = None
|
|
|
|
defaults = []
|
|
|
|
for param in real_parameters.values():
|
|
|
|
if param.name != "self" and param.name not in parsed_param_types:
|
|
|
|
raise ValueError(
|
|
|
|
f"The parameter {param.name} of {'.'.join(element_path)} "
|
|
|
|
"has no type definition in the function documentation"
|
|
|
|
)
|
|
|
|
param_ast = ast.arg(arg=param.name, annotation=parsed_param_types.get(param.name))
|
|
|
|
|
|
|
|
default_ast = None
|
|
|
|
if param.default != param.empty:
|
|
|
|
default_ast = ast.Constant(param.default)
|
|
|
|
if param.name not in optional_params:
|
|
|
|
raise ValueError(
|
|
|
|
f"Parameter {param.name} of {'.'.join(element_path)} "
|
|
|
|
"is optional according to the type but not flagged as such in the doc"
|
|
|
|
)
|
|
|
|
elif param.name in optional_params:
|
|
|
|
raise ValueError(
|
|
|
|
f"Parameter {param.name} of {'.'.join(element_path)} "
|
|
|
|
"is optional according to the documentation but has no default value"
|
|
|
|
)
|
|
|
|
|
|
|
|
if param.kind == param.POSITIONAL_ONLY:
|
|
|
|
posonlyargs.append(param_ast)
|
|
|
|
defaults.append(default_ast)
|
|
|
|
elif param.kind == param.POSITIONAL_OR_KEYWORD:
|
|
|
|
args.append(param_ast)
|
|
|
|
defaults.append(default_ast)
|
|
|
|
elif param.kind == param.VAR_POSITIONAL:
|
|
|
|
vararg = param_ast
|
|
|
|
elif param.kind == param.KEYWORD_ONLY:
|
|
|
|
kwonlyargs.append(param_ast)
|
|
|
|
kw_defaults.append(default_ast)
|
|
|
|
elif param.kind == param.VAR_KEYWORD:
|
|
|
|
kwarg = param_ast
|
|
|
|
|
|
|
|
return ast.arguments(
|
|
|
|
posonlyargs=posonlyargs,
|
|
|
|
args=args,
|
|
|
|
vararg=vararg,
|
|
|
|
kwonlyargs=kwonlyargs,
|
|
|
|
kw_defaults=kw_defaults,
|
|
|
|
defaults=defaults,
|
|
|
|
kwarg=kwarg,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def returns_stub(callable_name: str, doc: str, element_path: List[str], types_to_import: Set[str]) -> Optional[ast.AST]:
|
|
|
|
m = re.findall(r"^ *:rtype: *([^\n]*) *$", doc, re.MULTILINE)
|
|
|
|
if len(m) == 0:
|
|
|
|
builtin = BUILTINS.get(callable_name)
|
|
|
|
if isinstance(builtin, tuple) and builtin[1] is not None:
|
|
|
|
return builtin[1]
|
|
|
|
raise ValueError(
|
|
|
|
f"The return type of {'.'.join(element_path)} "
|
|
|
|
"has no type definition using :rtype: in the function documentation"
|
|
|
|
)
|
|
|
|
if len(m) > 1:
|
|
|
|
raise ValueError(f"Multiple return type annotations found with :rtype: for {'.'.join(element_path)}")
|
|
|
|
return convert_type_from_doc(m[0], element_path, types_to_import)
|
|
|
|
|
|
|
|
|
|
|
|
def convert_type_from_doc(type_str: str, element_path: List[str], types_to_import: Set[str]) -> ast.AST:
|
|
|
|
type_str = type_str.strip()
|
|
|
|
return parse_type_to_ast(type_str, element_path, types_to_import)
|
|
|
|
|
|
|
|
|
|
|
|
def parse_type_to_ast(type_str: str, element_path: List[str], types_to_import: Set[str]) -> ast.AST:
|
|
|
|
# let's tokenize
|
|
|
|
tokens = []
|
|
|
|
current_token = ""
|
|
|
|
for c in type_str:
|
|
|
|
if "a" <= c <= "z" or "A" <= c <= "Z" or c == ".":
|
|
|
|
current_token += c
|
|
|
|
else:
|
|
|
|
if current_token:
|
|
|
|
tokens.append(current_token)
|
|
|
|
current_token = ""
|
|
|
|
if c != " ":
|
|
|
|
tokens.append(c)
|
|
|
|
if current_token:
|
|
|
|
tokens.append(current_token)
|
|
|
|
|
|
|
|
# let's first parse nested parenthesis
|
|
|
|
stack: List[List[Any]] = [[]]
|
|
|
|
for token in tokens:
|
|
|
|
if token == "[":
|
|
|
|
children: List[str] = []
|
|
|
|
stack[-1].append(children)
|
|
|
|
stack.append(children)
|
|
|
|
elif token == "]":
|
|
|
|
stack.pop()
|
|
|
|
else:
|
|
|
|
stack[-1].append(token)
|
|
|
|
|
|
|
|
# then it's easy
|
|
|
|
def parse_sequence(sequence: List[Any]) -> ast.AST:
|
|
|
|
# we split based on "or"
|
|
|
|
or_groups: List[List[str]] = [[]]
|
|
|
|
for e in sequence:
|
|
|
|
if e == "or":
|
|
|
|
or_groups.append([])
|
|
|
|
else:
|
|
|
|
or_groups[-1].append(e)
|
|
|
|
if any(not g for g in or_groups):
|
|
|
|
raise ValueError(f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}")
|
|
|
|
|
|
|
|
new_elements: List[ast.AST] = []
|
|
|
|
for group in or_groups:
|
|
|
|
if len(group) == 1 and isinstance(group[0], str):
|
|
|
|
new_elements.append(concatenated_path_to_type(group[0], element_path, types_to_import))
|
|
|
|
elif len(group) == 2 and isinstance(group[0], str) and isinstance(group[1], list):
|
|
|
|
new_elements.append(
|
|
|
|
ast.Subscript(
|
|
|
|
value=concatenated_path_to_type(group[0], element_path, types_to_import),
|
|
|
|
slice=parse_sequence(group[1]),
|
|
|
|
ctx=ast.Load(),
|
|
|
|
)
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}")
|
|
|
|
return reduce(lambda left, right: ast.BinOp(left=left, op=ast.BitOr(), right=right), new_elements)
|
|
|
|
|
|
|
|
return parse_sequence(stack[0])
|
|
|
|
|
|
|
|
|
|
|
|
def concatenated_path_to_type(path: str, element_path: List[str], types_to_import: Set[str]) -> ast.AST:
|
|
|
|
parts = path.split(".")
|
|
|
|
if any(not p for p in parts):
|
|
|
|
raise ValueError(f"Not able to parse type '{path}' used by {'.'.join(element_path)}")
|
|
|
|
if len(parts) > 1:
|
|
|
|
types_to_import.add(".".join(parts[:-1]))
|
|
|
|
return path_to_type(*parts)
|
|
|
|
|
|
|
|
|
|
|
|
def build_doc_comment(doc: str) -> Optional[ast.Expr]:
|
|
|
|
lines = [line.strip() for line in doc.split("\n")]
|
|
|
|
clean_lines = []
|
|
|
|
for line in lines:
|
|
|
|
if line.startswith((":type", ":rtype")):
|
|
|
|
continue
|
|
|
|
clean_lines.append(line)
|
|
|
|
text = "\n".join(clean_lines).strip()
|
|
|
|
return ast.Expr(value=ast.Constant(text)) if text else None
|
|
|
|
|
|
|
|
|
|
|
|
def format_with_ruff(file: str) -> None:
|
|
|
|
subprocess.check_call(["python", "-m", "ruff", "format", file])
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="Extract Python type stub from a python module.")
|
|
|
|
parser.add_argument("module_name", help="Name of the Python module for which generate stubs")
|
|
|
|
parser.add_argument(
|
|
|
|
"out",
|
|
|
|
help="Name of the Python stub file to write to",
|
|
|
|
type=argparse.FileType("wt"),
|
|
|
|
)
|
|
|
|
parser.add_argument("--ruff", help="Formats the generated stubs using Ruff", action="store_true")
|
|
|
|
args = parser.parse_args()
|
|
|
|
stub_content = ast.unparse(module_stubs(importlib.import_module(args.module_name)))
|
|
|
|
args.out.write(stub_content)
|
|
|
|
if args.ruff:
|
|
|
|
format_with_ruff(args.out.name)
|