diff --git a/python/src/io.rs b/python/src/io.rs index 3a761caa..6714ccea 100644 --- a/python/src/io.rs +++ b/python/src/io.rs @@ -9,6 +9,7 @@ use pyo3::exceptions::{PyIOError, PySyntaxError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyBytes; use pyo3::wrap_pyfunction; +use std::fs::File; use std::io::{self, BufReader, BufWriter, Read, Write}; pub fn add_to_module(module: &PyModule) -> PyResult<()> { @@ -30,8 +31,8 @@ pub fn add_to_module(module: &PyModule) -> PyResult<()> { /// For example, ``application/turtle`` could also be used for `Turtle `_ /// and ``application/xml`` for `RDF/XML `_. /// -/// :param input: The binary I/O object to read from. For example, it could be a file opened in binary mode with ``open('my_file.ttl', 'rb')``. -/// :type input: io.RawIOBase or io.BufferedIOBase +/// :param input: The binary I/O object or file path to read from. For example, it could be a file path as a string or a file reader opened in binary mode with ``open('my_file.ttl', 'rb')``. +/// :type input: io.RawIOBase or io.BufferedIOBase or str /// :param mime_type: the MIME type of the RDF serialization. /// :type mime_type: str /// :param base_iri: the base IRI used to resolve the relative IRIs in the file or :py:const:`None` if relative IRI resolution should not be done. @@ -52,7 +53,7 @@ pub fn parse( base_iri: Option<&str>, py: Python<'_>, ) -> PyResult { - let input = BufReader::new(PyFileLike::new(input)); + let input = PyFileLike::open(input, py).map_err(map_io_err)?; if let Some(graph_format) = GraphFormat::from_media_type(mime_type) { let mut parser = GraphParser::from_format(graph_format); if let Some(base_iri) = base_iri { @@ -99,8 +100,8 @@ pub fn parse( /// /// :param input: the RDF triples and quads to serialize. /// :type input: iter(Triple) or iter(Quad) -/// :param output: The binary I/O object to write to. For example, it could be a file opened in binary mode with ``open('my_file.ttl', 'wb')``. -/// :type output: io.RawIOBase or io.BufferedIOBase +/// :param output: The binary I/O object or file path to write to. For example, it could be a file path as a string or a file writer opened in binary mode with ``open('my_file.ttl', 'wb')``. +/// :type output: io.RawIOBase or io.BufferedIOBase or str /// :param mime_type: the MIME type of the RDF serialization. /// :type mime_type: str /// :raises ValueError: if the MIME type is not supported. @@ -112,8 +113,8 @@ pub fn parse( /// b' "1" .\n' #[pyfunction] #[pyo3(text_signature = "(input, output, /, mime_type, *, base_iri = None)")] -pub fn serialize(input: &PyAny, output: PyObject, mime_type: &str) -> PyResult<()> { - let output = BufWriter::new(PyFileLike::new(output)); +pub fn serialize(input: &PyAny, output: PyObject, mime_type: &str, py: Python<'_>) -> PyResult<()> { + let output = PyFileLike::create(output, py).map_err(map_io_err)?; if let Some(graph_format) = GraphFormat::from_media_type(mime_type) { let mut writer = GraphSerializer::from_format(graph_format) .triple_writer(output) @@ -186,48 +187,72 @@ impl PyQuadReader { } } -pub struct PyFileLike { - inner: PyObject, +pub(crate) enum PyFileLike { + Io(PyObject), + File(File), } impl PyFileLike { - pub fn new(inner: PyObject) -> Self { - Self { inner } + pub fn open(inner: PyObject, py: Python<'_>) -> io::Result> { + Ok(BufReader::new(match inner.extract::<&str>(py) { + Ok(path) => Self::File(py.allow_threads(|| File::open(path))?), + Err(_) => Self::Io(inner), + })) + } + + pub fn create(inner: PyObject, py: Python<'_>) -> io::Result> { + Ok(BufWriter::new(match inner.extract::<&str>(py) { + Ok(path) => Self::File(py.allow_threads(|| File::create(path))?), + Err(_) => Self::Io(inner), + })) } } impl Read for PyFileLike { fn read(&mut self, mut buf: &mut [u8]) -> io::Result { - let gil = Python::acquire_gil(); - let py = gil.python(); - let read = self - .inner - .call_method(py, "read", (buf.len(),), None) - .map_err(to_io_err)?; - let bytes: &PyBytes = read.cast_as(py).map_err(to_io_err)?; - buf.write_all(bytes.as_bytes())?; - Ok(bytes.len()?) + match self { + Self::Io(io) => { + let gil = Python::acquire_gil(); + let py = gil.python(); + let read = io + .call_method(py, "read", (buf.len(),), None) + .map_err(to_io_err)?; + let bytes: &[u8] = read.extract(py).map_err(to_io_err)?; + buf.write_all(bytes)?; + Ok(bytes.len()) + } + Self::File(file) => file.read(buf), + } } } impl Write for PyFileLike { fn write(&mut self, buf: &[u8]) -> io::Result { - let gil = Python::acquire_gil(); - let py = gil.python(); - usize::extract( - self.inner - .call_method(py, "write", (PyBytes::new(py, buf),), None) - .map_err(to_io_err)? - .as_ref(py), - ) - .map_err(to_io_err) + match self { + Self::Io(io) => { + let gil = Python::acquire_gil(); + let py = gil.python(); + usize::extract( + io.call_method(py, "write", (PyBytes::new(py, buf),), None) + .map_err(to_io_err)? + .as_ref(py), + ) + .map_err(to_io_err) + } + Self::File(file) => file.write(buf), + } } fn flush(&mut self) -> io::Result<()> { - let gil = Python::acquire_gil(); - let py = gil.python(); - self.inner.call_method(py, "flush", (), None)?; - Ok(()) + match self { + Self::Io(io) => { + let gil = Python::acquire_gil(); + let py = gil.python(); + io.call_method(py, "flush", (), None)?; + Ok(()) + } + Self::File(file) => file.flush(), + } } } diff --git a/python/src/store.rs b/python/src/store.rs index f059c479..ef4b7538 100644 --- a/python/src/store.rs +++ b/python/src/store.rs @@ -1,6 +1,6 @@ #![allow(clippy::needless_option_as_deref)] -use crate::io::{allow_threads_unsafe, map_parse_error, PyFileLike}; +use crate::io::{allow_threads_unsafe, map_io_err, map_parse_error, PyFileLike}; use crate::model::*; use crate::sparql::*; use oxigraph::io::{DatasetFormat, GraphFormat}; @@ -10,7 +10,6 @@ use oxigraph::store::{self, LoaderError, SerializerError, StorageError, Store}; use pyo3::exceptions::{PyIOError, PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::{Py, PyRef}; -use std::io::{BufReader, BufWriter}; /// RDF store. /// @@ -263,8 +262,8 @@ impl PyStore { /// For example, ``application/turtle`` could also be used for `Turtle `_ /// and ``application/xml`` for `RDF/XML `_. /// - /// :param input: The binary I/O object to read from. For example, it could be a file opened in binary mode with ``open('my_file.ttl', 'rb')``. - /// :type input: io.RawIOBase or io.BufferedIOBase + /// :param input: The binary I/O object or file path to read from. For example, it could be a file path as a string or a file reader opened in binary mode with ``open('my_file.ttl', 'rb')``. + /// :type input: io.RawIOBase or io.BufferedIOBase or str /// :param mime_type: the MIME type of the RDF serialization. /// :type mime_type: str /// :param base_iri: the base IRI used to resolve the relative IRIs in the file or :py:const:`None` if relative IRI resolution should not be done. @@ -294,8 +293,8 @@ impl PyStore { } else { None }; + let input = PyFileLike::open(input, py).map_err(map_io_err)?; py.allow_threads(|| { - let input = BufReader::new(PyFileLike::new(input)); if let Some(graph_format) = GraphFormat::from_media_type(mime_type) { self.inner .load_graph( @@ -342,8 +341,8 @@ impl PyStore { /// For example, ``application/turtle`` could also be used for `Turtle `_ /// and ``application/xml`` for `RDF/XML `_. /// - /// :param input: The binary I/O object to read from. For example, it could be a file opened in binary mode with ``open('my_file.ttl', 'rb')``. - /// :type input: io.RawIOBase or io.BufferedIOBase + /// :param input: The binary I/O object or file path to read from. For example, it could be a file path as a string or a file reader opened in binary mode with ``open('my_file.ttl', 'rb')``. + /// :type input: io.RawIOBase or io.BufferedIOBase or str /// :param mime_type: the MIME type of the RDF serialization. /// :type mime_type: str /// :param base_iri: the base IRI used to resolve the relative IRIs in the file or :py:const:`None` if relative IRI resolution should not be done. @@ -373,8 +372,8 @@ impl PyStore { } else { None }; + let input = PyFileLike::open(input, py).map_err(map_io_err)?; py.allow_threads(|| { - let input = BufReader::new(PyFileLike::new(input)); if let Some(graph_format) = GraphFormat::from_media_type(mime_type) { self.inner .bulk_load_graph( @@ -416,8 +415,8 @@ impl PyStore { /// For example, ``application/turtle`` could also be used for `Turtle `_ /// and ``application/xml`` for `RDF/XML `_. /// - /// :param output: The binary I/O object to write to. For example, it could be a file opened in binary mode with ``open('my_file.ttl', 'wb')``. - /// :type input: io.RawIOBase or io.BufferedIOBase + /// :param output: The binary I/O object or file path to write to. For example, it could be a file path as a string or a file writer opened in binary mode with ``open('my_file.ttl', 'wb')``. + /// :type output: io.RawIOBase or io.BufferedIOBase or str /// :param mime_type: the MIME type of the RDF serialization. /// :type mime_type: str /// :param from_graph: if a triple based format is requested, the store graph from which dump the triples. By default, the default graph is used. @@ -445,8 +444,8 @@ impl PyStore { } else { None }; + let output = PyFileLike::create(output, py).map_err(map_io_err)?; py.allow_threads(|| { - let output = BufWriter::new(PyFileLike::new(output)); if let Some(graph_format) = GraphFormat::from_media_type(mime_type) { self.inner .dump_graph( diff --git a/python/tests/test_store.py b/python/tests/test_store.py index 6e3d1b7f..8e27178c 100644 --- a/python/tests/test_store.py +++ b/python/tests/test_store.py @@ -1,7 +1,9 @@ +import os import unittest from io import BytesIO, RawIOBase from pyoxigraph import * +from tempfile import NamedTemporaryFile foo = NamedNode("http://foo") bar = NamedNode("http://bar") @@ -221,6 +223,15 @@ class TestStore(unittest.TestCase): ) self.assertEqual(set(store), {Quad(foo, bar, baz, graph)}) + def test_load_file(self): + with NamedTemporaryFile(delete=False) as fp: + file_name = fp.name + fp.write(b" .") + store = Store() + store.load(file_name, mime_type="application/n-quads") + os.remove(file_name) + self.assertEqual(set(store), {Quad(foo, bar, baz, graph)}) + def test_load_with_io_error(self): class BadIO(RawIOBase): pass @@ -247,6 +258,19 @@ class TestStore(unittest.TestCase): b" .\n", ) + def test_dump_file(self): + with NamedTemporaryFile(delete=False) as fp: + file_name = fp.name + store = Store() + store.add(Quad(foo, bar, baz, graph)) + store.dump(file_name, "application/n-quads") + with open(file_name, 'rt') as fp: + file_content = fp.read() + self.assertEqual( + file_content, + " .\n", + ) + def test_dump_with_io_error(self): class BadIO(RawIOBase): pass