diff --git a/python/src/model.rs b/python/src/model.rs index 0ca1bbc4..701b800f 100644 --- a/python/src/model.rs +++ b/python/src/model.rs @@ -3,6 +3,7 @@ use oxigraph::sparql::Variable; use pyo3::basic::CompareOp; use pyo3::exceptions::{PyIndexError, PyNotImplementedError, PyTypeError, PyValueError}; use pyo3::prelude::*; +use pyo3::types::{PyDict, PyTuple}; use pyo3::PyTypeInfo; use std::collections::hash_map::DefaultHasher; use std::hash::Hash; @@ -109,6 +110,10 @@ impl PyNamedNode { )) } } + + fn __getnewargs__(&self) -> (&str,) { + (self.value(),) + } } /// An RDF `blank node `_. @@ -215,6 +220,10 @@ impl PyBlankNode { )) } } + + fn __getnewargs__(&self) -> (&str,) { + (self.value(),) + } } /// An RDF `literal `_. @@ -351,6 +360,16 @@ impl PyLiteral { )) } } + + fn __getnewargs_ex__<'a>(&'a self, py: Python<'a>) -> PyResult<((&'a str,), &'a PyDict)> { + let kwargs = PyDict::new(py); + if let Some(language) = self.language() { + kwargs.set_item("language", language)?; + } else { + kwargs.set_item("datatype", self.datatype().into_py(py))?; + } + Ok(((self.value(),), kwargs)) + } } /// The RDF `default graph name `_. @@ -405,6 +424,10 @@ impl PyDefaultGraph { )) } } + + fn __getnewargs__<'a>(&self, py: Python<'a>) -> &'a PyTuple { + PyTuple::empty(py) + } } #[derive(FromPyObject)] @@ -650,6 +673,10 @@ impl PyTriple { .into_iter(), } } + + fn __getnewargs__(&self) -> (PySubject, PyNamedNode, PyTerm) { + (self.subject(), self.predicate(), self.object()) + } } #[derive(FromPyObject)] @@ -861,6 +888,15 @@ impl PyQuad { .into_iter(), } } + + fn __getnewargs__(&self) -> (PySubject, PyNamedNode, PyTerm, PyGraphName) { + ( + self.subject(), + self.predicate(), + self.object(), + self.graph_name(), + ) + } } /// A SPARQL query variable. @@ -932,6 +968,10 @@ impl PyVariable { fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult { eq_compare(self, other, op) } + + fn __getnewargs__(&self) -> (&str,) { + (self.value(),) + } } pub struct PyNamedNodeRef<'a>(PyRef<'a, PyNamedNode>); diff --git a/python/tests/test_model.py b/python/tests/test_model.py index 63619274..079ec08b 100644 --- a/python/tests/test_model.py +++ b/python/tests/test_model.py @@ -1,3 +1,4 @@ +import pickle import unittest from pyoxigraph import ( @@ -26,6 +27,11 @@ class TestNamedNode(unittest.TestCase): self.assertEqual(NamedNode("http://foo"), NamedNode("http://foo")) self.assertNotEqual(NamedNode("http://foo"), NamedNode("http://bar")) + def test_pickle(self) -> None: + node = NamedNode("http://foo") + self.assertEqual(NamedNode(*node.__getnewargs__()), node) + self.assertEqual(pickle.loads(pickle.dumps(node)), node) + class TestBlankNode(unittest.TestCase): def test_constructor(self) -> None: @@ -41,6 +47,12 @@ class TestBlankNode(unittest.TestCase): self.assertNotEqual(BlankNode("foo"), NamedNode("http://foo")) self.assertNotEqual(NamedNode("http://foo"), BlankNode("foo")) + def test_pickle(self) -> None: + node = BlankNode("foo") + self.assertEqual(pickle.loads(pickle.dumps(node)), node) + auto = BlankNode() + self.assertEqual(pickle.loads(pickle.dumps(auto)), auto) + class TestLiteral(unittest.TestCase): def test_constructor(self) -> None: @@ -73,6 +85,14 @@ class TestLiteral(unittest.TestCase): self.assertNotEqual(BlankNode("foo"), Literal("foo")) self.assertNotEqual(Literal("foo"), BlankNode("foo")) + def test_pickle(self) -> None: + simple = Literal("foo") + self.assertEqual(pickle.loads(pickle.dumps(simple)), simple) + lang_tagged = Literal("foo", language="en") + self.assertEqual(pickle.loads(pickle.dumps(lang_tagged)), lang_tagged) + number = Literal("1", datatype=XSD_INTEGER) + self.assertEqual(pickle.loads(pickle.dumps(number)), number) + class TestTriple(unittest.TestCase): def test_constructor(self) -> None: @@ -149,6 +169,23 @@ class TestTriple(unittest.TestCase): " ", ) + def test_pickle(self) -> None: + triple = Triple( + NamedNode("http://example.com/s"), + NamedNode("http://example.com/p"), + NamedNode("http://example.com/o"), + ) + self.assertEqual(pickle.loads(pickle.dumps(triple)), triple) + + +class TestDefaultGraph(unittest.TestCase): + def test_equal(self) -> None: + self.assertEqual(DefaultGraph(), DefaultGraph()) + self.assertNotEqual(DefaultGraph(), NamedNode("http://bar")) + + def test_pickle(self) -> None: + self.assertEqual(pickle.loads(pickle.dumps(DefaultGraph())), DefaultGraph()) + class TestQuad(unittest.TestCase): def test_constructor(self) -> None: @@ -220,6 +257,15 @@ class TestQuad(unittest.TestCase): " ", ) + def test_pickle(self) -> None: + quad = Quad( + NamedNode("http://example.com/s"), + NamedNode("http://example.com/p"), + NamedNode("http://example.com/o"), + NamedNode("http://example.com/g"), + ) + self.assertEqual(pickle.loads(pickle.dumps(quad)), quad) + class TestVariable(unittest.TestCase): def test_constructor(self) -> None: @@ -232,6 +278,10 @@ class TestVariable(unittest.TestCase): self.assertEqual(Variable("foo"), Variable("foo")) self.assertNotEqual(Variable("foo"), Variable("bar")) + def test_pickle(self) -> None: + v = Variable("foo") + self.assertEqual(pickle.loads(pickle.dumps(v)), v) + if __name__ == "__main__": unittest.main()