From f46464d38363c1efe52fc7ba41b43dfbf75625c3 Mon Sep 17 00:00:00 2001 From: Maysam Yabandeh Date: Mon, 11 Sep 2017 08:58:52 -0700 Subject: [PATCH] write-prepared txn: call IsInSnapshot Summary: This patch instruments the read path to verify each read value against an optional ReadCallback class. If the value is rejected, the reader moves on to the next value. The WritePreparedTxn makes use of this feature to skip sequence numbers that are not in the read snapshot. Closes https://github.com/facebook/rocksdb/pull/2850 Differential Revision: D5787375 Pulled By: maysamyabandeh fbshipit-source-id: 49d808b3062ab35e7ae98ad388f659757794184c --- db/db_impl.cc | 13 ++-- db/db_impl.h | 5 +- db/db_test2.cc | 77 +++++++++++++++++++ db/memtable.cc | 20 ++++- db/memtable.h | 9 ++- db/memtable_list.cc | 21 ++--- db/memtable_list.h | 11 ++- db/read_callback.h | 21 +++++ db/version_set.cc | 6 +- db/version_set.h | 4 +- .../utilities/write_batch_with_index.h | 5 ++ table/get_context.cc | 23 +++--- table/get_context.h | 12 ++- .../transactions/pessimistic_transaction_db.h | 17 ++++ utilities/transactions/write_prepared_txn.cc | 12 +++ utilities/transactions/write_prepared_txn.h | 5 ++ .../write_batch_with_index.cc | 18 ++++- 17 files changed, 236 insertions(+), 43 deletions(-) create mode 100644 db/read_callback.h diff --git a/db/db_impl.cc b/db/db_impl.cc index 0e1a821b0..d1a343a76 100644 --- a/db/db_impl.cc +++ b/db/db_impl.cc @@ -11,14 +11,12 @@ #ifndef __STDC_FORMAT_MACROS #define __STDC_FORMAT_MACROS #endif -#include #include #ifdef OS_SOLARIS #include #endif #include -#include #include #include #include @@ -63,7 +61,6 @@ #include "options/cf_options.h" #include "options/options_helper.h" #include "options/options_parser.h" -#include "port/likely.h" #include "port/port.h" #include "rocksdb/cache.h" #include "rocksdb/compaction_filter.h" @@ -74,7 +71,6 @@ #include "rocksdb/statistics.h" #include "rocksdb/status.h" #include "rocksdb/table.h" -#include "rocksdb/version.h" #include "rocksdb/write_buffer_manager.h" #include "table/block.h" #include "table/block_based_table_factory.h" @@ -909,7 +905,8 @@ Status DBImpl::Get(const ReadOptions& read_options, Status DBImpl::GetImpl(const ReadOptions& read_options, ColumnFamilyHandle* column_family, const Slice& key, - PinnableSlice* pinnable_val, bool* value_found) { + PinnableSlice* pinnable_val, bool* value_found, + ReadCallback* callback) { assert(pinnable_val != nullptr); StopWatch sw(env_, stats_, DB_GET); PERF_TIMER_GUARD(get_snapshot_time); @@ -959,13 +956,13 @@ Status DBImpl::GetImpl(const ReadOptions& read_options, bool done = false; if (!skip_memtable) { if (sv->mem->Get(lkey, pinnable_val->GetSelf(), &s, &merge_context, - &range_del_agg, read_options)) { + &range_del_agg, read_options, callback)) { done = true; pinnable_val->PinSelf(); RecordTick(stats_, MEMTABLE_HIT); } else if ((s.ok() || s.IsMergeInProgress()) && sv->imm->Get(lkey, pinnable_val->GetSelf(), &s, &merge_context, - &range_del_agg, read_options)) { + &range_del_agg, read_options, callback)) { done = true; pinnable_val->PinSelf(); RecordTick(stats_, MEMTABLE_HIT); @@ -977,7 +974,7 @@ Status DBImpl::GetImpl(const ReadOptions& read_options, if (!done) { PERF_TIMER_GUARD(get_from_output_files_time); sv->current->Get(read_options, lkey, pinnable_val, &s, &merge_context, - &range_del_agg, value_found); + &range_del_agg, value_found, nullptr, nullptr, callback); RecordTick(stats_, MEMTABLE_MISS); } diff --git a/db/db_impl.h b/db/db_impl.h index 5115ac6f1..c5a7e2493 100644 --- a/db/db_impl.h +++ b/db/db_impl.h @@ -28,6 +28,7 @@ #include "db/flush_scheduler.h" #include "db/internal_stats.h" #include "db/log_writer.h" +#include "db/read_callback.h" #include "db/snapshot_impl.h" #include "db/version_edit.h" #include "db/wal_manager.h" @@ -634,10 +635,12 @@ class DBImpl : public DB { private: friend class DB; + friend class DBTest2_ReadCallbackTest_Test; friend class InternalStats; friend class PessimisticTransaction; friend class WriteCommittedTxn; friend class WritePreparedTxn; + friend class WriteBatchWithIndex; #ifndef ROCKSDB_LITE friend class ForwardIterator; #endif @@ -1244,7 +1247,7 @@ class DBImpl : public DB { // Note: 'value_found' from KeyMayExist propagates here Status GetImpl(const ReadOptions& options, ColumnFamilyHandle* column_family, const Slice& key, PinnableSlice* value, - bool* value_found = nullptr); + bool* value_found = nullptr, ReadCallback* callback = nullptr); bool GetIntPropertyInternal(ColumnFamilyData* cfd, const DBPropertyInfo& property_info, diff --git a/db/db_test2.cc b/db/db_test2.cc index 02923cebc..4b635f7fd 100644 --- a/db/db_test2.cc +++ b/db/db_test2.cc @@ -11,6 +11,7 @@ #include #include "db/db_test_util.h" +#include "db/read_callback.h" #include "port/port.h" #include "port/stack_trace.h" #include "rocksdb/persistent_cache.h" @@ -2325,6 +2326,82 @@ TEST_F(DBTest2, ReduceLevel) { Reopen(options); ASSERT_EQ("0,1", FilesPerLevel()); } + +// Test that ReadCallback is actually used in both memtbale and sst tables +TEST_F(DBTest2, ReadCallbackTest) { + Options options; + options.disable_auto_compactions = true; + options.num_levels = 7; + Reopen(options); + std::vector snapshots; + // Try to create a db with multiple layers and a memtable + const std::string key = "foo"; + const std::string value = "bar"; + // This test assumes that the seq start with 1 and increased by 1 after each + // write batch of size 1. If that behavior changes, the test needs to be + // updated as well. + // TODO(myabandeh): update this test to use the seq number that is returned by + // the DB instead of assuming what seq the DB used. + int i = 1; + for (; i < 10; i++) { + Put(key, value + std::to_string(i)); + // Take a snapshot to avoid the value being removed during compaction + auto snapshot = dbfull()->GetSnapshot(); + snapshots.push_back(snapshot); + } + Flush(); + for (; i < 20; i++) { + Put(key, value + std::to_string(i)); + // Take a snapshot to avoid the value being removed during compaction + auto snapshot = dbfull()->GetSnapshot(); + snapshots.push_back(snapshot); + } + Flush(); + MoveFilesToLevel(6); + ASSERT_EQ("0,0,0,0,0,0,2", FilesPerLevel()); + for (; i < 30; i++) { + Put(key, value + std::to_string(i)); + auto snapshot = dbfull()->GetSnapshot(); + snapshots.push_back(snapshot); + } + Flush(); + ASSERT_EQ("1,0,0,0,0,0,2", FilesPerLevel()); + // And also add some values to the memtable + for (; i < 40; i++) { + Put(key, value + std::to_string(i)); + auto snapshot = dbfull()->GetSnapshot(); + snapshots.push_back(snapshot); + } + + class TestReadCallback : public ReadCallback { + public: + explicit TestReadCallback(SequenceNumber snapshot) : snapshot_(snapshot) {} + virtual bool IsCommitted(SequenceNumber seq) override { + return seq <= snapshot_; + } + + private: + SequenceNumber snapshot_; + }; + + for (int seq = 1; seq < i; seq++) { + PinnableSlice pinnable_val; + ReadOptions roptions; + TestReadCallback callback(seq); + bool dont_care = true; + Status s = dbfull()->GetImpl(roptions, dbfull()->DefaultColumnFamily(), key, + &pinnable_val, &dont_care, &callback); + ASSERT_TRUE(s.ok()); + // Assuming that after each Put the DB increased seq by one, the value and + // seq number must be equal since we also inc value by 1 after each Put. + ASSERT_EQ(value + std::to_string(seq), pinnable_val.ToString()); + } + + for (auto snapshot : snapshots) { + dbfull()->ReleaseSnapshot(snapshot); + } +} + } // namespace rocksdb int main(int argc, char** argv) { diff --git a/db/memtable.cc b/db/memtable.cc index a24989123..854816a95 100644 --- a/db/memtable.cc +++ b/db/memtable.cc @@ -16,6 +16,7 @@ #include "db/merge_context.h" #include "db/merge_helper.h" #include "db/pinned_iterators_manager.h" +#include "db/read_callback.h" #include "monitoring/perf_context_imp.h" #include "monitoring/statistics.h" #include "port/port.h" @@ -537,6 +538,13 @@ struct Saver { Statistics* statistics; bool inplace_update_support; Env* env_; + ReadCallback* callback_; + bool CheckCallback(SequenceNumber _seq) { + if (callback_) { + return callback_->IsCommitted(_seq); + } + return true; + } }; } // namespace @@ -564,7 +572,14 @@ static bool SaveValue(void* arg, const char* entry) { // Correct user key const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8); ValueType type; - UnPackSequenceAndType(tag, &s->seq, &type); + SequenceNumber seq; + UnPackSequenceAndType(tag, &seq, &type); + // If the value is not in the snapshot, skip it + if (!s->CheckCallback(seq)) { + return true; // to continue to the next seq + } + + s->seq = seq; if ((type == kTypeValue || type == kTypeMerge) && range_del_agg->ShouldDelete(Slice(key_ptr, key_length))) { @@ -635,7 +650,7 @@ static bool SaveValue(void* arg, const char* entry) { bool MemTable::Get(const LookupKey& key, std::string* value, Status* s, MergeContext* merge_context, RangeDelAggregator* range_del_agg, SequenceNumber* seq, - const ReadOptions& read_opts) { + const ReadOptions& read_opts, ReadCallback* callback) { // The sequence number is updated synchronously in version_set.h if (IsEmpty()) { // Avoiding recording stats for speed. @@ -681,6 +696,7 @@ bool MemTable::Get(const LookupKey& key, std::string* value, Status* s, saver.inplace_update_support = moptions_.inplace_update_support; saver.statistics = moptions_.statistics; saver.env_ = env_; + saver.callback_ = callback; table_->Get(key, &saver, SaveValue); *seq = saver.seq; diff --git a/db/memtable.h b/db/memtable.h index fe9feaf57..9c98705bc 100644 --- a/db/memtable.h +++ b/db/memtable.h @@ -17,6 +17,7 @@ #include #include "db/dbformat.h" #include "db/range_del_aggregator.h" +#include "db/read_callback.h" #include "db/version_edit.h" #include "monitoring/instrumented_mutex.h" #include "options/cf_options.h" @@ -187,13 +188,15 @@ class MemTable { // status returned indicates a corruption or other unexpected error. bool Get(const LookupKey& key, std::string* value, Status* s, MergeContext* merge_context, RangeDelAggregator* range_del_agg, - SequenceNumber* seq, const ReadOptions& read_opts); + SequenceNumber* seq, const ReadOptions& read_opts, + ReadCallback* callback = nullptr); bool Get(const LookupKey& key, std::string* value, Status* s, MergeContext* merge_context, RangeDelAggregator* range_del_agg, - const ReadOptions& read_opts) { + const ReadOptions& read_opts, ReadCallback* callback = nullptr) { SequenceNumber seq; - return Get(key, value, s, merge_context, range_del_agg, &seq, read_opts); + return Get(key, value, s, merge_context, range_del_agg, &seq, read_opts, + callback); } // Attempts to update the new_value inplace, else does normal Add diff --git a/db/memtable_list.cc b/db/memtable_list.cc index 8f710c2e9..f0fb4843b 100644 --- a/db/memtable_list.cc +++ b/db/memtable_list.cc @@ -103,10 +103,10 @@ int MemTableList::NumFlushed() const { bool MemTableListVersion::Get(const LookupKey& key, std::string* value, Status* s, MergeContext* merge_context, RangeDelAggregator* range_del_agg, - SequenceNumber* seq, - const ReadOptions& read_opts) { + SequenceNumber* seq, const ReadOptions& read_opts, + ReadCallback* callback) { return GetFromList(&memlist_, key, value, s, merge_context, range_del_agg, - seq, read_opts); + seq, read_opts, callback); } bool MemTableListVersion::GetFromHistory(const LookupKey& key, @@ -119,24 +119,25 @@ bool MemTableListVersion::GetFromHistory(const LookupKey& key, range_del_agg, seq, read_opts); } -bool MemTableListVersion::GetFromList(std::list* list, - const LookupKey& key, std::string* value, - Status* s, MergeContext* merge_context, - RangeDelAggregator* range_del_agg, - SequenceNumber* seq, - const ReadOptions& read_opts) { +bool MemTableListVersion::GetFromList( + std::list* list, const LookupKey& key, std::string* value, + Status* s, MergeContext* merge_context, RangeDelAggregator* range_del_agg, + SequenceNumber* seq, const ReadOptions& read_opts, ReadCallback* callback) { *seq = kMaxSequenceNumber; for (auto& memtable : *list) { SequenceNumber current_seq = kMaxSequenceNumber; bool done = memtable->Get(key, value, s, merge_context, range_del_agg, - ¤t_seq, read_opts); + ¤t_seq, read_opts, callback); if (*seq == kMaxSequenceNumber) { // Store the most recent sequence number of any operation on this key. // Since we only care about the most recent change, we only need to // return the first operation found when searching memtables in // reverse-chronological order. + // current_seq would be equal to kMaxSequenceNumber if the value was to be + // skipped. This allows seq to be assigned again when the next value is + // read. *seq = current_seq; } diff --git a/db/memtable_list.h b/db/memtable_list.h index ed475b83a..1bec0debe 100644 --- a/db/memtable_list.h +++ b/db/memtable_list.h @@ -54,13 +54,15 @@ class MemTableListVersion { // returned). Otherwise, *seq will be set to kMaxSequenceNumber. bool Get(const LookupKey& key, std::string* value, Status* s, MergeContext* merge_context, RangeDelAggregator* range_del_agg, - SequenceNumber* seq, const ReadOptions& read_opts); + SequenceNumber* seq, const ReadOptions& read_opts, + ReadCallback* callback = nullptr); bool Get(const LookupKey& key, std::string* value, Status* s, MergeContext* merge_context, RangeDelAggregator* range_del_agg, - const ReadOptions& read_opts) { + const ReadOptions& read_opts, ReadCallback* callback = nullptr) { SequenceNumber seq; - return Get(key, value, s, merge_context, range_del_agg, &seq, read_opts); + return Get(key, value, s, merge_context, range_del_agg, &seq, read_opts, + callback); } // Similar to Get(), but searches the Memtable history of memtables that @@ -117,7 +119,8 @@ class MemTableListVersion { bool GetFromList(std::list* list, const LookupKey& key, std::string* value, Status* s, MergeContext* merge_context, RangeDelAggregator* range_del_agg, SequenceNumber* seq, - const ReadOptions& read_opts); + const ReadOptions& read_opts, + ReadCallback* callback = nullptr); void AddMemTable(MemTable* m); diff --git a/db/read_callback.h b/db/read_callback.h new file mode 100644 index 000000000..f3fe35dfc --- /dev/null +++ b/db/read_callback.h @@ -0,0 +1,21 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#pragma once + +#include "rocksdb/types.h" + +namespace rocksdb { + +class ReadCallback { + public: + virtual ~ReadCallback() {} + + // Will be called to see if the seq number accepted; if not it moves on to the + // next seq number. + virtual bool IsCommitted(SequenceNumber seq) = 0; +}; + +} // namespace rocksdb diff --git a/db/version_set.cc b/db/version_set.cc index 2ff425d20..823818210 100644 --- a/db/version_set.cc +++ b/db/version_set.cc @@ -16,7 +16,6 @@ #include #include #include -#include #include #include #include @@ -965,7 +964,8 @@ void Version::Get(const ReadOptions& read_options, const LookupKey& k, PinnableSlice* value, Status* status, MergeContext* merge_context, RangeDelAggregator* range_del_agg, bool* value_found, - bool* key_exists, SequenceNumber* seq) { + bool* key_exists, SequenceNumber* seq, + ReadCallback* callback) { Slice ikey = k.internal_key(); Slice user_key = k.user_key(); @@ -981,7 +981,7 @@ void Version::Get(const ReadOptions& read_options, const LookupKey& k, user_comparator(), merge_operator_, info_log_, db_statistics_, status->ok() ? GetContext::kNotFound : GetContext::kMerge, user_key, value, value_found, merge_context, range_del_agg, this->env_, seq, - merge_operator_ ? &pinned_iters_mgr : nullptr); + merge_operator_ ? &pinned_iters_mgr : nullptr, callback); // Pin blocks that we read to hold merge operands if (merge_operator_) { diff --git a/db/version_set.h b/db/version_set.h index 9fb000c05..8c1324698 100644 --- a/db/version_set.h +++ b/db/version_set.h @@ -35,6 +35,7 @@ #include "db/file_indexer.h" #include "db/log_reader.h" #include "db/range_del_aggregator.h" +#include "db/read_callback.h" #include "db/table_cache.h" #include "db/version_builder.h" #include "db/version_edit.h" @@ -485,7 +486,8 @@ class Version { void Get(const ReadOptions&, const LookupKey& key, PinnableSlice* value, Status* status, MergeContext* merge_context, RangeDelAggregator* range_del_agg, bool* value_found = nullptr, - bool* key_exists = nullptr, SequenceNumber* seq = nullptr); + bool* key_exists = nullptr, SequenceNumber* seq = nullptr, + ReadCallback* callback = nullptr); // Loads some stats information from files. Call without mutex held. It needs // to be called before applying the version to the version set. diff --git a/include/rocksdb/utilities/write_batch_with_index.h b/include/rocksdb/utilities/write_batch_with_index.h index 24d8f30aa..96a9e5fb5 100644 --- a/include/rocksdb/utilities/write_batch_with_index.h +++ b/include/rocksdb/utilities/write_batch_with_index.h @@ -27,6 +27,7 @@ namespace rocksdb { class ColumnFamilyHandle; class Comparator; class DB; +class ReadCallback; struct ReadOptions; struct DBOptions; @@ -226,6 +227,10 @@ class WriteBatchWithIndex : public WriteBatchBase { void SetMaxBytes(size_t max_bytes) override; private: + friend class WritePreparedTxn; + Status GetFromBatchAndDB(DB* db, const ReadOptions& read_options, + ColumnFamilyHandle* column_family, const Slice& key, + PinnableSlice* value, ReadCallback* callback); struct Rep; std::unique_ptr rep; }; diff --git a/table/get_context.cc b/table/get_context.cc index 0d688fe46..c68aa3984 100644 --- a/table/get_context.cc +++ b/table/get_context.cc @@ -6,6 +6,7 @@ #include "table/get_context.h" #include "db/merge_helper.h" #include "db/pinned_iterators_manager.h" +#include "db/read_callback.h" #include "monitoring/file_read_sample.h" #include "monitoring/perf_context_imp.h" #include "monitoring/statistics.h" @@ -33,14 +34,12 @@ void appendToReplayLog(std::string* replay_log, ValueType type, Slice value) { } // namespace -GetContext::GetContext(const Comparator* ucmp, - const MergeOperator* merge_operator, Logger* logger, - Statistics* statistics, GetState init_state, - const Slice& user_key, PinnableSlice* pinnable_val, - bool* value_found, MergeContext* merge_context, - RangeDelAggregator* _range_del_agg, Env* env, - SequenceNumber* seq, - PinnedIteratorsManager* _pinned_iters_mgr) +GetContext::GetContext( + const Comparator* ucmp, const MergeOperator* merge_operator, Logger* logger, + Statistics* statistics, GetState init_state, const Slice& user_key, + PinnableSlice* pinnable_val, bool* value_found, MergeContext* merge_context, + RangeDelAggregator* _range_del_agg, Env* env, SequenceNumber* seq, + PinnedIteratorsManager* _pinned_iters_mgr, ReadCallback* callback) : ucmp_(ucmp), merge_operator_(merge_operator), logger_(logger), @@ -54,7 +53,8 @@ GetContext::GetContext(const Comparator* ucmp, env_(env), seq_(seq), replay_log_(nullptr), - pinned_iters_mgr_(_pinned_iters_mgr) { + pinned_iters_mgr_(_pinned_iters_mgr), + callback_(callback) { if (seq_) { *seq_ = kMaxSequenceNumber; } @@ -88,6 +88,11 @@ bool GetContext::SaveValue(const ParsedInternalKey& parsed_key, assert((state_ != kMerge && parsed_key.type != kTypeMerge) || merge_context_ != nullptr); if (ucmp_->Equal(parsed_key.user_key, user_key_)) { + // If the value is not in the snapshot, skip it + if (!CheckCallback(parsed_key.sequence)) { + return true; // to continue to the next seq + } + appendToReplayLog(replay_log_, parsed_key.type, value); if (seq_ != nullptr) { diff --git a/table/get_context.h b/table/get_context.h index ac50680b6..efea68b0a 100644 --- a/table/get_context.h +++ b/table/get_context.h @@ -7,6 +7,7 @@ #include #include "db/merge_context.h" #include "db/range_del_aggregator.h" +#include "db/read_callback.h" #include "rocksdb/env.h" #include "rocksdb/types.h" #include "table/block.h" @@ -30,7 +31,8 @@ class GetContext { const Slice& user_key, PinnableSlice* value, bool* value_found, MergeContext* merge_context, RangeDelAggregator* range_del_agg, Env* env, SequenceNumber* seq = nullptr, - PinnedIteratorsManager* _pinned_iters_mgr = nullptr); + PinnedIteratorsManager* _pinned_iters_mgr = nullptr, + ReadCallback* callback = nullptr); void MarkKeyMayExist(); @@ -62,6 +64,13 @@ class GetContext { bool sample() const { return sample_; } + bool CheckCallback(SequenceNumber seq) { + if (callback_) { + return callback_->IsCommitted(seq); + } + return true; + } + private: const Comparator* ucmp_; const MergeOperator* merge_operator_; @@ -82,6 +91,7 @@ class GetContext { std::string* replay_log_; // Used to temporarily pin blocks when state_ == GetContext::kMerge PinnedIteratorsManager* pinned_iters_mgr_; + ReadCallback* callback_; bool sample_; }; diff --git a/utilities/transactions/pessimistic_transaction_db.h b/utilities/transactions/pessimistic_transaction_db.h index 23ecdea29..903b135f2 100644 --- a/utilities/transactions/pessimistic_transaction_db.h +++ b/utilities/transactions/pessimistic_transaction_db.h @@ -13,6 +13,7 @@ #include #include +#include "db/read_callback.h" #include "rocksdb/db.h" #include "rocksdb/options.h" #include "rocksdb/utilities/transaction_db.h" @@ -369,5 +370,21 @@ class WritePreparedTxnDB : public PessimisticTransactionDB { port::RWMutex snapshots_mutex_; }; +class WritePreparedTxnReadCallback : public ReadCallback { + public: + WritePreparedTxnReadCallback(WritePreparedTxnDB* db, SequenceNumber snapshot) + : db_(db), snapshot_(snapshot) {} + + // Will be called to see if the seq number accepted; if not it moves on to the + // next seq number. + virtual bool IsCommitted(SequenceNumber seq) override { + return db_->IsInSnapshot(seq, snapshot_); + } + + private: + WritePreparedTxnDB* db_; + SequenceNumber snapshot_; +}; + } // namespace rocksdb #endif // ROCKSDB_LITE diff --git a/utilities/transactions/write_prepared_txn.cc b/utilities/transactions/write_prepared_txn.cc index 07c66690a..243c91e2d 100644 --- a/utilities/transactions/write_prepared_txn.cc +++ b/utilities/transactions/write_prepared_txn.cc @@ -29,6 +29,18 @@ WritePreparedTxn::WritePreparedTxn(WritePreparedTxnDB* txn_db, PessimisticTransaction::Initialize(txn_options); } +Status WritePreparedTxn::Get(const ReadOptions& read_options, + ColumnFamilyHandle* column_family, + const Slice& key, PinnableSlice* pinnable_val) { + auto snapshot = GetSnapshot(); + auto snap_seq = + snapshot != nullptr ? snapshot->GetSequenceNumber() : kMaxSequenceNumber; + + WritePreparedTxnReadCallback callback(wpt_db_, snap_seq); + return write_batch_.GetFromBatchAndDB(db_, read_options, column_family, key, + pinnable_val, &callback); +} + Status WritePreparedTxn::CommitBatch(WriteBatch* /* unused */) { // TODO(myabandeh) Implement this throw std::runtime_error("CommitBatch not Implemented"); diff --git a/utilities/transactions/write_prepared_txn.h b/utilities/transactions/write_prepared_txn.h index 13afabe72..a2fe2ae4b 100644 --- a/utilities/transactions/write_prepared_txn.h +++ b/utilities/transactions/write_prepared_txn.h @@ -45,6 +45,11 @@ class WritePreparedTxn : public PessimisticTransaction { virtual ~WritePreparedTxn() {} + using Transaction::Get; + virtual Status Get(const ReadOptions& options, + ColumnFamilyHandle* column_family, const Slice& key, + PinnableSlice* value) override; + Status CommitBatch(WriteBatch* batch) override; Status Rollback() override; diff --git a/utilities/write_batch_with_index/write_batch_with_index.cc b/utilities/write_batch_with_index/write_batch_with_index.cc index b2820109c..b6a5d1dd3 100644 --- a/utilities/write_batch_with_index/write_batch_with_index.cc +++ b/utilities/write_batch_with_index/write_batch_with_index.cc @@ -783,6 +783,17 @@ Status WriteBatchWithIndex::GetFromBatchAndDB(DB* db, ColumnFamilyHandle* column_family, const Slice& key, PinnableSlice* pinnable_val) { + return GetFromBatchAndDB(db, read_options, column_family, key, pinnable_val, + nullptr); +} + +Status WriteBatchWithIndex::GetFromBatchAndDB( + DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family, + const Slice& key, PinnableSlice* pinnable_val, ReadCallback* callback) { + if (UNLIKELY(db->GetRootDB() != db)) { + return Status::NotSupported("The DB must be of DBImpl type"); + // Otherwise the cast below would fail + } Status s; MergeContext merge_context; const ImmutableDBOptions& immuable_db_options = @@ -819,7 +830,12 @@ Status WriteBatchWithIndex::GetFromBatchAndDB(DB* db, result == WriteBatchWithIndexInternal::Result::kNotFound); // Did not find key in batch OR could not resolve Merges. Try DB. - s = db->Get(read_options, column_family, key, pinnable_val); + if (!callback) { + s = db->Get(read_options, column_family, key, pinnable_val); + } else { + s = reinterpret_cast(db)->GetImpl(read_options, column_family, key, + pinnable_val, nullptr, callback); + } if (s.ok() || s.IsNotFound()) { // DB Get Succeeded if (result == WriteBatchWithIndexInternal::Result::kMergeInProgress) {