diff --git a/server/src/main.rs b/server/src/main.rs index 326d46c8..49f593c0 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -1,4 +1,4 @@ -use anyhow::{anyhow, bail}; +use anyhow::bail; use clap::{Parser, Subcommand}; use flate2::read::MultiGzDecoder; use oxhttp::model::{Body, HeaderName, HeaderValue, Request, Response, Status}; @@ -116,10 +116,17 @@ pub fn main() -> anyhow::Result<()> { bulk_load( loader, MultiGzDecoder::new(fp), - &file.with_extension(""), + GraphOrDatasetFormat::from_path(&file.with_extension("")) + .unwrap(), + None, ) } else { - bulk_load(loader, fp, &file) + bulk_load( + loader, + fp, + GraphOrDatasetFormat::from_path(&file).unwrap(), + None, + ) } } { eprintln!("Error while loading file {}: {}", file.display(), error) @@ -142,28 +149,71 @@ pub fn main() -> anyhow::Result<()> { } } -fn bulk_load(loader: BulkLoader, reader: impl Read, file: &Path) -> anyhow::Result<()> { - let extension = file - .extension() - .and_then(|extension| extension.to_str()) - .ok_or_else(|| { - anyhow!( - "Not able to guess the file format of {} because the file name as no extension", - file.display() - ) - })?; +fn bulk_load( + loader: BulkLoader, + reader: impl Read, + format: GraphOrDatasetFormat, + base_iri: Option<&str>, +) -> anyhow::Result<()> { let reader = BufReader::new(reader); - if let Some(format) = DatasetFormat::from_extension(extension) { - loader.load_dataset(reader, format, None)?; - Ok(()) - } else if let Some(format) = GraphFormat::from_extension(extension) { - loader.load_graph(reader, format, GraphNameRef::DefaultGraph, None)?; - Ok(()) - } else { - bail!( - "Not able to guess the file format from the extension {}", - extension - ) + match format { + GraphOrDatasetFormat::Graph(format) => { + loader.load_graph(reader, format, GraphNameRef::DefaultGraph, base_iri) + } + GraphOrDatasetFormat::Dataset(format) => loader.load_dataset(reader, format, base_iri), + }?; + Ok(()) +} + +#[derive(Copy, Clone)] +enum GraphOrDatasetFormat { + Graph(GraphFormat), + Dataset(DatasetFormat), +} + +impl GraphOrDatasetFormat { + fn from_path(path: &Path) -> anyhow::Result { + if let Some(ext) = path.extension().and_then(|ext| ext.to_str()) { + Self::from_name(ext).map_err(|e| { + e.context(format!( + "Not able to guess the file format from file name extension '{}'", + ext + )) + }) + } else { + bail!( + "The path {} has no extension to guess a file format from", + path.display() + ) + } + } + + fn from_name(name: &str) -> anyhow::Result { + let mut candidates = Vec::with_capacity(4); + if let Some(f) = GraphFormat::from_extension(name) { + candidates.push(GraphOrDatasetFormat::Graph(f)); + } + if let Some(f) = DatasetFormat::from_extension(name) { + candidates.push(GraphOrDatasetFormat::Dataset(f)); + } + if let Some(f) = GraphFormat::from_media_type(name) { + candidates.push(GraphOrDatasetFormat::Graph(f)); + } + if let Some(f) = DatasetFormat::from_media_type(name) { + candidates.push(GraphOrDatasetFormat::Dataset(f)); + } + if candidates.is_empty() { + bail!("The format '{}' is unknown", name) + } else if candidates.len() == 1 { + Ok(candidates[0]) + } else { + bail!("The format '{}' can be resolved to multiple known formats, not sure what to pick ({})", name, candidates.iter().fold(String::new(), |a, f| { + a + " " + match f { + GraphOrDatasetFormat::Graph(f) => f.file_extension(), + GraphOrDatasetFormat::Dataset(f) => f.file_extension(), + } + }).trim()) + } } } @@ -1093,6 +1143,8 @@ mod tests { use anyhow::Result; use assert_cmd::Command; use assert_fs::prelude::*; + use flate2::write::GzEncoder; + use flate2::Compression; use oxhttp::model::Method; use predicates::prelude::*; @@ -1146,6 +1198,24 @@ mod tests { Ok(()) } + #[test] + fn cli_load_gzip_dataset() -> Result<()> { + let file = assert_fs::NamedTempFile::new("sample.nq.gz")?; + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder + .write_all(b" .")?; + file.write_binary(&encoder.finish()?)?; + cli_command()? + .arg("load") + .arg("-f") + .arg(file.path()) + .assert() + .success() + .stdout("") + .stderr(predicate::str::starts_with("1 triples loaded")); + Ok(()) + } + #[test] fn get_ui() { ServerTest::new().test_status(