From 68597ef35ad108f40605783596873903d0921414 Mon Sep 17 00:00:00 2001 From: Tpt Date: Mon, 10 Aug 2020 08:20:03 +0200 Subject: [PATCH] Avoids copy in Python bindings and adds better __eq__ implementations --- python/src/memory_store.rs | 23 ++-- python/src/model.rs | 265 ++++++++++++++++++++++++++++--------- python/src/sled_store.rs | 22 +-- python/src/sparql.rs | 14 +- python/src/store_utils.rs | 30 ++--- python/tests/test_model.py | 12 +- 6 files changed, 258 insertions(+), 108 deletions(-) diff --git a/python/src/memory_store.rs b/python/src/memory_store.rs index 26f0e6df..4fdbc7c7 100644 --- a/python/src/memory_store.rs +++ b/python/src/memory_store.rs @@ -3,12 +3,15 @@ use crate::model::*; use crate::sparql::*; use crate::store_utils::*; use oxigraph::io::{DatasetFormat, GraphFormat}; -use oxigraph::model::*; use oxigraph::store::memory::*; use pyo3::basic::CompareOp; use pyo3::exceptions::{NotImplementedError, ValueError}; -use pyo3::prelude::*; +use pyo3::prelude::{ + pyclass, pymethods, pyproto, Py, PyAny, PyCell, PyObject, PyRef, PyRefMut, PyResult, Python, + ToPyObject, +}; use pyo3::{PyIterProtocol, PyObjectProtocol, PySequenceProtocol}; +use std::convert::TryFrom; use std::io::BufReader; /// In-memory store. @@ -96,10 +99,10 @@ impl PyMemoryStore { extract_quads_pattern(subject, predicate, object, graph_name)?; Ok(QuadIter { inner: self.inner.quads_for_pattern( - subject.as_ref().map(|t| t.into()), - predicate.as_ref().map(|t| t.into()), - object.as_ref().map(|t| t.into()), - graph_name.as_ref().map(|t| t.into()), + subject.as_ref().map(|p| p.into()), + predicate.as_ref().map(|p| p.into()), + object.as_ref().map(|p| p.into()), + graph_name.as_ref().map(|p| p.into()), ), }) } @@ -253,7 +256,7 @@ impl PyMemoryStore { py: Python<'_>, ) -> PyResult<()> { let to_graph_name = if let Some(graph_name) = to_graph { - Some(extract_graph_name(graph_name)?) + Some(PyGraphNameRef::try_from(graph_name)?) } else { None }; @@ -263,7 +266,7 @@ impl PyMemoryStore { .load_graph( input, graph_format, - &to_graph_name.unwrap_or(GraphName::DefaultGraph), + &to_graph_name.unwrap_or(PyGraphNameRef::DefaultGraph), base_iri, ) .map_err(map_io_err) @@ -322,7 +325,7 @@ impl PyMemoryStore { py: Python<'_>, ) -> PyResult<()> { let from_graph_name = if let Some(graph_name) = from_graph { - Some(extract_graph_name(graph_name)?) + Some(PyGraphNameRef::try_from(graph_name)?) } else { None }; @@ -332,7 +335,7 @@ impl PyMemoryStore { .dump_graph( output, graph_format, - &from_graph_name.unwrap_or(GraphName::DefaultGraph), + &from_graph_name.unwrap_or(PyGraphNameRef::DefaultGraph), ) .map_err(map_io_err) } else if let Some(dataset_format) = DatasetFormat::from_media_type(mime_type) { diff --git a/python/src/model.rs b/python/src/model.rs index b385c581..50282829 100644 --- a/python/src/model.rs +++ b/python/src/model.rs @@ -3,8 +3,9 @@ use oxigraph::sparql::Variable; use pyo3::basic::CompareOp; use pyo3::exceptions::{IndexError, NotImplementedError, TypeError, ValueError}; use pyo3::prelude::*; -use pyo3::{PyIterProtocol, PyMappingProtocol, PyObjectProtocol}; +use pyo3::{PyIterProtocol, PyMappingProtocol, PyObjectProtocol, PyTypeInfo}; use std::collections::hash_map::DefaultHasher; +use std::convert::TryFrom; use std::hash::Hash; use std::hash::Hasher; use std::vec::IntoIter; @@ -92,8 +93,19 @@ impl PyObjectProtocol for PyNamedNode { hash(&self.inner) } - fn __richcmp__(&self, other: &PyCell, op: CompareOp) -> bool { - eq_ord_compare(self, &other.borrow(), op) + fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyResult { + if let Ok(other) = other.downcast::>() { + Ok(eq_ord_compare(self, &other.borrow(), op)) + } else if PyBlankNode::is_instance(other) + || PyLiteral::is_instance(other) + || PyDefaultGraph::is_instance(other) + { + eq_compare_other_type(op) + } else { + Err(TypeError::py_err( + "NamedNode could only be compared with RDF terms", + )) + } } } @@ -182,8 +194,19 @@ impl PyObjectProtocol for PyBlankNode { hash(&self.inner) } - fn __richcmp__(&self, other: &PyCell, op: CompareOp) -> PyResult { - eq_compare(self, &other.borrow(), op) + fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyResult { + if let Ok(other) = other.downcast::>() { + eq_compare(self, &other.borrow(), op) + } else if PyNamedNode::is_instance(other) + || PyLiteral::is_instance(other) + || PyDefaultGraph::is_instance(other) + { + eq_compare_other_type(op) + } else { + Err(TypeError::py_err( + "BlankNode could only be compared with RDF terms", + )) + } } } @@ -310,8 +333,19 @@ impl PyObjectProtocol for PyLiteral { hash(&self.inner) } - fn __richcmp__(&self, other: &PyCell, op: CompareOp) -> PyResult { - eq_compare(self, &other.borrow(), op) + fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyResult { + if let Ok(other) = other.downcast::>() { + eq_compare(self, &other.borrow(), op) + } else if PyNamedNode::is_instance(other) + || PyBlankNode::is_instance(other) + || PyDefaultGraph::is_instance(other) + { + eq_compare_other_type(op) + } else { + Err(TypeError::py_err( + "Literal could only be compared with RDF terms", + )) + } } } @@ -353,8 +387,19 @@ impl PyObjectProtocol for PyDefaultGraph { 0 } - fn __richcmp__(&self, other: &PyCell, op: CompareOp) -> PyResult { - eq_compare(self, &other.borrow(), op) + fn __richcmp__(&self, other: &PyAny, op: CompareOp) -> PyResult { + if let Ok(other) = other.downcast::>() { + eq_compare(self, &other.borrow(), op) + } else if PyNamedNode::is_instance(other) + || PyBlankNode::is_instance(other) + || PyLiteral::is_instance(other) + { + eq_compare_other_type(op) + } else { + Err(TypeError::py_err( + "DefaultGraph could only be compared with RDF terms", + )) + } } } @@ -403,11 +448,11 @@ impl<'a> From<&'a PyTriple> for TripleRef<'a> { #[pymethods] impl PyTriple { #[new] - fn new(subject: &PyAny, predicate: &PyAny, object: &PyAny) -> PyResult { + fn new(subject: &PyAny, predicate: PyNamedNode, object: &PyAny) -> PyResult { Ok(Triple::new( - extract_named_or_blank_node(subject)?, - extract_named_node(predicate)?, - extract_term(object)?, + &PyNamedOrBlankNodeRef::try_from(subject)?, + predicate, + &PyTermRef::try_from(object)?, ) .into()) } @@ -557,18 +602,18 @@ impl PyQuad { #[new] fn new( subject: &PyAny, - predicate: &PyAny, + predicate: PyNamedNode, object: &PyAny, graph_name: Option<&PyAny>, ) -> PyResult { Ok(Quad::new( - extract_named_or_blank_node(subject)?, - extract_named_node(predicate)?, - extract_term(object)?, - if let Some(graph_name) = graph_name { - extract_graph_name(graph_name)? + &PyNamedOrBlankNodeRef::try_from(subject)?, + predicate, + &PyTermRef::try_from(object)?, + &if let Some(graph_name) = graph_name { + PyGraphNameRef::try_from(graph_name)? } else { - GraphName::DefaultGraph + PyGraphNameRef::DefaultGraph }, ) .into()) @@ -768,27 +813,63 @@ impl PyObjectProtocol for PyVariable { } } -pub fn extract_named_node(py: &PyAny) -> PyResult { - if let Ok(node) = py.downcast::>() { - Ok(node.borrow().clone().into()) - } else { - Err(TypeError::py_err(format!( - "{} is not an RDF named node", - py.get_type().name(), - ))) +pub struct PyNamedNodeRef<'a>(PyRef<'a, PyNamedNode>); + +impl<'a> From<&'a PyNamedNodeRef<'a>> for NamedNodeRef<'a> { + fn from(value: &'a PyNamedNodeRef<'a>) -> Self { + value.0.inner.as_ref() } } -pub fn extract_named_or_blank_node(py: &PyAny) -> PyResult { - if let Ok(node) = py.downcast::>() { - Ok(node.borrow().clone().into()) - } else if let Ok(node) = py.downcast::>() { - Ok(node.borrow().clone().into()) - } else { - Err(TypeError::py_err(format!( - "{} is not an RDF named or blank node", - py.get_type().name(), - ))) +impl<'a> TryFrom<&'a PyAny> for PyNamedNodeRef<'a> { + type Error = PyErr; + + fn try_from(value: &'a PyAny) -> PyResult { + if let Ok(node) = value.downcast::>() { + Ok(Self(node.borrow())) + } else { + Err(TypeError::py_err(format!( + "{} is not an RDF named node", + value.get_type().name(), + ))) + } + } +} + +pub enum PyNamedOrBlankNodeRef<'a> { + NamedNode(PyRef<'a, PyNamedNode>), + BlankNode(PyRef<'a, PyBlankNode>), +} + +impl<'a> From<&'a PyNamedOrBlankNodeRef<'a>> for NamedOrBlankNodeRef<'a> { + fn from(value: &'a PyNamedOrBlankNodeRef<'a>) -> Self { + match value { + PyNamedOrBlankNodeRef::NamedNode(value) => value.inner.as_ref().into(), + PyNamedOrBlankNodeRef::BlankNode(value) => value.inner.as_ref().into(), + } + } +} + +impl<'a> From<&'a PyNamedOrBlankNodeRef<'a>> for NamedOrBlankNode { + fn from(value: &'a PyNamedOrBlankNodeRef<'a>) -> Self { + NamedOrBlankNodeRef::from(value).into() + } +} + +impl<'a> TryFrom<&'a PyAny> for PyNamedOrBlankNodeRef<'a> { + type Error = PyErr; + + fn try_from(value: &'a PyAny) -> PyResult { + if let Ok(node) = value.downcast::>() { + Ok(Self::NamedNode(node.borrow())) + } else if let Ok(node) = value.downcast::>() { + Ok(Self::BlankNode(node.borrow())) + } else { + Err(TypeError::py_err(format!( + "{} is not an RDF named or blank node", + value.get_type().name(), + ))) + } } } @@ -799,18 +880,44 @@ pub fn named_or_blank_node_to_python(py: Python<'_>, node: NamedOrBlankNode) -> } } -pub fn extract_term(py: &PyAny) -> PyResult { - if let Ok(node) = py.downcast::>() { - Ok(node.borrow().clone().into()) - } else if let Ok(node) = py.downcast::>() { - Ok(node.borrow().clone().into()) - } else if let Ok(literal) = py.downcast::>() { - Ok(literal.borrow().clone().into()) - } else { - Err(TypeError::py_err(format!( - "{} is not an RDF named or blank node", - py.get_type().name(), - ))) +pub enum PyTermRef<'a> { + NamedNode(PyRef<'a, PyNamedNode>), + BlankNode(PyRef<'a, PyBlankNode>), + Literal(PyRef<'a, PyLiteral>), +} + +impl<'a> From<&'a PyTermRef<'a>> for TermRef<'a> { + fn from(value: &'a PyTermRef<'a>) -> Self { + match value { + PyTermRef::NamedNode(value) => value.inner.as_ref().into(), + PyTermRef::BlankNode(value) => value.inner.as_ref().into(), + PyTermRef::Literal(value) => value.inner.as_ref().into(), + } + } +} + +impl<'a> From<&'a PyTermRef<'a>> for Term { + fn from(value: &'a PyTermRef<'a>) -> Self { + TermRef::from(value).into() + } +} + +impl<'a> TryFrom<&'a PyAny> for PyTermRef<'a> { + type Error = PyErr; + + fn try_from(value: &'a PyAny) -> PyResult { + if let Ok(node) = value.downcast::>() { + Ok(Self::NamedNode(node.borrow())) + } else if let Ok(node) = value.downcast::>() { + Ok(Self::BlankNode(node.borrow())) + } else if let Ok(node) = value.downcast::>() { + Ok(Self::Literal(node.borrow())) + } else { + Err(TypeError::py_err(format!( + "{} is not an RDF term", + value.get_type().name(), + ))) + } } } @@ -822,18 +929,44 @@ pub fn term_to_python(py: Python<'_>, term: Term) -> PyObject { } } -pub fn extract_graph_name(py: &PyAny) -> PyResult { - if let Ok(node) = py.downcast::>() { - Ok(node.borrow().clone().into()) - } else if let Ok(node) = py.downcast::>() { - Ok(node.borrow().clone().into()) - } else if let Ok(node) = py.downcast::>() { - Ok(node.borrow().clone().into()) - } else { - Err(TypeError::py_err(format!( - "{} is not a valid RDF graph name", - py.get_type().name(), - ))) +pub enum PyGraphNameRef<'a> { + NamedNode(PyRef<'a, PyNamedNode>), + BlankNode(PyRef<'a, PyBlankNode>), + DefaultGraph, +} + +impl<'a> From<&'a PyGraphNameRef<'a>> for GraphNameRef<'a> { + fn from(value: &'a PyGraphNameRef<'a>) -> Self { + match value { + PyGraphNameRef::NamedNode(value) => value.inner.as_ref().into(), + PyGraphNameRef::BlankNode(value) => value.inner.as_ref().into(), + PyGraphNameRef::DefaultGraph => Self::DefaultGraph, + } + } +} + +impl<'a> From<&'a PyGraphNameRef<'a>> for GraphName { + fn from(value: &'a PyGraphNameRef<'a>) -> Self { + GraphNameRef::from(value).into() + } +} + +impl<'a> TryFrom<&'a PyAny> for PyGraphNameRef<'a> { + type Error = PyErr; + + fn try_from(value: &'a PyAny) -> PyResult { + if let Ok(node) = value.downcast::>() { + Ok(Self::NamedNode(node.borrow())) + } else if let Ok(node) = value.downcast::>() { + Ok(Self::BlankNode(node.borrow())) + } else if value.downcast::>().is_ok() { + Ok(Self::DefaultGraph) + } else { + Err(TypeError::py_err(format!( + "{} is not an RDF graph name", + value.get_type().name(), + ))) + } } } @@ -853,6 +986,14 @@ fn eq_compare(a: &T, b: &T, op: CompareOp) -> PyResult { } } +fn eq_compare_other_type(op: CompareOp) -> PyResult { + match op { + CompareOp::Eq => Ok(false), + CompareOp::Ne => Ok(true), + _ => Err(NotImplementedError::py_err("Ordering is not implemented")), + } +} + fn eq_ord_compare(a: &T, b: &T, op: CompareOp) -> bool { match op { CompareOp::Lt => a < b, diff --git a/python/src/sled_store.rs b/python/src/sled_store.rs index 4d9a818c..129bc573 100644 --- a/python/src/sled_store.rs +++ b/python/src/sled_store.rs @@ -3,11 +3,13 @@ use crate::model::*; use crate::sparql::*; use crate::store_utils::*; use oxigraph::io::{DatasetFormat, GraphFormat}; -use oxigraph::model::*; use oxigraph::store::sled::*; use pyo3::exceptions::ValueError; -use pyo3::prelude::*; +use pyo3::prelude::{ + pyclass, pymethods, pyproto, Py, PyAny, PyObject, PyRef, PyRefMut, PyResult, Python, ToPyObject, +}; use pyo3::{PyIterProtocol, PyObjectProtocol, PySequenceProtocol}; +use std::convert::TryFrom; use std::io::BufReader; /// Store based on the `Sled `_ key-value database. @@ -110,10 +112,10 @@ impl PySledStore { extract_quads_pattern(subject, predicate, object, graph_name)?; Ok(QuadIter { inner: self.inner.quads_for_pattern( - subject.as_ref().map(|t| t.into()), - predicate.as_ref().map(|t| t.into()), - object.as_ref().map(|t| t.into()), - graph_name.as_ref().map(|t| t.into()), + subject.as_ref().map(|p| p.into()), + predicate.as_ref().map(|p| p.into()), + object.as_ref().map(|p| p.into()), + graph_name.as_ref().map(|p| p.into()), ), }) } @@ -270,7 +272,7 @@ impl PySledStore { py: Python<'_>, ) -> PyResult<()> { let to_graph_name = if let Some(graph_name) = to_graph { - Some(extract_graph_name(graph_name)?) + Some(PyGraphNameRef::try_from(graph_name)?) } else { None }; @@ -280,7 +282,7 @@ impl PySledStore { .load_graph( input, graph_format, - &to_graph_name.unwrap_or(GraphName::DefaultGraph), + &to_graph_name.unwrap_or(PyGraphNameRef::DefaultGraph), base_iri, ) .map_err(map_io_err) @@ -340,7 +342,7 @@ impl PySledStore { py: Python<'_>, ) -> PyResult<()> { let from_graph_name = if let Some(graph_name) = from_graph { - Some(extract_graph_name(graph_name)?) + Some(PyGraphNameRef::try_from(graph_name)?) } else { None }; @@ -350,7 +352,7 @@ impl PySledStore { .dump_graph( output, graph_format, - &from_graph_name.unwrap_or(GraphName::DefaultGraph), + &from_graph_name.unwrap_or(PyGraphNameRef::DefaultGraph), ) .map_err(map_io_err) } else if let Some(dataset_format) = DatasetFormat::from_media_type(mime_type) { diff --git a/python/src/sparql.rs b/python/src/sparql.rs index 7a909039..181cb112 100644 --- a/python/src/sparql.rs +++ b/python/src/sparql.rs @@ -2,8 +2,12 @@ use crate::model::*; use crate::store_utils::*; use oxigraph::sparql::*; use pyo3::exceptions::{RuntimeError, SyntaxError, TypeError, ValueError}; -use pyo3::prelude::*; +use pyo3::prelude::{ + pyclass, pymethods, pyproto, FromPyObject, IntoPy, Py, PyAny, PyCell, PyErr, PyObject, + PyRefMut, PyResult, Python, +}; use pyo3::{PyIterProtocol, PyMappingProtocol, PyNativeType, PyObjectProtocol}; +use std::convert::TryFrom; pub fn build_query_options( use_default_graph_as_union: bool, @@ -36,10 +40,10 @@ pub fn build_query_options( )); } for default_graph in default_graphs { - options = options.with_default_graph(extract_graph_name(default_graph?)?); + options = options.with_default_graph(&PyGraphNameRef::try_from(default_graph?)?); } - } else if let Ok(default_graph) = extract_graph_name(default_graph) { - options = options.with_default_graph(default_graph); + } else if let Ok(default_graph) = PyGraphNameRef::try_from(default_graph) { + options = options.with_default_graph(&default_graph); } else { return Err(ValueError::py_err( format!("The query() method default_graph argument should be a NamedNode, a BlankNode, the DefaultGraph or a not empty list of them. {} found", default_graph.get_type() @@ -54,7 +58,7 @@ pub fn build_query_options( )); } for named_graph in named_graphs.iter()? { - options = options.with_named_graph(extract_named_or_blank_node(named_graph?)?); + options = options.with_named_graph(&PyNamedOrBlankNodeRef::try_from(named_graph?)?); } } diff --git a/python/src/store_utils.rs b/python/src/store_utils.rs index c483f719..2e4e9681 100644 --- a/python/src/store_utils.rs +++ b/python/src/store_utils.rs @@ -1,41 +1,41 @@ use crate::model::*; -use oxigraph::model::*; use pyo3::exceptions::{IOError, SyntaxError, ValueError}; -use pyo3::prelude::*; +use pyo3::{PyAny, PyErr, PyResult}; +use std::convert::TryInto; use std::io; -pub fn extract_quads_pattern( - subject: &PyAny, - predicate: &PyAny, - object: &PyAny, - graph_name: Option<&PyAny>, +pub fn extract_quads_pattern<'a>( + subject: &'a PyAny, + predicate: &'a PyAny, + object: &'a PyAny, + graph_name: Option<&'a PyAny>, ) -> PyResult<( - Option, - Option, - Option, - Option, + Option>, + Option>, + Option>, + Option>, )> { Ok(( if subject.is_none() { None } else { - Some(extract_named_or_blank_node(subject)?) + Some(subject.try_into()?) }, if predicate.is_none() { None } else { - Some(extract_named_node(predicate)?) + Some(predicate.try_into()?) }, if object.is_none() { None } else { - Some(extract_term(object)?) + Some(object.try_into()?) }, if let Some(graph_name) = graph_name { if graph_name.is_none() { None } else { - Some(extract_graph_name(graph_name)?) + Some(graph_name.try_into()?) } } else { None diff --git a/python/tests/test_model.py b/python/tests/test_model.py index 70f932f7..c878f2a5 100644 --- a/python/tests/test_model.py +++ b/python/tests/test_model.py @@ -29,8 +29,8 @@ class TestBlankNode(unittest.TestCase): def test_equal(self): self.assertEqual(BlankNode("foo"), BlankNode("foo")) self.assertNotEqual(BlankNode("foo"), BlankNode("bar")) - # TODO self.assertNotEqual(BlankNode('foo'), NamedNode('http://foo')) - # TODO self.assertNotEqual(NamedNode('http://foo'), BlankNode('foo')) + self.assertNotEqual(BlankNode('foo'), NamedNode('http://foo')) + self.assertNotEqual(NamedNode('http://foo'), BlankNode('foo')) class TestLiteral(unittest.TestCase): @@ -59,10 +59,10 @@ class TestLiteral(unittest.TestCase): Literal("foo", language="en", datatype=RDF_LANG_STRING), Literal("foo", language="en"), ) - # TODO self.assertNotEqual(NamedNode('http://foo'), Literal('foo')) - # TODO self.assertNotEqual(Literal('foo'), NamedNode('http://foo')) - # TODO self.assertNotEqual(BlankNode('foo'), Literal('foo')) - # TODO self.assertNotEqual(Literal('foo'), BlankNode('foo')) + self.assertNotEqual(NamedNode('http://foo'), Literal('foo')) + self.assertNotEqual(Literal('foo'), NamedNode('http://foo')) + self.assertNotEqual(BlankNode('foo'), Literal('foo')) + self.assertNotEqual(Literal('foo'), BlankNode('foo')) class TestTriple(unittest.TestCase):