import argparse
import ast
import importlib
import inspect
import logging
import re
import subprocess
from typing import Any, Dict, List, Mapping, Optional, Set, Tuple, Union


def _path_to_type(*elements: str) -> ast.AST:
    base: ast.AST = ast.Name(id=elements[0], ctx=AST_LOAD)
    for e in elements[1:]:
        base = ast.Attribute(value=base, attr=e, ctx=AST_LOAD)
    return base


AST_LOAD = ast.Load()
AST_ELLIPSIS = ast.Ellipsis()
AST_STORE = ast.Store()
AST_TYPING_ANY = _path_to_type("typing", "Any")
GENERICS = {
    "iterable": _path_to_type("typing", "Iterable"),
    "iterator": _path_to_type("typing", "Iterator"),
    "list": _path_to_type("typing", "List"),
    "io": _path_to_type("typing", "IO"),
}
OBJECT_MEMBERS = dict(inspect.getmembers(object))


BUILTINS: Dict[str, Union[None, Tuple[List[ast.AST], ast.AST]]] = {
    "__annotations__": None,
    "__bool__": ([], _path_to_type("bool")),
    "__bytes__": ([], _path_to_type("bytes")),
    "__class__": None,
    "__contains__": ([AST_TYPING_ANY], _path_to_type("bool")),
    "__del__": None,
    "__delattr__": ([_path_to_type("str")], _path_to_type("None")),
    "__delitem__": ([AST_TYPING_ANY], AST_TYPING_ANY),
    "__dict__": None,
    "__dir__": None,
    "__doc__": None,
    "__eq__": ([AST_TYPING_ANY], _path_to_type("bool")),
    "__format__": ([_path_to_type("str")], _path_to_type("str")),
    "__ge__": ([AST_TYPING_ANY], _path_to_type("bool")),
    "__getattribute__": ([_path_to_type("str")], AST_TYPING_ANY),
    "__getitem__": ([AST_TYPING_ANY], AST_TYPING_ANY),
    "__gt__": ([AST_TYPING_ANY], _path_to_type("bool")),
    "__hash__": ([], _path_to_type("int")),
    "__init__": ([], _path_to_type("None")),
    "__init_subclass__": None,
    "__iter__": ([], AST_TYPING_ANY),
    "__le__": ([AST_TYPING_ANY], _path_to_type("bool")),
    "__len__": ([], _path_to_type("int")),
    "__lt__": ([AST_TYPING_ANY], _path_to_type("bool")),
    "__module__": None,
    "__ne__": ([AST_TYPING_ANY], _path_to_type("bool")),
    "__new__": None,
    "__next__": ([], AST_TYPING_ANY),
    "__reduce__": None,
    "__reduce_ex__": None,
    "__repr__": ([], _path_to_type("str")),
    "__setattr__": ([_path_to_type("str"), AST_TYPING_ANY], _path_to_type("None")),
    "__setitem__": ([AST_TYPING_ANY, AST_TYPING_ANY], AST_TYPING_ANY),
    "__sizeof__": None,
    "__str__": ([], _path_to_type("str")),
    "__subclasshook__": None,
}


def module_stubs(module: Any) -> ast.Module:
    types_to_import = {"typing"}
    classes = []
    functions = []
    for member_name, member_value in inspect.getmembers(module):
        element_path = [module.__name__, member_name]
        if member_name.startswith("__"):
            pass
        elif inspect.isclass(member_value):
            classes.append(
                class_stubs(member_name, member_value, element_path, types_to_import)
            )
        elif inspect.isbuiltin(member_value):
            functions.append(
                function_stub(
                    member_name,
                    member_value,
                    element_path,
                    types_to_import,
                    in_class=False,
                )
            )
        else:
            logging.warning(f"Unsupported root construction {member_name}")
    return ast.Module(
        body=[ast.Import(names=[ast.alias(name=t)]) for t in sorted(types_to_import)]
        + classes
        + functions,
        type_ignores=[],
    )


