From be51f9035290696b9bc571e1a8b0b470a780eae3 Mon Sep 17 00:00:00 2001 From: Tpt Date: Sat, 5 Nov 2022 16:40:26 +0100 Subject: [PATCH] Server: Uses PathBuf instead of string for I/O Safer with paths that are not valid UTF-8 --- server/src/main.rs | 30 +++++++++++++++++------------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/server/src/main.rs b/server/src/main.rs index d4149b6c..13fc08a1 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -12,10 +12,11 @@ use rayon_core::ThreadPoolBuilder; use sparesults::{QueryResultsFormat, QueryResultsSerializer}; use std::cell::RefCell; use std::cmp::{max, min}; +use std::ffi::OsStr; use std::fmt; use std::fs::File; use std::io::{self, BufReader, Error, ErrorKind, Read, Write}; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::rc::Rc; use std::str::FromStr; use std::thread::available_parallelism; @@ -52,7 +53,7 @@ enum Command { /// /// If multiple files are provided they are loaded in parallel. #[arg(short, long, global = true, num_args = 0..)] - file: Vec, + file: Vec, /// Attempt to keep loading even if the data file is invalid. /// /// Only works with N-Triples and N-Quads for now. @@ -79,7 +80,6 @@ pub fn main() -> io::Result<()> { .scope(|s| { for file in file { let store = store.clone(); - let file = file.to_string(); s.spawn(move |_| { let f = file.clone(); let start = Instant::now(); @@ -90,35 +90,39 @@ pub fn main() -> io::Result<()> { size, elapsed.as_secs(), ((size as f64) / elapsed.as_secs_f64()).round(), - f + f.display() ) }); if lenient { let f = file.clone(); loader = loader.on_parse_error(move |e| { - eprintln!("Parsing error on file {}: {}", f, e); + eprintln!("Parsing error on file {}: {}", f.display(), e); Ok(()) }) } let fp = match File::open(&file) { Ok(fp) => fp, Err(error) => { - eprintln!("Error while opening file {}: {}", file, error); + eprintln!( + "Error while opening file {}: {}", + file.display(), + error + ); return; } }; if let Err(error) = { - if file.ends_with(".gz") { + if file.extension().map_or(false, |e| e == OsStr::new("gz")) { bulk_load( loader, - &file[..file.len() - 3], MultiGzDecoder::new(fp), + &file.with_extension(""), ) } else { - bulk_load(loader, &file, fp) + bulk_load(loader, fp, &file) } } { - eprintln!("Error while loading file {}: {}", file, error) + eprintln!("Error while loading file {}: {}", file.display(), error) } }) } @@ -138,10 +142,10 @@ pub fn main() -> io::Result<()> { } } -fn bulk_load(loader: BulkLoader, file: &str, reader: impl Read) -> io::Result<()> { - let (_, extension) = file.rsplit_once('.').ok_or_else(|| Error::new( +fn bulk_load(loader: BulkLoader, reader: impl Read, file: &Path) -> io::Result<()> { + let extension = file.extension().and_then(|extension| extension.to_str()).ok_or_else(|| Error::new( ErrorKind::InvalidInput, - format!("The server is not able to guess the file format of {} because the file name as no extension", file)))?; + format!("The server is not able to guess the file format of {} because the file name as no extension", file.display())))?; let reader = BufReader::new(reader); if let Some(format) = DatasetFormat::from_extension(extension) { loader.load_dataset(reader, format, None)?;