write-prepared txn: call IsInSnapshot

Summary:
This patch instruments the read path to verify each read value against an optional ReadCallback class. If the value is rejected, the reader moves on to the next value. The WritePreparedTxn makes use of this feature to skip sequence numbers that are not in the read snapshot.
Closes https://github.com/facebook/rocksdb/pull/2850

Differential Revision: D5787375

Pulled By: maysamyabandeh

fbshipit-source-id: 49d808b3062ab35e7ae98ad388f659757794184c
main
Maysam Yabandeh 7 years ago committed by Facebook Github Bot
parent 9a4df72994
commit f46464d383
  1. 13
      db/db_impl.cc
  2. 5
      db/db_impl.h
  3. 77
      db/db_test2.cc
  4. 20
      db/memtable.cc
  5. 9
      db/memtable.h
  6. 21
      db/memtable_list.cc
  7. 11
      db/memtable_list.h
  8. 21
      db/read_callback.h
  9. 6
      db/version_set.cc
  10. 4
      db/version_set.h
  11. 5
      include/rocksdb/utilities/write_batch_with_index.h
  12. 23
      table/get_context.cc
  13. 12
      table/get_context.h
  14. 17
      utilities/transactions/pessimistic_transaction_db.h
  15. 12
      utilities/transactions/write_prepared_txn.cc
  16. 5
      utilities/transactions/write_prepared_txn.h
  17. 16
      utilities/write_batch_with_index/write_batch_with_index.cc