def class_stubs(
    cls_name: str, cls_def: Any, element_path: List[str], types_to_import: Set[str]
) -> ast.ClassDef:
    attributes: List[ast.AST] = []
    methods: List[ast.AST] = []
    magic_methods: List[ast.AST] = []
    constants: List[ast.AST] = []
    for member_name, member_value in inspect.getmembers(cls_def):
        current_element_path = [*element_path, member_name]
        if member_name == "__init__":
            try:
                inspect.signature(cls_def)  # we check it actually exists
                methods = [
                    function_stub(
                        member_name,
                        cls_def,
                        current_element_path,
                        types_to_import,
                        in_class=True,
                    ),
                    *methods,
                ]
            except ValueError as e:
                if "no signature found" not in str(e):
                    raise ValueError(
                        f"Error while parsing signature of {cls_name}.__init_"
                    ) from e
        elif (
            member_value == OBJECT_MEMBERS.get(member_name)
            or BUILTINS.get(member_name, ()) is None
        ):
            pass
        elif inspect.isdatadescriptor(member_value):
            attributes.extend(
                data_descriptor_stub(
                    member_name, member_value, current_element_path, types_to_import
                )
            )
        elif inspect.isroutine(member_value):
            (magic_methods if member_name.startswith("__") else methods).append(
                function_stub(
                    member_name,
                    member_value,
                    current_element_path,
                    types_to_import,
                    in_class=True,
                )
            )
        elif member_name == "__match_args__":
            constants.append(
                ast.AnnAssign(
                    target=ast.Name(id=member_name, ctx=AST_STORE),
                    annotation=ast.Subscript(
                        value=_path_to_type("typing", "Tuple"),
                        slice=ast.Tuple(
                            elts=[_path_to_type("str"), ast.Ellipsis()], ctx=AST_LOAD
                        ),
                        ctx=AST_LOAD,
                    ),
                    value=ast.Constant(member_value),
                    simple=1,
                )
            )
        else:
            logging.warning(
                f"Unsupported member {member_name} of class {'.'.join(element_path)}"
            )

    doc = inspect.getdoc(cls_def)
    doc_comment = build_doc_comment(doc) if doc else None
    return ast.ClassDef(
        cls_name,
        bases=[],
        keywords=[],
        body=(
            ([doc_comment] if doc_comment else [])
            + attributes
            + methods
            + magic_methods
            + constants
        )
        or [AST_ELLIPSIS],
        decorator_list=[_path_to_type("typing", "final")],
    )


def data_descriptor_stub(
    data_desc_name: str,
    data_desc_def: Any,
    element_path: List[str],
    types_to_import: Set[str],
) -> Union[Tuple[ast.AnnAssign, ast.Expr], Tuple[ast.AnnAssign]]:
    annotation = None
    doc_comment = None

    doc = inspect.getdoc(data_desc_def)
    if doc is not None:
        annotation = returns_stub(data_desc_name, doc, element_path, types_to_import)
        m = re.findall(r"^ *:return: *(.*) *$", doc, re.MULTILINE)
        if len(m) == 1:
            doc_comment = m[0]
        elif len(m) > 1:
            raise ValueError(
                f"Multiple return annotations found with :return: in {'.'.join(element_path)} documentation"
            )

    assign = ast.AnnAssign(
        target=ast.Name(id=data_desc_name, ctx=AST_STORE),
        annotation=annotation or AST_TYPING_ANY,
        simple=1,
    )
    doc_comment = build_doc_comment(doc_comment) if doc_comment else None
    return (assign, doc_comment) if doc_comment else (assign,)


def function_stub(
    fn_name: str,
    fn_def: Any,
    element_path: List[str],
    types_to_import: Set[str],
    *,
    in_class: bool,
) -> ast.FunctionDef:
    body: List[ast.AST] = []
    doc = inspect.getdoc(fn_def)
    if doc is not None:
        doc_comment = build_doc_comment(doc)
        if doc_comment is not None:
            body.append(doc_comment)

    decorator_list = []
    if in_class and hasattr(fn_def, "__self__"):
        decorator_list.append(ast.Name("staticmethod"))

    return ast.FunctionDef(
        fn_name,
        arguments_stub(fn_name, fn_def, doc or "", element_path, types_to_import),
        body or [AST_ELLIPSIS],
        decorator_list=decorator_list,
        returns=returns_stub(fn_name, doc, element_path, types_to_import)
        if doc
        else None,
        lineno=0,
    )


