diff --git a/python/src/io.rs b/python/src/io.rs index c3032990..681a25fa 100644 --- a/python/src/io.rs +++ b/python/src/io.rs @@ -8,7 +8,8 @@ use oxigraph::io::{ use pyo3::exceptions::{PyIOError, PySyntaxError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyBytes; -use pyo3::wrap_pyfunction; +use pyo3::{intern, wrap_pyfunction}; +use std::cmp::max; use std::error::Error; use std::fs::File; use std::io::{self, BufRead, BufReader, BufWriter, Cursor, Read, Write}; @@ -282,17 +283,22 @@ impl Write for PyWritable { pub struct PyIo(PyObject); impl Read for PyIo { - fn read(&mut self, mut buf: &mut [u8]) -> io::Result { + fn read(&mut self, buf: &mut [u8]) -> io::Result { Python::with_gil(|py| { + if buf.is_empty() { + return Ok(0); + } + let to_read = max(1, buf.len() / 4); // We divide by 4 because TextIO works with number of characters and not with number of bytes let read = self .0 - .call_method(py, "read", (buf.len(),), None) + .as_ref(py) + .call_method1(intern!(py, "read"), (to_read,)) .map_err(to_io_err)?; let bytes = read - .extract::<&[u8]>(py) - .or_else(|e| read.extract::<&str>(py).map(str::as_bytes).map_err(|_| e)) + .extract::<&[u8]>() + .or_else(|e| read.extract::<&str>().map(str::as_bytes).map_err(|_| e)) .map_err(to_io_err)?; - buf.write_all(bytes)?; + buf[..bytes.len()].copy_from_slice(bytes); Ok(bytes.len()) }) } @@ -302,16 +308,17 @@ impl Write for PyIo { fn write(&mut self, buf: &[u8]) -> io::Result { Python::with_gil(|py| { self.0 - .call_method(py, "write", (PyBytes::new(py, buf),), None) + .as_ref(py) + .call_method1(intern!(py, "write"), (PyBytes::new(py, buf),)) .map_err(to_io_err)? - .extract::(py) + .extract::() .map_err(to_io_err) }) } fn flush(&mut self) -> io::Result<()> { Python::with_gil(|py| { - self.0.call_method(py, "flush", (), None)?; + self.0.as_ref(py).call_method0(intern!(py, "flush"))?; Ok(()) }) } diff --git a/python/tests/test_io.py b/python/tests/test_io.py index 5dda57ca..e7519f5d 100644 --- a/python/tests/test_io.py +++ b/python/tests/test_io.py @@ -5,7 +5,9 @@ from tempfile import NamedTemporaryFile, TemporaryFile from pyoxigraph import Literal, NamedNode, Quad, Triple, parse, serialize EXAMPLE_TRIPLE = Triple( - NamedNode("http://example.com/foo"), NamedNode("http://example.com/p"), Literal("1") + NamedNode("http://example.com/foo"), + NamedNode("http://example.com/p"), + Literal("éù"), ) EXAMPLE_QUAD = Quad( NamedNode("http://example.com/foo"), @@ -18,7 +20,7 @@ EXAMPLE_QUAD = Quad( class TestParse(unittest.TestCase): def test_parse_file(self) -> None: with NamedTemporaryFile() as fp: - fp.write(b'

"1" .') + fp.write('

"éù" .'.encode()) fp.flush() self.assertEqual( list(parse(fp.name, "text/turtle", base_iri="http://example.com/")), @@ -33,7 +35,7 @@ class TestParse(unittest.TestCase): self.assertEqual( list( parse( - StringIO('

"1" .'), + StringIO('

"éù" .'), "text/turtle", base_iri="http://example.com/", ) @@ -41,11 +43,23 @@ class TestParse(unittest.TestCase): [EXAMPLE_TRIPLE], ) + def test_parse_long_str_io(self) -> None: + self.assertEqual( + list( + parse( + StringIO('

"éù" .\n' * 1024), + "text/turtle", + base_iri="http://example.com/", + ) + ), + [EXAMPLE_TRIPLE] * 1024, + ) + def test_parse_bytes_io(self) -> None: self.assertEqual( list( parse( - BytesIO(b'

"1" .'), + BytesIO('

"éù" .'.encode()), "text/turtle", base_iri="http://example.com/", ) @@ -75,15 +89,16 @@ class TestSerialize(unittest.TestCase): output = BytesIO() serialize([EXAMPLE_TRIPLE], output, "text/turtle") self.assertEqual( - output.getvalue(), - b' "1" .\n', + output.getvalue().decode(), + ' "éù" .\n', ) def test_serialize_to_file(self) -> None: with NamedTemporaryFile() as fp: serialize([EXAMPLE_TRIPLE], fp.name, "text/turtle") self.assertEqual( - fp.read(), b' "1" .\n' + fp.read().decode(), + ' "éù" .\n', ) def test_serialize_io_error(self) -> None: