diff --git a/python/src/io.rs b/python/src/io.rs index 6714ccea..da1c288f 100644 --- a/python/src/io.rs +++ b/python/src/io.rs @@ -10,7 +10,7 @@ use pyo3::prelude::*; use pyo3::types::PyBytes; use pyo3::wrap_pyfunction; use std::fs::File; -use std::io::{self, BufReader, BufWriter, Read, Write}; +use std::io::{self, BufRead, BufReader, BufWriter, Cursor, Read, Write}; pub fn add_to_module(module: &PyModule) -> PyResult<()> { module.add_wrapped(wrap_pyfunction!(parse))?; @@ -32,7 +32,7 @@ pub fn add_to_module(module: &PyModule) -> PyResult<()> { /// and ``application/xml`` for `RDF/XML `_. /// /// :param input: The binary I/O object or file path to read from. For example, it could be a file path as a string or a file reader opened in binary mode with ``open('my_file.ttl', 'rb')``. -/// :type input: io.RawIOBase or io.BufferedIOBase or str +/// :type input: io.RawIOBase or io.BufferedIOBase or io.TextIOBase or str /// :param mime_type: the MIME type of the RDF serialization. /// :type mime_type: str /// :param base_iri: the base IRI used to resolve the relative IRIs in the file or :py:const:`None` if relative IRI resolution should not be done. @@ -53,7 +53,12 @@ pub fn parse( base_iri: Option<&str>, py: Python<'_>, ) -> PyResult { - let input = PyFileLike::open(input, py).map_err(map_io_err)?; + let input = if let Ok(path) = input.extract::<&str>(py) { + PyReadable::from_file(path, py) + } else { + PyReadable::from_data(input, py) + } + .map_err(map_io_err)?; if let Some(graph_format) = GraphFormat::from_media_type(mime_type) { let mut parser = GraphParser::from_format(graph_format); if let Some(base_iri) = base_iri { @@ -114,7 +119,12 @@ pub fn parse( #[pyfunction] #[pyo3(text_signature = "(input, output, /, mime_type, *, base_iri = None)")] pub fn serialize(input: &PyAny, output: PyObject, mime_type: &str, py: Python<'_>) -> PyResult<()> { - let output = PyFileLike::create(output, py).map_err(map_io_err)?; + let output = if let Ok(path) = output.extract::<&str>(py) { + PyWritable::from_file(path, py) + } else { + PyWritable::from_data(output) + } + .map_err(map_io_err)?; if let Some(graph_format) = GraphFormat::from_media_type(mime_type) { let mut writer = GraphSerializer::from_format(graph_format) .triple_writer(output) @@ -147,7 +157,7 @@ pub fn serialize(input: &PyAny, output: PyObject, mime_type: &str, py: Python<'_ #[pyclass(name = "TripleReader", module = "oxigraph")] pub struct PyTripleReader { - inner: TripleReader>, + inner: TripleReader, } #[pymethods] @@ -168,7 +178,7 @@ impl PyTripleReader { #[pyclass(name = "QuadReader", module = "oxigraph")] pub struct PyQuadReader { - inner: QuadReader>, + inner: QuadReader, } #[pymethods] @@ -187,75 +197,129 @@ impl PyQuadReader { } } -pub(crate) enum PyFileLike { - Io(PyObject), - File(File), +pub(crate) enum PyReadable { + Bytes(Cursor>), + Io(BufReader), + File(BufReader), } -impl PyFileLike { - pub fn open(inner: PyObject, py: Python<'_>) -> io::Result> { - Ok(BufReader::new(match inner.extract::<&str>(py) { - Ok(path) => Self::File(py.allow_threads(|| File::open(path))?), - Err(_) => Self::Io(inner), - })) +impl PyReadable { + pub fn from_file(file: &str, py: Python<'_>) -> io::Result { + Ok(Self::File(BufReader::new( + py.allow_threads(|| File::open(file))?, + ))) } - pub fn create(inner: PyObject, py: Python<'_>) -> io::Result> { - Ok(BufWriter::new(match inner.extract::<&str>(py) { - Ok(path) => Self::File(py.allow_threads(|| File::create(path))?), - Err(_) => Self::Io(inner), - })) + pub fn from_data(data: PyObject, py: Python<'_>) -> io::Result { + Ok(if let Ok(bytes) = data.extract::>(py) { + Self::Bytes(Cursor::new(bytes)) + } else if let Ok(string) = data.extract::(py) { + Self::Bytes(Cursor::new(string.into_bytes())) + } else { + Self::Io(BufReader::new(PyIo(data))) + }) } } -impl Read for PyFileLike { - fn read(&mut self, mut buf: &mut [u8]) -> io::Result { +impl Read for PyReadable { + fn read(&mut self, buf: &mut [u8]) -> io::Result { match self { - Self::Io(io) => { - let gil = Python::acquire_gil(); - let py = gil.python(); - let read = io - .call_method(py, "read", (buf.len(),), None) - .map_err(to_io_err)?; - let bytes: &[u8] = read.extract(py).map_err(to_io_err)?; - buf.write_all(bytes)?; - Ok(bytes.len()) - } + Self::Bytes(bytes) => bytes.read(buf), + Self::Io(io) => io.read(buf), Self::File(file) => file.read(buf), } } } -impl Write for PyFileLike { +impl BufRead for PyReadable { + fn fill_buf(&mut self) -> io::Result<&[u8]> { + match self { + Self::Bytes(bytes) => bytes.fill_buf(), + Self::Io(io) => io.fill_buf(), + Self::File(file) => file.fill_buf(), + } + } + + fn consume(&mut self, amt: usize) { + match self { + Self::Bytes(bytes) => bytes.consume(amt), + Self::Io(io) => io.consume(amt), + Self::File(file) => file.consume(amt), + } + } +} + +pub(crate) enum PyWritable { + Io(BufWriter), + File(BufWriter), +} + +impl PyWritable { + pub fn from_file(file: &str, py: Python<'_>) -> io::Result { + Ok(Self::File(BufWriter::new( + py.allow_threads(|| File::create(file))?, + ))) + } + + pub fn from_data(data: PyObject) -> io::Result { + Ok(Self::Io(BufWriter::new(PyIo(data)))) + } +} + +impl Write for PyWritable { fn write(&mut self, buf: &[u8]) -> io::Result { match self { - Self::Io(io) => { - let gil = Python::acquire_gil(); - let py = gil.python(); - usize::extract( - io.call_method(py, "write", (PyBytes::new(py, buf),), None) - .map_err(to_io_err)? - .as_ref(py), - ) - .map_err(to_io_err) - } + Self::Io(io) => io.write(buf), Self::File(file) => file.write(buf), } } fn flush(&mut self) -> io::Result<()> { match self { - Self::Io(io) => { - let gil = Python::acquire_gil(); - let py = gil.python(); - io.call_method(py, "flush", (), None)?; - Ok(()) - } + Self::Io(io) => io.flush(), Self::File(file) => file.flush(), } } } +pub(crate) struct PyIo(PyObject); + +impl Read for PyIo { + fn read(&mut self, mut buf: &mut [u8]) -> io::Result { + let gil = Python::acquire_gil(); + let py = gil.python(); + let read = self + .0 + .call_method(py, "read", (buf.len(),), None) + .map_err(to_io_err)?; + let bytes = read + .extract::<&[u8]>(py) + .or_else(|e| read.extract::<&str>(py).map(|s| s.as_bytes()).or(Err(e))) + .map_err(to_io_err)?; + buf.write_all(bytes)?; + Ok(bytes.len()) + } +} + +impl Write for PyIo { + fn write(&mut self, buf: &[u8]) -> io::Result { + let gil = Python::acquire_gil(); + let py = gil.python(); + self.0 + .call_method(py, "write", (PyBytes::new(py, buf),), None) + .map_err(to_io_err)? + .extract::(py) + .map_err(to_io_err) + } + + fn flush(&mut self) -> io::Result<()> { + let gil = Python::acquire_gil(); + let py = gil.python(); + self.0.call_method(py, "flush", (), None)?; + Ok(()) + } +} + fn to_io_err(error: impl Into) -> io::Error { io::Error::new(io::ErrorKind::Other, error.into()) } diff --git a/python/src/store.rs b/python/src/store.rs index 2ae9af0c..6775a1e3 100644 --- a/python/src/store.rs +++ b/python/src/store.rs @@ -1,6 +1,6 @@ #![allow(clippy::needless_option_as_deref)] -use crate::io::{allow_threads_unsafe, map_io_err, map_parse_error, PyFileLike}; +use crate::io::{allow_threads_unsafe, map_io_err, map_parse_error, PyReadable, PyWritable}; use crate::model::*; use crate::sparql::*; use oxigraph::io::{DatasetFormat, GraphFormat}; @@ -263,7 +263,7 @@ impl PyStore { /// and ``application/xml`` for `RDF/XML `_. /// /// :param input: The binary I/O object or file path to read from. For example, it could be a file path as a string or a file reader opened in binary mode with ``open('my_file.ttl', 'rb')``. - /// :type input: io.RawIOBase or io.BufferedIOBase or str + /// :type input: io.RawIOBase or io.BufferedIOBase or io.TextIOBase or str /// :param mime_type: the MIME type of the RDF serialization. /// :type mime_type: str /// :param base_iri: the base IRI used to resolve the relative IRIs in the file or :py:const:`None` if relative IRI resolution should not be done. @@ -293,7 +293,12 @@ impl PyStore { } else { None }; - let input = PyFileLike::open(input, py).map_err(map_io_err)?; + let input = if let Ok(path) = input.extract::<&str>(py) { + PyReadable::from_file(path, py) + } else { + PyReadable::from_data(input, py) + } + .map_err(map_io_err)?; py.allow_threads(|| { if let Some(graph_format) = GraphFormat::from_media_type(mime_type) { self.inner @@ -342,7 +347,7 @@ impl PyStore { /// and ``application/xml`` for `RDF/XML `_. /// /// :param input: The binary I/O object or file path to read from. For example, it could be a file path as a string or a file reader opened in binary mode with ``open('my_file.ttl', 'rb')``. - /// :type input: io.RawIOBase or io.BufferedIOBase or str + /// :type input: io.RawIOBase or io.BufferedIOBase or io.TextIOBase or str /// :param mime_type: the MIME type of the RDF serialization. /// :type mime_type: str /// :param base_iri: the base IRI used to resolve the relative IRIs in the file or :py:const:`None` if relative IRI resolution should not be done. @@ -372,7 +377,12 @@ impl PyStore { } else { None }; - let input = PyFileLike::open(input, py).map_err(map_io_err)?; + let input = if let Ok(path) = input.extract::<&str>(py) { + PyReadable::from_file(path, py) + } else { + PyReadable::from_data(input, py) + } + .map_err(map_io_err)?; py.allow_threads(|| { if let Some(graph_format) = GraphFormat::from_media_type(mime_type) { self.inner @@ -441,12 +451,17 @@ impl PyStore { from_graph: Option<&PyAny>, py: Python<'_>, ) -> PyResult<()> { + let output = if let Ok(path) = output.extract::<&str>(py) { + PyWritable::from_file(path, py) + } else { + PyWritable::from_data(output) + } + .map_err(map_io_err)?; let from_graph_name = if let Some(graph_name) = from_graph { Some(GraphName::from(&PyGraphNameRef::try_from(graph_name)?)) } else { None }; - let output = PyFileLike::create(output, py).map_err(map_io_err)?; py.allow_threads(|| { if let Some(graph_format) = GraphFormat::from_media_type(mime_type) { self.inner diff --git a/python/tests/test_io.py b/python/tests/test_io.py index d00c6d28..4e29bad9 100644 --- a/python/tests/test_io.py +++ b/python/tests/test_io.py @@ -1,44 +1,58 @@ import unittest -import io +from io import StringIO, BytesIO, RawIOBase +from tempfile import NamedTemporaryFile from pyoxigraph import * +EXAMPLE_TRIPLE = Triple( + NamedNode("http://example.com/foo"), + NamedNode("http://example.com/p"), + Literal("1") +) + + class TestParse(unittest.TestCase): - def test_parse(self): - input = io.BytesIO(b'

"1" .') - result = list(parse(input, "text/turtle", base_iri="http://example.com/")) + def test_parse_file(self): + with NamedTemporaryFile() as fp: + fp.write(b'

"1" .') + fp.flush() + self.assertEqual( + list(parse(fp.name, "text/turtle", base_iri="http://example.com/")), + [EXAMPLE_TRIPLE] + ) + + def test_parse_not_existing_file(self): + with self.assertRaises(IOError) as _: + parse("/tmp/not-existing-oxigraph-file.ttl", "text/turtle") + + def test_parse_str_io(self): + self.assertEqual( + list(parse(StringIO('

"1" .'), "text/turtle", base_iri="http://example.com/")), + [EXAMPLE_TRIPLE] + ) + def test_parse_bytes_io(self): self.assertEqual( - result, - [ - Triple( - NamedNode("http://example.com/foo"), - NamedNode("http://example.com/p"), - Literal( - "1", - datatype=NamedNode("http://www.w3.org/2001/XMLSchema#string"), - ), - ) - ], + list(parse(BytesIO(b'

"1" .'), "text/turtle", base_iri="http://example.com/")), + [EXAMPLE_TRIPLE] ) + + def test_parse_io_error(self): + class BadIO(RawIOBase): + pass + + with self.assertRaises(NotImplementedError) as _: + list(parse(BadIO(), mime_type="application/n-triples")) class TestSerialize(unittest.TestCase): - def test_serialize(self): - output = io.BytesIO() - serialize( - [ - Triple( - NamedNode("http://example.com"), - NamedNode("http://example.com/p"), - Literal("1"), - ) - ], - output, - "text/turtle", - ) + def test_serialize_to_bytes_io(self): + output = BytesIO() + serialize([EXAMPLE_TRIPLE], output, "text/turtle") + self.assertEqual(output.getvalue(), b' "1" .\n') - self.assertEqual( - output.getvalue(), b' "1" .\n' - ) + def test_serialize_to_file(self): + with NamedTemporaryFile() as fp: + serialize([EXAMPLE_TRIPLE], fp.name, "text/turtle") + self.assertEqual(fp.read(), b' "1" .\n')