diff --git a/python/src/store.rs b/python/src/store.rs index 3bb7fca1..9cc5bdae 100644 --- a/python/src/store.rs +++ b/python/src/store.rs @@ -8,7 +8,7 @@ use oxigraph::model::{GraphName, GraphNameRef}; use oxigraph::sparql::Update; use oxigraph::store::{self, LoaderError, SerializerError, StorageError, Store}; use pyo3::exceptions::{self, PyIOError, PyRuntimeError, PyValueError}; -use pyo3::types::{IntoPyDict, PyBytes, PyString}; +use pyo3::types::{IntoPyDict, PyBytes, PyString, PyType}; use pyo3::{prelude::*, PyTypeInfo}; use pyo3::{Py, PyRef}; @@ -18,6 +18,24 @@ impl<'source> FromPyObject<'source> for InputValue { fn extract(ob: &'source pyo3::PyAny) -> PyResult { let gil = Python::acquire_gil(); let py = gil.python(); + + let py_io_module = PyModule::import(py, "io")?; + let py_rawiobase: &PyType = py_io_module + .getattr("RawIOBase") + .unwrap() + .extract() + .unwrap(); + let py_bufferediobase: &PyType = py_io_module + .getattr("BufferedIOBase") + .unwrap() + .extract() + .unwrap(); + let py_textiobase: &PyType = py_io_module + .getattr("TextIOBase") + .unwrap() + .extract() + .unwrap(); + if PyString::is_type_of(ob) { Ok(InputValue(ob.to_string())) } else if PyBytes::is_type_of(ob) { @@ -27,6 +45,19 @@ impl<'source> FromPyObject<'source> for InputValue { .unwrap() .to_string(), )) + } else if ob.is_instance(py_rawiobase).unwrap() + || ob.is_instance(py_bufferediobase).unwrap() + { + let kwargs = vec![("encoding", "utf-8")].into_py_dict(py); + Ok(InputValue( + ob.call_method0("read") + .unwrap() + .call_method("decode", (), Some(kwargs)) + .unwrap() + .to_string(), + )) + } else if ob.is_instance(py_textiobase).unwrap() { + Ok(InputValue(ob.call_method0("read").unwrap().to_string())) } else { Err(exceptions::PyTypeError::new_err(format!( "'{}' type is unsupported.", @@ -426,7 +457,7 @@ impl PyStore { }) } - /// Loads an RDF serialization into the store from a str or bytes buffer. + /// Loads an RDF serialization into the store from a buffer or a python I/O object. /// /// Loads are applied in a transactional manner: either the full operation succeeds or nothing is written to the database. /// The :py:func:`bulk_load` method is also available for much faster loading of big files but without transactional guarantees. @@ -446,7 +477,7 @@ impl PyStore { /// and ``application/xml`` for `RDF/XML `_. /// /// :param input: the data to be loaded into the store. - /// :type input: str or bytes + /// :type input: str or bytes or io.RawIOBase or io.BufferedIOBase or io.TextIOBase /// :param mime_type: the MIME type of the RDF serialization. /// :type mime_type: str /// :param base_iri: the base IRI used to resolve the relative IRIs in the file or :py:const:`None` if relative IRI resolution should not be done. diff --git a/python/tests/data.nq b/python/tests/data.nq new file mode 100644 index 00000000..97f52c15 --- /dev/null +++ b/python/tests/data.nq @@ -0,0 +1 @@ + . \ No newline at end of file diff --git a/python/tests/test_store.py b/python/tests/test_store.py index 7ab744ad..d30f4bcc 100644 --- a/python/tests/test_store.py +++ b/python/tests/test_store.py @@ -1,3 +1,4 @@ +import io import os import unittest from io import BytesIO, RawIOBase @@ -253,6 +254,29 @@ class TestStore(unittest.TestCase): store.load_data(data, mime_type="application/n-quads") self.assertEqual(set(store), {Quad(foo, bar, baz, graph)}) + def test_load_data_bytesio(self): + data = b" ." + input_io = io.BytesIO(data) + store = Store() + store.load_data(input_io, mime_type="application/n-quads") + self.assertEqual(set(store), {Quad(foo, bar, baz, graph)}) + + def test_load_data_stringio(self): + data = " ." + input_io = io.StringIO(data) + store = Store() + store.load_data(input_io, mime_type="application/n-quads") + self.assertEqual(set(store), {Quad(foo, bar, baz, graph)}) + + def test_load_data_fileio(self): + try: + input_io = io.FileIO("data.nq", "r") + store = Store() + store.load_data(input_io, mime_type="application/n-quads") + self.assertEqual(set(store), {Quad(foo, bar, baz, graph)}) + finally: + input_io.close() + def test_load_data_int(self): data = 123 with self.assertRaises(TypeError) as _: