diff --git a/HISTORY.md b/HISTORY.md index 9cd260be2..a2e536132 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -7,6 +7,7 @@ * Improve subcompaction range partition so that it is likely to be more even. More evenly distribution of subcompaction will improve compaction throughput for some workloads. All input files' index blocks to sample some anchor key points from which we pick positions to partition the input range. This would introduce some CPU overhead in compaction preparation phase, if subcompaction is enabled, but it should be a small fraction of the CPU usage of the whole compaction process. This also brings a behavier change: subcompaction number is much more likely to maxed out than before. * Add CompactionPri::kRoundRobin, a compaction picking mode that cycles through all the files with a compact cursor in a round-robin manner. This feature is available since 7.5. * Provide support for subcompactions for user_defined_timestamp. +* Added an option `memtable_protection_bytes_per_key` that turns on memtable per key-value checksum protection. Each memtable entry will be suffixed by a checksum that is computed during writes, and verified in reads/compaction. Detected corruption will be logged and with corruption status returned to user. ### Public API changes * Removed Customizable support for RateLimiter and removed its CreateFromString() and Type() functions. diff --git a/db/compaction/compaction_iterator.cc b/db/compaction/compaction_iterator.cc index cfdd0b033..0ca9e75f6 100644 --- a/db/compaction/compaction_iterator.cc +++ b/db/compaction/compaction_iterator.cc @@ -937,6 +937,11 @@ void CompactionIterator::NextFromInput() { if (IsPausingManualCompaction()) { status_ = Status::Incomplete(Status::SubCode::kManualCompactionPaused); } + + // Propagate corruption status from memtable itereator + if (!input_.Valid() && input_.status().IsCorruption()) { + status_ = input_.status(); + } } bool CompactionIterator::ExtractLargeValueIfNeededImpl() { diff --git a/db/db_kv_checksum_test.cc b/db/db_kv_checksum_test.cc index 7c8c0190e..64b910a50 100644 --- a/db/db_kv_checksum_test.cc +++ b/db/db_kv_checksum_test.cc @@ -13,9 +13,9 @@ enum class WriteBatchOpType { kPut = 0, kDelete, kSingleDelete, - kDeleteRange, kMerge, kPutEntity, + kDeleteRange, kNum, }; @@ -26,11 +26,14 @@ WriteBatchOpType operator+(WriteBatchOpType lhs, const int rhs) { } enum class WriteMode { + // `Write()` a `WriteBatch` constructed with `protection_bytes_per_key = 0` + // and `WriteOptions::protection_bytes_per_key = 0` + kWriteUnprotectedBatch = 0, // `Write()` a `WriteBatch` constructed with `protection_bytes_per_key > 0`. - kWriteProtectedBatch = 0, + kWriteProtectedBatch, // `Write()` a `WriteBatch` constructed with `protection_bytes_per_key == 0`. // Protection is enabled via `WriteOptions::protection_bytes_per_key > 0`. - kWriteUnprotectedBatch, + kWriteOptionProtectedBatch, // TODO(ajkr): add a mode that uses `Write()` wrappers, e.g., `Put()`. kNum, }; @@ -89,19 +92,30 @@ class DbKvChecksumTestBase : public DBTestBase { } }; -class DbKvChecksumTest : public DbKvChecksumTestBase, - public ::testing::WithParamInterface< - std::tuple> { +class DbKvChecksumTest + : public DbKvChecksumTestBase, + public ::testing::WithParamInterface< + std::tuple> { public: DbKvChecksumTest() : DbKvChecksumTestBase("db_kv_checksum_test", /*env_do_fsync=*/false) { op_type_ = std::get<0>(GetParam()); corrupt_byte_addend_ = std::get<1>(GetParam()); write_mode_ = std::get<2>(GetParam()); + memtable_protection_bytes_per_key_ = std::get<3>(GetParam()); } Status ExecuteWrite(ColumnFamilyHandle* cf_handle) { switch (write_mode_) { + case WriteMode::kWriteUnprotectedBatch: { + auto batch_and_status = + GetWriteBatch(GetCFHandleToUse(cf_handle, op_type_), + 0 /* protection_bytes_per_key */, op_type_); + assert(batch_and_status.second.ok()); + // Default write option has protection_bytes_per_key = 0 + return db_->Write(WriteOptions(), &batch_and_status.first); + } case WriteMode::kWriteProtectedBatch: { auto batch_and_status = GetWriteBatch(GetCFHandleToUse(cf_handle, op_type_), @@ -109,7 +123,7 @@ class DbKvChecksumTest : public DbKvChecksumTestBase, assert(batch_and_status.second.ok()); return db_->Write(WriteOptions(), &batch_and_status.first); } - case WriteMode::kWriteUnprotectedBatch: { + case WriteMode::kWriteOptionProtectedBatch: { auto batch_and_status = GetWriteBatch(GetCFHandleToUse(cf_handle, op_type_), 0 /* protection_bytes_per_key */, op_type_); @@ -131,8 +145,6 @@ class DbKvChecksumTest : public DbKvChecksumTestBase, // We learn the entry size on the first attempt entry_len_ = encoded.size(); } - // All entries should be the same size - assert(entry_len_ == encoded.size()); char* buf = const_cast(encoded.data()); buf[corrupt_byte_offset_] += corrupt_byte_addend_; ++corrupt_byte_offset_; @@ -144,6 +156,7 @@ class DbKvChecksumTest : public DbKvChecksumTestBase, WriteBatchOpType op_type_; char corrupt_byte_addend_; WriteMode write_mode_; + uint32_t memtable_protection_bytes_per_key_; size_t corrupt_byte_offset_ = 0; size_t entry_len_ = std::numeric_limits::max(); }; @@ -169,29 +182,36 @@ std::string GetOpTypeString(const WriteBatchOpType& op_type) { return ""; } +std::string GetWriteModeString(const WriteMode& mode) { + switch (mode) { + case WriteMode::kWriteUnprotectedBatch: + return "WriteUnprotectedBatch"; + case WriteMode::kWriteProtectedBatch: + return "WriteProtectedBatch"; + case WriteMode::kWriteOptionProtectedBatch: + return "kWriteOptionProtectedBatch"; + case WriteMode::kNum: + assert(false); + } + return ""; +} + INSTANTIATE_TEST_CASE_P( DbKvChecksumTest, DbKvChecksumTest, ::testing::Combine(::testing::Range(static_cast(0), WriteBatchOpType::kNum), ::testing::Values(2, 103, 251), - ::testing::Range(static_cast(0), - WriteMode::kNum)), + ::testing::Range(WriteMode::kWriteProtectedBatch, + WriteMode::kNum), + ::testing::Values(0)), [](const testing::TestParamInfo< - std::tuple>& args) { + std::tuple>& args) { std::ostringstream oss; oss << GetOpTypeString(std::get<0>(args.param)) << "Add" << static_cast( - static_cast(std::get<1>(args.param))); - switch (std::get<2>(args.param)) { - case WriteMode::kWriteProtectedBatch: - oss << "WriteProtectedBatch"; - break; - case WriteMode::kWriteUnprotectedBatch: - oss << "WriteUnprotectedBatch"; - break; - case WriteMode::kNum: - assert(false); - } + static_cast(std::get<1>(args.param))) + << GetWriteModeString(std::get<2>(args.param)) + << static_cast(std::get<3>(args.param)); return oss.str(); }); @@ -660,6 +680,202 @@ TEST_F(DbKVChecksumWALToWriteBatchTest, WriteBatchChecksumHandoff) { SyncPoint::GetInstance()->DisableProcessing(); }; +// TODO (cbi): add DeleteRange coverage once it is implemented +class DbMemtableKVChecksumTest : public DbKvChecksumTest { + public: + DbMemtableKVChecksumTest() : DbKvChecksumTest() {} + + protected: + // Indices in the memtable entry that we will not corrupt. + // For memtable entry format, see comments in MemTable::Add(). + // We do not corrupt key length and value length fields in this test + // case since it causes segfault and ASAN will complain. + // For this test case, key and value are all of length 3, so + // key length field is at index 0 and value length field is at index 12. + const std::set index_not_to_corrupt{0, 12}; + + void SkipNotToCorruptEntry() { + if (index_not_to_corrupt.find(corrupt_byte_offset_) != + index_not_to_corrupt.end()) { + corrupt_byte_offset_++; + } + } +}; + +INSTANTIATE_TEST_CASE_P( + DbMemtableKVChecksumTest, DbMemtableKVChecksumTest, + ::testing::Combine(::testing::Range(static_cast(0), + WriteBatchOpType::kDeleteRange), + ::testing::Values(2, 103, 251), + ::testing::Range(static_cast(0), + WriteMode::kWriteOptionProtectedBatch), + // skip 1 byte checksum as it makes test flaky + ::testing::Values(2, 4, 8)), + [](const testing::TestParamInfo< + std::tuple>& args) { + std::ostringstream oss; + oss << GetOpTypeString(std::get<0>(args.param)) << "Add" + << static_cast( + static_cast(std::get<1>(args.param))) + << GetWriteModeString(std::get<2>(args.param)) + << static_cast(std::get<3>(args.param)); + return oss.str(); + }); + +TEST_P(DbMemtableKVChecksumTest, GetWithCorruptAfterMemtableInsert) { + // Record memtable entry size. + // Not corrupting memtable entry here since it will segfault + // or fail some asserts inside memtablerep implementation + // e.g., when key_len is corrupted. + SyncPoint::GetInstance()->SetCallBack( + "MemTable::Add:BeforeReturn:Encoded", [&](void* arg) { + Slice encoded = *static_cast(arg); + entry_len_ = encoded.size(); + }); + + SyncPoint::GetInstance()->SetCallBack( + "Memtable::SaveValue:Begin:entry", [&](void* entry) { + char* buf = *static_cast(entry); + buf[corrupt_byte_offset_] += corrupt_byte_addend_; + ++corrupt_byte_offset_; + }); + SyncPoint::GetInstance()->EnableProcessing(); + Options options = CurrentOptions(); + options.memtable_protection_bytes_per_key = + memtable_protection_bytes_per_key_; + if (op_type_ == WriteBatchOpType::kMerge) { + options.merge_operator = MergeOperators::CreateStringAppendOperator(); + } + + SkipNotToCorruptEntry(); + while (MoreBytesToCorrupt()) { + Reopen(options); + ASSERT_OK(ExecuteWrite(nullptr)); + std::string val; + ASSERT_TRUE(db_->Get(ReadOptions(), "key", &val).IsCorruption()); + Destroy(options); + SkipNotToCorruptEntry(); + } +} + +TEST_P(DbMemtableKVChecksumTest, + GetWithColumnFamilyCorruptAfterMemtableInsert) { + // Record memtable entry size. + // Not corrupting memtable entry here since it will segfault + // or fail some asserts inside memtablerep implementation + // e.g., when key_len is corrupted. + SyncPoint::GetInstance()->SetCallBack( + "MemTable::Add:BeforeReturn:Encoded", [&](void* arg) { + Slice encoded = *static_cast(arg); + entry_len_ = encoded.size(); + }); + + SyncPoint::GetInstance()->SetCallBack( + "Memtable::SaveValue:Begin:entry", [&](void* entry) { + char* buf = *static_cast(entry); + buf[corrupt_byte_offset_] += corrupt_byte_addend_; + ++corrupt_byte_offset_; + }); + SyncPoint::GetInstance()->EnableProcessing(); + Options options = CurrentOptions(); + options.memtable_protection_bytes_per_key = + memtable_protection_bytes_per_key_; + if (op_type_ == WriteBatchOpType::kMerge) { + options.merge_operator = MergeOperators::CreateStringAppendOperator(); + } + + SkipNotToCorruptEntry(); + while (MoreBytesToCorrupt()) { + Reopen(options); + CreateAndReopenWithCF({"pikachu"}, options); + ASSERT_OK(ExecuteWrite(handles_[1])); + std::string val; + ASSERT_TRUE( + db_->Get(ReadOptions(), handles_[1], "key", &val).IsCorruption()); + Destroy(options); + SkipNotToCorruptEntry(); + } +} + +TEST_P(DbMemtableKVChecksumTest, IteratorWithCorruptAfterMemtableInsert) { + SyncPoint::GetInstance()->SetCallBack( + "MemTable::Add:BeforeReturn:Encoded", + std::bind(&DbKvChecksumTest::CorruptNextByteCallBack, this, + std::placeholders::_1)); + SyncPoint::GetInstance()->EnableProcessing(); + Options options = CurrentOptions(); + options.memtable_protection_bytes_per_key = + memtable_protection_bytes_per_key_; + if (op_type_ == WriteBatchOpType::kMerge) { + options.merge_operator = MergeOperators::CreateStringAppendOperator(); + } + + SkipNotToCorruptEntry(); + while (MoreBytesToCorrupt()) { + Reopen(options); + ASSERT_OK(ExecuteWrite(nullptr)); + Iterator* it = db_->NewIterator(ReadOptions()); + it->SeekToFirst(); + ASSERT_FALSE(it->Valid()); + ASSERT_TRUE(it->status().IsCorruption()); + delete it; + Destroy(options); + SkipNotToCorruptEntry(); + } +} + +TEST_P(DbMemtableKVChecksumTest, + IteratorWithColumnFamilyCorruptAfterMemtableInsert) { + SyncPoint::GetInstance()->SetCallBack( + "MemTable::Add:BeforeReturn:Encoded", + std::bind(&DbKvChecksumTest::CorruptNextByteCallBack, this, + std::placeholders::_1)); + SyncPoint::GetInstance()->EnableProcessing(); + Options options = CurrentOptions(); + options.memtable_protection_bytes_per_key = + memtable_protection_bytes_per_key_; + if (op_type_ == WriteBatchOpType::kMerge) { + options.merge_operator = MergeOperators::CreateStringAppendOperator(); + } + + SkipNotToCorruptEntry(); + while (MoreBytesToCorrupt()) { + Reopen(options); + CreateAndReopenWithCF({"pikachu"}, options); + ASSERT_OK(ExecuteWrite(handles_[1])); + Iterator* it = db_->NewIterator(ReadOptions(), handles_[1]); + it->SeekToFirst(); + ASSERT_FALSE(it->Valid()); + ASSERT_TRUE(it->status().IsCorruption()); + delete it; + Destroy(options); + SkipNotToCorruptEntry(); + } +} + +TEST_P(DbMemtableKVChecksumTest, FlushWithCorruptAfterMemtableInsert) { + SyncPoint::GetInstance()->SetCallBack( + "MemTable::Add:BeforeReturn:Encoded", + std::bind(&DbKvChecksumTest::CorruptNextByteCallBack, this, + std::placeholders::_1)); + SyncPoint::GetInstance()->EnableProcessing(); + Options options = CurrentOptions(); + options.memtable_protection_bytes_per_key = + memtable_protection_bytes_per_key_; + if (op_type_ == WriteBatchOpType::kMerge) { + options.merge_operator = MergeOperators::CreateStringAppendOperator(); + } + + SkipNotToCorruptEntry(); + // Not corruping each byte like other tests since Flush() is relatively slow. + Reopen(options); + ASSERT_OK(ExecuteWrite(nullptr)); + ASSERT_TRUE(Flush().IsCorruption()); + // DB enters read-only state when flush reads corrupted data + ASSERT_TRUE(dbfull()->TEST_GetBGError().IsCorruption()); + Destroy(options); +} + } // namespace ROCKSDB_NAMESPACE int main(int argc, char** argv) { diff --git a/db/kv_checksum.h b/db/kv_checksum.h index 0729d1318..bce507fcf 100644 --- a/db/kv_checksum.h +++ b/db/kv_checksum.h @@ -65,6 +65,8 @@ class ProtectionInfo { const SliceParts& value, ValueType op_type) const; + T GetVal() const { return val_; } + private: friend class ProtectionInfoKVO; friend class ProtectionInfoKVOS; @@ -87,7 +89,6 @@ class ProtectionInfo { static_assert(sizeof(ProtectionInfo) == sizeof(T), ""); } - T GetVal() const { return val_; } void SetVal(T val) { val_ = val; } T val_ = 0; @@ -112,6 +113,8 @@ class ProtectionInfoKVO { void UpdateV(const SliceParts& old_value, const SliceParts& new_value); void UpdateO(ValueType old_op_type, ValueType new_op_type); + T GetVal() const { return info_.GetVal(); } + private: friend class ProtectionInfo; friend class ProtectionInfoKVOS; @@ -121,7 +124,6 @@ class ProtectionInfoKVO { static_assert(sizeof(ProtectionInfoKVO) == sizeof(T), ""); } - T GetVal() const { return info_.GetVal(); } void SetVal(T val) { info_.SetVal(val); } ProtectionInfo info_; @@ -152,6 +154,8 @@ class ProtectionInfoKVOC { void UpdateC(ColumnFamilyId old_column_family_id, ColumnFamilyId new_column_family_id); + T GetVal() const { return kvo_.GetVal(); } + private: friend class ProtectionInfoKVO; @@ -159,7 +163,6 @@ class ProtectionInfoKVOC { static_assert(sizeof(ProtectionInfoKVOC) == sizeof(T), ""); } - T GetVal() const { return kvo_.GetVal(); } void SetVal(T val) { kvo_.SetVal(val); } ProtectionInfoKVO kvo_; @@ -190,6 +193,8 @@ class ProtectionInfoKVOS { void UpdateS(SequenceNumber old_sequence_number, SequenceNumber new_sequence_number); + T GetVal() const { return kvo_.GetVal(); } + private: friend class ProtectionInfoKVO; @@ -197,7 +202,6 @@ class ProtectionInfoKVOS { static_assert(sizeof(ProtectionInfoKVOS) == sizeof(T), ""); } - T GetVal() const { return kvo_.GetVal(); } void SetVal(T val) { kvo_.SetVal(val); } ProtectionInfoKVO kvo_; diff --git a/db/memtable.cc b/db/memtable.cc index 3998df837..daf4c6720 100644 --- a/db/memtable.cc +++ b/db/memtable.cc @@ -64,7 +64,9 @@ ImmutableMemTableOptions::ImmutableMemTableOptions( statistics(ioptions.stats), merge_operator(ioptions.merge_operator.get()), info_log(ioptions.logger), - allow_data_in_errors(ioptions.allow_data_in_errors) {} + allow_data_in_errors(ioptions.allow_data_in_errors), + protection_bytes_per_key( + mutable_cf_options.memtable_protection_bytes_per_key) {} MemTable::MemTable(const InternalKeyComparator& cmp, const ImmutableOptions& ioptions, @@ -237,6 +239,73 @@ void MemTable::UpdateOldestKeyTime() { } } +Status MemTable::VerifyEntryChecksum(const char* entry, + size_t protection_bytes_per_key, + bool allow_data_in_errors) { + if (protection_bytes_per_key == 0) { + return Status::OK(); + } + uint32_t key_length; + const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); + if (key_ptr == nullptr) { + return Status::Corruption("Unable to parse internal key length"); + } + if (key_length < 8) { + return Status::Corruption("Memtable entry internal key length too short."); + } + Slice user_key = Slice(key_ptr, key_length - 8); + + const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8); + ValueType type; + SequenceNumber seq; + UnPackSequenceAndType(tag, &seq, &type); + + uint32_t value_length = 0; + const char* value_ptr = GetVarint32Ptr( + key_ptr + key_length, key_ptr + key_length + 5, &value_length); + if (value_ptr == nullptr) { + return Status::Corruption("Unable to parse internal key value"); + } + Slice value = Slice(value_ptr, value_length); + + const char* checksum_ptr = value_ptr + value_length; + uint64_t expected = ProtectionInfo64() + .ProtectKVO(user_key, value, type) + .ProtectS(seq) + .GetVal(); + bool match = true; + switch (protection_bytes_per_key) { + case 1: + match = static_cast(checksum_ptr[0]) == + static_cast(expected); + break; + case 2: + match = DecodeFixed16(checksum_ptr) == static_cast(expected); + break; + case 4: + match = DecodeFixed32(checksum_ptr) == static_cast(expected); + break; + case 8: + match = DecodeFixed64(checksum_ptr) == expected; + break; + default: + assert(false); + } + if (!match) { + std::string msg( + "Corrupted memtable entry, per key-value checksum verification " + "failed."); + if (allow_data_in_errors) { + msg.append("Unrecognized value type: " + + std::to_string(static_cast(type)) + ". "); + msg.append("User key: " + user_key.ToString(/*hex=*/true) + ". "); + msg.append("seq: " + std::to_string(seq) + "."); + } + return Status::Corruption(msg.c_str()); + } + return Status::OK(); +} + int MemTable::KeyComparator::operator()(const char* prefix_len_key1, const char* prefix_len_key2) const { // Internal keys are encoded as length-prefixed strings. @@ -291,7 +360,10 @@ class MemTableIterator : public InternalIterator { valid_(false), arena_mode_(arena != nullptr), value_pinned_( - !mem.GetImmutableMemTableOptions()->inplace_update_support) { + !mem.GetImmutableMemTableOptions()->inplace_update_support), + protection_bytes_per_key_(mem.moptions_.protection_bytes_per_key), + status_(Status::OK()), + logger_(mem.moptions_.info_log) { if (use_range_del_table) { iter_ = mem.range_del_table_->GetIterator(arena); } else if (prefix_extractor_ != nullptr && !read_options.total_order_seek && @@ -302,6 +374,7 @@ class MemTableIterator : public InternalIterator { } else { iter_ = mem.table_->GetIterator(arena); } + status_.PermitUncheckedError(); } // No copying allowed MemTableIterator(const MemTableIterator&) = delete; @@ -327,7 +400,7 @@ class MemTableIterator : public InternalIterator { PinnedIteratorsManager* pinned_iters_mgr_ = nullptr; #endif - bool Valid() const override { return valid_; } + bool Valid() const override { return valid_ && status_.ok(); } void Seek(const Slice& k) override { PERF_TIMER_GUARD(seek_on_memtable_time); PERF_COUNTER_ADD(seek_on_memtable_count, 1); @@ -348,6 +421,7 @@ class MemTableIterator : public InternalIterator { } iter_->Seek(k, nullptr); valid_ = iter_->Valid(); + VerifyEntryChecksum(); } void SeekForPrev(const Slice& k) override { PERF_TIMER_GUARD(seek_on_memtable_time); @@ -368,7 +442,8 @@ class MemTableIterator : public InternalIterator { } iter_->Seek(k, nullptr); valid_ = iter_->Valid(); - if (!Valid()) { + VerifyEntryChecksum(); + if (!Valid() && status().ok()) { SeekToLast(); } while (Valid() && comparator_.comparator.Compare(k, key()) < 0) { @@ -378,10 +453,12 @@ class MemTableIterator : public InternalIterator { void SeekToFirst() override { iter_->SeekToFirst(); valid_ = iter_->Valid(); + VerifyEntryChecksum(); } void SeekToLast() override { iter_->SeekToLast(); valid_ = iter_->Valid(); + VerifyEntryChecksum(); } void Next() override { PERF_COUNTER_ADD(next_on_memtable_count, 1); @@ -389,10 +466,11 @@ class MemTableIterator : public InternalIterator { iter_->Next(); TEST_SYNC_POINT_CALLBACK("MemTableIterator::Next:0", iter_); valid_ = iter_->Valid(); + VerifyEntryChecksum(); } bool NextAndGetResult(IterateResult* result) override { Next(); - bool is_valid = valid_; + bool is_valid = Valid(); if (is_valid) { result->key = key(); result->bound_check_result = IterBoundCheck::kUnknown; @@ -405,6 +483,7 @@ class MemTableIterator : public InternalIterator { assert(Valid()); iter_->Prev(); valid_ = iter_->Valid(); + VerifyEntryChecksum(); } Slice key() const override { assert(Valid()); @@ -416,7 +495,7 @@ class MemTableIterator : public InternalIterator { return GetLengthPrefixedSlice(key_slice.data() + key_slice.size()); } - Status status() const override { return Status::OK(); } + Status status() const override { return status_; } bool IsKeyPinned() const override { // memtable data is always pinned @@ -436,6 +515,19 @@ class MemTableIterator : public InternalIterator { bool valid_; bool arena_mode_; bool value_pinned_; + size_t protection_bytes_per_key_; + Status status_; + Logger* logger_; + + void VerifyEntryChecksum() { + if (protection_bytes_per_key_ > 0 && Valid()) { + status_ = MemTable::VerifyEntryChecksum(iter_->key(), + protection_bytes_per_key_); + if (!status_.ok()) { + ROCKS_LOG_ERROR(logger_, "In MemtableIterator: %s", status_.getState()); + } + } + } }; InternalIterator* MemTable::NewIterator(const ReadOptions& read_options, @@ -560,6 +652,39 @@ Status MemTable::VerifyEncodedEntry(Slice encoded, .GetStatus(); } +void MemTable::UpdateEntryChecksum(const ProtectionInfoKVOS64* kv_prot_info, + const Slice& key, const Slice& value, + ValueType type, SequenceNumber s, + char* checksum_ptr) { + if (moptions_.protection_bytes_per_key == 0) { + return; + } + + uint64_t checksum = 0; + if (kv_prot_info == nullptr) { + checksum = + ProtectionInfo64().ProtectKVO(key, value, type).ProtectS(s).GetVal(); + } else { + checksum = kv_prot_info->GetVal(); + } + switch (moptions_.protection_bytes_per_key) { + case 1: + checksum_ptr[0] = static_cast(checksum); + break; + case 2: + EncodeFixed16(checksum_ptr, static_cast(checksum)); + break; + case 4: + EncodeFixed32(checksum_ptr, static_cast(checksum)); + break; + case 8: + EncodeFixed64(checksum_ptr, checksum); + break; + default: + assert(false); + } +} + Status MemTable::Add(SequenceNumber s, ValueType type, const Slice& key, /* user key */ const Slice& value, @@ -571,12 +696,13 @@ Status MemTable::Add(SequenceNumber s, ValueType type, // key bytes : char[internal_key.size()] // value_size : varint32 of value.size() // value bytes : char[value.size()] + // checksum : char[moptions_.protection_bytes_per_key] uint32_t key_size = static_cast(key.size()); uint32_t val_size = static_cast(value.size()); uint32_t internal_key_size = key_size + 8; const uint32_t encoded_len = VarintLength(internal_key_size) + internal_key_size + VarintLength(val_size) + - val_size; + val_size + moptions_.protection_bytes_per_key; char* buf = nullptr; std::unique_ptr& table = type == kTypeRangeDeletion ? range_del_table_ : table_; @@ -591,9 +717,13 @@ Status MemTable::Add(SequenceNumber s, ValueType type, p += 8; p = EncodeVarint32(p, val_size); memcpy(p, value.data(), val_size); - assert((unsigned)(p + val_size - buf) == (unsigned)encoded_len); + assert((unsigned)(p + val_size - buf + moptions_.protection_bytes_per_key) == + (unsigned)encoded_len); + + UpdateEntryChecksum(kv_prot_info, key, value, type, s, + buf + encoded_len - moptions_.protection_bytes_per_key); + Slice encoded(buf, encoded_len - moptions_.protection_bytes_per_key); if (kv_prot_info != nullptr) { - Slice encoded(buf, encoded_len); TEST_SYNC_POINT_CALLBACK("MemTable::Add:Encoded", &encoded); Status status = VerifyEncodedEntry(encoded, *kv_prot_info); if (!status.ok()) { @@ -692,6 +822,8 @@ Status MemTable::Add(SequenceNumber s, ValueType type, is_range_del_table_empty_.store(false, std::memory_order_relaxed); } UpdateOldestKeyTime(); + + TEST_SYNC_POINT_CALLBACK("MemTable::Add:BeforeReturn:Encoded", &encoded); return Status::OK(); } @@ -720,6 +852,7 @@ struct Saver { ReadCallback* callback_; bool* is_blob_index; bool allow_data_in_errors; + size_t protection_bytes_per_key; bool CheckCallback(SequenceNumber _seq) { if (callback_) { return callback_->IsVisible(_seq); @@ -730,23 +863,28 @@ struct Saver { } // namespace static bool SaveValue(void* arg, const char* entry) { + TEST_SYNC_POINT_CALLBACK("Memtable::SaveValue:Begin:entry", &entry); Saver* s = reinterpret_cast(arg); assert(s != nullptr); + + if (s->protection_bytes_per_key > 0) { + *(s->status) = MemTable::VerifyEntryChecksum( + entry, s->protection_bytes_per_key, s->allow_data_in_errors); + if (!s->status->ok()) { + ROCKS_LOG_ERROR(s->logger, "In SaveValue: %s", s->status->getState()); + // Memtable entry corrupted + return false; + } + } + MergeContext* merge_context = s->merge_context; SequenceNumber max_covering_tombstone_seq = s->max_covering_tombstone_seq; const MergeOperator* merge_operator = s->merge_operator; assert(merge_context != nullptr); - // entry format is: - // klength varint32 - // userkey char[klength-8] - // tag uint64 - // vlength varint32f - // value char[vlength] - // Check that it belongs to same user key. We do not check the - // sequence number since the Seek() call above should have skipped - // all entries with overly large sequence numbers. + // Refer to comments under MemTable::Add() for entry format. + // Check that it belongs to same user key. uint32_t key_length = 0; const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); assert(key_length >= 8); @@ -972,7 +1110,8 @@ bool MemTable::Get(const LookupKey& key, std::string* value, } // No change to value, since we have not yet found a Put/Delete - if (!found_final_value && merge_in_progress) { + // Propagate corruption error + if (!found_final_value && merge_in_progress && !s->IsCorruption()) { *s = Status::MergeInProgress(); } PERF_COUNTER_ADD(get_from_memtable_count, 1); @@ -1006,6 +1145,7 @@ void MemTable::GetFromTable(const LookupKey& key, saver.is_blob_index = is_blob_index; saver.do_merge = do_merge; saver.allow_data_in_errors = moptions_.allow_data_in_errors; + saver.protection_bytes_per_key = moptions_.protection_bytes_per_key; table_->Get(key, &saver, SaveValue); *seq = saver.seq; } @@ -1104,12 +1244,7 @@ Status MemTable::Update(SequenceNumber seq, ValueType value_type, iter->Seek(lkey.internal_key(), mem_key.data()); if (iter->Valid()) { - // entry format is: - // key_length varint32 - // userkey char[klength-8] - // tag uint64 - // vlength varint32 - // value char[vlength] + // Refer to comments under MemTable::Add() for entry format. // Check that it belongs to same user key. We do not check the // sequence number since the Seek() call above should have skipped // all entries with overly large sequence numbers. @@ -1143,8 +1278,13 @@ Status MemTable::Update(SequenceNumber seq, ValueType value_type, ProtectionInfoKVOS64 updated_kv_prot_info(*kv_prot_info); // `seq` is swallowed and `existing_seq` prevails. updated_kv_prot_info.UpdateS(seq, existing_seq); + UpdateEntryChecksum(&updated_kv_prot_info, key, value, type, + existing_seq, p + value.size()); Slice encoded(entry, p + value.size() - entry); return VerifyEncodedEntry(encoded, updated_kv_prot_info); + } else { + UpdateEntryChecksum(nullptr, key, value, type, existing_seq, + p + value.size()); } return Status::OK(); } @@ -1167,12 +1307,7 @@ Status MemTable::UpdateCallback(SequenceNumber seq, const Slice& key, iter->Seek(lkey.internal_key(), memkey.data()); if (iter->Valid()) { - // entry format is: - // key_length varint32 - // userkey char[klength-8] - // tag uint64 - // vlength varint32 - // value char[vlength] + // Refer to comments under MemTable::Add() for entry format. // Check that it belongs to same user key. We do not check the // sequence number since the Seek() call above should have skipped // all entries with overly large sequence numbers. @@ -1212,14 +1347,19 @@ Status MemTable::UpdateCallback(SequenceNumber seq, const Slice& key, } RecordTick(moptions_.statistics, NUMBER_KEYS_UPDATED); UpdateFlushState(); + Slice new_value(prev_buffer, new_prev_size); if (kv_prot_info != nullptr) { ProtectionInfoKVOS64 updated_kv_prot_info(*kv_prot_info); // `seq` is swallowed and `existing_seq` prevails. updated_kv_prot_info.UpdateS(seq, existing_seq); - updated_kv_prot_info.UpdateV(delta, - Slice(prev_buffer, new_prev_size)); + updated_kv_prot_info.UpdateV(delta, new_value); Slice encoded(entry, prev_buffer + new_prev_size - entry); + UpdateEntryChecksum(&updated_kv_prot_info, key, new_value, type, + existing_seq, prev_buffer + new_prev_size); return VerifyEncodedEntry(encoded, updated_kv_prot_info); + } else { + UpdateEntryChecksum(nullptr, key, new_value, type, existing_seq, + prev_buffer + new_prev_size); } return Status::OK(); } else if (status == UpdateStatus::UPDATED) { diff --git a/db/memtable.h b/db/memtable.h index 80d23657c..bb32b5529 100644 --- a/db/memtable.h +++ b/db/memtable.h @@ -58,6 +58,7 @@ struct ImmutableMemTableOptions { MergeOperator* merge_operator; Logger* info_log; bool allow_data_in_errors; + uint32_t protection_bytes_per_key; }; // Batched counters to updated when inserting keys in one write batch. @@ -539,6 +540,11 @@ class MemTable { } } + // Returns Corruption status if verification fails. + static Status VerifyEntryChecksum(const char* entry, + size_t protection_bytes_per_key, + bool allow_data_in_errors = false); + private: enum FlushStateEnum { FLUSH_NOT_REQUESTED, FLUSH_REQUESTED, FLUSH_SCHEDULED }; @@ -650,6 +656,10 @@ class MemTable { // if !is_range_del_table_empty_. std::unique_ptr fragmented_range_tombstone_list_; + + void UpdateEntryChecksum(const ProtectionInfoKVOS64* kv_prot_info, + const Slice& key, const Slice& value, ValueType type, + SequenceNumber s, char* checksum_ptr); }; extern const char* EncodeKey(std::string* scratch, const Slice& target); diff --git a/db_stress_tool/db_stress_common.h b/db_stress_tool/db_stress_common.h index ad4b6bb7e..c3316ee0f 100644 --- a/db_stress_tool/db_stress_common.h +++ b/db_stress_tool/db_stress_common.h @@ -284,6 +284,7 @@ DECLARE_bool(enable_compaction_filter); DECLARE_bool(paranoid_file_checks); DECLARE_bool(fail_if_options_file_error); DECLARE_uint64(batch_protection_bytes_per_key); +DECLARE_uint32(memtable_protection_bytes_per_key); DECLARE_uint64(user_timestamp_size); DECLARE_string(secondary_cache_uri); diff --git a/db_stress_tool/db_stress_gflags.cc b/db_stress_tool/db_stress_gflags.cc index a8733a52b..90bc9282e 100644 --- a/db_stress_tool/db_stress_gflags.cc +++ b/db_stress_tool/db_stress_gflags.cc @@ -941,6 +941,12 @@ DEFINE_uint64(batch_protection_bytes_per_key, 0, "specified number of bytes per key. Currently the only supported " "nonzero value is eight."); +DEFINE_uint32( + memtable_protection_bytes_per_key, 0, + "If nonzero, enables integrity protection in memtable entries at the " + "specified number of bytes per key. Currently the supported " + "nonzero values are 1, 2, 4 and 8."); + DEFINE_string(file_checksum_impl, "none", "Name of an implementation for file_checksum_gen_factory, or " "\"none\" for null."); diff --git a/db_stress_tool/db_stress_test_base.cc b/db_stress_tool/db_stress_test_base.cc index 24ea7cc1f..1f6b3794a 100644 --- a/db_stress_tool/db_stress_test_base.cc +++ b/db_stress_tool/db_stress_test_base.cc @@ -3016,6 +3016,8 @@ void InitializeOptionsFromFlags( options.track_and_verify_wals_in_manifest = true; options.verify_sst_unique_id_in_manifest = FLAGS_verify_sst_unique_id_in_manifest; + options.memtable_protection_bytes_per_key = + FLAGS_memtable_protection_bytes_per_key; // Integrated BlobDB options.enable_blob_files = FLAGS_enable_blob_files; diff --git a/include/rocksdb/advanced_options.h b/include/rocksdb/advanced_options.h index 9452e950d..8667f7d7a 100644 --- a/include/rocksdb/advanced_options.h +++ b/include/rocksdb/advanced_options.h @@ -1024,6 +1024,21 @@ struct AdvancedColumnFamilyOptions { // Dynamically changeable through the SetOptions() API PrepopulateBlobCache prepopulate_blob_cache = PrepopulateBlobCache::kDisable; + // Enable memtable per key-value checksum protection. + // + // Each entry in memtable will be suffixed by a per key-value checksum. + // This options determines the size of such checksums. + // + // It is suggested to turn on write batch per key-value + // checksum protection together with this option, so that the checksum + // computation is done outside of writer threads (memtable kv checksum can be + // computed from write batch checksum) See + // WriteOptions::protection_bytes_per_key for more detail. + // + // Default: 0 (no protection) + // Supported values: 0, 1, 2, 4, 8. + uint32_t memtable_protection_bytes_per_key = 0; + // Create ColumnFamilyOptions with default values for all fields AdvancedColumnFamilyOptions(); // Create ColumnFamilyOptions from Options diff --git a/options/cf_options.cc b/options/cf_options.cc index 2d3ac114d..53d705ecd 100644 --- a/options/cf_options.cc +++ b/options/cf_options.cc @@ -475,6 +475,10 @@ static std::unordered_map {offsetof(struct MutableCFOptions, experimental_mempurge_threshold), OptionType::kDouble, OptionVerificationType::kNormal, OptionTypeFlags::kMutable}}, + {"memtable_protection_bytes_per_key", + {offsetof(struct MutableCFOptions, memtable_protection_bytes_per_key), + OptionType::kUInt32T, OptionVerificationType::kNormal, + OptionTypeFlags::kMutable}}, {kOptNameCompOpts, OptionTypeInfo::Struct( kOptNameCompOpts, &compression_options_type_info, diff --git a/options/cf_options.h b/options/cf_options.h index d29f969f0..47de8e7ae 100644 --- a/options/cf_options.h +++ b/options/cf_options.h @@ -162,6 +162,8 @@ struct MutableCFOptions { Temperature::kUnknown ? options.bottommost_temperature : options.last_level_temperature), + memtable_protection_bytes_per_key( + options.memtable_protection_bytes_per_key), sample_for_compression( options.sample_for_compression), // TODO: is 0 fine here? compression_per_level(options.compression_per_level) { @@ -210,6 +212,7 @@ struct MutableCFOptions { compression(Snappy_Supported() ? kSnappyCompression : kNoCompression), bottommost_compression(kDisableCompressionOption), last_level_temperature(Temperature::kUnknown), + memtable_protection_bytes_per_key(0), sample_for_compression(0) {} explicit MutableCFOptions(const Options& options); @@ -298,6 +301,7 @@ struct MutableCFOptions { CompressionOptions compression_opts; CompressionOptions bottommost_compression_opts; Temperature last_level_temperature; + uint32_t memtable_protection_bytes_per_key; uint64_t sample_for_compression; std::vector compression_per_level; diff --git a/options/options_helper.cc b/options/options_helper.cc index 76f99a90e..efb1d382e 100644 --- a/options/options_helper.cc +++ b/options/options_helper.cc @@ -214,6 +214,8 @@ void UpdateColumnFamilyOptions(const MutableCFOptions& moptions, cf_opts->prefix_extractor = moptions.prefix_extractor; cf_opts->experimental_mempurge_threshold = moptions.experimental_mempurge_threshold; + cf_opts->memtable_protection_bytes_per_key = + moptions.memtable_protection_bytes_per_key; // Compaction related options cf_opts->disable_auto_compactions = moptions.disable_auto_compactions; diff --git a/options/options_settable_test.cc b/options/options_settable_test.cc index 15b20d4d4..49b80e17e 100644 --- a/options/options_settable_test.cc +++ b/options/options_settable_test.cc @@ -532,7 +532,8 @@ TEST_F(OptionsSettableTest, ColumnFamilyOptionsAllFieldsSettable) { "preclude_last_level_data_seconds=86400;" "compaction_options_fifo={max_table_files_size=3;allow_" "compaction=false;age_for_warm=1;};" - "blob_cache=1M;", + "blob_cache=1M;" + "memtable_protection_bytes_per_key=2;", new_options)); ASSERT_NE(new_options->blob_cache.get(), nullptr); diff --git a/tools/db_bench_tool.cc b/tools/db_bench_tool.cc index 5920f5cc3..568a2c73a 100644 --- a/tools/db_bench_tool.cc +++ b/tools/db_bench_tool.cc @@ -1695,6 +1695,13 @@ DEFINE_uint32(write_batch_protection_bytes_per_key, 0, "Size of per-key-value checksum in each write batch. Currently " "only value 0 and 8 are supported."); +DEFINE_uint32( + memtable_protection_bytes_per_key, 0, + "Enable memtable per key-value checksum protection. " + "Each entry in memtable will be suffixed by a per key-value checksum. " + "This options determines the size of such checksums. " + "Supported values: 0, 1, 2, 4, 8."); + DEFINE_bool(build_info, false, "Print the build info via GetRocksBuildInfoAsString"); @@ -4586,6 +4593,8 @@ class Benchmark { exit(1); } #endif // ROCKSDB_LITE + options.memtable_protection_bytes_per_key = + FLAGS_memtable_protection_bytes_per_key; } void InitializeOptionsGeneral(Options* opts) { diff --git a/tools/db_crashtest.py b/tools/db_crashtest.py index 22a677bad..53a2c5c28 100644 --- a/tools/db_crashtest.py +++ b/tools/db_crashtest.py @@ -35,6 +35,7 @@ default_params = { # Consider larger number when backups considered more stable "backup_one_in": 100000, "batch_protection_bytes_per_key": lambda: random.choice([0, 8]), + "memtable_protection_bytes_per_key": lambda: random.choice([0, 1, 2, 4, 8]), "block_size": 16384, "bloom_bits": lambda: random.choice([random.randint(0,19), random.lognormvariate(2.3, 1.3)]),