Python: Allow to pickle all basic model classes

pull/430/head
Tpt 2 years ago committed by Thomas Tanon
parent 935e778db1
commit d4e964ac47
  1. 40
      python/src/model.rs
  2. 50
      python/tests/test_model.py

@ -3,6 +3,7 @@ use oxigraph::sparql::Variable;
use pyo3::basic::CompareOp;
use pyo3::exceptions::{PyIndexError, PyNotImplementedError, PyTypeError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyTuple};
use pyo3::PyTypeInfo;
use std::collections::hash_map::DefaultHasher;
use std::hash::Hash;
@ -109,6 +110,10 @@ impl PyNamedNode {
))
}
}
fn __getnewargs__(&self) -> (&str,) {
(self.value(),)
}
}
/// An RDF `blank node <https://www.w3.org/TR/rdf11-concepts/#dfn-blank-node>`_.
@ -215,6 +220,10 @@ impl PyBlankNode {
))
}
}
fn __getnewargs__(&self) -> (&str,) {
(self.value(),)
}
}
/// An RDF `literal <https://www.w3.org/TR/rdf11-concepts/#dfn-literal>`_.
@ -351,6 +360,16 @@ impl PyLiteral {
))
}
}
fn __getnewargs_ex__<'a>(&'a self, py: Python<'a>) -> PyResult<((&'a str,), &'a PyDict)> {
let kwargs = PyDict::new(py);
if let Some(language) = self.language() {
kwargs.set_item("language", language)?;
} else {
kwargs.set_item("datatype", self.datatype().into_py(py))?;
}
Ok(((self.value(),), kwargs))
}
}
/// The RDF `default graph name <https://www.w3.org/TR/rdf11-concepts/#dfn-default-graph>`_.
@ -405,6 +424,10 @@ impl PyDefaultGraph {
))
}
}
fn __getnewargs__<'a>(&self, py: Python<'a>) -> &'a PyTuple {
PyTuple::empty(py)
}
}
#[derive(FromPyObject)]
@ -650,6 +673,10 @@ impl PyTriple {
.into_iter(),
}
}
fn __getnewargs__(&self) -> (PySubject, PyNamedNode, PyTerm) {
(self.subject(), self.predicate(), self.object())
}
}
#[derive(FromPyObject)]
@ -861,6 +888,15 @@ impl PyQuad {
.into_iter(),
}
}
fn __getnewargs__(&self) -> (PySubject, PyNamedNode, PyTerm, PyGraphName) {
(
self.subject(),
self.predicate(),
self.object(),
self.graph_name(),
)
}
}
/// A SPARQL query variable.
@ -932,6 +968,10 @@ impl PyVariable {
fn __richcmp__(&self, other: &Self, op: CompareOp) -> PyResult<bool> {
eq_compare(self, other, op)
}
fn __getnewargs__(&self) -> (&str,) {
(self.value(),)
}
}
pub struct PyNamedNodeRef<'a>(PyRef<'a, PyNamedNode>);

@ -1,3 +1,4 @@
import pickle
import unittest
from pyoxigraph import (
@ -26,6 +27,11 @@ class TestNamedNode(unittest.TestCase):
self.assertEqual(NamedNode("http://foo"), NamedNode("http://foo"))
self.assertNotEqual(NamedNode("http://foo"), NamedNode("http://bar"))
def test_pickle(self) -> None:
node = NamedNode("http://foo")
self.assertEqual(NamedNode(*node.__getnewargs__()), node)
self.assertEqual(pickle.loads(pickle.dumps(node)), node)
class TestBlankNode(unittest.TestCase):
def test_constructor(self) -> None:
@ -41,6 +47,12 @@ class TestBlankNode(unittest.TestCase):
self.assertNotEqual(BlankNode("foo"), NamedNode("http://foo"))
self.assertNotEqual(NamedNode("http://foo"), BlankNode("foo"))
def test_pickle(self) -> None:
node = BlankNode("foo")
self.assertEqual(pickle.loads(pickle.dumps(node)), node)
auto = BlankNode()
self.assertEqual(pickle.loads(pickle.dumps(auto)), auto)
class TestLiteral(unittest.TestCase):
def test_constructor(self) -> None:
@ -73,6 +85,14 @@ class TestLiteral(unittest.TestCase):
self.assertNotEqual(BlankNode("foo"), Literal("foo"))
self.assertNotEqual(Literal("foo"), BlankNode("foo"))
def test_pickle(self) -> None:
simple = Literal("foo")
self.assertEqual(pickle.loads(pickle.dumps(simple)), simple)
lang_tagged = Literal("foo", language="en")
self.assertEqual(pickle.loads(pickle.dumps(lang_tagged)), lang_tagged)
number = Literal("1", datatype=XSD_INTEGER)
self.assertEqual(pickle.loads(pickle.dumps(number)), number)
class TestTriple(unittest.TestCase):
def test_constructor(self) -> None:
@ -149,6 +169,23 @@ class TestTriple(unittest.TestCase):
"<http://example.com/s> <http://example.com/p> <http://example.com/o>",
)
def test_pickle(self) -> None:
triple = Triple(
NamedNode("http://example.com/s"),
NamedNode("http://example.com/p"),
NamedNode("http://example.com/o"),
)
self.assertEqual(pickle.loads(pickle.dumps(triple)), triple)
class TestDefaultGraph(unittest.TestCase):
def test_equal(self) -> None:
self.assertEqual(DefaultGraph(), DefaultGraph())
self.assertNotEqual(DefaultGraph(), NamedNode("http://bar"))
def test_pickle(self) -> None:
self.assertEqual(pickle.loads(pickle.dumps(DefaultGraph())), DefaultGraph())
class TestQuad(unittest.TestCase):
def test_constructor(self) -> None:
@ -220,6 +257,15 @@ class TestQuad(unittest.TestCase):
"<http://example.com/s> <http://example.com/p> <http://example.com/o>",
)
def test_pickle(self) -> None:
quad = Quad(
NamedNode("http://example.com/s"),
NamedNode("http://example.com/p"),
NamedNode("http://example.com/o"),
NamedNode("http://example.com/g"),
)
self.assertEqual(pickle.loads(pickle.dumps(quad)), quad)
class TestVariable(unittest.TestCase):
def test_constructor(self) -> None:
@ -232,6 +278,10 @@ class TestVariable(unittest.TestCase):
self.assertEqual(Variable("foo"), Variable("foo"))
self.assertNotEqual(Variable("foo"), Variable("bar"))
def test_pickle(self) -> None:
v = Variable("foo")
self.assertEqual(pickle.loads(pickle.dumps(v)), v)
if __name__ == "__main__":
unittest.main()

Loading…
Cancel
Save