BulkLoader: Uses thread::scope

pull/553/head
Tpt 2 years ago committed by Thomas Tanon
parent 2281575c14
commit 76deca135c
  1. 92
      lib/src/storage/mod.rs

@ -28,11 +28,7 @@ use std::path::{Path, PathBuf};
#[cfg(not(target_family = "wasm"))] #[cfg(not(target_family = "wasm"))]
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
#[cfg(not(target_family = "wasm"))] #[cfg(not(target_family = "wasm"))]
use std::sync::Arc; use std::thread;
#[cfg(not(target_family = "wasm"))]
use std::thread::spawn;
#[cfg(not(target_family = "wasm"))]
use std::thread::JoinHandle;
mod backend; mod backend;
mod binary_encoder; mod binary_encoder;
@ -1253,44 +1249,49 @@ impl StorageBulkLoader {
) )
.into()); .into());
} }
let mut threads = VecDeque::with_capacity(num_threads - 1); let done_counter = AtomicU64::new(0);
let mut buffer = Vec::with_capacity(batch_size);
let done_counter = Arc::new(AtomicU64::new(0));
let mut done_and_displayed_counter = 0; let mut done_and_displayed_counter = 0;
for quad in quads { thread::scope(|thread_scope| {
let quad = quad?; let mut threads = VecDeque::with_capacity(num_threads - 1);
buffer.push(quad); let mut buffer = Vec::with_capacity(batch_size);
if buffer.len() >= batch_size { for quad in quads {
self.spawn_load_thread( let quad = quad?;
&mut buffer, buffer.push(quad);
&mut threads, if buffer.len() >= batch_size {
&done_counter, self.spawn_load_thread(
&mut done_and_displayed_counter, &mut buffer,
num_threads, &mut threads,
batch_size, thread_scope,
)?; &done_counter,
&mut done_and_displayed_counter,
num_threads,
batch_size,
)?;
}
} }
} self.spawn_load_thread(
self.spawn_load_thread( &mut buffer,
&mut buffer, &mut threads,
&mut threads, thread_scope,
&done_counter, &done_counter,
&mut done_and_displayed_counter, &mut done_and_displayed_counter,
num_threads, num_threads,
batch_size, batch_size,
)?; )?;
for thread in threads { for thread in threads {
thread.join().unwrap()?; thread.join().unwrap()?;
self.on_possible_progress(&done_counter, &mut done_and_displayed_counter); self.on_possible_progress(&done_counter, &mut done_and_displayed_counter);
} }
Ok(()) Ok(())
})
} }
fn spawn_load_thread( fn spawn_load_thread<'scope>(
&self, &'scope self,
buffer: &mut Vec<Quad>, buffer: &mut Vec<Quad>,
threads: &mut VecDeque<JoinHandle<Result<(), StorageError>>>, threads: &mut VecDeque<thread::ScopedJoinHandle<'scope, Result<(), StorageError>>>,
done_counter: &Arc<AtomicU64>, thread_scope: &'scope thread::Scope<'scope, '_>,
done_counter: &'scope AtomicU64,
done_and_displayed_counter: &mut u64, done_and_displayed_counter: &mut u64,
num_threads: usize, num_threads: usize,
batch_size: usize, batch_size: usize,
@ -1305,10 +1306,9 @@ impl StorageBulkLoader {
} }
let mut buffer_to_load = Vec::with_capacity(batch_size); let mut buffer_to_load = Vec::with_capacity(batch_size);
swap(buffer, &mut buffer_to_load); swap(buffer, &mut buffer_to_load);
let storage = self.storage.clone(); let storage = &self.storage;
let done_counter_clone = Arc::clone(done_counter); threads.push_back(thread_scope.spawn(move || {
threads.push_back(spawn(move || { FileBulkLoader::new(storage, batch_size).load(buffer_to_load, done_counter)
FileBulkLoader::new(storage, batch_size).load(buffer_to_load, &done_counter_clone)
})); }));
Ok(()) Ok(())
} }
@ -1326,8 +1326,8 @@ impl StorageBulkLoader {
} }
#[cfg(not(target_family = "wasm"))] #[cfg(not(target_family = "wasm"))]
struct FileBulkLoader { struct FileBulkLoader<'a> {
storage: Storage, storage: &'a Storage,
id2str: HashMap<StrHash, Box<str>>, id2str: HashMap<StrHash, Box<str>>,
quads: HashSet<EncodedQuad>, quads: HashSet<EncodedQuad>,
triples: HashSet<EncodedQuad>, triples: HashSet<EncodedQuad>,
@ -1335,8 +1335,8 @@ struct FileBulkLoader {
} }
#[cfg(not(target_family = "wasm"))] #[cfg(not(target_family = "wasm"))]
impl FileBulkLoader { impl<'a> FileBulkLoader<'a> {
fn new(storage: Storage, batch_size: usize) -> Self { fn new(storage: &'a Storage, batch_size: usize) -> Self {
Self { Self {
storage, storage,
id2str: HashMap::with_capacity(3 * batch_size), id2str: HashMap::with_capacity(3 * batch_size),

Loading…
Cancel
Save