Fork of https://github.com/oxigraph/oxigraph.git for the purpose of NextGraph project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
oxigraph/python/generate_stubs.py

506 lines
17 KiB

import argparse
import ast
import importlib
import inspect
import logging
import re
import subprocess
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union
def _path_to_type(*elements: str) -> ast.AST:
base: ast.AST = ast.Name(id=elements[0], ctx=AST_LOAD)
for e in elements[1:]:
base = ast.Attribute(value=base, attr=e, ctx=AST_LOAD)
return base
AST_LOAD = ast.Load()
AST_ELLIPSIS = ast.Ellipsis()
AST_STORE = ast.Store()
AST_TYPING_ANY = _path_to_type("typing", "Any")
GENERICS = {
"iterable": _path_to_type("typing", "Iterable"),
"iterator": _path_to_type("typing", "Iterator"),
"list": _path_to_type("typing", "List"),
"io": _path_to_type("typing", "IO"),
}
OBJECT_MEMBERS = dict(inspect.getmembers(object))
BUILTINS: Dict[str, Union[None, Tuple[List[ast.AST], ast.AST]]] = {
"__annotations__": None,
"__bool__": ([], _path_to_type("bool")),
"__bytes__": ([], _path_to_type("bytes")),
"__class__": None,
"__contains__": ([AST_TYPING_ANY], _path_to_type("bool")),
"__del__": None,
"__delattr__": ([_path_to_type("str")], _path_to_type("None")),
"__delitem__": ([AST_TYPING_ANY], AST_TYPING_ANY),
"__dict__": None,
"__dir__": None,
"__doc__": None,
"__eq__": ([AST_TYPING_ANY], _path_to_type("bool")),
"__format__": ([_path_to_type("str")], _path_to_type("str")),
"__ge__": ([AST_TYPING_ANY], _path_to_type("bool")),
"__getattribute__": ([_path_to_type("str")], AST_TYPING_ANY),
"__getitem__": ([AST_TYPING_ANY], AST_TYPING_ANY),
"__gt__": ([AST_TYPING_ANY], _path_to_type("bool")),
"__hash__": ([], _path_to_type("int")),
"__init__": ([], _path_to_type("None")),
"__init_subclass__": None,
"__iter__": ([], AST_TYPING_ANY),
"__le__": ([AST_TYPING_ANY], _path_to_type("bool")),
"__len__": ([], _path_to_type("int")),
"__lt__": ([AST_TYPING_ANY], _path_to_type("bool")),
"__module__": None,
"__ne__": ([AST_TYPING_ANY], _path_to_type("bool")),
"__new__": None,
"__next__": ([], AST_TYPING_ANY),
"__reduce__": None,
"__reduce_ex__": None,
"__repr__": ([], _path_to_type("str")),
"__setattr__": ([_path_to_type("str"), AST_TYPING_ANY], _path_to_type("None")),
"__setitem__": ([AST_TYPING_ANY, AST_TYPING_ANY], AST_TYPING_ANY),
"__sizeof__": None,
"__str__": ([], _path_to_type("str")),
"__subclasshook__": None,
}
def module_stubs(module: Any) -> ast.Module:
types_to_import = {"typing"}
classes = []
functions = []
for member_name, member_value in inspect.getmembers(module):
element_path = [module.__name__, member_name]
if member_name.startswith("__"):
pass
elif inspect.isclass(member_value):
classes.append(
class_stubs(member_name, member_value, element_path, types_to_import)
)
elif inspect.isbuiltin(member_value):
functions.append(
function_stub(
member_name,
member_value,
element_path,
types_to_import,
in_class=False,
)
)
else:
logging.warning(f"Unsupported root construction {member_name}")
return ast.Module(
body=[ast.Import(names=[ast.alias(name=t)]) for t in sorted(types_to_import)]
+ classes
+ functions,
type_ignores=[],
)
def class_stubs(
cls_name: str, cls_def: Any, element_path: List[str], types_to_import: Set[str]
) -> ast.ClassDef:
attributes: List[ast.AST] = []
methods: List[ast.AST] = []
magic_methods: List[ast.AST] = []
constants: List[ast.AST] = []
for member_name, member_value in inspect.getmembers(cls_def):
current_element_path = [*element_path, member_name]
if member_name == "__init__":
try:
inspect.signature(cls_def) # we check it actually exists
methods = [
function_stub(
member_name,
cls_def,
current_element_path,
types_to_import,
in_class=True,
),
*methods,
]
except ValueError as e:
if "no signature found" not in str(e):
raise ValueError(
f"Error while parsing signature of {cls_name}.__init_"
) from e
elif (
member_value == OBJECT_MEMBERS.get(member_name)
or BUILTINS.get(member_name, ()) is None
):
pass
elif inspect.isdatadescriptor(member_value):
attributes.extend(
data_descriptor_stub(
member_name, member_value, current_element_path, types_to_import
)
)
elif inspect.isroutine(member_value):
(magic_methods if member_name.startswith("__") else methods).append(
function_stub(
member_name,
member_value,
current_element_path,
types_to_import,
in_class=True,
)
)
elif member_name == "__match_args__":
constants.append(
ast.AnnAssign(
target=ast.Name(id=member_name, ctx=AST_STORE),
annotation=ast.Subscript(
value=_path_to_type("typing", "Tuple"),
slice=ast.Tuple(
elts=[_path_to_type("str"), ast.Ellipsis()], ctx=AST_LOAD
),
ctx=AST_LOAD,
),
value=ast.Constant(member_value),
simple=1,
)
)
else:
logging.warning(
f"Unsupported member {member_name} of class {'.'.join(element_path)}"
)
doc = inspect.getdoc(cls_def)
doc_comment = build_doc_comment(doc) if doc else None
return ast.ClassDef(
cls_name,
bases=[],
keywords=[],
body=(
([doc_comment] if doc_comment else [])
+ attributes
+ methods
+ magic_methods
+ constants
)
or [AST_ELLIPSIS],
decorator_list=[_path_to_type("typing", "final")],
)
def data_descriptor_stub(
data_desc_name: str,
data_desc_def: Any,
element_path: List[str],
types_to_import: Set[str],
) -> Union[Tuple[ast.AnnAssign, ast.Expr], Tuple[ast.AnnAssign]]:
annotation = None
doc_comment = None
doc = inspect.getdoc(data_desc_def)
if doc is not None:
annotation = returns_stub(data_desc_name, doc, element_path, types_to_import)
m = re.findall(r"^ *:return: *(.*) *$", doc, re.MULTILINE)
if len(m) == 1:
doc_comment = m[0]
elif len(m) > 1:
raise ValueError(
f"Multiple return annotations found with :return: in {'.'.join(element_path)} documentation"
)
assign = ast.AnnAssign(
target=ast.Name(id=data_desc_name, ctx=AST_STORE),
annotation=annotation or AST_TYPING_ANY,
simple=1,
)
doc_comment = build_doc_comment(doc_comment) if doc_comment else None
return (assign, doc_comment) if doc_comment else (assign,)
def function_stub(
fn_name: str,
fn_def: Any,
element_path: List[str],
types_to_import: Set[str],
*,
in_class: bool,
) -> ast.FunctionDef:
body: List[ast.AST] = []
doc = inspect.getdoc(fn_def)
if doc is not None:
doc_comment = build_doc_comment(doc)
if doc_comment is not None:
body.append(doc_comment)
decorator_list = []
if in_class and hasattr(fn_def, "__self__"):
decorator_list.append(ast.Name("staticmethod"))
return ast.FunctionDef(
fn_name,
arguments_stub(fn_name, fn_def, doc or "", element_path, types_to_import),
body or [AST_ELLIPSIS],
decorator_list=decorator_list,
returns=returns_stub(fn_name, doc, element_path, types_to_import)
if doc
else None,
lineno=0,
)
def arguments_stub(
callable_name: str,
callable_def: Any,
doc: str,
element_path: List[str],
types_to_import: Set[str],
) -> ast.arguments:
real_parameters: Mapping[str, inspect.Parameter] = inspect.signature(
callable_def
).parameters
if callable_name == "__init__":
real_parameters = {
"self": inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY),
**real_parameters,
}
parsed_param_types = {}
optional_params = set()
# Types for magic functions types
builtin = BUILTINS.get(callable_name)
if isinstance(builtin, tuple):
param_names = list(real_parameters.keys())
if param_names and param_names[0] == "self":
del param_names[0]
for name, t in zip(param_names, builtin[0]):
parsed_param_types[name] = t
# Types from comment
for match in re.findall(r"^ *:type *([a-z_]+): ([^\n]*) *$", doc, re.MULTILINE):
if match[0] not in real_parameters:
raise ValueError(
f"The parameter {match[0]} of {'.'.join(element_path)} "
"is defined in the documentation but not in the function signature"
)
type = match[1]
if type.endswith(", optional"):
optional_params.add(match[0])
type = type[:-10]
parsed_param_types[match[0]] = convert_type_from_doc(
type, element_path, types_to_import
)
# we parse the parameters
posonlyargs = []
args = []
vararg = None
kwonlyargs = []
kw_defaults = []
kwarg = None
defaults = []
for param in real_parameters.values():
if param.name != "self" and param.name not in parsed_param_types:
raise ValueError(
f"The parameter {param.name} of {'.'.join(element_path)} "
"has no type definition in the function documentation"
)
param_ast = ast.arg(
arg=param.name, annotation=parsed_param_types.get(param.name)
)
default_ast = None
if param.default != param.empty:
default_ast = ast.Constant(param.default)
if param.name not in optional_params:
raise ValueError(
f"Parameter {param.name} of {'.'.join(element_path)} "
"is optional according to the type but not flagged as such in the doc"
)
elif param.name in optional_params:
raise ValueError(
f"Parameter {param.name} of {'.'.join(element_path)} "
"is optional according to the documentation but has no default value"
)
if param.kind == param.POSITIONAL_ONLY:
posonlyargs.append(param_ast)
defaults.append(default_ast)
elif param.kind == param.POSITIONAL_OR_KEYWORD:
args.append(param_ast)
defaults.append(default_ast)
elif param.kind == param.VAR_POSITIONAL:
vararg = param_ast
elif param.kind == param.KEYWORD_ONLY:
kwonlyargs.append(param_ast)
kw_defaults.append(default_ast)
elif param.kind == param.VAR_KEYWORD:
kwarg = param_ast
return ast.arguments(
posonlyargs=posonlyargs,
args=args,
vararg=vararg,
kwonlyargs=kwonlyargs,
kw_defaults=kw_defaults,
defaults=defaults,
kwarg=kwarg,
)
def returns_stub(
callable_name: str, doc: str, element_path: List[str], types_to_import: Set[str]
) -> Optional[ast.AST]:
m = re.findall(r"^ *:rtype: *([^\n]*) *$", doc, re.MULTILINE)
if len(m) == 0:
builtin = BUILTINS.get(callable_name)
if isinstance(builtin, tuple) and builtin[1] is not None:
return builtin[1]
raise ValueError(
f"The return type of {'.'.join(element_path)} "
"has no type definition using :rtype: in the function documentation"
)
if len(m) > 1:
raise ValueError(
f"Multiple return type annotations found with :rtype: for {'.'.join(element_path)}"
)
return convert_type_from_doc(m[0], element_path, types_to_import)
def convert_type_from_doc(
type_str: str, element_path: List[str], types_to_import: Set[str]
) -> ast.AST:
type_str = type_str.strip()
return parse_type_to_ast(type_str, element_path, types_to_import)
def parse_type_to_ast(
type_str: str, element_path: List[str], types_to_import: Set[str]
) -> ast.AST:
# let's tokenize
tokens = []
current_token = ""
for c in type_str:
if "a" <= c <= "z" or "A" <= c <= "Z" or c == ".":
current_token += c
else:
if current_token:
tokens.append(current_token)
current_token = ""
if c != " ":
tokens.append(c)
if current_token:
tokens.append(current_token)
# let's first parse nested parenthesis
stack: List[List[Any]] = [[]]
for token in tokens:
if token == "(":
children: List[str] = []
stack[-1].append(children)
stack.append(children)
elif token == ")":
stack.pop()
else:
stack[-1].append(token)
# then it's easy
def parse_sequence(sequence: List[Any]) -> ast.AST:
# we split based on "or"
or_groups: List[List[str]] = [[]]
for e in sequence:
if e == "or":
or_groups.append([])
else:
or_groups[-1].append(e)
if any(not g for g in or_groups):
raise ValueError(
f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}"
)
new_elements: List[ast.AST] = []
for group in or_groups:
if len(group) == 1 and isinstance(group[0], str):
parts = group[0].split(".")
if any(not p for p in parts):
raise ValueError(
f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}"
)
if len(parts) > 1:
types_to_import.add(parts[0])
new_elements.append(_path_to_type(*parts))
elif (
len(group) == 2
and isinstance(group[0], str)
and isinstance(group[1], list)
):
if group[0] not in GENERICS:
raise ValueError(
f"Constructor {group[0]} is not supported in type '{type_str}' used by {'.'.join(element_path)}"
)
new_elements.append(
ast.Subscript(
value=GENERICS[group[0]],
slice=parse_sequence(group[1]),
ctx=AST_LOAD,
)
)
else:
raise ValueError(
f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}"
)
return (
ast.Subscript(
value=_path_to_type("typing", "Union"),
slice=ast.Tuple(elts=new_elements, ctx=AST_LOAD),
ctx=AST_LOAD,
)
if len(new_elements) > 1
else new_elements[0]
)
return parse_sequence(stack[0])
def build_doc_comment(doc: str) -> Optional[ast.Expr]:
lines = [line.strip() for line in doc.split("\n")]
clean_lines = []
for line in lines:
if line.startswith((":type", ":rtype")):
continue
clean_lines.append(line)
text = "\n".join(clean_lines).strip()
return ast.Expr(value=ast.Constant(text)) if text else None
def format_with_black(code: str) -> str:
result = subprocess.run(
["python", "-m", "black", "-t", "py37", "--pyi", "-"],
input=code.encode(),
capture_output=True,
)
result.check_returncode()
return result.stdout.decode()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extract Python type stub from a python module."
)
parser.add_argument(
"module_name", help="Name of the Python module for which generate stubs"
)
parser.add_argument(
"out",
help="Name of the Python stub file to write to",
type=argparse.FileType("wt"),
)
parser.add_argument(
"--black", help="Formats the generated stubs using Black", action="store_true"
)
args = parser.parse_args()
stub_content = ast.unparse(module_stubs(importlib.import_module(args.module_name)))
stub_content = stub_content.replace(
", /", ""
) # TODO: remove when targeting Python 3.8+
if args.black:
stub_content = format_with_black(stub_content)
args.out.write(stub_content)