From c40c81447ec46f546dbc999659332b9cd1cfb72c Mon Sep 17 00:00:00 2001 From: Tpt Date: Tue, 14 Mar 2023 21:52:02 +0100 Subject: [PATCH] Python: Optimizes copy on basic model classes Immutable values do not need to be actually copied --- python/generate_stubs.py | 15 ++++--- python/src/model.rs | 91 ++++++++++++++++++++++++++++++++++++++ python/tests/test_model.py | 25 ++++++++++- 3 files changed, 125 insertions(+), 6 deletions(-) diff --git a/python/generate_stubs.py b/python/generate_stubs.py index 3c2054d6..c8c281a5 100644 --- a/python/generate_stubs.py +++ b/python/generate_stubs.py @@ -153,12 +153,13 @@ def class_stubs( ) doc = inspect.getdoc(cls_def) + doc_comment = build_doc_comment(doc) if doc else None return ast.ClassDef( cls_name, bases=[], keywords=[], body=( - ([build_doc_comment(doc)] if doc else []) + ([doc_comment] if doc_comment else []) + attributes + methods + magic_methods @@ -193,7 +194,8 @@ def data_descriptor_stub( annotation=annotation or AST_TYPING_ANY, simple=1, ) - return (assign, build_doc_comment(doc_comment)) if doc_comment else (assign,) + doc_comment = build_doc_comment(doc_comment) if doc_comment else None + return (assign, doc_comment) if doc_comment else (assign,) def function_stub( @@ -207,7 +209,9 @@ def function_stub( body: List[ast.AST] = [] doc = inspect.getdoc(fn_def) if doc is not None: - body.append(build_doc_comment(doc)) + doc_comment = build_doc_comment(doc) + if doc_comment is not None: + body.append(doc_comment) decorator_list = [] if in_class and hasattr(fn_def, "__self__"): @@ -439,14 +443,15 @@ def parse_type_to_ast( return parse_sequence(stack[0]) -def build_doc_comment(doc: str) -> ast.Expr: +def build_doc_comment(doc: str) -> Optional[ast.Expr]: lines = [line.strip() for line in doc.split("\n")] clean_lines = [] for line in lines: if line.startswith((":type", ":rtype")): continue clean_lines.append(line) - return ast.Expr(value=ast.Constant("\n".join(clean_lines).strip())) + text = "\n".join(clean_lines).strip() + return ast.Expr(value=ast.Constant(text)) if text else None def format_with_black(code: str) -> str: diff --git a/python/src/model.rs b/python/src/model.rs index 701b800f..348ccb66 100644 --- a/python/src/model.rs +++ b/python/src/model.rs @@ -111,9 +111,22 @@ impl PyNamedNode { } } + /// :rtype: typing.Any fn __getnewargs__(&self) -> (&str,) { (self.value(),) } + + /// :rtype: NamedNode + fn __copy__(slf: PyRef<'_, Self>) -> PyRef { + slf + } + + /// :type memo: typing.Any + /// :rtype: NamedNode + #[allow(unused_variables)] + fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { + slf + } } /// An RDF `blank node `_. @@ -221,9 +234,22 @@ impl PyBlankNode { } } + /// :rtype: typing.Any fn __getnewargs__(&self) -> (&str,) { (self.value(),) } + + /// :rtype: BlankNode + fn __copy__(slf: PyRef<'_, Self>) -> PyRef { + slf + } + + /// :type memo: typing.Any + /// :rtype: BlankNode + #[allow(unused_variables)] + fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { + slf + } } /// An RDF `literal `_. @@ -361,6 +387,7 @@ impl PyLiteral { } } + /// :rtype: typing.Any fn __getnewargs_ex__<'a>(&'a self, py: Python<'a>) -> PyResult<((&'a str,), &'a PyDict)> { let kwargs = PyDict::new(py); if let Some(language) = self.language() { @@ -370,6 +397,18 @@ impl PyLiteral { } Ok(((self.value(),), kwargs)) } + + /// :rtype: Literal + fn __copy__(slf: PyRef<'_, Self>) -> PyRef { + slf + } + + /// :type memo: typing.Any + /// :rtype: Literal + #[allow(unused_variables)] + fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { + slf + } } /// The RDF `default graph name `_. @@ -425,9 +464,22 @@ impl PyDefaultGraph { } } + /// :rtype: typing.Any fn __getnewargs__<'a>(&self, py: Python<'a>) -> &'a PyTuple { PyTuple::empty(py) } + + /// :rtype: DefaultGraph + fn __copy__(slf: PyRef<'_, Self>) -> PyRef { + slf + } + + /// :type memo: typing.Any + /// :rtype: DefaultGraph + #[allow(unused_variables)] + fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { + slf + } } #[derive(FromPyObject)] @@ -674,9 +726,22 @@ impl PyTriple { } } + /// :rtype: typing.Any fn __getnewargs__(&self) -> (PySubject, PyNamedNode, PyTerm) { (self.subject(), self.predicate(), self.object()) } + + /// :rtype: Triple + fn __copy__(slf: PyRef<'_, Self>) -> PyRef { + slf + } + + /// :type memo: typing.Any + /// :rtype: Triple + #[allow(unused_variables)] + fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { + slf + } } #[derive(FromPyObject)] @@ -889,6 +954,7 @@ impl PyQuad { } } + /// :rtype: typing.Any fn __getnewargs__(&self) -> (PySubject, PyNamedNode, PyTerm, PyGraphName) { ( self.subject(), @@ -897,6 +963,18 @@ impl PyQuad { self.graph_name(), ) } + + /// :rtype: Quad + fn __copy__(slf: PyRef<'_, Self>) -> PyRef { + slf + } + + /// :type memo: typing.Any + /// :rtype: Quad + #[allow(unused_variables)] + fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { + slf + } } /// A SPARQL query variable. @@ -969,9 +1047,22 @@ impl PyVariable { eq_compare(self, other, op) } + /// :rtype: typing.Any fn __getnewargs__(&self) -> (&str,) { (self.value(),) } + + /// :rtype: Variable + fn __copy__(slf: PyRef<'_, Self>) -> PyRef { + slf + } + + /// :type memo: typing.Any + /// :rtype: Variable + #[allow(unused_variables)] + fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { + slf + } } pub struct PyNamedNodeRef<'a>(PyRef<'a, PyNamedNode>); diff --git a/python/tests/test_model.py b/python/tests/test_model.py index 079ec08b..4c954350 100644 --- a/python/tests/test_model.py +++ b/python/tests/test_model.py @@ -1,3 +1,4 @@ +import copy import pickle import unittest @@ -29,8 +30,9 @@ class TestNamedNode(unittest.TestCase): def test_pickle(self) -> None: node = NamedNode("http://foo") - self.assertEqual(NamedNode(*node.__getnewargs__()), node) self.assertEqual(pickle.loads(pickle.dumps(node)), node) + self.assertEqual(copy.copy(node), node) + self.assertEqual(copy.deepcopy(node), node) class TestBlankNode(unittest.TestCase): @@ -50,8 +52,13 @@ class TestBlankNode(unittest.TestCase): def test_pickle(self) -> None: node = BlankNode("foo") self.assertEqual(pickle.loads(pickle.dumps(node)), node) + self.assertEqual(copy.copy(node), node) + self.assertEqual(copy.deepcopy(node), node) + auto = BlankNode() self.assertEqual(pickle.loads(pickle.dumps(auto)), auto) + self.assertEqual(copy.copy(auto), auto) + self.assertEqual(copy.deepcopy(auto), auto) class TestLiteral(unittest.TestCase): @@ -88,10 +95,18 @@ class TestLiteral(unittest.TestCase): def test_pickle(self) -> None: simple = Literal("foo") self.assertEqual(pickle.loads(pickle.dumps(simple)), simple) + self.assertEqual(copy.copy(simple), simple) + self.assertEqual(copy.deepcopy(simple), simple) + lang_tagged = Literal("foo", language="en") self.assertEqual(pickle.loads(pickle.dumps(lang_tagged)), lang_tagged) + self.assertEqual(copy.copy(lang_tagged), lang_tagged) + self.assertEqual(copy.deepcopy(lang_tagged), lang_tagged) + number = Literal("1", datatype=XSD_INTEGER) self.assertEqual(pickle.loads(pickle.dumps(number)), number) + self.assertEqual(copy.copy(number), number) + self.assertEqual(copy.deepcopy(number), number) class TestTriple(unittest.TestCase): @@ -176,6 +191,8 @@ class TestTriple(unittest.TestCase): NamedNode("http://example.com/o"), ) self.assertEqual(pickle.loads(pickle.dumps(triple)), triple) + self.assertEqual(copy.copy(triple), triple) + self.assertEqual(copy.deepcopy(triple), triple) class TestDefaultGraph(unittest.TestCase): @@ -185,6 +202,8 @@ class TestDefaultGraph(unittest.TestCase): def test_pickle(self) -> None: self.assertEqual(pickle.loads(pickle.dumps(DefaultGraph())), DefaultGraph()) + self.assertEqual(copy.copy(DefaultGraph()), DefaultGraph()) + self.assertEqual(copy.deepcopy(DefaultGraph()), DefaultGraph()) class TestQuad(unittest.TestCase): @@ -265,6 +284,8 @@ class TestQuad(unittest.TestCase): NamedNode("http://example.com/g"), ) self.assertEqual(pickle.loads(pickle.dumps(quad)), quad) + self.assertEqual(copy.copy(quad), quad) + self.assertEqual(copy.deepcopy(quad), quad) class TestVariable(unittest.TestCase): @@ -281,6 +302,8 @@ class TestVariable(unittest.TestCase): def test_pickle(self) -> None: v = Variable("foo") self.assertEqual(pickle.loads(pickle.dumps(v)), v) + self.assertEqual(copy.copy(v), v) + self.assertEqual(copy.deepcopy(v), v) if __name__ == "__main__":