def arguments_stub(
    callable_name: str,
    callable_def: Any,
    doc: str,
    element_path: List[str],
    types_to_import: Set[str],
) -> ast.arguments:
    real_parameters: Mapping[str, inspect.Parameter] = inspect.signature(
        callable_def
    ).parameters
    if callable_name == "__init__":
        real_parameters = {
            "self": inspect.Parameter("self", inspect.Parameter.POSITIONAL_ONLY),
            **real_parameters,
        }

    parsed_param_types = {}
    optional_params = set()

    # Types for magic functions types
    builtin = BUILTINS.get(callable_name)
    if isinstance(builtin, tuple):
        param_names = list(real_parameters.keys())
        if param_names and param_names[0] == "self":
            del param_names[0]
        for name, t in zip(param_names, builtin[0]):
            parsed_param_types[name] = t

    # Types from comment
    for match in re.findall(r"^ *:type *([a-z_]+): ([^\n]*) *$", doc, re.MULTILINE):
        if match[0] not in real_parameters:
            raise ValueError(
                f"The parameter {match[0]} of {'.'.join(element_path)} "
                "is defined in the documentation but not in the function signature"
            )
        type = match[1]
        if type.endswith(", optional"):
            optional_params.add(match[0])
            type = type[:-10]
        parsed_param_types[match[0]] = convert_type_from_doc(
            type, element_path, types_to_import
        )

    # we parse the parameters
    posonlyargs = []
    args = []
    vararg = None
    kwonlyargs = []
    kw_defaults = []
    kwarg = None
    defaults = []
    for param in real_parameters.values():
        if param.name != "self" and param.name not in parsed_param_types:
            raise ValueError(
                f"The parameter {param.name} of {'.'.join(element_path)} "
                "has no type definition in the function documentation"
            )
        param_ast = ast.arg(
            arg=param.name, annotation=parsed_param_types.get(param.name)
        )

        default_ast = None
        if param.default != param.empty:
            default_ast = ast.Constant(param.default)
            if param.name not in optional_params:
                raise ValueError(
                    f"Parameter {param.name} of {'.'.join(element_path)} "
                    "is optional according to the type but not flagged as such in the doc"
                )
        elif param.name in optional_params:
            raise ValueError(
                f"Parameter {param.name} of {'.'.join(element_path)} "
                "is optional according to the documentation but has no default value"
            )

        if param.kind == param.POSITIONAL_ONLY:
            posonlyargs.append(param_ast)
            defaults.append(default_ast)
        elif param.kind == param.POSITIONAL_OR_KEYWORD:
            args.append(param_ast)
            defaults.append(default_ast)
        elif param.kind == param.VAR_POSITIONAL:
            vararg = param_ast
        elif param.kind == param.KEYWORD_ONLY:
            kwonlyargs.append(param_ast)
            kw_defaults.append(default_ast)
        elif param.kind == param.VAR_KEYWORD:
            kwarg = param_ast

    return ast.arguments(
        posonlyargs=posonlyargs,
        args=args,
        vararg=vararg,
        kwonlyargs=kwonlyargs,
        kw_defaults=kw_defaults,
        defaults=defaults,
        kwarg=kwarg,
    )


def returns_stub(
    callable_name: str, doc: str, element_path: List[str], types_to_import: Set[str]
) -> Optional[ast.AST]:
    m = re.findall(r"^ *:rtype: *([^\n]*) *$", doc, re.MULTILINE)
    if len(m) == 0:
        builtin = BUILTINS.get(callable_name)
        if isinstance(builtin, tuple) and builtin[1] is not None:
            return builtin[1]
        raise ValueError(
            f"The return type of {'.'.join(element_path)} "
            "has no type definition using :rtype: in the function documentation"
        )
    if len(m) > 1:
        raise ValueError(
            f"Multiple return type annotations found with :rtype: for {'.'.join(element_path)}"
        )
    return convert_type_from_doc(m[0], element_path, types_to_import)


def convert_type_from_doc(
    type_str: str, element_path: List[str], types_to_import: Set[str]
) -> ast.AST:
    type_str = type_str.strip()
    return parse_type_to_ast(type_str, element_path, types_to_import)


def parse_type_to_ast(
    type_str: str, element_path: List[str], types_to_import: Set[str]
) -> ast.AST:
    # let's tokenize
    tokens = []
    current_token = ""
    for c in type_str:
        if "a" <= c <= "z" or "A" <= c <= "Z" or c == ".":
            current_token += c
        else:
            if current_token:
                tokens.append(current_token)
            current_token = ""
            if c != " ":
                tokens.append(c)
    if current_token:
        tokens.append(current_token)

    # let's first parse nested parenthesis
    stack: List[List[Any]] = [[]]
    for token in tokens:
        if token == "(":
            children: List[str] = []
            stack[-1].append(children)
            stack.append(children)
        elif token == ")":
            stack.pop()
        else:
            stack[-1].append(token)

    # then it's easy
    def parse_sequence(sequence: List[Any]) -> ast.AST:
        # we split based on "or"
        or_groups: List[List[str]] = [[]]
        for e in sequence:
            if e == "or":
                or_groups.append([])
            else:
                or_groups[-1].append(e)
        if any(not g for g in or_groups):
            raise ValueError(
                f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}"
            )

        new_elements: List[ast.AST] = []
        for group in or_groups:
            if len(group) == 1 and isinstance(group[0], str):
                parts = group[0].split(".")
                if any(not p for p in parts):
                    raise ValueError(
                        f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}"
                    )
                if len(parts) > 1:
                    types_to_import.add(parts[0])
                new_elements.append(_path_to_type(*parts))
            elif (
                len(group) == 2
                and isinstance(group[0], str)
                and isinstance(group[1], list)
            ):
                if group[0] not in GENERICS:
                    raise ValueError(
                        f"Constructor {group[0]} is not supported in type '{type_str}' used by {'.'.join(element_path)}"
                    )
                new_elements.append(
                    ast.Subscript(
                        value=GENERICS[group[0]],
                        slice=parse_sequence(group[1]),
                        ctx=AST_LOAD,
                    )
                )
            else:
                raise ValueError(
                    f"Not able to parse type '{type_str}' used by {'.'.join(element_path)}"
                )
        return (
            ast.Subscript(
                value=_path_to_type("typing", "Union"),
                slice=ast.Tuple(elts=new_elements, ctx=AST_LOAD),
                ctx=AST_LOAD,
            )
            if len(new_elements) > 1
            else new_elements[0]
        )

    return parse_sequence(stack[0])


def build_doc_comment(doc: str) -> Optional[ast.Expr]:
    lines = [line.strip() for line in doc.split("\n")]
    clean_lines = []
    for line in lines:
        if line.startswith((":type", ":rtype")):
            continue
        clean_lines.append(line)
    text = "\n".join(clean_lines).strip()
    return ast.Expr(value=ast.Constant(text)) if text else None


def format_with_black(code: str) -> str:
    result = subprocess.run(
        ["python", "-m", "black", "-t", "py37", "--pyi", "-"],
        input=code.encode(),
        capture_output=True,
    )
    result.check_returncode()
    return result.stdout.decode()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Extract Python type stub from a python module."
    )
    parser.add_argument(
        "module_name", help="Name of the Python module for which generate stubs"
    )
    parser.add_argument(
        "out",
        help="Name of the Python stub file to write to",
        type=argparse.FileType("wt"),
    )
    parser.add_argument(
        "--black", help="Formats the generated stubs using Black", action="store_true"
    )
    args = parser.parse_args()
    stub_content = ast.unparse(module_stubs(importlib.import_module(args.module_name)))
    stub_content = stub_content.replace(
        ", /", ""
    )  # TODO: remove when targeting Python 3.8+
    if args.black:
        stub_content = format_with_black(stub_content)
    args.out.write(stub_content)