Python: Adds __match_args__ definition where relevant

Allows positional pattern matching

Closes #449
speedb
Tpt 2 years ago committed by Thomas Tanon
parent 76dec0b6a8
commit ef65d53190
  1. 17
      python/generate_stubs.py
  2. 30
      python/src/model.rs
  3. 70
      python/tests/test_model.py

@ -106,6 +106,7 @@ def class_stubs(
attributes: List[ast.AST] = [] attributes: List[ast.AST] = []
methods: List[ast.AST] = [] methods: List[ast.AST] = []
magic_methods: List[ast.AST] = [] magic_methods: List[ast.AST] = []
constants: List[ast.AST] = []
for member_name, member_value in inspect.getmembers(cls_def): for member_name, member_value in inspect.getmembers(cls_def):
current_element_path = [*element_path, member_name] current_element_path = [*element_path, member_name]
if member_name == "__init__": if member_name == "__init__":
@ -147,6 +148,21 @@ def class_stubs(
in_class=True, in_class=True,
) )
) )
elif member_name == "__match_args__":
constants.append(
ast.AnnAssign(
target=ast.Name(id=member_name, ctx=AST_STORE),
annotation=ast.Subscript(
value=_path_to_type("typing", "Tuple"),
slice=ast.Tuple(
elts=[_path_to_type("str"), ast.Ellipsis()], ctx=AST_LOAD
),
ctx=AST_LOAD,
),
value=ast.Constant(member_value),
simple=1,
)
)
else: else:
logging.warning( logging.warning(
f"Unsupported member {member_name} of class {'.'.join(element_path)}" f"Unsupported member {member_name} of class {'.'.join(element_path)}"
@ -163,6 +179,7 @@ def class_stubs(
+ attributes + attributes
+ methods + methods
+ magic_methods + magic_methods
+ constants
) )
or [AST_ELLIPSIS], or [AST_ELLIPSIS],
decorator_list=[_path_to_type("typing", "final")], decorator_list=[_path_to_type("typing", "final")],

@ -127,6 +127,11 @@ impl PyNamedNode {
fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> {
slf slf
} }
#[classattr]
fn __match_args__() -> (&'static str,) {
("value",)
}
} }
/// An RDF `blank node <https://www.w3.org/TR/rdf11-concepts/#dfn-blank-node>`_. /// An RDF `blank node <https://www.w3.org/TR/rdf11-concepts/#dfn-blank-node>`_.
@ -250,6 +255,11 @@ impl PyBlankNode {
fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> {
slf slf
} }
#[classattr]
fn __match_args__() -> (&'static str,) {
("value",)
}
} }
/// An RDF `literal <https://www.w3.org/TR/rdf11-concepts/#dfn-literal>`_. /// An RDF `literal <https://www.w3.org/TR/rdf11-concepts/#dfn-literal>`_.
@ -409,6 +419,11 @@ impl PyLiteral {
fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> {
slf slf
} }
#[classattr]
fn __match_args__() -> (&'static str,) {
("value",)
}
} }
/// The RDF `default graph name <https://www.w3.org/TR/rdf11-concepts/#dfn-default-graph>`_. /// The RDF `default graph name <https://www.w3.org/TR/rdf11-concepts/#dfn-default-graph>`_.
@ -742,6 +757,11 @@ impl PyTriple {
fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> {
slf slf
} }
#[classattr]
fn __match_args__() -> (&'static str, &'static str, &'static str) {
("subject", "predicate", "object")
}
} }
#[derive(FromPyObject)] #[derive(FromPyObject)]
@ -975,6 +995,11 @@ impl PyQuad {
fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> {
slf slf
} }
#[classattr]
fn __match_args__() -> (&'static str, &'static str, &'static str, &'static str) {
("subject", "predicate", "object", "graph_name")
}
} }
/// A SPARQL query variable. /// A SPARQL query variable.
@ -1063,6 +1088,11 @@ impl PyVariable {
fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> { fn __deepcopy__<'a>(slf: PyRef<'a, Self>, memo: &'_ PyAny) -> PyRef<'a, Self> {
slf slf
} }
#[classattr]
fn __match_args__() -> (&'static str,) {
("value",)
}
} }
pub struct PyNamedNodeRef<'a>(PyRef<'a, PyNamedNode>); pub struct PyNamedNodeRef<'a>(PyRef<'a, PyNamedNode>);

