diff --git a/python/src/io.rs b/python/src/io.rs index 34c00828..e0bda5cd 100644 --- a/python/src/io.rs +++ b/python/src/io.rs @@ -199,8 +199,8 @@ impl Read for PyFileLike { let read = self .inner .call_method(py, "read", (buf.len(),), None) - .map_err(|e| to_io_err(e, py))?; - let bytes: &PyBytes = read.cast_as(py).map_err(|e| to_io_err(e, py))?; + .map_err(to_io_err)?; + let bytes: &PyBytes = read.cast_as(py).map_err(to_io_err)?; buf.write_all(bytes.as_bytes())?; Ok(bytes.len()?) } @@ -213,10 +213,10 @@ impl Write for PyFileLike { usize::extract( self.inner .call_method(py, "write", (PyBytes::new(py, buf),), None) - .map_err(|e| to_io_err(e, py))? + .map_err(to_io_err)? .as_ref(py), ) - .map_err(|e| to_io_err(e, py)) + .map_err(to_io_err) } fn flush(&mut self) -> io::Result<()> { @@ -227,24 +227,16 @@ impl Write for PyFileLike { } } -fn to_io_err(error: impl Into, py: Python<'_>) -> io::Error { - if let Ok(message) = error - .into() - .to_object(py) - .call_method(py, "__str__", (), None) - { - if let Ok(message) = message.extract::(py) { - io::Error::new(io::ErrorKind::Other, message) - } else { - io::Error::new(io::ErrorKind::Other, "An unknown error has occurred") - } - } else { - io::Error::new(io::ErrorKind::Other, "An unknown error has occurred") - } +fn to_io_err(error: impl Into) -> io::Error { + io::Error::new(io::ErrorKind::Other, error.into()) } pub(crate) fn map_io_err(error: io::Error) -> PyErr { - PyIOError::new_err(error.to_string()) + if error.get_ref().map_or(false, |s| s.is::()) { + *error.into_inner().unwrap().downcast().unwrap() + } else { + PyIOError::new_err(error.to_string()) + } } pub(crate) fn map_parse_error(error: ParseError) -> PyErr { diff --git a/python/tests/test_store.py b/python/tests/test_store.py index 8df9330a..6e3d1b7f 100644 --- a/python/tests/test_store.py +++ b/python/tests/test_store.py @@ -1,5 +1,5 @@ import unittest -from io import BytesIO +from io import BytesIO, RawIOBase from pyoxigraph import * @@ -103,7 +103,7 @@ class TestStore(unittest.TestCase): self.assertEqual(solution["o"], baz) self.assertEqual(solution[Variable("s")], foo) self.assertEqual(solution[Variable("o")], baz) - s,o = solution + s, o = solution self.assertEqual(s, foo) self.assertEqual(o, baz) @@ -221,6 +221,13 @@ class TestStore(unittest.TestCase): ) self.assertEqual(set(store), {Quad(foo, bar, baz, graph)}) + def test_load_with_io_error(self): + class BadIO(RawIOBase): + pass + + with self.assertRaises(NotImplementedError) as _: + Store().load(BadIO(), mime_type="application/n-triples") + def test_dump_ntriples(self): store = Store() store.add(Quad(foo, bar, baz, graph)) @@ -240,6 +247,13 @@ class TestStore(unittest.TestCase): b" .\n", ) + def test_dump_with_io_error(self): + class BadIO(RawIOBase): + pass + + with self.assertRaises(OSError) as _: + Store().dump(BadIO(), mime_type="application/rdf+xml") + def test_write_in_read(self): store = Store() store.add(Quad(foo, bar, bar))