From 7b1c4e0ad570da114ae20d98c5de1d9263712d37 Mon Sep 17 00:00:00 2001 From: Tpt Date: Sun, 21 Nov 2021 18:12:04 +0100 Subject: [PATCH] Returns a clean error on transaction read after commit --- lib/src/storage/backend/rocksdb.rs | 43 ++++++++++++++++++++++-------- lib/src/storage/mod.rs | 6 ++--- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/lib/src/storage/backend/rocksdb.rs b/lib/src/storage/backend/rocksdb.rs index 5ded1a00..4ea7a805 100644 --- a/lib/src/storage/backend/rocksdb.rs +++ b/lib/src/storage/backend/rocksdb.rs @@ -542,6 +542,9 @@ impl Reader { key.len() )), InnerReader::Transaction(inner) => { + if inner.is_ended.get() { + return Err(invalid_input_error("The transaction is already ended")); + } ffi_result!(rocksdb_transaction_get_pinned_cf( inner.transaction, self.options, @@ -564,12 +567,12 @@ impl Reader { } #[must_use] - pub fn iter(&self, column_family: &ColumnFamily) -> Iter { + pub fn iter(&self, column_family: &ColumnFamily) -> Result { self.scan_prefix(column_family, &[]) } #[must_use] - pub fn scan_prefix(&self, column_family: &ColumnFamily, prefix: &[u8]) -> Iter { + pub fn scan_prefix(&self, column_family: &ColumnFamily, prefix: &[u8]) -> Result { //We generate the upper bound let upper_bound = { let mut bound = prefix.to_vec(); @@ -608,11 +611,16 @@ impl Reader { InnerReader::Snapshot(inner) => { rocksdb_transactiondb_create_iterator_cf(inner.db.db, options, column_family.0) } - InnerReader::Transaction(inner) => rocksdb_transaction_create_iterator_cf( - inner.transaction, - options, - column_family.0, - ), + InnerReader::Transaction(inner) => { + if inner.is_ended.get() { + return Err(invalid_input_error("The transaction is already ended")); + } + rocksdb_transaction_create_iterator_cf( + inner.transaction, + options, + column_family.0, + ) + } }; assert!(!iter.is_null(), "rocksdb_create_iterator returned null"); if prefix.is_empty() { @@ -621,19 +629,19 @@ impl Reader { rocksdb_iter_seek(iter, prefix.as_ptr() as *const c_char, prefix.len()); } let is_currently_valid = rocksdb_iter_valid(iter) != 0; - Iter { + Ok(Iter { iter, options, _upper_bound: upper_bound, _reader: self.clone(), is_currently_valid, - } + }) } } pub fn len(&self, column_family: &ColumnFamily) -> Result { let mut count = 0; - let mut iter = self.iter(column_family); + let mut iter = self.iter(column_family)?; while iter.is_valid() { count += 1; iter.next(); @@ -643,7 +651,7 @@ impl Reader { } pub fn is_empty(&self, column_family: &ColumnFamily) -> Result { - let iter = self.iter(column_family); + let iter = self.iter(column_family)?; iter.status()?; // We makes sure there is no read problem Ok(!iter.is_valid()) } @@ -1095,3 +1103,16 @@ fn path_to_cstring(path: &Path) -> Result { ) .map_err(invalid_input_error) } + +#[test] +fn test_transaction_read_after_commit() -> Result<()> { + let db = Db::new(vec![])?; + let cf = db.column_family("default").unwrap(); + let mut tr = db.transaction(); + let reader = tr.reader(); + tr.insert(&cf, b"test", b"foo")?; + assert_eq!(reader.get(&cf, b"test")?.as_deref(), Some(b"foo".as_ref())); + tr.commit()?; + assert!(reader.get(&cf, b"test").is_err()); + Ok(()) +} diff --git a/lib/src/storage/mod.rs b/lib/src/storage/mod.rs index 331166f2..f2940710 100644 --- a/lib/src/storage/mod.rs +++ b/lib/src/storage/mod.rs @@ -244,7 +244,7 @@ impl Storage { let mut transaction = this.db.transaction(); let reader = this.db.snapshot(); let mut size = 0; - let mut iter = reader.iter(&this.id2str_cf); + let mut iter = reader.iter(&this.id2str_cf)?; while let (Some(key), Some(value)) = (iter.key(), iter.value()) { let mut new_value = Vec::with_capacity(value.len() + 4); new_value.extend_from_slice(&i32::MAX.to_be_bytes()); @@ -615,7 +615,7 @@ impl StorageReader { pub fn named_graphs(&self) -> DecodingGraphIterator { DecodingGraphIterator { - iter: self.reader.iter(&self.storage.graphs_cf), + iter: self.reader.iter(&self.storage.graphs_cf).unwrap(), //TODO: propagate error? } } @@ -667,7 +667,7 @@ impl StorageReader { encoding: QuadEncoding, ) -> DecodingQuadIterator { DecodingQuadIterator { - iter: self.reader.scan_prefix(column_family, prefix), + iter: self.reader.scan_prefix(column_family, prefix).unwrap(), // TODO: propagate error? encoding, } }