|
|
@ -77,9 +77,7 @@ def module_stubs(module: Any) -> ast.Module: |
|
|
|
if member_name.startswith("__"): |
|
|
|
if member_name.startswith("__"): |
|
|
|
pass |
|
|
|
pass |
|
|
|
elif inspect.isclass(member_value): |
|
|
|
elif inspect.isclass(member_value): |
|
|
|
classes.append( |
|
|
|
classes.append(class_stubs(member_name, member_value, element_path, types_to_import)) |
|
|
|
class_stubs(member_name, member_value, element_path, types_to_import) |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
elif inspect.isbuiltin(member_value): |
|
|
|
elif inspect.isbuiltin(member_value): |
|
|
|
functions.append( |
|
|
|
functions.append( |
|
|
|
function_stub( |
|
|
|
function_stub( |
|
|
@ -93,16 +91,12 @@ def module_stubs(module: Any) -> ast.Module: |
|
|
|
else: |
|
|
|
else: |
|
|
|
logging.warning(f"Unsupported root construction {member_name}") |
|
|
|
logging.warning(f"Unsupported root construction {member_name}") |
|
|
|
return ast.Module( |
|
|
|
return ast.Module( |
|
|
|
body=[ast.Import(names=[ast.alias(name=t)]) for t in sorted(types_to_import)] |
|
|
|
body=[ast.Import(names=[ast.alias(name=t)]) for t in sorted(types_to_import)] + classes + functions, |
|
|
|
+ classes |
|
|
|
|
|
|
|
+ functions, |
|
|
|
|
|
|
|
type_ignores=[], |
|
|
|
type_ignores=[], |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def class_stubs( |
|
|
|
def class_stubs(cls_name: str, cls_def: Any, element_path: List[str], types_to_import: Set[str]) -> ast.ClassDef: |
|
|
|
cls_name: str, cls_def: Any, element_path: List[str], types_to_import: Set[str] |
|
|
|
|
|
|
|
) -> ast.ClassDef: |
|
|
|
|
|
|
|
attributes: List[ast.AST] = [] |
|
|
|
attributes: List[ast.AST] = [] |
|
|
|
methods: List[ast.AST] = [] |
|
|
|
methods: List[ast.AST] = [] |
|
|
|
magic_methods: List[ast.AST] = [] |
|
|
|
magic_methods: List[ast.AST] = [] |
|
|
@ -124,20 +118,11 @@ def class_stubs( |
|
|
|
] |
|
|
|
] |
|
|
|
except ValueError as e: |
|
|
|
except ValueError as e: |
|
|
|
if "no signature found" not in str(e): |
|
|
|
if "no signature found" not in str(e): |
|
|
|
raise ValueError( |
|
|
|
raise ValueError(f"Error while parsing signature of {cls_name}.__init_") from e |
|
|
|
f"Error while parsing signature of {cls_name}.__init_" |
|
|
|
elif member_value == OBJECT_MEMBERS.get(member_name) or BUILTINS.get(member_name, ()) is None: |
|
|
|
) from e |
|
|
|
|
|
|
|
elif ( |
|
|
|
|
|
|
|
member_value == OBJECT_MEMBERS.get(member_name) |
|
|
|
|
|
|
|
or BUILTINS.get(member_name, ()) is None |
|
|
|
|
|
|
|
): |
|
|
|
|
|
|
|
pass |
|
|
|
pass |
|
|
|
elif inspect.isdatadescriptor(member_value): |
|
|
|
elif inspect.isdatadescriptor(member_value): |
|
|
|
attributes.extend( |
|
|
|
attributes.extend(data_descriptor_stub(member_name, member_value, current_element_path, types_to_import)) |
|
|
|
data_descriptor_stub( |
|
|
|
|
|
|
|
member_name, member_value, current_element_path, types_to_import |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
elif inspect.isroutine(member_value): |
|
|
|
elif inspect.isroutine(member_value): |
|
|
|
(magic_methods if member_name.startswith("__") else methods).append( |
|
|
|
(magic_methods if member_name.startswith("__") else methods).append( |
|
|
|
function_stub( |
|
|
|
function_stub( |
|
|
@ -154,9 +139,7 @@ def class_stubs( |
|
|
|
target=ast.Name(id=member_name, ctx=AST_STORE), |
|
|
|
target=ast.Name(id=member_name, ctx=AST_STORE), |
|
|
|
annotation=ast.Subscript( |
|
|
|
annotation=ast.Subscript( |
|
|
|
value=_path_to_type("typing", "Tuple"), |
|
|
|
value=_path_to_type("typing", "Tuple"), |
|
|
|
slice=ast.Tuple( |
|
|
|
slice=ast.Tuple(elts=[_path_to_type("str"), ast.Ellipsis()], ctx=AST_LOAD), |
|
|
|
elts=[_path_to_type("str"), ast.Ellipsis()], ctx=AST_LOAD |
|
|
|
|
|
|
|
), |
|
|
|
|
|
|
|
ctx=AST_LOAD, |
|
|
|
ctx=AST_LOAD, |
|
|
|
), |
|
|
|
), |
|
|
|
value=ast.Constant(member_value), |
|
|
|
value=ast.Constant(member_value), |
|
|
@ -164,9 +147,7 @@ def class_stubs( |
|
|
|
) |
|
|
|
) |
|
|
|
) |
|
|
|
) |
|
|
|
else: |
|
|
|
else: |
|
|
|
logging.warning( |
|
|
|
logging.warning(f"Unsupported member {member_name} of class {'.'.join(element_path)}") |
|
|
|
f"Unsupported member {member_name} of class {'.'.join(element_path)}" |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
doc = inspect.getdoc(cls_def) |
|
|
|
doc = inspect.getdoc(cls_def) |
|
|
|
doc_comment = build_doc_comment(doc) if doc else None |
|
|
|
doc_comment = build_doc_comment(doc) if doc else None |
|
|
@ -174,13 +155,7 @@ def class_stubs( |
|
|
|
cls_name, |
|
|
|
cls_name, |
|
|
|
bases=[], |
|
|
|
bases=[], |
|
|
|
keywords=[], |
|
|
|
keywords=[], |
|
|
|
body=( |
|
|
|
body=(([doc_comment] if doc_comment else []) + attributes + methods + magic_methods + constants) |
|
|
|
([doc_comment] if doc_comment else []) |
|
|
|
|
|
|
|
+ attributes |
|
|
|
|
|
|
|
+ methods |
|
|
|
|
|
|
|
+ magic_methods |
|
|
|
|
|
|
|
+ constants |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
or [AST_ELLIPSIS], |
|
|
|
or [AST_ELLIPSIS], |
|
|
|
decorator_list=[_path_to_type("typing", "final")], |
|
|
|
decorator_list=[_path_to_type("typing", "final")], |
|
|
|
) |
|
|
|
) |
|
|
@ -239,9 +214,7 @@ def function_stub( |
|
|
|
arguments_stub(fn_name, fn_def, doc or "", element_path, types_to_import), |
|
|
|
arguments_stub(fn_name, fn_def, doc or "", element_path, types_to_import), |
|
|
|
body or [AST_ELLIPSIS], |
|
|
|
body or [AST_ELLIPSIS], |
|
|
|
decorator_list=decorator_list, |
|
|
|
decorator_list=decorator_list, |
|
|
|
returns=returns_stub(fn_name, doc, element_path, types_to_import) |
|
|
|
returns=returns_stub(fn_name, doc, element_path, types_to_import) if doc else None, |
|
|
|
if doc |
|
|
|
|
|
|
|
else None, |
|
|
|
|
|
|
|
lineno=0, |
|
|
|
lineno=0, |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
@ -253,9 +226,7 @@ def arguments_stub( |
|
|
|
element_path: List[str], |
|
|
|
element_path: List[str], |
|
|
|
types_to_import: Set[str], |
|
|
|
types_to_import: Set[str], |
|
|
|
) -> ast.arguments: |
|
|
|
) -> ast.arguments: |
|
|
|
real_parameters: Mapping[str, inspect.Parameter] = inspect.signature( |
|
|
|
real_parameters: Mapping[str, inspect.Parameter] = inspect.signature(callable_def).parameters |
|
|
|
callable_def |
|
|
|
|
|
|
|
).parameters |
|
|
|
|
|
|
|
if callable_name == "__init__": |
|
|
|
if callable_name == "__init__": |
|
|
|
real_parameters = { |
|
|
|
real_parameters = { |
|
|
|
"self": inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY), |
|
|
|
"self": inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY), |
|
|
@ -285,9 +256,7 @@ def arguments_stub( |
|
|
|
if type.endswith(", optional"): |
|
|
|
if type.endswith(", optional"): |
|
|
|
optional_params.add(match[0]) |
|
|
|
optional_params.add(match[0]) |
|
|
|
type = type[:-10] |
|
|
|
type = type[:-10] |
|
|
|
parsed_param_types[match[0]] = convert_type_from_doc( |
|
|
|
parsed_param_types[match[0]] = convert_type_from_doc(type, element_path, types_to_import) |
|
|
|
type, element_path, types_to_import |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# we parse the parameters |
|
|
|
# we parse the parameters |
|
|
|
posonlyargs = [] |
|
|
|
posonlyargs = [] |
|
|
@ -303,9 +272,7 @@ def arguments_stub( |
|
|
|
f"The parameter {param.name} of {'.'.join(element_path)} " |
|
|
|
f"The parameter {param.name} of {'.'.join(element_path)} " |
|
|
|
"has no type definition in the function documentation" |
|
|
|
"has no type definition in the function documentation" |
|
|
|
) |
|
|
|
) |
|
|
|
param_ast = ast.arg( |
|
|
|
param_ast = ast.arg(arg=param.name, annotation=parsed_param_types.get(param.name)) |
|
|
|
arg=param.name, annotation=parsed_param_types.get(param.name) |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_ast = None |
|
|
|
default_ast = None |
|
|
|
if param.default != param.empty: |
|
|
|
if param.default != param.empty: |
|
|
@ -346,9 +313,7 @@ def arguments_stub( |
|
|
|
) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def returns_stub( |
|
|
|
def returns_stub(callable_name: str, doc: str, element_path: List[str], types_to_import: Set[str]) -> Optional[ast.AST]: |
|
|
|
callable_name: str, doc: str, element_path: List[str], types_to_import: Set[str] |
|
|
|
|
|
|
|
) -> Optional[ast.AST]: |
|
|
|
|
|
|
|
m = re.findall(r"^ *:rtype: *([^\n]*) *$", doc, re.MULTILINE) |
|
|
|
m = re.findall(r"^ *:rtype: *([^\n]*) *$", doc, re.MULTILINE) |
|
|
|
if len(m) == 0: |
|
|
|
if len(m) == 0: |
|
|
|
builtin = BUILTINS.get(callable_name) |
|
|
|
builtin = BUILTINS.get(callable_name) |
|
|
@ -359,22 +324,16 @@ def returns_stub( |
|
|
|
"has no type definition using :rtype: in the function documentation" |
|
|
|
"has no type definition using :rtype: in the function documentation" |
|
|
|
) |
|
|
|
) |
|
|
|
if len(m) > 1: |
|
|
|
if len(m) > 1: |
|
|
|
raise ValueError( |
|
|
|
raise ValueError(f"Multiple return type annotations found with :rtype: for {'.'.join(element_path)}") |
|
|
|
f"Multiple return type annotations found with :rtype: for {'.'.join(element_path)}" |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
return convert_type_from_doc(m[0], element_path, types_to_import) |
|
|
|
return convert_type_from_doc(m[0], element_path, types_to_import) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert_type_from_doc( |
|
|
|
def convert_type_from_doc(type_str: str, element_path: List[str], types_to_import: Set[str]) -> ast.AST: |
|
|
|
type_str: str, element_path: List[str], types_to_import: Set[str] |
|
|
|
|
|
|
|
) -> ast.AST: |
|
|
|
|
|
|
|
type_str = type_str.strip() |
|
|
|
type_str = type_str.strip() |
|
|
|
return parse_type_to_ast(type_str, element_path, types_to_import) |
|
|
|
return parse_type_to_ast(type_str, element_path, types_to_import) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_type_to_ast( |
|
|
|
def parse_type_to_ast(type_str: str, element_path: List[str], types_to_import: Set[str]) -> ast.AST: |
|
|
|
type_str: str, element_path: List[str], types_to_import: Set[str] |
|
|
|
|
|
|
|
) -> ast.AST: |
|
|
|
|
|
|
|
# let's tokenize |
|
|
|
# let's tokenize |
|
|
|
tokens = [] |
|
|
|
tokens = [] |
|
|
|
current_token = "" |
|
|
|
current_token = "" |
|
|
@ -412,26 +371,18 @@ def parse_type_to_ast( |
|
|
|
else: |
|
|
|
else: |
|
|
|
or_groups[-1].append(e) |
|
|
|
or_groups[-1].append(e) |
|
|
|
if any(not g for g in or_groups): |
|
|
|
if any(not g for g in or_groups): |
|
|
|
raise ValueError( |
|
|
|
raise ValueError(f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}") |
|
|
|
f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}" |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
new_elements: List[ast.AST] = [] |
|
|
|
new_elements: List[ast.AST] = [] |
|
|
|
for group in or_groups: |
|
|
|
for group in or_groups: |
|
|
|
if len(group) == 1 and isinstance(group[0], str): |
|
|
|
if len(group) == 1 and isinstance(group[0], str): |
|
|
|
parts = group[0].split(".") |
|
|
|
parts = group[0].split(".") |
|
|
|
if any(not p for p in parts): |
|
|
|
if any(not p for p in parts): |
|
|
|
raise ValueError( |
|
|
|
raise ValueError(f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}") |
|
|
|
f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}" |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
if len(parts) > 1: |
|
|
|
if len(parts) > 1: |
|
|
|
types_to_import.add(parts[0]) |
|
|
|
types_to_import.add(parts[0]) |
|
|
|
new_elements.append(_path_to_type(*parts)) |
|
|
|
new_elements.append(_path_to_type(*parts)) |
|
|
|
elif ( |
|
|
|
elif len(group) == 2 and isinstance(group[0], str) and isinstance(group[1], list): |
|
|
|
len(group) == 2 |
|
|
|
|
|
|
|
and isinstance(group[0], str) |
|
|
|
|
|
|
|
and isinstance(group[1], list) |
|
|
|
|
|
|
|
): |
|
|
|
|
|
|
|
if group[0] not in GENERICS: |
|
|
|
if group[0] not in GENERICS: |
|
|
|
raise ValueError( |
|
|
|
raise ValueError( |
|
|
|
f"Constructor {group[0]} is not supported in type '{type_str}' used by {'.'.join(element_path)}" |
|
|
|
f"Constructor {group[0]} is not supported in type '{type_str}' used by {'.'.join(element_path)}" |
|
|
@ -444,9 +395,7 @@ def parse_type_to_ast( |
|
|
|
) |
|
|
|
) |
|
|
|
) |
|
|
|
) |
|
|
|
else: |
|
|
|
else: |
|
|
|
raise ValueError( |
|
|
|
raise ValueError(f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}") |
|
|
|
f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}" |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
return ( |
|
|
|
return ( |
|
|
|
ast.Subscript( |
|
|
|
ast.Subscript( |
|
|
|
value=_path_to_type("typing", "Union"), |
|
|
|
value=_path_to_type("typing", "Union"), |
|
|
@ -471,33 +420,21 @@ def build_doc_comment(doc: str) -> Optional[ast.Expr]: |
|
|
|
return ast.Expr(value=ast.Constant(text)) if text else None |
|
|
|
return ast.Expr(value=ast.Constant(text)) if text else None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def format_with_black(code: str) -> str: |
|
|
|
def format_with_ruff(file: str) -> None: |
|
|
|
result = subprocess.run( |
|
|
|
subprocess.check_call(["python", "-m", "ruff", "format", file]) |
|
|
|
["python", "-m", "black", "-t", "py38", "--pyi", "-"], |
|
|
|
|
|
|
|
input=code.encode(), |
|
|
|
|
|
|
|
capture_output=True, |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
result.check_returncode() |
|
|
|
|
|
|
|
return result.stdout.decode() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser( |
|
|
|
parser = argparse.ArgumentParser(description="Extract Python type stub from a python module.") |
|
|
|
description="Extract Python type stub from a python module." |
|
|
|
parser.add_argument("module_name", help="Name of the Python module for which generate stubs") |
|
|
|
) |
|
|
|
|
|
|
|
parser.add_argument( |
|
|
|
|
|
|
|
"module_name", help="Name of the Python module for which generate stubs" |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
parser.add_argument( |
|
|
|
parser.add_argument( |
|
|
|
"out", |
|
|
|
"out", |
|
|
|
help="Name of the Python stub file to write to", |
|
|
|
help="Name of the Python stub file to write to", |
|
|
|
type=argparse.FileType("wt"), |
|
|
|
type=argparse.FileType("wt"), |
|
|
|
) |
|
|
|
) |
|
|
|
parser.add_argument( |
|
|
|
parser.add_argument("--ruff", help="Formats the generated stubs using Ruff", action="store_true") |
|
|
|
"--black", help="Formats the generated stubs using Black", action="store_true" |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
|
args = parser.parse_args() |
|
|
|
stub_content = ast.unparse(module_stubs(importlib.import_module(args.module_name))) |
|
|
|
stub_content = ast.unparse(module_stubs(importlib.import_module(args.module_name))) |
|
|
|
if args.black: |
|
|
|
|
|
|
|
stub_content = format_with_black(stub_content) |
|
|
|
|
|
|
|
args.out.write(stub_content) |
|
|
|
args.out.write(stub_content) |
|
|
|
|
|
|
|
if args.ruff: |
|
|
|
|
|
|
|
format_with_ruff(args.out.name) |
|
|
|