@ -1,5 +1,6 @@
import copy import copy
import pickle import pickle
import sys
import unittest import unittest
from pyoxigraph import ( from pyoxigraph import (
@ -17,6 +18,22 @@ XSD_INTEGER = NamedNode("http://www.w3.org/2001/XMLSchema#integer")
RDF_LANG_STRING = NamedNode("http://www.w3.org/1999/02/22-rdf-syntax-ns#langString") RDF_LANG_STRING = NamedNode("http://www.w3.org/1999/02/22-rdf-syntax-ns#langString")
def match_works(test: unittest.TestCase, matched_value: str, constraint: str) -> None:
"""Hack for Python < 3.10 compatibility"""
if sys.version_info < (3, 10):
return test.skipTest("match has been introduced by Python 3.10")
found = True
exec(
f"""
match {matched_value}:
case {constraint}:
found = True
"""
)
test.assertTrue(found)
return None
class TestNamedNode(unittest.TestCase): class TestNamedNode(unittest.TestCase):
def test_constructor(self) -> None: def test_constructor(self) -> None:
self.assertEqual(NamedNode("http://foo").value, "http://foo") self.assertEqual(NamedNode("http://foo").value, "http://foo")
@ -34,6 +51,12 @@ class TestNamedNode(unittest.TestCase):
self.assertEqual(copy.copy(node), node) self.assertEqual(copy.copy(node), node)
self.assertEqual(copy.deepcopy(node), node) self.assertEqual(copy.deepcopy(node), node)
def test_basic_match(self) -> None:
match_works(self, 'NamedNode("http://foo")', 'NamedNode("http://foo")')
def test_wildcard_match(self) -> None:
match_works(self, 'NamedNode("http://foo")', "NamedNode(x)")
class TestBlankNode(unittest.TestCase): class TestBlankNode(unittest.TestCase):
def test_constructor(self) -> None: def test_constructor(self) -> None:
@ -60,6 +83,12 @@ class TestBlankNode(unittest.TestCase):
self.assertEqual(copy.copy(auto), auto) self.assertEqual(copy.copy(auto), auto)
self.assertEqual(copy.deepcopy(auto), auto) self.assertEqual(copy.deepcopy(auto), auto)
def test_basic_match(self) -> None:
match_works(self, 'BlankNode("foo")', 'BlankNode("foo")')
def test_wildcard_match(self) -> None:
match_works(self, 'BlankNode("foo")', "BlankNode(x)")
class TestLiteral(unittest.TestCase): class TestLiteral(unittest.TestCase):
def test_constructor(self) -> None: def test_constructor(self) -> None:
@ -108,6 +137,22 @@ class TestLiteral(unittest.TestCase):
self.assertEqual(copy.copy(number), number) self.assertEqual(copy.copy(number), number)
self.assertEqual(copy.deepcopy(number), number) self.assertEqual(copy.deepcopy(number), number)
def test_basic_match(self) -> None:
match_works(
self, 'Literal("foo", language="en")', 'Literal("foo", language="en")'
)
match_works(
self,
'Literal("1", datatype=XSD_INTEGER)',
'Literal("1", datatype=NamedNode("http://www.w3.org/2001/XMLSchema#integer"))',
)
def test_wildcard_match(self) -> None:
match_works(self, 'Literal("foo", language="en")', "Literal(v, language=l)")
match_works(
self, 'Literal("1", datatype=XSD_INTEGER)', "Literal(v, datatype=d)"
)
class TestTriple(unittest.TestCase): class TestTriple(unittest.TestCase):
def test_constructor(self) -> None: def test_constructor(self) -> None:
@ -194,6 +239,14 @@ class TestTriple(unittest.TestCase):
self.assertEqual(copy.copy(triple), triple) self.assertEqual(copy.copy(triple), triple)
self.assertEqual(copy.deepcopy(triple), triple) self.assertEqual(copy.deepcopy(triple), triple)
def test_match(self) -> None:
match_works(
self,
'Triple(NamedNode("http://example.com/s"), NamedNode("http://example.com/p"), '
'NamedNode("http://example.com/o"))',
'Triple(NamedNode("http://example.com/s"), NamedNode(p), o)',
)
class TestDefaultGraph(unittest.TestCase): class TestDefaultGraph(unittest.TestCase):
def test_equal(self) -> None: def test_equal(self) -> None:
@ -205,6 +258,9 @@ class TestDefaultGraph(unittest.TestCase):
self.assertEqual(copy.copy(DefaultGraph()), DefaultGraph()) self.assertEqual(copy.copy(DefaultGraph()), DefaultGraph())
self.assertEqual(copy.deepcopy(DefaultGraph()), DefaultGraph()) self.assertEqual(copy.deepcopy(DefaultGraph()), DefaultGraph())
def test_match(self) -> None:
match_works(self, "DefaultGraph()", "DefaultGraph()")
class TestQuad(unittest.TestCase): class TestQuad(unittest.TestCase):
def test_constructor(self) -> None: def test_constructor(self) -> None:
@ -287,6 +343,14 @@ class TestQuad(unittest.TestCase):
self.assertEqual(copy.copy(quad), quad) self.assertEqual(copy.copy(quad), quad)
self.assertEqual(copy.deepcopy(quad), quad) self.assertEqual(copy.deepcopy(quad), quad)
def test_match(self) -> None:
match_works(
self,
'Quad(NamedNode("http://example.com/s"), NamedNode("http://example.com/p"), '
'NamedNode("http://example.com/o"), NamedNode("http://example.com/g"))',
'Quad(NamedNode("http://example.com/s"), NamedNode(p), o, NamedNode("http://example.com/g"))',
)
class TestVariable(unittest.TestCase): class TestVariable(unittest.TestCase):
def test_constructor(self) -> None: def test_constructor(self) -> None:
@ -305,6 +369,12 @@ class TestVariable(unittest.TestCase):
self.assertEqual(copy.copy(v), v) self.assertEqual(copy.copy(v), v)
self.assertEqual(copy.deepcopy(v), v) self.assertEqual(copy.deepcopy(v), v)
def test_basic_match(self) -> None:
match_works(self, 'Variable("foo")', 'Variable("foo")')
def test_wildcard_match(self) -> None:
match_works(self, 'Variable("foo")', "Variable(x)")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

Loading…
Cancel
Save