From ef65d531905c7126dbe936947015336af9549465 Mon Sep 17 00:00:00 2001 From: Tpt Date: Sun, 26 Mar 2023 21:48:12 +0200 Subject: [PATCH] Python: Adds __match_args__ definition where relevant Allows positional pattern matching Closes #449 --- python/generate_stubs.py | 17 +++++++++ python/src/model.rs | 30 ++++++++++++++++ python/tests/test_model.py | 70 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 117 insertions(+) diff --git a/python/generate_stubs.py b/python/generate_stubs.py index c8c281a5..8a36f1b4 100644 --- a/python/generate_stubs.py +++ b/python/generate_stubs.py @@ -106,6 +106,7 @@ def class_stubs( attributes: List[ast.AST] = [] methods: List[ast.AST] = [] magic_methods: List[ast.AST] = [] + constants: List[ast.AST] = [] for member_name, member_value in inspect.getmembers(cls_def): current_element_path = [*element_path, member_name] if member_name == "__init__": @@ -147,6 +148,21 @@ def class_stubs( in_class=True, ) ) + elif member_name == "__match_args__": + constants.append( + ast.AnnAssign( + target=ast.Name(id=member_name, ctx=AST_STORE), + annotation=ast.Subscript( + value=_path_to_type("typing", "Tuple"), + slice=ast.Tuple( + elts=[_path_to_type("str"), ast.Ellipsis()], ctx=AST_LOAD + ), + ctx=AST_LOAD, + ), + value=ast.Constant(member_value), + simple=1, + ) + ) else: logging.warning( f"Unsupported member {member_name} of class {'.'.join(element_path)}" @@ -163,6 +179,7 @@ def class_stubs( + attributes + methods + magic_methods + + constants ) or [AST_ELLIPSIS], decorator_list=[_path_to_type("typing", "final")], diff --git a/python/src/model.rs b/python/src/model.rs index 348ccb66..48f57e14 100644 --- a/python/src/model.rs +++ b/python/src/model.rs @@ -127,6 +127,11 @@ impl PyNamedNode { fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { slf } + + #[classattr] + fn __match_args__() -> (&'static str,) { + ("value",) + } } /// An RDF `blank node `_. @@ -250,6 +255,11 @@ impl PyBlankNode { fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { slf } + + #[classattr] + fn __match_args__() -> (&'static str,) { + ("value",) + } } /// An RDF `literal `_. @@ -409,6 +419,11 @@ impl PyLiteral { fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { slf } + + #[classattr] + fn __match_args__() -> (&'static str,) { + ("value",) + } } /// The RDF `default graph name `_. @@ -742,6 +757,11 @@ impl PyTriple { fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { slf } + + #[classattr] + fn __match_args__() -> (&'static str, &'static str, &'static str) { + ("subject", "predicate", "object") + } } #[derive(FromPyObject)] @@ -975,6 +995,11 @@ impl PyQuad { fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { slf } + + #[classattr] + fn __match_args__() -> (&'static str, &'static str, &'static str, &'static str) { + ("subject", "predicate", "object", "graph_name") + } } /// A SPARQL query variable. @@ -1063,6 +1088,11 @@ impl PyVariable { fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { slf } + + #[classattr] + fn __match_args__() -> (&'static str,) { + ("value",) + } } pub struct PyNamedNodeRef<'a>(PyRef<'a, PyNamedNode>); diff --git a/python/tests/test_model.py b/python/tests/test_model.py index 4c954350..2931d06d 100644 --- a/python/tests/test_model.py +++ b/python/tests/test_model.py @@ -1,5 +1,6 @@ import copy import pickle +import sys import unittest from pyoxigraph import ( @@ -17,6 +18,22 @@ XSD_INTEGER = NamedNode("http://www.w3.org/2001/XMLSchema#integer") RDF_LANG_STRING = NamedNode("http://www.w3.org/1999/02/22-rdf-syntax-ns#langString") +def match_works(test: unittest.TestCase, matched_value: str, constraint: str) -> None: + """Hack for Python < 3.10 compatibility""" + if sys.version_info < (3, 10): + return test.skipTest("match has been introduced by Python 3.10") + found = True + exec( + f""" +match {matched_value}: + case {constraint}: + found = True +""" + ) + test.assertTrue(found) + return None + + class TestNamedNode(unittest.TestCase): def test_constructor(self) -> None: self.assertEqual(NamedNode("http://foo").value, "http://foo") @@ -34,6 +51,12 @@ class TestNamedNode(unittest.TestCase): self.assertEqual(copy.copy(node), node) self.assertEqual(copy.deepcopy(node), node) + def test_basic_match(self) -> None: + match_works(self, 'NamedNode("http://foo")', 'NamedNode("http://foo")') + + def test_wildcard_match(self) -> None: + match_works(self, 'NamedNode("http://foo")', "NamedNode(x)") + class TestBlankNode(unittest.TestCase): def test_constructor(self) -> None: @@ -60,6 +83,12 @@ class TestBlankNode(unittest.TestCase): self.assertEqual(copy.copy(auto), auto) self.assertEqual(copy.deepcopy(auto), auto) + def test_basic_match(self) -> None: + match_works(self, 'BlankNode("foo")', 'BlankNode("foo")') + + def test_wildcard_match(self) -> None: + match_works(self, 'BlankNode("foo")', "BlankNode(x)") + class TestLiteral(unittest.TestCase): def test_constructor(self) -> None: @@ -108,6 +137,22 @@ class TestLiteral(unittest.TestCase): self.assertEqual(copy.copy(number), number) self.assertEqual(copy.deepcopy(number), number) + def test_basic_match(self) -> None: + match_works( + self, 'Literal("foo", language="en")', 'Literal("foo", language="en")' + ) + match_works( + self, + 'Literal("1", datatype=XSD_INTEGER)', + 'Literal("1", datatype=NamedNode("http://www.w3.org/2001/XMLSchema#integer"))', + ) + + def test_wildcard_match(self) -> None: + match_works(self, 'Literal("foo", language="en")', "Literal(v, language=l)") + match_works( + self, 'Literal("1", datatype=XSD_INTEGER)', "Literal(v, datatype=d)" + ) + class TestTriple(unittest.TestCase): def test_constructor(self) -> None: @@ -194,6 +239,14 @@ class TestTriple(unittest.TestCase): self.assertEqual(copy.copy(triple), triple) self.assertEqual(copy.deepcopy(triple), triple) + def test_match(self) -> None: + match_works( + self, + 'Triple(NamedNode("http://example.com/s"), NamedNode("http://example.com/p"), ' + 'NamedNode("http://example.com/o"))', + 'Triple(NamedNode("http://example.com/s"), NamedNode(p), o)', + ) + class TestDefaultGraph(unittest.TestCase): def test_equal(self) -> None: @@ -205,6 +258,9 @@ class TestDefaultGraph(unittest.TestCase): self.assertEqual(copy.copy(DefaultGraph()), DefaultGraph()) self.assertEqual(copy.deepcopy(DefaultGraph()), DefaultGraph()) + def test_match(self) -> None: + match_works(self, "DefaultGraph()", "DefaultGraph()") + class TestQuad(unittest.TestCase): def test_constructor(self) -> None: @@ -287,6 +343,14 @@ class TestQuad(unittest.TestCase): self.assertEqual(copy.copy(quad), quad) self.assertEqual(copy.deepcopy(quad), quad) + def test_match(self) -> None: + match_works( + self, + 'Quad(NamedNode("http://example.com/s"), NamedNode("http://example.com/p"), ' + 'NamedNode("http://example.com/o"), NamedNode("http://example.com/g"))', + 'Quad(NamedNode("http://example.com/s"), NamedNode(p), o, NamedNode("http://example.com/g"))', + ) + class TestVariable(unittest.TestCase): def test_constructor(self) -> None: @@ -305,6 +369,12 @@ class TestVariable(unittest.TestCase): self.assertEqual(copy.copy(v), v) self.assertEqual(copy.deepcopy(v), v) + def test_basic_match(self) -> None: + match_works(self, 'Variable("foo")', 'Variable("foo")') + + def test_wildcard_match(self) -> None: + match_works(self, 'Variable("foo")', "Variable(x)") + if __name__ == "__main__": unittest.main()