@ -11,14 +11,12 @@
#ifndef __STDC_FORMAT_MACROS #ifndef __STDC_FORMAT_MACROS
#define __STDC_FORMAT_MACROS #define __STDC_FORMAT_MACROS
#endif #endif
#include <inttypes.h>
#include <stdint.h> #include <stdint.h>
#ifdef OS_SOLARIS #ifdef OS_SOLARIS
#include <alloca.h> #include <alloca.h>
#endif #endif
#include <algorithm> #include <algorithm>
#include <climits>
#include <cstdio> #include <cstdio>
#include <map> #include <map>
#include <set> #include <set>
@ -63,7 +61,6 @@
#include "options/cf_options.h" #include "options/cf_options.h"
#include "options/options_helper.h" #include "options/options_helper.h"
#include "options/options_parser.h" #include "options/options_parser.h"
#include "port/likely.h"
#include "port/port.h" #include "port/port.h"
#include "rocksdb/cache.h" #include "rocksdb/cache.h"
#include "rocksdb/compaction_filter.h" #include "rocksdb/compaction_filter.h"
@ -74,7 +71,6 @@
#include "rocksdb/statistics.h" #include "rocksdb/statistics.h"
#include "rocksdb/status.h" #include "rocksdb/status.h"
#include "rocksdb/table.h" #include "rocksdb/table.h"
#include "rocksdb/version.h"
#include "rocksdb/write_buffer_manager.h" #include "rocksdb/write_buffer_manager.h"
#include "table/block.h" #include "table/block.h"
#include "table/block_based_table_factory.h" #include "table/block_based_table_factory.h"
@ -909,7 +905,8 @@ Status DBImpl::Get(const ReadOptions& read_options,
Status DBImpl::GetImpl(const ReadOptions& read_options, Status DBImpl::GetImpl(const ReadOptions& read_options,
ColumnFamilyHandle* column_family, const Slice& key, ColumnFamilyHandle* column_family, const Slice& key,
PinnableSlice* pinnable_val, bool* value_found) { PinnableSlice* pinnable_val, bool* value_found,
ReadCallback* callback) {
assert(pinnable_val != nullptr); assert(pinnable_val != nullptr);
StopWatch sw(env_, stats_, DB_GET); StopWatch sw(env_, stats_, DB_GET);
PERF_TIMER_GUARD(get_snapshot_time); PERF_TIMER_GUARD(get_snapshot_time);
@ -959,13 +956,13 @@ Status DBImpl::GetImpl(const ReadOptions& read_options,
bool done = false; bool done = false;
if (!skip_memtable) { if (!skip_memtable) {
if (sv->mem->Get(lkey, pinnable_val->GetSelf(), &s, &merge_context, if (sv->mem->Get(lkey, pinnable_val->GetSelf(), &s, &merge_context,
&range_del_agg, read_options)) { &range_del_agg, read_options, callback)) {
done = true; done = true;
pinnable_val->PinSelf(); pinnable_val->PinSelf();
RecordTick(stats_, MEMTABLE_HIT); RecordTick(stats_, MEMTABLE_HIT);
} else if ((s.ok() || s.IsMergeInProgress()) && } else if ((s.ok() || s.IsMergeInProgress()) &&
sv->imm->Get(lkey, pinnable_val->GetSelf(), &s, &merge_context, sv->imm->Get(lkey, pinnable_val->GetSelf(), &s, &merge_context,
&range_del_agg, read_options)) { &range_del_agg, read_options, callback)) {
done = true; done = true;
pinnable_val->PinSelf(); pinnable_val->PinSelf();
RecordTick(stats_, MEMTABLE_HIT); RecordTick(stats_, MEMTABLE_HIT);
@ -977,7 +974,7 @@ Status DBImpl::GetImpl(const ReadOptions& read_options,
if (!done) { if (!done) {
PERF_TIMER_GUARD(get_from_output_files_time); PERF_TIMER_GUARD(get_from_output_files_time);
sv->current->Get(read_options, lkey, pinnable_val, &s, &merge_context, sv->current->Get(read_options, lkey, pinnable_val, &s, &merge_context,
&range_del_agg, value_found); &range_del_agg, value_found, nullptr, nullptr, callback);
RecordTick(stats_, MEMTABLE_MISS); RecordTick(stats_, MEMTABLE_MISS);
} }

@ -28,6 +28,7 @@
#include "db/flush_scheduler.h" #include "db/flush_scheduler.h"
#include "db/internal_stats.h" #include "db/internal_stats.h"
#include "db/log_writer.h" #include "db/log_writer.h"
#include "db/read_callback.h"
#include "db/snapshot_impl.h" #include "db/snapshot_impl.h"
#include "db/version_edit.h" #include "db/version_edit.h"
#include "db/wal_manager.h" #include "db/wal_manager.h"
@ -634,10 +635,12 @@ class DBImpl : public DB {
private: private:
friend class DB; friend class DB;
friend class DBTest2_ReadCallbackTest_Test;
friend class InternalStats; friend class InternalStats;
friend class PessimisticTransaction; friend class PessimisticTransaction;
friend class WriteCommittedTxn; friend class WriteCommittedTxn;
friend class WritePreparedTxn; friend class WritePreparedTxn;
friend class WriteBatchWithIndex;
#ifndef ROCKSDB_LITE #ifndef ROCKSDB_LITE
friend class ForwardIterator; friend class ForwardIterator;
#endif #endif
@ -1244,7 +1247,7 @@ class DBImpl : public DB {
// Note: 'value_found' from KeyMayExist propagates here // Note: 'value_found' from KeyMayExist propagates here
Status GetImpl(const ReadOptions& options, ColumnFamilyHandle* column_family, Status GetImpl(const ReadOptions& options, ColumnFamilyHandle* column_family,
const Slice& key, PinnableSlice* value, const Slice& key, PinnableSlice* value,
bool* value_found = nullptr); bool* value_found = nullptr, ReadCallback* callback = nullptr);
bool GetIntPropertyInternal(ColumnFamilyData* cfd, bool GetIntPropertyInternal(ColumnFamilyData* cfd,
const DBPropertyInfo& property_info, const DBPropertyInfo& property_info,

@ -11,6 +11,7 @@
#include <functional> #include <functional>
#include "db/db_test_util.h" #include "db/db_test_util.h"
#include "db/read_callback.h"
#include "port/port.h" #include "port/port.h"
#include "port/stack_trace.h" #include "port/stack_trace.h"
#include "rocksdb/persistent_cache.h" #include "rocksdb/persistent_cache.h"
@ -2325,6 +2326,82 @@ TEST_F(DBTest2, ReduceLevel) {
Reopen(options); Reopen(options);
ASSERT_EQ("0,1", FilesPerLevel()); ASSERT_EQ("0,1", FilesPerLevel());
} }
// Test that ReadCallback is actually used in both memtbale and sst tables
TEST_F(DBTest2, ReadCallbackTest) {
Options options;
options.disable_auto_compactions = true;
options.num_levels = 7;
Reopen(options);
std::vector<const Snapshot*> snapshots;
// Try to create a db with multiple layers and a memtable
const std::string key = "foo";
const std::string value = "bar";
// This test assumes that the seq start with 1 and increased by 1 after each
// write batch of size 1. If that behavior changes, the test needs to be
// updated as well.
// TODO(myabandeh): update this test to use the seq number that is returned by
// the DB instead of assuming what seq the DB used.
int i = 1;
for (; i < 10; i++) {
Put(key, value + std::to_string(i));
// Take a snapshot to avoid the value being removed during compaction
auto snapshot = dbfull()->GetSnapshot();
snapshots.push_back(snapshot);
}
Flush();
for (; i < 20; i++) {
Put(key, value + std::to_string(i));
// Take a snapshot to avoid the value being removed during compaction
auto snapshot = dbfull()->GetSnapshot();
snapshots.push_back(snapshot);
}
Flush();
MoveFilesToLevel(6);
ASSERT_EQ("0,0,0,0,0,0,2", FilesPerLevel());
for (; i < 30; i++) {
Put(key, value + std::to_string(i));
auto snapshot = dbfull()->GetSnapshot();
snapshots.push_back(snapshot);
}
Flush();
ASSERT_EQ("1,0,0,0,0,0,2", FilesPerLevel());
// And also add some values to the memtable
for (; i < 40; i++) {
Put(key, value + std::to_string(i));
auto snapshot = dbfull()->GetSnapshot();
snapshots.push_back(snapshot);
}
class TestReadCallback : public ReadCallback {
public:
explicit TestReadCallback(SequenceNumber snapshot) : snapshot_(snapshot) {}
virtual bool IsCommitted(SequenceNumber seq) override {
return seq <= snapshot_;
}
private:
SequenceNumber snapshot_;
};
for (int seq = 1; seq < i; seq++) {
PinnableSlice pinnable_val;
ReadOptions roptions;
TestReadCallback callback(seq);
bool dont_care = true;
Status s = dbfull()->GetImpl(roptions, dbfull()->DefaultColumnFamily(), key,
&pinnable_val, &dont_care, &callback);
ASSERT_TRUE(s.ok());
// Assuming that after each Put the DB increased seq by one, the value and
// seq number must be equal since we also inc value by 1 after each Put.
ASSERT_EQ(value + std::to_string(seq), pinnable_val.ToString());
}
for (auto snapshot : snapshots) {
dbfull()->ReleaseSnapshot(snapshot);
}
}
} // namespace rocksdb } // namespace rocksdb
int main(int argc, char** argv) { int main(int argc, char** argv) {

@ -16,6 +16,7 @@
#include "db/merge_context.h" #include "db/merge_context.h"
#include "db/merge_helper.h" #include "db/merge_helper.h"
#include "db/pinned_iterators_manager.h" #include "db/pinned_iterators_manager.h"
#include "db/read_callback.h"
#include "monitoring/perf_context_imp.h" #include "monitoring/perf_context_imp.h"
#include "monitoring/statistics.h" #include "monitoring/statistics.h"
#include "port/port.h" #include "port/port.h"
@ -537,6 +538,13 @@ struct Saver {
Statistics* statistics; Statistics* statistics;
bool inplace_update_support; bool inplace_update_support;
Env* env_; Env* env_;
ReadCallback* callback_;
bool CheckCallback(SequenceNumber _seq) {
if (callback_) {
return callback_->IsCommitted(_seq);
}
return true;
}
}; };
} // namespace } // namespace
@ -564,7 +572,14 @@ static bool SaveValue(void* arg, const char* entry) {
// Correct user key // Correct user key
const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8); const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8);
ValueType type; ValueType type;
UnPackSequenceAndType(tag, &s->seq, &type); SequenceNumber seq;
UnPackSequenceAndType(tag, &seq, &type);
// If the value is not in the snapshot, skip it
if (!s->CheckCallback(seq)) {
return true; // to continue to the next seq
}
s->seq = seq;
if ((type == kTypeValue || type == kTypeMerge) && if ((type == kTypeValue || type == kTypeMerge) &&
range_del_agg->ShouldDelete(Slice(key_ptr, key_length))) { range_del_agg->ShouldDelete(Slice(key_ptr, key_length))) {
@ -635,7 +650,7 @@ static bool SaveValue(void* arg, const char* entry) {
bool MemTable::Get(const LookupKey& key, std::string* value, Status* s, bool MemTable::Get(const LookupKey& key, std::string* value, Status* s,
MergeContext* merge_context, MergeContext* merge_context,
RangeDelAggregator* range_del_agg, SequenceNumber* seq, RangeDelAggregator* range_del_agg, SequenceNumber* seq,
const ReadOptions& read_opts) { const ReadOptions& read_opts, ReadCallback* callback) {
// The sequence number is updated synchronously in version_set.h // The sequence number is updated synchronously in version_set.h
if (IsEmpty()) { if (IsEmpty()) {
// Avoiding recording stats for speed. // Avoiding recording stats for speed.
@ -681,6 +696,7 @@ bool MemTable::Get(const LookupKey& key, std::string* value, Status* s,
saver.inplace_update_support = moptions_.inplace_update_support; saver.inplace_update_support = moptions_.inplace_update_support;
saver.statistics = moptions_.statistics; saver.statistics = moptions_.statistics;
saver.env_ = env_; saver.env_ = env_;
saver.callback_ = callback;
table_->Get(key, &saver, SaveValue); table_->Get(key, &saver, SaveValue);
*seq = saver.seq; *seq = saver.seq;

@ -17,6 +17,7 @@
#include <vector> #include <vector>
#include "db/dbformat.h" #include "db/dbformat.h"
#include "db/range_del_aggregator.h" #include "db/range_del_aggregator.h"
#include "db/read_callback.h"
#include "db/version_edit.h" #include "db/version_edit.h"
#include "monitoring/instrumented_mutex.h" #include "monitoring/instrumented_mutex.h"
#include "options/cf_options.h" #include "options/cf_options.h"
@ -187,13 +188,15 @@ class MemTable {
// status returned indicates a corruption or other unexpected error. // status returned indicates a corruption or other unexpected error.
bool Get(const LookupKey& key, std::string* value, Status* s, bool Get(const LookupKey& key, std::string* value, Status* s,
MergeContext* merge_context, RangeDelAggregator* range_del_agg, MergeContext* merge_context, RangeDelAggregator* range_del_agg,
SequenceNumber* seq, const ReadOptions& read_opts); SequenceNumber* seq, const ReadOptions& read_opts,
ReadCallback* callback = nullptr);
bool Get(const LookupKey& key, std::string* value, Status* s, bool Get(const LookupKey& key, std::string* value, Status* s,
MergeContext* merge_context, RangeDelAggregator* range_del_agg, MergeContext* merge_context, RangeDelAggregator* range_del_agg,
const ReadOptions& read_opts) { const ReadOptions& read_opts, ReadCallback* callback = nullptr) {
SequenceNumber seq; SequenceNumber seq;
return Get(key, value, s, merge_context, range_del_agg, &seq, read_opts); return Get(key, value, s, merge_context, range_del_agg, &seq, read_opts,
callback);
} }
// Attempts to update the new_value inplace, else does normal Add // Attempts to update the new_value inplace, else does normal Add

@ -103,10 +103,10 @@ int MemTableList::NumFlushed() const {
bool MemTableListVersion::Get(const LookupKey& key, std::string* value, bool MemTableListVersion::Get(const LookupKey& key, std::string* value,
Status* s, MergeContext* merge_context, Status* s, MergeContext* merge_context,
RangeDelAggregator* range_del_agg, RangeDelAggregator* range_del_agg,
SequenceNumber* seq, SequenceNumber* seq, const ReadOptions& read_opts,
const ReadOptions& read_opts) { ReadCallback* callback) {
return GetFromList(&memlist_, key, value, s, merge_context, range_del_agg, return GetFromList(&memlist_, key, value, s, merge_context, range_del_agg,
seq, read_opts); seq, read_opts, callback);
} }
bool MemTableListVersion::GetFromHistory(const LookupKey& key, bool MemTableListVersion::GetFromHistory(const LookupKey& key,
@ -119,24 +119,25 @@ bool MemTableListVersion::GetFromHistory(const LookupKey& key,
range_del_agg, seq, read_opts); range_del_agg, seq, read_opts);
} }
bool MemTableListVersion::GetFromList(std::list<MemTable*>* list, bool MemTableListVersion::GetFromList(
const LookupKey& key, std::string* value, std::list<MemTable*>* list, const LookupKey& key, std::string* value,
Status* s, MergeContext* merge_context, Status* s, MergeContext* merge_context, RangeDelAggregator* range_del_agg,
RangeDelAggregator* range_del_agg, SequenceNumber* seq, const ReadOptions& read_opts, ReadCallback* callback) {
SequenceNumber* seq,
const ReadOptions& read_opts) {
*seq = kMaxSequenceNumber; *seq = kMaxSequenceNumber;
for (auto& memtable : *list) { for (auto& memtable : *list) {
SequenceNumber current_seq = kMaxSequenceNumber; SequenceNumber current_seq = kMaxSequenceNumber;
bool done = memtable->Get(key, value, s, merge_context, range_del_agg, bool done = memtable->Get(key, value, s, merge_context, range_del_agg,
&current_seq, read_opts); &current_seq, read_opts, callback);
if (*seq == kMaxSequenceNumber) { if (*seq == kMaxSequenceNumber) {
// Store the most recent sequence number of any operation on this key. // Store the most recent sequence number of any operation on this key.
// Since we only care about the most recent change, we only need to // Since we only care about the most recent change, we only need to
// return the first operation found when searching memtables in // return the first operation found when searching memtables in
// reverse-chronological order. // reverse-chronological order.
// current_seq would be equal to kMaxSequenceNumber if the value was to be
// skipped. This allows seq to be assigned again when the next value is
// read.
*seq = current_seq; *seq = current_seq;
} }

@ -54,13 +54,15 @@ class MemTableListVersion {
// returned). Otherwise, *seq will be set to kMaxSequenceNumber. // returned). Otherwise, *seq will be set to kMaxSequenceNumber.
bool Get(const LookupKey& key, std::string* value, Status* s, bool Get(const LookupKey& key, std::string* value, Status* s,
MergeContext* merge_context, RangeDelAggregator* range_del_agg, MergeContext* merge_context, RangeDelAggregator* range_del_agg,
SequenceNumber* seq, const ReadOptions& read_opts); SequenceNumber* seq, const ReadOptions& read_opts,
ReadCallback* callback = nullptr);
bool Get(const LookupKey& key, std::string* value, Status* s, bool Get(const LookupKey& key, std::string* value, Status* s,
MergeContext* merge_context, RangeDelAggregator* range_del_agg, MergeContext* merge_context, RangeDelAggregator* range_del_agg,
const ReadOptions& read_opts) { const ReadOptions& read_opts, ReadCallback* callback = nullptr) {
SequenceNumber seq; SequenceNumber seq;
return Get(key, value, s, merge_context, range_del_agg, &seq, read_opts); return Get(key, value, s, merge_context, range_del_agg, &seq, read_opts,
callback);
} }
// Similar to Get(), but searches the Memtable history of memtables that // Similar to Get(), but searches the Memtable history of memtables that
@ -117,7 +119,8 @@ class MemTableListVersion {
bool GetFromList(std::list<MemTable*>* list, const LookupKey& key, bool GetFromList(std::list<MemTable*>* list, const LookupKey& key,
std::string* value, Status* s, MergeContext* merge_context, std::string* value, Status* s, MergeContext* merge_context,
RangeDelAggregator* range_del_agg, SequenceNumber* seq, RangeDelAggregator* range_del_agg, SequenceNumber* seq,
const ReadOptions& read_opts); const ReadOptions& read_opts,
ReadCallback* callback = nullptr);
void AddMemTable(MemTable* m); void AddMemTable(MemTable* m);

@ -0,0 +1,21 @@
// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
// This source code is licensed under both the GPLv2 (found in the
// COPYING file in the root directory) and Apache 2.0 License
// (found in the LICENSE.Apache file in the root directory).
#pragma once
#include "rocksdb/types.h"
namespace rocksdb {
class ReadCallback {
public:
virtual ~ReadCallback() {}
// Will be called to see if the seq number accepted; if not it moves on to the
// next seq number.
virtual bool IsCommitted(SequenceNumber seq) = 0;
};
} // namespace rocksdb

@ -16,7 +16,6 @@
#include <inttypes.h> #include <inttypes.h>
#include <stdio.h> #include <stdio.h>
#include <algorithm> #include <algorithm>
#include <climits>
#include <map> #include <map>
#include <set> #include <set>
#include <string> #include <string>
@ -965,7 +964,8 @@ void Version::Get(const ReadOptions& read_options, const LookupKey& k,
PinnableSlice* value, Status* status, PinnableSlice* value, Status* status,
MergeContext* merge_context, MergeContext* merge_context,
RangeDelAggregator* range_del_agg, bool* value_found, RangeDelAggregator* range_del_agg, bool* value_found,
bool* key_exists, SequenceNumber* seq) { bool* key_exists, SequenceNumber* seq,
ReadCallback* callback) {
Slice ikey = k.internal_key(); Slice ikey = k.internal_key();
Slice user_key = k.user_key(); Slice user_key = k.user_key();
@ -981,7 +981,7 @@ void Version::Get(const ReadOptions& read_options, const LookupKey& k,
user_comparator(), merge_operator_, info_log_, db_statistics_, user_comparator(), merge_operator_, info_log_, db_statistics_,
status->ok() ? GetContext::kNotFound : GetContext::kMerge, user_key, status->ok() ? GetContext::kNotFound : GetContext::kMerge, user_key,
value, value_found, merge_context, range_del_agg, this->env_, seq, value, value_found, merge_context, range_del_agg, this->env_, seq,
merge_operator_ ? &pinned_iters_mgr : nullptr); merge_operator_ ? &pinned_iters_mgr : nullptr, callback);
// Pin blocks that we read to hold merge operands // Pin blocks that we read to hold merge operands
if (merge_operator_) { if (merge_operator_) {

@ -35,6 +35,7 @@
#include "db/file_indexer.h" #include "db/file_indexer.h"
#include "db/log_reader.h" #include "db/log_reader.h"
#include "db/range_del_aggregator.h" #include "db/range_del_aggregator.h"
#include "db/read_callback.h"
#include "db/table_cache.h" #include "db/table_cache.h"
#include "db/version_builder.h" #include "db/version_builder.h"
#include "db/version_edit.h" #include "db/version_edit.h"
@ -485,7 +486,8 @@ class Version {
void Get(const ReadOptions&, const LookupKey& key, PinnableSlice* value, void Get(const ReadOptions&, const LookupKey& key, PinnableSlice* value,
Status* status, MergeContext* merge_context, Status* status, MergeContext* merge_context,
RangeDelAggregator* range_del_agg, bool* value_found = nullptr, RangeDelAggregator* range_del_agg, bool* value_found = nullptr,
bool* key_exists = nullptr, SequenceNumber* seq = nullptr); bool* key_exists = nullptr, SequenceNumber* seq = nullptr,
ReadCallback* callback = nullptr);
// Loads some stats information from files. Call without mutex held. It needs // Loads some stats information from files. Call without mutex held. It needs
// to be called before applying the version to the version set. // to be called before applying the version to the version set.

@ -27,6 +27,7 @@ namespace rocksdb {
class ColumnFamilyHandle; class ColumnFamilyHandle;
class Comparator; class Comparator;
class DB; class DB;
class ReadCallback;
struct ReadOptions; struct ReadOptions;
struct DBOptions; struct DBOptions;
@ -226,6 +227,10 @@ class WriteBatchWithIndex : public WriteBatchBase {
void SetMaxBytes(size_t max_bytes) override; void SetMaxBytes(size_t max_bytes) override;
private: private:
friend class WritePreparedTxn;
Status GetFromBatchAndDB(DB* db, const ReadOptions& read_options,
ColumnFamilyHandle* column_family, const Slice& key,
PinnableSlice* value, ReadCallback* callback);
struct Rep; struct Rep;
std::unique_ptr<Rep> rep; std::unique_ptr<Rep> rep;
}; };

@ -6,6 +6,7 @@
#include "table/get_context.h" #include "table/get_context.h"
#include "db/merge_helper.h" #include "db/merge_helper.h"
#include "db/pinned_iterators_manager.h" #include "db/pinned_iterators_manager.h"
#include "db/read_callback.h"
#include "monitoring/file_read_sample.h" #include "monitoring/file_read_sample.h"
#include "monitoring/perf_context_imp.h" #include "monitoring/perf_context_imp.h"
#include "monitoring/statistics.h" #include "monitoring/statistics.h"
@ -33,14 +34,12 @@ void appendToReplayLog(std::string* replay_log, ValueType type, Slice value) {
} // namespace } // namespace
GetContext::GetContext(const Comparator* ucmp, GetContext::GetContext(
const MergeOperator* merge_operator, Logger* logger, const Comparator* ucmp, const MergeOperator* merge_operator, Logger* logger,
Statistics* statistics, GetState init_state, Statistics* statistics, GetState init_state, const Slice& user_key,
const Slice& user_key, PinnableSlice* pinnable_val, PinnableSlice* pinnable_val, bool* value_found, MergeContext* merge_context,
bool* value_found, MergeContext* merge_context, RangeDelAggregator* _range_del_agg, Env* env, SequenceNumber* seq,
RangeDelAggregator* _range_del_agg, Env* env, PinnedIteratorsManager* _pinned_iters_mgr, ReadCallback* callback)
SequenceNumber* seq,
PinnedIteratorsManager* _pinned_iters_mgr)
: ucmp_(ucmp), : ucmp_(ucmp),
merge_operator_(merge_operator), merge_operator_(merge_operator),
logger_(logger), logger_(logger),
@ -54,7 +53,8 @@ GetContext::GetContext(const Comparator* ucmp,
env_(env), env_(env),
seq_(seq), seq_(seq),
replay_log_(nullptr), replay_log_(nullptr),
pinned_iters_mgr_(_pinned_iters_mgr) { pinned_iters_mgr_(_pinned_iters_mgr),
callback_(callback) {
if (seq_) { if (seq_) {
*seq_ = kMaxSequenceNumber; *seq_ = kMaxSequenceNumber;
} }
@ -88,6 +88,11 @@ bool GetContext::SaveValue(const ParsedInternalKey& parsed_key,
assert((state_ != kMerge && parsed_key.type != kTypeMerge) || assert((state_ != kMerge && parsed_key.type != kTypeMerge) ||
merge_context_ != nullptr); merge_context_ != nullptr);
if (ucmp_->Equal(parsed_key.user_key, user_key_)) { if (ucmp_->Equal(parsed_key.user_key, user_key_)) {
// If the value is not in the snapshot, skip it
if (!CheckCallback(parsed_key.sequence)) {
return true; // to continue to the next seq
}
appendToReplayLog(replay_log_, parsed_key.type, value); appendToReplayLog(replay_log_, parsed_key.type, value);
if (seq_ != nullptr) { if (seq_ != nullptr) {

@ -7,6 +7,7 @@
#include <string> #include <string>
#include "db/merge_context.h" #include "db/merge_context.h"
#include "db/range_del_aggregator.h" #include "db/range_del_aggregator.h"
#include "db/read_callback.h"
#include "rocksdb/env.h" #include "rocksdb/env.h"
#include "rocksdb/types.h" #include "rocksdb/types.h"
#include "table/block.h" #include "table/block.h"
@ -30,7 +31,8 @@ class GetContext {
const Slice& user_key, PinnableSlice* value, bool* value_found, const Slice& user_key, PinnableSlice* value, bool* value_found,
MergeContext* merge_context, RangeDelAggregator* range_del_agg, MergeContext* merge_context, RangeDelAggregator* range_del_agg,
Env* env, SequenceNumber* seq = nullptr, Env* env, SequenceNumber* seq = nullptr,
PinnedIteratorsManager* _pinned_iters_mgr = nullptr); PinnedIteratorsManager* _pinned_iters_mgr = nullptr,
ReadCallback* callback = nullptr);
void MarkKeyMayExist(); void MarkKeyMayExist();
@ -62,6 +64,13 @@ class GetContext {
bool sample() const { return sample_; } bool sample() const { return sample_; }
bool CheckCallback(SequenceNumber seq) {
if (callback_) {
return callback_->IsCommitted(seq);
}
return true;
}
private: private:
const Comparator* ucmp_; const Comparator* ucmp_;
const MergeOperator* merge_operator_; const MergeOperator* merge_operator_;
@ -82,6 +91,7 @@ class GetContext {
std::string* replay_log_; std::string* replay_log_;
// Used to temporarily pin blocks when state_ == GetContext::kMerge // Used to temporarily pin blocks when state_ == GetContext::kMerge
PinnedIteratorsManager* pinned_iters_mgr_; PinnedIteratorsManager* pinned_iters_mgr_;
ReadCallback* callback_;
bool sample_; bool sample_;
}; };

@ -13,6 +13,7 @@
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "db/read_callback.h"
#include "rocksdb/db.h" #include "rocksdb/db.h"
#include "rocksdb/options.h" #include "rocksdb/options.h"
#include "rocksdb/utilities/transaction_db.h" #include "rocksdb/utilities/transaction_db.h"
@ -369,5 +370,21 @@ class WritePreparedTxnDB : public PessimisticTransactionDB {
port::RWMutex snapshots_mutex_; port::RWMutex snapshots_mutex_;
}; };
class WritePreparedTxnReadCallback : public ReadCallback {
public:
WritePreparedTxnReadCallback(WritePreparedTxnDB* db, SequenceNumber snapshot)
: db_(db), snapshot_(snapshot) {}
// Will be called to see if the seq number accepted; if not it moves on to the
// next seq number.
virtual bool IsCommitted(SequenceNumber seq) override {
return db_->IsInSnapshot(seq, snapshot_);
}
private:
WritePreparedTxnDB* db_;
SequenceNumber snapshot_;
};
} // namespace rocksdb } // namespace rocksdb
#endif // ROCKSDB_LITE #endif // ROCKSDB_LITE

@ -29,6 +29,18 @@ WritePreparedTxn::WritePreparedTxn(WritePreparedTxnDB* txn_db,
PessimisticTransaction::Initialize(txn_options); PessimisticTransaction::Initialize(txn_options);
} }
Status WritePreparedTxn::Get(const ReadOptions& read_options,
ColumnFamilyHandle* column_family,
const Slice& key, PinnableSlice* pinnable_val) {
auto snapshot = GetSnapshot();
auto snap_seq =
snapshot != nullptr ? snapshot->GetSequenceNumber() : kMaxSequenceNumber;
WritePreparedTxnReadCallback callback(wpt_db_, snap_seq);
return write_batch_.GetFromBatchAndDB(db_, read_options, column_family, key,
pinnable_val, &callback);
}
Status WritePreparedTxn::CommitBatch(WriteBatch* /* unused */) { Status WritePreparedTxn::CommitBatch(WriteBatch* /* unused */) {
// TODO(myabandeh) Implement this // TODO(myabandeh) Implement this
throw std::runtime_error("CommitBatch not Implemented"); throw std::runtime_error("CommitBatch not Implemented");

@ -45,6 +45,11 @@ class WritePreparedTxn : public PessimisticTransaction {
virtual ~WritePreparedTxn() {} virtual ~WritePreparedTxn() {}
using Transaction::Get;
virtual Status Get(const ReadOptions& options,
ColumnFamilyHandle* column_family, const Slice& key,
PinnableSlice* value) override;
Status CommitBatch(WriteBatch* batch) override; Status CommitBatch(WriteBatch* batch) override;
Status Rollback() override; Status Rollback() override;

@ -783,6 +783,17 @@ Status WriteBatchWithIndex::GetFromBatchAndDB(DB* db,
ColumnFamilyHandle* column_family, ColumnFamilyHandle* column_family,
const Slice& key, const Slice& key,
PinnableSlice* pinnable_val) { PinnableSlice* pinnable_val) {
return GetFromBatchAndDB(db, read_options, column_family, key, pinnable_val,
nullptr);
}
Status WriteBatchWithIndex::GetFromBatchAndDB(
DB* db, const ReadOptions& read_options, ColumnFamilyHandle* column_family,
const Slice& key, PinnableSlice* pinnable_val, ReadCallback* callback) {
if (UNLIKELY(db->GetRootDB() != db)) {
return Status::NotSupported("The DB must be of DBImpl type");
// Otherwise the cast below would fail
}
Status s; Status s;
MergeContext merge_context; MergeContext merge_context;
const ImmutableDBOptions& immuable_db_options = const ImmutableDBOptions& immuable_db_options =
@ -819,7 +830,12 @@ Status WriteBatchWithIndex::GetFromBatchAndDB(DB* db,
result == WriteBatchWithIndexInternal::Result::kNotFound); result == WriteBatchWithIndexInternal::Result::kNotFound);
// Did not find key in batch OR could not resolve Merges. Try DB. // Did not find key in batch OR could not resolve Merges. Try DB.
if (!callback) {
s = db->Get(read_options, column_family, key, pinnable_val); s = db->Get(read_options, column_family, key, pinnable_val);
} else {
s = reinterpret_cast<DBImpl*>(db)->GetImpl(read_options, column_family, key,
pinnable_val, nullptr, callback);
}
if (s.ok() || s.IsNotFound()) { // DB Get Succeeded if (s.ok() || s.IsNotFound()) { // DB Get Succeeded
if (result == WriteBatchWithIndexInternal::Result::kMergeInProgress) { if (result == WriteBatchWithIndexInternal::Result::kMergeInProgress) {

Loading…
Cancel
Save