diff --git a/python/src/io.rs b/python/src/io.rs index 44cc48e9..8e24a24d 100644 --- a/python/src/io.rs +++ b/python/src/io.rs @@ -190,26 +190,13 @@ impl PyReadable { (Some(_), Some(_)) => Err(PyValueError::new_err( "input and file_path can't be both set at the same time", )), - (Some(path), None) => Ok(PyReadable::from_file(path, py)?), + (Some(path), None) => Ok(Self::File(py.allow_threads(|| File::open(path))?)), (None, Some(input)) => Ok(input.into()), (None, None) => Err(PyValueError::new_err( "Either input or file_path must be set", )), } } - pub fn from_file(file: &Path, py: Python<'_>) -> io::Result { - Ok(Self::File(py.allow_threads(|| File::open(file))?)) - } - - pub fn from_data(data: &PyAny) -> Self { - if let Ok(bytes) = data.extract::>() { - Self::Bytes(Cursor::new(bytes)) - } else if let Ok(string) = data.extract::() { - Self::Bytes(Cursor::new(string.into_bytes())) - } else { - Self::Io(PyIo(data.into())) - } - } } impl Read for PyReadable { diff --git a/python/src/sparql.rs b/python/src/sparql.rs index 207a4818..863bda9b 100644 --- a/python/src/sparql.rs +++ b/python/src/sparql.rs @@ -12,7 +12,7 @@ use oxigraph::sparql::{ Variable, }; use pyo3::basic::CompareOp; -use pyo3::exceptions::{PyRuntimeError, PySyntaxError, PyTypeError, PyValueError}; +use pyo3::exceptions::{PyRuntimeError, PySyntaxError, PyValueError}; use pyo3::prelude::*; use pyo3::types::PyBytes; use std::io; @@ -132,22 +132,13 @@ impl PyQuerySolution { self.inner.len() } - fn __getitem__(&self, input: &PyAny) -> PyResult> { - if let Ok(key) = usize::extract(input) { - Ok(self.inner.get(key).map(|term| PyTerm::from(term.clone()))) - } else if let Ok(key) = <&str>::extract(input) { - Ok(self.inner.get(key).map(|term| PyTerm::from(term.clone()))) - } else if let Ok(key) = input.extract::>() { - Ok(self - .inner - .get(<&Variable>::from(&*key)) - .map(|term| PyTerm::from(term.clone()))) - } else { - Err(PyTypeError::new_err(format!( - "{} is not an integer of a string", - input.get_type().name()?, - ))) + fn __getitem__(&self, key: PySolutionKey<'_>) -> Option { + match key { + PySolutionKey::Usize(key) => self.inner.get(key), + PySolutionKey::Str(key) => self.inner.get(key), + PySolutionKey::Variable(key) => self.inner.get(<&Variable>::from(&*key)), } + .map(|term| PyTerm::from(term.clone())) } #[allow(clippy::unnecessary_to_owned)] @@ -158,6 +149,13 @@ impl PyQuerySolution { } } +#[derive(FromPyObject)] +pub enum PySolutionKey<'a> { + Usize(usize), + Str(&'a str), + Variable(PyRef<'a, PyVariable>), +} + #[pyclass(module = "pyoxigraph")] pub struct SolutionValueIter { inner: IntoIter>, @@ -460,43 +458,42 @@ impl PyQueryTriples { /// It supports also some media type and extension aliases. /// For example, ``application/json`` could also be used for `JSON `_. /// -/// :param input: The I/O object or file path to read from. For example, it could be a file path as a string or a file reader opened in binary mode with ``open('my_file.ttl', 'rb')``. -/// :type input: typing.IO[bytes] or typing.IO[str] or str or os.PathLike[str] +/// :param input: The :py:class:`str`, :py:class:`bytes` or I/O object to read from. For example, it could be the file content as a string or a file reader opened in binary mode with ``open('my_file.ttl', 'rb')``. +/// :type input: bytes or str or typing.IO[bytes] or typing.IO[str] or None, optional /// :param format: the format of the RDF serialization using a media type like ``text/turtle`` or an extension like `ttl`. If :py:const:`None`, the format is guessed from the file name extension. /// :type format: str or None, optional +/// :param path: The file path to read from. Replaces the ``input`` parameter. +/// :type path: str or os.PathLike[str] or None, optional /// :return: an iterator of :py:class:`QuerySolution` or a :py:class:`bool`. /// :rtype: QuerySolutions or QueryBoolean /// :raises ValueError: if the format is not supported. /// :raises SyntaxError: if the provided data is invalid. /// :raises OSError: if a system error happens while reading the file. /// -/// >>> input = io.BytesIO(b'?s\t?p\t?o\n\t\t1\n') -/// >>> list(parse_query_results(input, "text/tsv")) +/// >>> list(parse_query_results('?s\t?p\t?o\n\t\t1\n', "text/tsv")) /// [ p= o=>>] /// -/// >>> input = io.BytesIO(b'{"head":{},"boolean":true}') -/// >>> parse_query_results(input, "application/sparql-results+json") +/// >>> parse_query_results('{"head":{},"boolean":true}', "application/sparql-results+json") /// #[pyfunction] -#[pyo3(signature = (input, /, format = None))] +#[pyo3(signature = (input = None, format = None, *, path = None))] pub fn parse_query_results( - input: &PyAny, + input: Option, format: Option<&str>, + path: Option, py: Python<'_>, ) -> PyResult { - let file_path = input.extract::().ok(); - let format = parse_format(format, file_path.as_deref())?; - let input = if let Some(file_path) = &file_path { - PyReadable::from_file(file_path, py)? - } else { - PyReadable::from_data(input) - }; + let input = PyReadable::from_args(&path, input, py)?; + let format = parse_format(format, path.as_deref())?; let results = QueryResultsParser::from_format(format) .parse_read(input) - .map_err(|e| map_query_results_parse_error(e, file_path.clone()))?; + .map_err(|e| map_query_results_parse_error(e, path.clone()))?; Ok(match results { FromReadQueryResultsReader::Solutions(iter) => PyQuerySolutions { - inner: PyQuerySolutionsVariant::Reader { iter, file_path }, + inner: PyQuerySolutionsVariant::Reader { + iter, + file_path: path, + }, } .into_py(py), FromReadQueryResultsReader::Boolean(inner) => PyQueryBoolean { inner }.into_py(py), diff --git a/python/tests/test_io.py b/python/tests/test_io.py index 596c6fd3..9c8d4047 100644 --- a/python/tests/test_io.py +++ b/python/tests/test_io.py @@ -202,7 +202,7 @@ class TestParseQuerySolutions(unittest.TestCase): with NamedTemporaryFile(suffix=".tsv") as fp: fp.write(b'?s\t?p\t?o\n\t\t"1"\n') fp.flush() - r = parse_query_results(fp.name) + r = parse_query_results(path=fp.name) self.assertIsInstance(r, QuerySolutions) results = list(r) # type: ignore[arg-type] self.assertEqual(results[0]["s"], NamedNode("http://example.com/s")) @@ -210,10 +210,20 @@ class TestParseQuerySolutions(unittest.TestCase): def test_parse_not_existing_file(self) -> None: with self.assertRaises(IOError) as _: - parse_query_results("/tmp/not-existing-oxigraph-file.ttl", "application/json") + parse_query_results(path="/tmp/not-existing-oxigraph-file.ttl", format="application/json") + + def test_parse_str(self) -> None: + result = parse_query_results("true", "tsv") + self.assertIsInstance(result, QueryBoolean) + self.assertTrue(result) + + def test_parse_bytes(self) -> None: + result = parse_query_results(b"false", "tsv") + self.assertIsInstance(result, QueryBoolean) + self.assertFalse(result) def test_parse_str_io(self) -> None: - result = parse_query_results(StringIO("true"), "tsv") + result = parse_query_results("true", "tsv") self.assertIsInstance(result, QueryBoolean) self.assertTrue(result) @@ -231,7 +241,7 @@ class TestParseQuerySolutions(unittest.TestCase): fp.write(b"{]") fp.flush() with self.assertRaises(SyntaxError) as ctx: - list(parse_query_results(fp.name, "srj")) # type: ignore[arg-type] + list(parse_query_results(path=fp.name, format="srj")) # type: ignore[arg-type] self.assertEqual(ctx.exception.filename, fp.name) self.assertEqual(ctx.exception.lineno, 1) self.assertEqual(ctx.exception.offset, 2) @@ -245,7 +255,7 @@ class TestParseQuerySolutions(unittest.TestCase): fp.write(b"1\t\n") fp.flush() with self.assertRaises(SyntaxError) as ctx: - list(parse_query_results(fp.name, "tsv")) # type: ignore[arg-type] + list(parse_query_results(path=fp.name, format="tsv")) # type: ignore[arg-type] self.assertEqual(ctx.exception.filename, fp.name) self.assertEqual(ctx.exception.lineno, 2) self.assertEqual(ctx.exception.offset, 3)