Python: let the underlying Python errors go through Oxigraph

pull/190/head
Tpt 3 years ago
parent 4efd193708
commit bfac7d3bbf
  1. 30
      python/src/io.rs
  2. 18
      python/tests/test_store.py

@ -199,8 +199,8 @@ impl Read for PyFileLike {
let read = self
.inner
.call_method(py, "read", (buf.len(),), None)
.map_err(|e| to_io_err(e, py))?;
let bytes: &PyBytes = read.cast_as(py).map_err(|e| to_io_err(e, py))?;
.map_err(to_io_err)?;
let bytes: &PyBytes = read.cast_as(py).map_err(to_io_err)?;
buf.write_all(bytes.as_bytes())?;
Ok(bytes.len()?)
}
@ -213,10 +213,10 @@ impl Write for PyFileLike {
usize::extract(
self.inner
.call_method(py, "write", (PyBytes::new(py, buf),), None)
.map_err(|e| to_io_err(e, py))?
.map_err(to_io_err)?
.as_ref(py),
)
.map_err(|e| to_io_err(e, py))
.map_err(to_io_err)
}
fn flush(&mut self) -> io::Result<()> {
@ -227,24 +227,16 @@ impl Write for PyFileLike {
}
}
fn to_io_err(error: impl Into<PyErr>, py: Python<'_>) -> io::Error {
if let Ok(message) = error
.into()
.to_object(py)
.call_method(py, "__str__", (), None)
{
if let Ok(message) = message.extract::<String>(py) {
io::Error::new(io::ErrorKind::Other, message)
} else {
io::Error::new(io::ErrorKind::Other, "An unknown error has occurred")
}
} else {
io::Error::new(io::ErrorKind::Other, "An unknown error has occurred")
}
fn to_io_err(error: impl Into<PyErr>) -> io::Error {
io::Error::new(io::ErrorKind::Other, error.into())
}
pub(crate) fn map_io_err(error: io::Error) -> PyErr {
PyIOError::new_err(error.to_string())
if error.get_ref().map_or(false, |s| s.is::<PyErr>()) {
*error.into_inner().unwrap().downcast().unwrap()
} else {
PyIOError::new_err(error.to_string())
}
}
pub(crate) fn map_parse_error(error: ParseError) -> PyErr {

@ -1,5 +1,5 @@
import unittest
from io import BytesIO
from io import BytesIO, RawIOBase
from pyoxigraph import *
@ -103,7 +103,7 @@ class TestStore(unittest.TestCase):
self.assertEqual(solution["o"], baz)
self.assertEqual(solution[Variable("s")], foo)
self.assertEqual(solution[Variable("o")], baz)
s,o = solution
s, o = solution
self.assertEqual(s, foo)
self.assertEqual(o, baz)
@ -221,6 +221,13 @@ class TestStore(unittest.TestCase):
)
self.assertEqual(set(store), {Quad(foo, bar, baz, graph)})
def test_load_with_io_error(self):
class BadIO(RawIOBase):
pass
with self.assertRaises(NotImplementedError) as _:
Store().load(BadIO(), mime_type="application/n-triples")
def test_dump_ntriples(self):
store = Store()
store.add(Quad(foo, bar, baz, graph))
@ -240,6 +247,13 @@ class TestStore(unittest.TestCase):
b"<http://foo> <http://bar> <http://baz> <http://graph> .\n",
)
def test_dump_with_io_error(self):
class BadIO(RawIOBase):
pass
with self.assertRaises(OSError) as _:
Store().dump(BadIO(), mime_type="application/rdf+xml")
def test_write_in_read(self):
store = Store()
store.add(Quad(foo, bar, bar))

Loading…
Cancel
Save