From a8f98a00560166a9326442e3d4c1183433440247 Mon Sep 17 00:00:00 2001 From: Tpt Date: Tue, 12 Sep 2023 08:26:13 +0200 Subject: [PATCH] Python: makes serialization method output bytes if no output is specified --- python/src/io.rs | 56 ++++++++++++++++++++++++++------------ python/src/store.rs | 41 +++++++++++++++++----------- python/tests/test_io.py | 6 ++++ python/tests/test_store.py | 4 +-- 4 files changed, 70 insertions(+), 37 deletions(-) diff --git a/python/src/io.rs b/python/src/io.rs index 97ca9b3d..28e7d81b 100644 --- a/python/src/io.rs +++ b/python/src/io.rs @@ -105,31 +105,39 @@ pub fn parse( /// /// :param input: the RDF triples and quads to serialize. /// :type input: iterable(Triple) or iterable(Quad) -/// :param output: The binary I/O object or file path to write to. For example, it could be a file path as a string or a file writer opened in binary mode with ``open('my_file.ttl', 'wb')``. -/// :type output: io(bytes) or str or pathlib.Path +/// :param output: The binary I/O object or file path to write to. For example, it could be a file path as a string or a file writer opened in binary mode with ``open('my_file.ttl', 'wb')``. If :py:const:`None`, a :py:func:`bytes` buffer is returned with the serialized content. +/// :type output: io(bytes) or str or pathlib.Path or None, optional /// :param format: the format of the RDF serialization using a media type like ``text/turtle`` or an extension like `ttl`. If :py:const:`None`, the format is guessed from the file name extension. /// :type format: str or None, optional -/// :rtype: None +/// :rtype: bytes or None /// :raises ValueError: if the format is not supported. /// :raises TypeError: if a triple is given during a quad format serialization or reverse. /// +/// >>> serialize([Triple(NamedNode('http://example.com'), NamedNode('http://example.com/p'), Literal('1'))], format="ttl") +/// b' "1" .\n' +/// /// >>> output = io.BytesIO() /// >>> serialize([Triple(NamedNode('http://example.com'), NamedNode('http://example.com/p'), Literal('1'))], output, "text/turtle") /// >>> output.getvalue() /// b' "1" .\n' #[pyfunction] -pub fn serialize( +#[pyo3(signature = (input, output = None, /, format = None))] +pub fn serialize<'a>( input: &PyAny, - output: &PyAny, + output: Option<&PyAny>, format: Option<&str>, - py: Python<'_>, -) -> PyResult<()> { - let file_path = output.extract::().ok(); + py: Python<'a>, +) -> PyResult> { + let file_path = output.and_then(|output| output.extract::().ok()); let format = rdf_format(format, file_path.as_deref())?; - let output = if let Some(file_path) = &file_path { - PyWritable::from_file(file_path, py).map_err(map_io_err)? + let output = if let Some(output) = output { + if let Some(file_path) = &file_path { + PyWritable::from_file(file_path, py).map_err(map_io_err)? + } else { + PyWritable::from_data(output) + } } else { - PyWritable::from_data(output) + PyWritable::Bytes(Vec::new()) }; let mut writer = RdfSerializer::from_format(format).serialize_to_write(BufWriter::new(output)); for i in input.iter()? { @@ -153,8 +161,7 @@ pub fn serialize( .map_err(map_io_err)? .into_inner() .map_err(|e| map_io_err(e.into_error()))? - .close() - .map_err(map_io_err) + .close(py) } #[pyclass(name = "QuadReader", module = "pyoxigraph")] @@ -215,6 +222,7 @@ impl Read for PyReadable { } pub enum PyWritable { + Bytes(Vec), Io(PyIo), File(File), } @@ -228,18 +236,29 @@ impl PyWritable { Self::Io(PyIo(data.into())) } - pub fn close(mut self) -> io::Result<()> { - self.flush()?; - if let Self::File(file) = self { - file.sync_all()?; + pub fn close(self, py: Python<'_>) -> PyResult> { + match self { + Self::Bytes(bytes) => Ok(Some(PyBytes::new(py, &bytes))), + Self::File(mut file) => { + py.allow_threads(|| { + file.flush()?; + file.sync_all() + }) + .map_err(map_io_err)?; + Ok(None) + } + Self::Io(mut io) => { + py.allow_threads(|| io.flush()).map_err(map_io_err)?; + Ok(None) + } } - Ok(()) } } impl Write for PyWritable { fn write(&mut self, buf: &[u8]) -> io::Result { match self { + Self::Bytes(bytes) => bytes.write(buf), Self::Io(io) => io.write(buf), Self::File(file) => file.write(buf), } @@ -247,6 +266,7 @@ impl Write for PyWritable { fn flush(&mut self) -> io::Result<()> { match self { + Self::Bytes(_) => Ok(()), Self::Io(io) => io.flush(), Self::File(file) => file.flush(), } diff --git a/python/src/store.rs b/python/src/store.rs index 42dbc533..80a04709 100644 --- a/python/src/store.rs +++ b/python/src/store.rs @@ -10,6 +10,7 @@ use oxigraph::sparql::Update; use oxigraph::store::{self, LoaderError, SerializerError, StorageError, Store}; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; +use pyo3::types::PyBytes; use std::io::BufWriter; use std::path::PathBuf; @@ -496,41 +497,50 @@ impl PyStore { /// For example, ``application/turtle`` could also be used for `Turtle `_ /// and ``application/xml`` or ``xml`` for `RDF/XML `_. /// - /// :param output: The binary I/O object or file path to write to. For example, it could be a file path as a string or a file writer opened in binary mode with ``open('my_file.ttl', 'wb')``. - /// :type output: io(bytes) or str or pathlib.Path + /// :param output: The binary I/O object or file path to write to. For example, it could be a file path as a string or a file writer opened in binary mode with ``open('my_file.ttl', 'wb')``. If :py:const:`None`, a :py:func:`bytes` buffer is returned with the serialized content. + /// :type output: io(bytes) or str or pathlib.Path or None, optional /// :param format: the format of the RDF serialization using a media type like ``text/turtle`` or an extension like `ttl`. If :py:const:`None`, the format is guessed from the file name extension. /// :type format: str or None, optional /// :param from_graph: the store graph from which dump the triples. Required if the serialization format does not support named graphs. If it does supports named graphs the full dataset is written. /// :type from_graph: NamedNode or BlankNode or DefaultGraph or None, optional - /// :rtype: None + /// :rtype: bytes or None /// :raises ValueError: if the format is not supported or the `from_graph` parameter is not given with a syntax not supporting named graphs. /// :raises OSError: if an error happens during a quad lookup /// /// >>> store = Store() + /// >>> store.add(Quad(NamedNode('http://example.com'), NamedNode('http://example.com/p'), Literal('1'))) + /// >>> store.dump(format="trig") + /// b' "1" .\n' + /// + /// >>> store = Store() /// >>> store.add(Quad(NamedNode('http://example.com'), NamedNode('http://example.com/p'), Literal('1'), NamedNode('http://example.com/g'))) /// >>> output = io.BytesIO() /// >>> store.dump(output, "text/turtle", from_graph=NamedNode("http://example.com/g")) /// >>> output.getvalue() /// b' "1" .\n' - #[pyo3(signature = (output, /, format = None, *, from_graph = None))] - fn dump( + #[pyo3(signature = (output = None, /, format = None, *, from_graph = None))] + fn dump<'a>( &self, - output: &PyAny, + output: Option<&PyAny>, format: Option<&str>, from_graph: Option<&PyAny>, - py: Python<'_>, - ) -> PyResult<()> { + py: Python<'a>, + ) -> PyResult> { let from_graph_name = if let Some(graph_name) = from_graph { Some(GraphName::from(&PyGraphNameRef::try_from(graph_name)?)) } else { None }; - let file_path = output.extract::().ok(); + let file_path = output.and_then(|output| output.extract::().ok()); let format = rdf_format(format, file_path.as_deref())?; - let output = if let Some(file_path) = &file_path { - PyWritable::from_file(file_path, py).map_err(map_io_err)? + let output = if let Some(output) = output { + if let Some(file_path) = &file_path { + PyWritable::from_file(file_path, py).map_err(map_io_err)? + } else { + PyWritable::from_data(output) + } } else { - PyWritable::from_data(output) + PyWritable::Bytes(Vec::new()) }; py.allow_threads(|| { let output = BufWriter::new(output); @@ -541,10 +551,9 @@ impl PyStore { } .map_err(map_serializer_error)? .into_inner() - .map_err(|e| map_io_err(e.into_error()))? - .close() - .map_err(map_io_err) - }) + .map_err(|e| map_io_err(e.into_error())) + })? + .close(py) } /// Returns an iterator over all the store named graphs. diff --git a/python/tests/test_io.py b/python/tests/test_io.py index 006fc436..851b66b5 100644 --- a/python/tests/test_io.py +++ b/python/tests/test_io.py @@ -129,6 +129,12 @@ class TestParse(unittest.TestCase): class TestSerialize(unittest.TestCase): + def test_serialize_to_bytes(self) -> None: + self.assertEqual( + serialize([EXAMPLE_TRIPLE.triple], None, "text/turtle").decode(), + ' "éù" .\n', + ) + def test_serialize_to_bytes_io(self) -> None: output = BytesIO() serialize([EXAMPLE_TRIPLE.triple], output, "text/turtle") diff --git a/python/tests/test_store.py b/python/tests/test_store.py index 56f30b4b..4df54406 100644 --- a/python/tests/test_store.py +++ b/python/tests/test_store.py @@ -289,10 +289,8 @@ class TestStore(unittest.TestCase): def test_dump_nquads(self) -> None: store = Store() store.add(Quad(foo, bar, baz, graph)) - output = BytesIO() - store.dump(output, "nq") self.assertEqual( - output.getvalue(), + store.dump(format="nq"), b" .\n", )