From c99f392fe6407e33f2f9e34b1b6b11270ebdd5a3 Mon Sep 17 00:00:00 2001 From: Edmond Date: Wed, 3 Aug 2022 03:50:13 +0000 Subject: [PATCH] Store.load_data supports input types str and bytes. Update tests and add a new test to check for TypeError --- python/src/store.rs | 55 ++++++++++++++++++++++++++++---------- python/tests/test_store.py | 21 +++++++++++---- 2 files changed, 57 insertions(+), 19 deletions(-) diff --git a/python/src/store.rs b/python/src/store.rs index 1bacff99..a9accda2 100644 --- a/python/src/store.rs +++ b/python/src/store.rs @@ -7,10 +7,35 @@ use oxigraph::io::{DatasetFormat, GraphFormat}; use oxigraph::model::{GraphName, GraphNameRef}; use oxigraph::sparql::Update; use oxigraph::store::{self, LoaderError, SerializerError, StorageError, Store}; -use pyo3::exceptions::{PyIOError, PyRuntimeError, PyValueError}; -use pyo3::prelude::*; +use pyo3::exceptions::{self, PyIOError, PyRuntimeError, PyValueError}; +use pyo3::types::{IntoPyDict, PyBytes, PyString}; +use pyo3::{prelude::*, PyTypeInfo}; use pyo3::{Py, PyRef}; +struct InputValue(String); + +impl<'source> FromPyObject<'source> for InputValue { + fn extract(ob: &'source pyo3::PyAny) -> PyResult { + let gil = Python::acquire_gil(); + let py = gil.python(); + if PyString::is_type_of(ob) { + Ok(InputValue(ob.to_string())) + } else if PyBytes::is_type_of(ob) { + let kwargs = vec![("encoding", "utf-8")].into_py_dict(py); + Ok(InputValue( + ob.call_method("decode", (), Some(kwargs)) + .unwrap() + .to_string(), + )) + } else { + Err(exceptions::PyTypeError::new_err(format!( + "'{}' type is unsupported.", + ob.get_type().name().unwrap() + ))) + } + } +} + /// RDF store. /// /// It encodes a `RDF dataset `_ and allows to query it using SPARQL. @@ -322,7 +347,7 @@ impl PyStore { }) } - /// Loads an RDF serialization into the store from a stream. + /// Loads an RDF serialization into the store from a file. /// /// Loads are applied in a transactional manner: either the full operation succeeds or nothing is written to the database. /// The :py:func:`bulk_load` method is also available for much faster loading of big files but without transactional guarantees. @@ -354,12 +379,12 @@ impl PyStore { /// :raises IOError: if an I/O error happens during a quad insertion. /// /// >>> store = Store() - /// >>> store.load_from_stream(io.BytesIO(b'

"1" .'), "text/turtle", base_iri="http://example.com/", to_graph=NamedNode("http://example.com/g")) + /// >>> store.load(io.BytesIO(b'

"1" .'), "text/turtle", base_iri="http://example.com/", to_graph=NamedNode("http://example.com/g")) /// >>> list(store) /// [ predicate= object=> graph_name=>] #[pyo3(text_signature = "($self, data, /, mime_type, *, base_iri = None, to_graph = None)")] #[args(input, mime_type, "*", base_iri = "None", to_graph = "None")] - fn load_from_stream( + fn load_file( &self, input: PyObject, mime_type: &str, @@ -401,7 +426,7 @@ impl PyStore { }) } - /// Loads an RDF serialization into the store from a string. + /// Loads an RDF serialization into the store from a str or bytes buffer. /// /// Loads are applied in a transactional manner: either the full operation succeeds or nothing is written to the database. /// The :py:func:`bulk_load` method is also available for much faster loading of big files but without transactional guarantees. @@ -420,8 +445,8 @@ impl PyStore { /// For example, ``application/turtle`` could also be used for `Turtle `_ /// and ``application/xml`` for `RDF/XML `_. /// - /// :param input: The RDF string data - /// :type input: str + /// :param input: the data to be loaded into the store. + /// :type input: str or bytes /// :param mime_type: the MIME type of the RDF serialization. /// :type mime_type: str /// :param base_iri: the base IRI used to resolve the relative IRIs in the file or :py:const:`None` if relative IRI resolution should not be done. @@ -432,16 +457,16 @@ impl PyStore { /// :raises SyntaxError: if the provided data is invalid. /// :raises IOError: if an I/O error happens during a quad insertion. /// - /// >>> store = Store() /// >>> data = '

"1" .' - /// >>> store.load_from_data(data, "text/turtle", base_iri="http://example.com/", to_graph=NamedNode("http://example.com/g")) + /// >>> store = Store() + /// >>> store.load_data(data, "text/turtle", base_iri="http://example.com/", to_graph=NamedNode("http://example.com/g")) /// >>> list(store) /// [ predicate= object=> graph_name=>] #[pyo3(text_signature = "($self, data, /, mime_type, *, base_iri = None, to_graph = None)")] #[args(input, mime_type, "*", base_iri = "None", to_graph = "None")] - fn load_from_data( + fn load_data( &self, - input: &str, + input: InputValue, mime_type: &str, base_iri: Option<&str>, to_graph: Option<&PyAny>, @@ -452,11 +477,13 @@ impl PyStore { } else { None }; + let InputValue(value) = input; + py.allow_threads(|| { if let Some(graph_format) = GraphFormat::from_media_type(mime_type) { self.inner .load_graph( - input.as_ref(), + value.as_ref(), graph_format, to_graph_name.as_ref().unwrap_or(&GraphName::DefaultGraph), base_iri, @@ -469,7 +496,7 @@ impl PyStore { )); } self.inner - .load_dataset(input.as_ref(), dataset_format, base_iri) + .load_dataset(value.as_ref(), dataset_format, base_iri) .map_err(map_loader_error) } else { Err(PyValueError::new_err(format!( diff --git a/python/tests/test_store.py b/python/tests/test_store.py index f74f2ebf..7ab744ad 100644 --- a/python/tests/test_store.py +++ b/python/tests/test_store.py @@ -223,7 +223,7 @@ class TestStore(unittest.TestCase): ) self.assertEqual(set(store), {Quad(foo, bar, baz, graph)}) - def test_load_file(self): + def test_load(self): with NamedTemporaryFile(delete=False) as fp: file_name = fp.name fp.write(b" .") @@ -232,21 +232,32 @@ class TestStore(unittest.TestCase): os.remove(file_name) self.assertEqual(set(store), {Quad(foo, bar, baz, graph)}) - def test_load_from_stream(self): + def test_load_file(self): with NamedTemporaryFile(delete=False) as fp: file_name = fp.name fp.write(b" .") store = Store() - store.load_from_stream(file_name, mime_type="application/n-quads") + store.load_file(file_name, mime_type="application/n-quads") os.remove(file_name) self.assertEqual(set(store), {Quad(foo, bar, baz, graph)}) - def test_load_from_data(self): + def test_load_data_str(self): data = " ." store = Store() - store.load_from_data(data, mime_type="application/n-quads") + store.load_data(data, mime_type="application/n-quads") + self.assertEqual(set(store), {Quad(foo, bar, baz, graph)}) + + def test_load_data_bytestr(self): + data = b" ." + store = Store() + store.load_data(data, mime_type="application/n-quads") self.assertEqual(set(store), {Quad(foo, bar, baz, graph)}) + def test_load_data_int(self): + data = 123 + with self.assertRaises(TypeError) as _: + Store().load_data(data, mime_type="application/n-triples") + def test_load_with_io_error(self): class BadIO(RawIOBase): pass