diff --git a/src/db_iterator.rs b/src/db_iterator.rs index 968e354..1c6111d 100644 --- a/src/db_iterator.rs +++ b/src/db_iterator.rs @@ -377,22 +377,26 @@ pub type DBIterator<'a> = DBIteratorWithThreadMode<'a, DB>; /// { /// let db = DB::open_default(path).unwrap(); /// let mut iter = db.iterator(IteratorMode::Start); // Always iterates forward -/// for (key, value) in iter { +/// for item in iter { +/// let (key, value) = item.unwrap(); /// println!("Saw {:?} {:?}", key, value); /// } /// iter = db.iterator(IteratorMode::End); // Always iterates backward -/// for (key, value) in iter { +/// for item in iter { +/// let (key, value) = item.unwrap(); /// println!("Saw {:?} {:?}", key, value); /// } /// iter = db.iterator(IteratorMode::From(b"my key", Direction::Forward)); // From a key in Direction::{forward,reverse} -/// for (key, value) in iter { +/// for item in iter { +/// let (key, value) = item.unwrap(); /// println!("Saw {:?} {:?}", key, value); /// } /// /// // You can seek with an existing Iterator instance, too /// iter = db.iterator(IteratorMode::Start); /// iter.set_mode(IteratorMode::From(b"another key", Direction::Reverse)); -/// for (key, value) in iter { +/// for item in iter { +/// let (key, value) = item.unwrap(); /// println!("Saw {:?} {:?}", key, value); /// } /// } @@ -401,9 +405,10 @@ pub type DBIterator<'a> = DBIteratorWithThreadMode<'a, DB>; pub struct DBIteratorWithThreadMode<'a, D: DBAccess> { raw: DBRawIteratorWithThreadMode<'a, D>, direction: Direction, - just_seeked: bool, + done: bool, } +#[derive(Copy, Clone)] pub enum Direction { Forward, Reverse, @@ -411,6 +416,7 @@ pub enum Direction { pub type KVBytes = (Box<[u8]>, Box<[u8]>); +#[derive(Copy, Clone)] pub enum IteratorMode<'a> { Start, End, @@ -419,13 +425,7 @@ pub enum IteratorMode<'a> { impl<'a, D: DBAccess> DBIteratorWithThreadMode<'a, D> { pub(crate) fn new(db: &D, readopts: ReadOptions, mode: IteratorMode) -> Self { - let mut rv = DBIteratorWithThreadMode { - raw: DBRawIteratorWithThreadMode::new(db, readopts), - direction: Direction::Forward, // blown away by set_mode() - just_seeked: false, - }; - rv.set_mode(mode); - rv + Self::from_raw(DBRawIteratorWithThreadMode::new(db, readopts), mode) } pub(crate) fn new_cf( @@ -434,76 +434,67 @@ impl<'a, D: DBAccess> DBIteratorWithThreadMode<'a, D> { readopts: ReadOptions, mode: IteratorMode, ) -> Self { + Self::from_raw( + DBRawIteratorWithThreadMode::new_cf(db, cf_handle, readopts), + mode, + ) + } + + fn from_raw(raw: DBRawIteratorWithThreadMode<'a, D>, mode: IteratorMode) -> Self { let mut rv = DBIteratorWithThreadMode { - raw: DBRawIteratorWithThreadMode::new_cf(db, cf_handle, readopts), + raw, direction: Direction::Forward, // blown away by set_mode() - just_seeked: false, + done: false, }; rv.set_mode(mode); rv } pub fn set_mode(&mut self, mode: IteratorMode) { - match mode { + self.done = false; + self.direction = match mode { IteratorMode::Start => { self.raw.seek_to_first(); - self.direction = Direction::Forward; + Direction::Forward } IteratorMode::End => { self.raw.seek_to_last(); - self.direction = Direction::Reverse; + Direction::Reverse } IteratorMode::From(key, Direction::Forward) => { self.raw.seek(key); - self.direction = Direction::Forward; + Direction::Forward } IteratorMode::From(key, Direction::Reverse) => { self.raw.seek_for_prev(key); - self.direction = Direction::Reverse; + Direction::Reverse } }; - - self.just_seeked = true; - } - - /// See [`valid`](DBRawIteratorWithThreadMode::valid) - pub fn valid(&self) -> bool { - self.raw.valid() - } - - /// See [`status`](DBRawIteratorWithThreadMode::status) - pub fn status(&self) -> Result<(), Error> { - self.raw.status() } } impl<'a, D: DBAccess> Iterator for DBIteratorWithThreadMode<'a, D> { - type Item = KVBytes; - - fn next(&mut self) -> Option { - if !self.raw.valid() { - return None; - } + type Item = Result; - // Initial call to next() after seeking should not move the iterator - // or the first item will not be returned - if self.just_seeked { - self.just_seeked = false; - } else { + fn next(&mut self) -> Option> { + if self.done { + None + } else if let Some((key, value)) = self.raw.item() { + let item = (Box::from(key), Box::from(value)); match self.direction { Direction::Forward => self.raw.next(), Direction::Reverse => self.raw.prev(), } - } - - if let Some((key, value)) = self.raw.item() { - Some((Box::from(key), Box::from(value))) + Some(Ok(item)) } else { - None + self.done = true; + self.raw.status().err().map(Result::Err) } } } +impl<'a, D: DBAccess> std::iter::FusedIterator for DBIteratorWithThreadMode<'a, D> {} + impl<'a, D: DBAccess> Into> for DBIteratorWithThreadMode<'a, D> { fn into(self) -> DBRawIteratorWithThreadMode<'a, D> { self.raw @@ -548,9 +539,9 @@ impl DBWALIterator { } impl Iterator for DBWALIterator { - type Item = (u64, WriteBatch); + type Item = Result<(u64, WriteBatch), Error>; - fn next(&mut self) -> Option<(u64, WriteBatch)> { + fn next(&mut self) -> Option { if !self.valid() { return None; } @@ -562,9 +553,9 @@ impl Iterator for DBWALIterator { if self.valid() { let mut seq: u64 = 0; let inner = unsafe { ffi::rocksdb_wal_iter_get_batch(self.inner, &mut seq) }; - Some((seq, WriteBatch { inner })) + Some(Ok((seq, WriteBatch { inner }))) } else { - None + self.status().err().map(Result::Err) } } } diff --git a/tests/test_db.rs b/tests/test_db.rs index 15ede99..a60b199 100644 --- a/tests/test_db.rs +++ b/tests/test_db.rs @@ -147,7 +147,7 @@ fn iterator_test() { let iter = db.iterator(IteratorMode::Start); - for (idx, (db_key, db_value)) in iter.enumerate() { + for (idx, (db_key, db_value)) in iter.map(Result::unwrap).enumerate() { let (key, value) = data[idx]; assert_eq!((&key[..], &value[..]), (db_key.as_ref(), db_value.as_ref())); } @@ -188,7 +188,7 @@ fn iterator_test_tailing() { } let mut tot = 0; - for (i, (k, v)) in tail_iter.enumerate() { + for (i, (k, v)) in tail_iter.map(Result::unwrap).enumerate() { assert_eq!( (k.to_vec(), v.to_vec()), (data[i].0.to_vec(), data[i].1.to_vec()) @@ -424,13 +424,13 @@ fn test_get_updates_since_multiple_batches() { puts: 0, deletes: 0, }; - let (seq, batch) = iter.next().unwrap(); + let (seq, batch) = iter.next().unwrap().unwrap(); assert_eq!(seq, 2); batch.iterate(&mut counts); - let (seq, batch) = iter.next().unwrap(); + let (seq, batch) = iter.next().unwrap().unwrap(); assert_eq!(seq, 3); batch.iterate(&mut counts); - let (seq, batch) = iter.next().unwrap(); + let (seq, batch) = iter.next().unwrap().unwrap(); assert_eq!(seq, 4); batch.iterate(&mut counts); assert!(iter.next().is_none()); @@ -457,7 +457,7 @@ fn test_get_updates_since_one_batch() { puts: 0, deletes: 0, }; - let (seq, batch) = iter.next().unwrap(); + let (seq, batch) = iter.next().unwrap().unwrap(); assert_eq!(seq, 2); batch.iterate(&mut counts); assert!(iter.next().is_none()); @@ -857,7 +857,7 @@ fn get_with_cache_and_bulkload_test() { // try to get key let iter = db.iterator(IteratorMode::Start); - for (expected, (k, _)) in iter.enumerate() { + for (expected, (k, _)) in iter.map(Result::unwrap).enumerate() { assert_eq!(k.as_ref(), format!("{:0>4}", expected).as_bytes()); } @@ -918,7 +918,7 @@ fn get_with_cache_and_bulkload_test() { // try to get key let iter = db.iterator(IteratorMode::Start); - for (expected, (k, _)) in iter.enumerate() { + for (expected, (k, _)) in iter.map(Result::unwrap).enumerate() { assert_eq!(k.as_ref(), format!("{:0>4}", expected).as_bytes()); } } @@ -992,7 +992,7 @@ fn get_with_cache_and_bulkload_and_blobs_test() { // try to get key let iter = db.iterator(IteratorMode::Start); - for (expected, (k, _)) in iter.enumerate() { + for (expected, (k, _)) in iter.map(Result::unwrap).enumerate() { assert_eq!(k.as_ref(), format!("{:0>4}", expected).as_bytes()); } @@ -1053,7 +1053,7 @@ fn get_with_cache_and_bulkload_and_blobs_test() { // try to get key let iter = db.iterator(IteratorMode::Start); - for (expected, (k, _)) in iter.enumerate() { + for (expected, (k, _)) in iter.map(Result::unwrap).enumerate() { assert_eq!(k.as_ref(), format!("{:0>4}", expected).as_bytes()); } } diff --git a/tests/test_iterator.rs b/tests/test_iterator.rs index c49a233..ffc6b97 100644 --- a/tests/test_iterator.rs +++ b/tests/test_iterator.rs @@ -73,29 +73,39 @@ fn test_iterator() { ); { - let iterator1 = db.iterator(IteratorMode::From(b"k0", Direction::Forward)); - assert!(iterator1.valid()); - let iterator2 = db.iterator(IteratorMode::From(b"k1", Direction::Forward)); - assert!(iterator2.valid()); - let iterator3 = db.iterator(IteratorMode::From(b"k11", Direction::Forward)); - assert!(iterator3.valid()); - let iterator4 = db.iterator(IteratorMode::From(b"k5", Direction::Forward)); - assert!(!iterator4.valid()); - let iterator5 = db.iterator(IteratorMode::From(b"k0", Direction::Reverse)); - assert!(!iterator5.valid()); - let iterator6 = db.iterator(IteratorMode::From(b"k1", Direction::Reverse)); - assert!(iterator6.valid()); - let iterator7 = db.iterator(IteratorMode::From(b"k11", Direction::Reverse)); - assert!(iterator7.valid()); - let iterator8 = db.iterator(IteratorMode::From(b"k5", Direction::Reverse)); - assert!(iterator8.valid()); + let test = |valid, key, dir| { + let mut it = db.iterator(IteratorMode::From(key, dir)); + let value = it.next(); + if valid { + assert!(matches!(value, Some(Ok(_))), "{:?}", value); + } else { + assert_eq!(None, value); + assert_eq!(None, it.next()); // Iterator is fused + } + }; + + test(true, b"k0", Direction::Forward); + test(true, b"k1", Direction::Forward); + test(true, b"k11", Direction::Forward); + test(false, b"k5", Direction::Forward); + test(false, b"k0", Direction::Reverse); + test(true, b"k1", Direction::Reverse); + test(true, b"k11", Direction::Reverse); + test(true, b"k5", Direction::Reverse); } { let mut iterator1 = db.iterator(IteratorMode::From(b"k4", Direction::Forward)); - iterator1.next(); - assert!(iterator1.valid()); - iterator1.next(); - assert!(!iterator1.valid()); + iterator1.next().unwrap().unwrap(); + assert_eq!(None, iterator1.next()); + assert_eq!(None, iterator1.next()); + } + { + // Check that set_mode resets the iterator + let mode = IteratorMode::From(K3, Direction::Forward); + let mut iterator = db.iterator(mode); + assert_iter(&mut iterator, &expected2[2..]); + iterator.set_mode(mode); + assert_iter(&mut iterator, &expected2[2..]); } } } @@ -217,6 +227,7 @@ fn test_full_iterator() { fn custom_iter(db: &'_ DB) -> impl Iterator + '_ { db.iterator(IteratorMode::Start) + .map(Result::unwrap) .map(|(_, db_value)| db_value.len()) } @@ -275,6 +286,7 @@ fn test_iter_range() { ro.set_iterate_range(range); let got = db .iterator_opt(mode, ro) + .map(Result::unwrap) .map(|(key, _value)| key) .collect::>(); let mut got = got.iter().map(Box::as_ref).collect::>(); diff --git a/tests/util/mod.rs b/tests/util/mod.rs index c25118c..6a077e8 100644 --- a/tests/util/mod.rs +++ b/tests/util/mod.rs @@ -2,7 +2,7 @@ use std::path::{Path, PathBuf}; -use rocksdb::{Options, DB}; +use rocksdb::{Error, Options, DB}; /// Temporary database path which calls DB::Destroy when DBPath is dropped. pub struct DBPath { @@ -47,19 +47,14 @@ pub fn pair(left: &[u8], right: &[u8]) -> Pair { } #[track_caller] -pub fn assert_iter( - iter: rocksdb::DBIteratorWithThreadMode<'_, D>, - want: &[Pair], -) { - assert_eq!(iter.collect::>().as_slice(), want); +pub fn assert_iter(iter: impl Iterator>, want: &[Pair]) { + let got = iter.collect::, _>>().unwrap(); + assert_eq!(got.as_slice(), want); } #[track_caller] -pub fn assert_iter_reversed( - iter: rocksdb::DBIteratorWithThreadMode<'_, D>, - want: &[Pair], -) { - let mut got = iter.collect::>(); +pub fn assert_iter_reversed(iter: impl Iterator>, want: &[Pair]) { + let mut got = iter.collect::, _>>().unwrap(); got.reverse(); assert_eq!(got.as_slice(), want); }