Integrity protection for live updates to WriteBatch (#7748)

Summary:
This PR adds the foundation classes for key-value integrity protection and the first use case: protecting live updates from the source buffers added to `WriteBatch` through the destination buffer in `MemTable`. The width of the protection info is not yet configurable -- only eight bytes per key is supported. This PR allows users to enable protection by constructing `WriteBatch` with `protection_bytes_per_key == 8`. It does not yet expose a way for users to get integrity protection via other write APIs (e.g., `Put()`, `Merge()`, `Delete()`, etc.).

The foundation classes (`ProtectionInfo.*`) embed the coverage info in their type, and provide `Protect.*()` and `Strip.*()` functions to navigate between types with different coverage. For making bytes per key configurable (for powers of two up to eight) in the future, these classes are templated on the unsigned integer type used to store the protection info. That integer contains the XOR'd result of hashes with independent seeds for all covered fields. For integer fields, the hash is computed on the raw unadjusted bytes, so the result is endian-dependent. The most significant bytes are truncated when the hash value (8 bytes) is wider than the protection integer.

When `WriteBatch` is constructed with `protection_bytes_per_key == 8`, we hold a `ProtectionInfoKVOTC` (i.e., one that covers key, value, optype aka `ValueType`, timestamp, and CF ID) for each entry added to the batch. The protection info is generated from the original buffers passed by the user, as well as the original metadata generated internally. When writing to memtable, each entry is transformed to a `ProtectionInfoKVOTS` (i.e., dropping coverage of CF ID and adding coverage of sequence number), since at that point we know the sequence number, and have already selected a memtable corresponding to a particular CF. This protection info is verified once the entry is encoded in the `MemTable` buffer.

Pull Request resolved: https://github.com/facebook/rocksdb/pull/7748

Test Plan:
- an integration test to verify a wide variety of single-byte changes to the encoded `MemTable` buffer are caught
- add to stress/crash test to verify it works in variety of configs/operations without intentional corruption
- [deferred] unit tests for `ProtectionInfo.*` classes for edge cases like KV swap, `SliceParts` and `Slice` APIs are interchangeable, etc.

Reviewed By: pdillinger

Differential Revision: D25754492

Pulled By: ajkr

fbshipit-source-id: e481bac6c03c2ab268be41359730f1ceb9964866
main
Andrew Kryczka 3 years ago committed by Facebook GitHub Bot
parent 4a09d632c4
commit 78ee8564ad
  1. 1
      CMakeLists.txt
  2. 4
      HISTORY.md
  3. 4
      Makefile
  4. 7
      TARGETS
  5. 2
      db/blob/blob_file_cache.cc
  6. 199
      db/db_kv_checksum_test.cc
  7. 52
      db/db_memtable_test.cc
  8. 6
      db/dbformat.h
  9. 19
      db/flush_job_test.cc
  10. 424
      db/kv_checksum.h
  11. 97
      db/memtable.cc
  12. 14
      db/memtable.h
  13. 63
      db/memtable_list_test.cc
  14. 2
      db/table_cache.cc
  15. 369
      db/write_batch.cc
  16. 13
      db/write_batch_internal.h
  17. 5
      db/write_thread.cc
  18. 3
      db/write_thread.h
  19. 6
      db_stress_tool/batched_ops_stress.cc
  20. 1
      db_stress_tool/db_stress_common.h
  21. 5
      db_stress_tool/db_stress_gflags.cc
  22. 7
      db_stress_tool/db_stress_tool.cc
  23. 12
      include/rocksdb/write_batch.h
  24. 1
      src.mk
  25. 3
      table/table_test.cc
  26. 3
      tools/db_crashtest.py
  27. 17
      util/hash.cc
  28. 11
      util/hash.h

@ -1088,6 +1088,7 @@ if(WITH_TESTS)
db/db_iter_test.cc
db/db_iter_stress_test.cc
db/db_iterator_test.cc
db/db_kv_checksum_test.cc
db/db_log_iter_test.cc
db/db_memtable_test.cc
db/db_merge_operator_test.cc

@ -3,6 +3,9 @@
### Behavior Changes
* When retryable IO error occurs during compaction, it is mapped to soft error and set the BG error. However, auto resume is not called to clean the soft error since compaction will reschedule by itself. In this change, When retryable IO error occurs during compaction, BG error is not set. User will be informed the error via EventHelper.
### New Features
* Add support for key-value integrity protection in live updates from the user buffers provided to `WriteBatch` through the write to RocksDB's in-memory update buffer (memtable). This is intended to detect some cases of in-memory data corruption, due to either software or hardware errors. Users can enable protection by constructing their `WriteBatch` with `protection_bytes_per_key == 8`.
## 6.17.0 (01/15/2021)
### Behavior Changes
* When verifying full file checksum with `DB::VerifyFileChecksums()`, we now fail with `Status::InvalidArgument` if the name of the checksum generator used for verification does not match the name of the checksum generator used for protecting the file when it was created.
@ -16,6 +19,7 @@
* Add a public API WriteBufferManager::dummy_entries_in_cache_usage() which reports the size of dummy entries stored in cache (passed to WriteBufferManager). Dummy entries are used to account for DataBlocks.
* Add a SystemClock class that contains the time-related methods from Env. The original methods in Env may be deprecated in a future release. This class will allow easier testing, development, and expansion of time-related features.
* Add a public API GetRocksBuildProperties and GetRocksBuildInfoAsString to get properties about the current build. These properties may include settings related to the GIT settings (branch, timestamp). This change also sets the "build date" based on the GIT properties, rather than the actual build time, thereby enabling more reproducible builds.
## 6.16.0 (12/18/2020)
### Behavior Changes
* Attempting to write a merge operand without explicitly configuring `merge_operator` now fails immediately, causing the DB to enter read-only mode. Previously, failure was deferred until the `merge_operator` was needed by a user read or a background operation.

@ -594,6 +594,7 @@ ifdef ASSERT_STATUS_CHECKED
db_inplace_update_test \
db_io_failure_test \
db_iterator_test \
db_kv_checksum_test \
db_logical_block_size_cache_test \
db_memtable_test \
db_merge_operand_test \
@ -1608,6 +1609,9 @@ db_inplace_update_test: $(OBJ_DIR)/db/db_inplace_update_test.o $(TEST_LIBRARY) $
db_iterator_test: $(OBJ_DIR)/db/db_iterator_test.o $(TEST_LIBRARY) $(LIBRARY)
$(AM_LINK)
db_kv_checksum_test: $(OBJ_DIR)/db/db_kv_checksum_test.o $(TEST_LIBRARY) $(LIBRARY)
$(AM_LINK)
db_memtable_test: $(OBJ_DIR)/db/db_memtable_test.o $(TEST_LIBRARY) $(LIBRARY)
$(AM_LINK)

@ -1240,6 +1240,13 @@ ROCKS_TESTS = [
[],
[],
],
[
"db_kv_checksum_test",
"db/db_kv_checksum_test.cc",
"serial",
[],
[],
],
[
"db_log_iter_test",
"db/db_log_iter_test.cc",

@ -23,7 +23,7 @@ BlobFileCache::BlobFileCache(Cache* cache,
uint32_t column_family_id,
HistogramImpl* blob_file_read_hist)
: cache_(cache),
mutex_(kNumberOfMutexStripes, GetSliceNPHash64),
mutex_(kNumberOfMutexStripes, kGetSliceNPHash64UnseededFnPtr),
immutable_cf_options_(immutable_cf_options),
file_options_(file_options),
column_family_id_(column_family_id),

@ -0,0 +1,199 @@
// Copyright (c) 2020-present, Facebook, Inc. All rights reserved.
// This source code is licensed under both the GPLv2 (found in the
// COPYING file in the root directory) and Apache 2.0 License
// (found in the LICENSE.Apache file in the root directory).
#include "db/db_test_util.h"
#include "rocksdb/rocksdb_namespace.h"
namespace ROCKSDB_NAMESPACE {
enum class WriteBatchOpType {
kPut = 0,
kDelete,
kSingleDelete,
kDeleteRange,
kMerge,
kBlobIndex,
kNum,
};
// Integer addition is needed for `::testing::Range()` to take the enum type.
WriteBatchOpType operator+(WriteBatchOpType lhs, const int rhs) {
using T = std::underlying_type<WriteBatchOpType>::type;
return static_cast<WriteBatchOpType>(static_cast<T>(lhs) + rhs);
}
class DbKvChecksumTest
: public DBTestBase,
public ::testing::WithParamInterface<std::tuple<WriteBatchOpType, char>> {
public:
DbKvChecksumTest()
: DBTestBase("/db_kv_checksum_test", /*env_do_fsync=*/false) {
op_type_ = std::get<0>(GetParam());
corrupt_byte_addend_ = std::get<1>(GetParam());
}
std::pair<WriteBatch, Status> GetWriteBatch(size_t ts_sz,
ColumnFamilyHandle* cf_handle) {
Status s;
WriteBatch wb(0 /* reserved_bytes */, 0 /* max_bytes */, ts_sz,
8 /* protection_bytes_per_entry */);
switch (op_type_) {
case WriteBatchOpType::kPut:
s = wb.Put(cf_handle, "key", "val");
break;
case WriteBatchOpType::kDelete:
s = wb.Delete(cf_handle, "key");
break;
case WriteBatchOpType::kSingleDelete:
s = wb.SingleDelete(cf_handle, "key");
break;
case WriteBatchOpType::kDeleteRange:
s = wb.DeleteRange(cf_handle, "begin", "end");
break;
case WriteBatchOpType::kMerge:
s = wb.Merge(cf_handle, "key", "val");
break;
case WriteBatchOpType::kBlobIndex:
// TODO(ajkr): use public API once available.
uint32_t cf_id;
if (cf_handle == nullptr) {
cf_id = 0;
} else {
cf_id = cf_handle->GetID();
}
s = WriteBatchInternal::PutBlobIndex(&wb, cf_id, "key", "val");
break;
case WriteBatchOpType::kNum:
assert(false);
}
return {std::move(wb), std::move(s)};
}
void CorruptNextByteCallBack(void* arg) {
Slice encoded = *static_cast<Slice*>(arg);
if (entry_len_ == port::kMaxSizet) {
// We learn the entry size on the first attempt
entry_len_ = encoded.size();
}
// All entries should be the same size
assert(entry_len_ == encoded.size());
char* buf = const_cast<char*>(encoded.data());
buf[corrupt_byte_offset_] += corrupt_byte_addend_;
++corrupt_byte_offset_;
}
bool MoreBytesToCorrupt() { return corrupt_byte_offset_ < entry_len_; }
protected:
WriteBatchOpType op_type_;
char corrupt_byte_addend_;
size_t corrupt_byte_offset_ = 0;
size_t entry_len_ = port::kMaxSizet;
};
std::string GetTestNameSuffix(
::testing::TestParamInfo<std::tuple<WriteBatchOpType, char>> info) {
std::ostringstream oss;
switch (std::get<0>(info.param)) {
case WriteBatchOpType::kPut:
oss << "Put";
break;
case WriteBatchOpType::kDelete:
oss << "Delete";
break;
case WriteBatchOpType::kSingleDelete:
oss << "SingleDelete";
break;
case WriteBatchOpType::kDeleteRange:
oss << "DeleteRange";
break;
case WriteBatchOpType::kMerge:
oss << "Merge";
break;
case WriteBatchOpType::kBlobIndex:
oss << "BlobIndex";
break;
case WriteBatchOpType::kNum:
assert(false);
}
oss << "Add"
<< static_cast<int>(static_cast<unsigned char>(std::get<1>(info.param)));
return oss.str();
}
INSTANTIATE_TEST_CASE_P(
DbKvChecksumTest, DbKvChecksumTest,
::testing::Combine(::testing::Range(static_cast<WriteBatchOpType>(0),
WriteBatchOpType::kNum),
::testing::Values(2, 103, 251)),
GetTestNameSuffix);
TEST_P(DbKvChecksumTest, MemTableAddCorrupted) {
// This test repeatedly attempts to write `WriteBatch`es containing a single
// entry of type `op_type_`. Each attempt has one byte corrupted in its
// memtable entry by adding `corrupt_byte_addend_` to its original value. The
// test repeats until an attempt has been made on each byte in the encoded
// memtable entry. All attempts are expected to fail with `Status::Corruption`
SyncPoint::GetInstance()->SetCallBack(
"MemTable::Add:Encoded",
std::bind(&DbKvChecksumTest::CorruptNextByteCallBack, this,
std::placeholders::_1));
while (MoreBytesToCorrupt()) {
// Failed memtable insert always leads to read-only mode, so we have to
// reopen for every attempt.
Options options = CurrentOptions();
if (op_type_ == WriteBatchOpType::kMerge) {
options.merge_operator = MergeOperators::CreateStringAppendOperator();
}
Reopen(options);
SyncPoint::GetInstance()->EnableProcessing();
auto batch_and_status =
GetWriteBatch(0 /* ts_sz */, nullptr /* cf_handle */);
ASSERT_OK(batch_and_status.second);
ASSERT_TRUE(
db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption());
SyncPoint::GetInstance()->DisableProcessing();
}
}
TEST_P(DbKvChecksumTest, MemTableAddWithColumnFamilyCorrupted) {
// This test repeatedly attempts to write `WriteBatch`es containing a single
// entry of type `op_type_` to a non-default column family. Each attempt has
// one byte corrupted in its memtable entry by adding `corrupt_byte_addend_`
// to its original value. The test repeats until an attempt has been made on
// each byte in the encoded memtable entry. All attempts are expected to fail
// with `Status::Corruption`.
Options options = CurrentOptions();
if (op_type_ == WriteBatchOpType::kMerge) {
options.merge_operator = MergeOperators::CreateStringAppendOperator();
}
CreateAndReopenWithCF({"pikachu"}, options);
SyncPoint::GetInstance()->SetCallBack(
"MemTable::Add:Encoded",
std::bind(&DbKvChecksumTest::CorruptNextByteCallBack, this,
std::placeholders::_1));
while (MoreBytesToCorrupt()) {
// Failed memtable insert always leads to read-only mode, so we have to
// reopen for every attempt.
ReopenWithColumnFamilies({kDefaultColumnFamilyName, "pikachu"}, options);
SyncPoint::GetInstance()->EnableProcessing();
auto batch_and_status = GetWriteBatch(0 /* ts_sz */, handles_[1]);
ASSERT_OK(batch_and_status.second);
ASSERT_TRUE(
db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption());
SyncPoint::GetInstance()->DisableProcessing();
}
}
} // namespace ROCKSDB_NAMESPACE
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

@ -145,15 +145,25 @@ TEST_F(DBMemTableTest, DuplicateSeq) {
kMaxSequenceNumber, 0 /* column_family_id */);
// Write some keys and make sure it returns false on duplicates
ASSERT_OK(mem->Add(seq, kTypeValue, "key", "value2"));
ASSERT_TRUE(mem->Add(seq, kTypeValue, "key", "value2").IsTryAgain());
ASSERT_OK(
mem->Add(seq, kTypeValue, "key", "value2", nullptr /* kv_prot_info */));
ASSERT_TRUE(
mem->Add(seq, kTypeValue, "key", "value2", nullptr /* kv_prot_info */)
.IsTryAgain());
// Changing the type should still cause the duplicatae key
ASSERT_TRUE(mem->Add(seq, kTypeMerge, "key", "value2").IsTryAgain());
ASSERT_TRUE(
mem->Add(seq, kTypeMerge, "key", "value2", nullptr /* kv_prot_info */)
.IsTryAgain());
// Changing the seq number will make the key fresh
ASSERT_OK(mem->Add(seq + 1, kTypeMerge, "key", "value2"));
ASSERT_OK(mem->Add(seq + 1, kTypeMerge, "key", "value2",
nullptr /* kv_prot_info */));
// Test with different types for duplicate keys
ASSERT_TRUE(mem->Add(seq, kTypeDeletion, "key", "").IsTryAgain());
ASSERT_TRUE(mem->Add(seq, kTypeSingleDeletion, "key", "").IsTryAgain());
ASSERT_TRUE(
mem->Add(seq, kTypeDeletion, "key", "", nullptr /* kv_prot_info */)
.IsTryAgain());
ASSERT_TRUE(
mem->Add(seq, kTypeSingleDeletion, "key", "", nullptr /* kv_prot_info */)
.IsTryAgain());
// Test the duplicate keys under stress
for (int i = 0; i < 10000; i++) {
@ -161,7 +171,8 @@ TEST_F(DBMemTableTest, DuplicateSeq) {
if (!insert_dup) {
seq++;
}
Status s = mem->Add(seq, kTypeValue, "foo", "value" + ToString(seq));
Status s = mem->Add(seq, kTypeValue, "foo", "value" + ToString(seq),
nullptr /* kv_prot_info */);
if (insert_dup) {
ASSERT_TRUE(s.IsTryAgain());
} else {
@ -177,8 +188,11 @@ TEST_F(DBMemTableTest, DuplicateSeq) {
mem = new MemTable(cmp, ioptions, MutableCFOptions(options), &wb,
kMaxSequenceNumber, 0 /* column_family_id */);
// Insert a duplicate key with _ in it
ASSERT_OK(mem->Add(seq, kTypeValue, "key_1", "value"));
ASSERT_TRUE(mem->Add(seq, kTypeValue, "key_1", "value").IsTryAgain());
ASSERT_OK(
mem->Add(seq, kTypeValue, "key_1", "value", nullptr /* kv_prot_info */));
ASSERT_TRUE(
mem->Add(seq, kTypeValue, "key_1", "value", nullptr /* kv_prot_info */)
.IsTryAgain());
delete mem;
// Test when InsertConcurrently will be invoked
@ -187,11 +201,11 @@ TEST_F(DBMemTableTest, DuplicateSeq) {
mem = new MemTable(cmp, ioptions, MutableCFOptions(options), &wb,
kMaxSequenceNumber, 0 /* column_family_id */);
MemTablePostProcessInfo post_process_info;
ASSERT_OK(
mem->Add(seq, kTypeValue, "key", "value", true, &post_process_info));
ASSERT_TRUE(
mem->Add(seq, kTypeValue, "key", "value", true, &post_process_info)
.IsTryAgain());
ASSERT_OK(mem->Add(seq, kTypeValue, "key", "value",
nullptr /* kv_prot_info */, true, &post_process_info));
ASSERT_TRUE(mem->Add(seq, kTypeValue, "key", "value",
nullptr /* kv_prot_info */, true, &post_process_info)
.IsTryAgain());
delete mem;
}
@ -217,7 +231,7 @@ TEST_F(DBMemTableTest, ConcurrentMergeWrite) {
// Put 0 as the base
PutFixed64(&value, static_cast<uint64_t>(0));
ASSERT_OK(mem->Add(0, kTypeValue, "key", value));
ASSERT_OK(mem->Add(0, kTypeValue, "key", value, nullptr /* kv_prot_info */));
value.clear();
// Write Merge concurrently
@ -226,8 +240,8 @@ TEST_F(DBMemTableTest, ConcurrentMergeWrite) {
std::string v1;
for (int seq = 1; seq < num_ops / 2; seq++) {
PutFixed64(&v1, seq);
ASSERT_OK(
mem->Add(seq, kTypeMerge, "key", v1, true, &post_process_info1));
ASSERT_OK(mem->Add(seq, kTypeMerge, "key", v1, nullptr /* kv_prot_info */,
true, &post_process_info1));
v1.clear();
}
});
@ -236,8 +250,8 @@ TEST_F(DBMemTableTest, ConcurrentMergeWrite) {
std::string v2;
for (int seq = num_ops / 2; seq < num_ops; seq++) {
PutFixed64(&v2, seq);
ASSERT_OK(
mem->Add(seq, kTypeMerge, "key", v2, true, &post_process_info2));
ASSERT_OK(mem->Add(seq, kTypeMerge, "key", v2, nullptr /* kv_prot_info */,
true, &post_process_info2));
v2.clear();
}
});

@ -146,8 +146,10 @@ inline void UnPackSequenceAndType(uint64_t packed, uint64_t* seq,
*seq = packed >> 8;
*t = static_cast<ValueType>(packed & 0xff);
assert(*seq <= kMaxSequenceNumber);
assert(IsExtendedValueType(*t));
// Commented the following two assertions in order to test key-value checksum
// on corrupted keys without crashing ("DbKvChecksumTest").
// assert(*seq <= kMaxSequenceNumber);
// assert(IsExtendedValueType(*t));
}
EntryType GetEntryType(ValueType value_type);

@ -190,7 +190,8 @@ TEST_F(FlushJobTest, NonEmpty) {
for (int i = 1; i < 10000; ++i) {
std::string key(ToString((i + 1000) % 10000));
std::string value("value" + key);
ASSERT_OK(new_mem->Add(SequenceNumber(i), kTypeValue, key, value));
ASSERT_OK(new_mem->Add(SequenceNumber(i), kTypeValue, key, value,
nullptr /* kv_prot_info */));
if ((i + 1000) % 10000 < 9995) {
InternalKey internal_key(key, SequenceNumber(i), kTypeValue);
inserted_keys.push_back({internal_key.Encode().ToString(), value});
@ -199,7 +200,7 @@ TEST_F(FlushJobTest, NonEmpty) {
{
ASSERT_OK(new_mem->Add(SequenceNumber(10000), kTypeRangeDeletion, "9995",
"9999a"));
"9999a", nullptr /* kv_prot_info */));
InternalKey internal_key("9995", SequenceNumber(10000), kTypeRangeDeletion);
inserted_keys.push_back({internal_key.Encode().ToString(), "9999a"});
}
@ -226,7 +227,8 @@ TEST_F(FlushJobTest, NonEmpty) {
}
const SequenceNumber seq(i + 10001);
ASSERT_OK(new_mem->Add(seq, kTypeBlobIndex, key, blob_index));
ASSERT_OK(new_mem->Add(seq, kTypeBlobIndex, key, blob_index,
nullptr /* kv_prot_info */));
InternalKey internal_key(key, seq, kTypeBlobIndex);
inserted_keys.push_back({internal_key.Encode().ToString(), blob_index});
@ -288,7 +290,7 @@ TEST_F(FlushJobTest, FlushMemTablesSingleColumnFamily) {
std::string key(ToString(j + i * num_keys_per_table));
std::string value("value" + key);
ASSERT_OK(mem->Add(SequenceNumber(j + i * num_keys_per_table), kTypeValue,
key, value));
key, value, nullptr /* kv_prot_info */));
}
}
@ -360,7 +362,8 @@ TEST_F(FlushJobTest, FlushMemtablesMultipleColumnFamilies) {
for (size_t j = 0; j != num_keys_per_memtable; ++j) {
std::string key(ToString(j + i * num_keys_per_memtable));
std::string value("value" + key);
ASSERT_OK(mem->Add(curr_seqno++, kTypeValue, key, value));
ASSERT_OK(mem->Add(curr_seqno++, kTypeValue, key, value,
nullptr /* kv_prot_info */));
}
cfd->imm()->Add(mem, &to_delete);
@ -471,7 +474,8 @@ TEST_F(FlushJobTest, Snapshots) {
for (int j = 0; j < insertions; ++j) {
std::string value(rnd.HumanReadableString(10));
auto seqno = ++current_seqno;
ASSERT_OK(new_mem->Add(SequenceNumber(seqno), kTypeValue, key, value));
ASSERT_OK(new_mem->Add(SequenceNumber(seqno), kTypeValue, key, value,
nullptr /* kv_prot_info */));
// a key is visible only if:
// 1. it's the last one written (j == insertions - 1)
// 2. there's a snapshot pointing at it
@ -523,7 +527,8 @@ class FlushJobTimestampTest : public FlushJobTestBase {
Slice value) {
std::string key_str(std::move(key));
PutFixed64(&key_str, ts);
ASSERT_OK(memtable->Add(seq, value_type, key_str, value));
ASSERT_OK(memtable->Add(seq, value_type, key_str, value,
nullptr /* kv_prot_info */));
}
protected:

@ -0,0 +1,424 @@
// Copyright (c) 2020-present, Facebook, Inc. All rights reserved.
// This source code is licensed under both the GPLv2 (found in the
// COPYING file in the root directory) and Apache 2.0 License
// (found in the LICENSE.Apache file in the root directory).
//
// This file contains classes containing fields to protect individual entries.
// The classes are named "ProtectionInfo<suffix>", where <suffix> indicates the
// combination of fields that are covered. Each field has a single letter
// abbreviation as follows.
//
// K = key
// V = value
// O = optype aka value type
// T = timestamp
// S = seqno
// C = CF ID
//
// Then, for example, a class that protects an entry consisting of key, value,
// optype, timestamp, and CF ID (i.e., a `WriteBatch` entry) would be named
// `ProtectionInfoKVOTC`.
//
// The `ProtectionInfo.*` classes are templated on the integer type used to hold
// the XOR of hashes for each field. Only unsigned integer types are supported,
// and the maximum supported integer width is 64 bits. When the integer type is
// narrower than the hash values, we lop off the most significant bits to make
// them fit.
//
// The `ProtectionInfo.*` classes are all intended to be non-persistent. We do
// not currently make the byte order consistent for integer fields before
// hashing them, so the resulting values are endianness-dependent.
#pragma once
#include <type_traits>
#include "db/dbformat.h"
#include "rocksdb/types.h"
#include "util/hash.h"
namespace ROCKSDB_NAMESPACE {
template <typename T>
class ProtectionInfo;
template <typename T>
class ProtectionInfoKVOT;
template <typename T>
class ProtectionInfoKVOTC;
template <typename T>
class ProtectionInfoKVOTS;
// Aliases for 64-bit protection infos.
typedef ProtectionInfo<uint64_t> ProtectionInfo64;
typedef ProtectionInfoKVOT<uint64_t> ProtectionInfoKVOT64;
typedef ProtectionInfoKVOTC<uint64_t> ProtectionInfoKVOTC64;
typedef ProtectionInfoKVOTS<uint64_t> ProtectionInfoKVOTS64;
template <typename T>
class ProtectionInfo {
public:
ProtectionInfo<T>() = default;
Status GetStatus() const;
ProtectionInfoKVOT<T> ProtectKVOT(const Slice& key, const Slice& value,
ValueType op_type,
const Slice& timestamp) const;
ProtectionInfoKVOT<T> ProtectKVOT(const SliceParts& key,
const SliceParts& value, ValueType op_type,
const Slice& timestamp) const;
private:
friend class ProtectionInfoKVOT<T>;
friend class ProtectionInfoKVOTS<T>;
friend class ProtectionInfoKVOTC<T>;
// Each field is hashed with an independent value so we can catch fields being
// swapped. Per the `NPHash64()` docs, using consecutive seeds is a pitfall,
// and we should instead vary our seeds by a large odd number. This value by
// which we increment (0xD28AAD72F49BD50B) was taken from
// `head -c8 /dev/urandom | hexdump`, run repeatedly until it yielded an odd
// number. The values are computed manually since the Windows C++ compiler
// complains about the overflow when adding constants.
static const uint64_t kSeedK = 0;
static const uint64_t kSeedV = 0xD28AAD72F49BD50B;
static const uint64_t kSeedO = 0xA5155AE5E937AA16;
static const uint64_t kSeedT = 0x77A00858DDD37F21;
static const uint64_t kSeedS = 0x4A2AB5CBD26F542C;
static const uint64_t kSeedC = 0x1CB5633EC70B2937;
ProtectionInfo<T>(T val) : val_(val) {
static_assert(sizeof(ProtectionInfo<T>) == sizeof(T), "");
}
T GetVal() const { return val_; }
void SetVal(T val) { val_ = val; }
T val_ = 0;
};
template <typename T>
class ProtectionInfoKVOT {
public:
ProtectionInfoKVOT<T>() = default;
ProtectionInfo<T> StripKVOT(const Slice& key, const Slice& value,
ValueType op_type, const Slice& timestamp) const;
ProtectionInfo<T> StripKVOT(const SliceParts& key, const SliceParts& value,
ValueType op_type, const Slice& timestamp) const;
ProtectionInfoKVOTC<T> ProtectC(ColumnFamilyId column_family_id) const;
ProtectionInfoKVOTS<T> ProtectS(SequenceNumber sequence_number) const;
void UpdateK(const Slice& old_key, const Slice& new_key);
void UpdateK(const SliceParts& old_key, const SliceParts& new_key);
void UpdateV(const Slice& old_value, const Slice& new_value);
void UpdateV(const SliceParts& old_value, const SliceParts& new_value);
void UpdateO(ValueType old_op_type, ValueType new_op_type);
void UpdateT(const Slice& old_timestamp, const Slice& new_timestamp);
private:
friend class ProtectionInfo<T>;
friend class ProtectionInfoKVOTS<T>;
friend class ProtectionInfoKVOTC<T>;
ProtectionInfoKVOT<T>(T val) : info_(val) {
static_assert(sizeof(ProtectionInfoKVOT<T>) == sizeof(T), "");
}
T GetVal() const { return info_.GetVal(); }
void SetVal(T val) { info_.SetVal(val); }
ProtectionInfo<T> info_;
};
template <typename T>
class ProtectionInfoKVOTC {
public:
ProtectionInfoKVOTC<T>() = default;
ProtectionInfoKVOT<T> StripC(ColumnFamilyId column_family_id) const;
void UpdateK(const Slice& old_key, const Slice& new_key) {
kvot_.UpdateK(old_key, new_key);
}
void UpdateK(const SliceParts& old_key, const SliceParts& new_key) {
kvot_.UpdateK(old_key, new_key);
}
void UpdateV(const Slice& old_value, const Slice& new_value) {
kvot_.UpdateV(old_value, new_value);
}
void UpdateV(const SliceParts& old_value, const SliceParts& new_value) {
kvot_.UpdateV(old_value, new_value);
}
void UpdateO(ValueType old_op_type, ValueType new_op_type) {
kvot_.UpdateO(old_op_type, new_op_type);
}
void UpdateT(const Slice& old_timestamp, const Slice& new_timestamp) {
kvot_.UpdateT(old_timestamp, new_timestamp);
}
void UpdateC(ColumnFamilyId old_column_family_id,
ColumnFamilyId new_column_family_id);
private:
friend class ProtectionInfoKVOT<T>;
ProtectionInfoKVOTC<T>(T val) : kvot_(val) {
static_assert(sizeof(ProtectionInfoKVOTC<T>) == sizeof(T), "");
}
T GetVal() const { return kvot_.GetVal(); }
void SetVal(T val) { kvot_.SetVal(val); }
ProtectionInfoKVOT<T> kvot_;
};
template <typename T>
class ProtectionInfoKVOTS {
public:
ProtectionInfoKVOTS<T>() = default;
ProtectionInfoKVOT<T> StripS(SequenceNumber sequence_number) const;
void UpdateK(const Slice& old_key, const Slice& new_key) {
kvot_.UpdateK(old_key, new_key);
}
void UpdateK(const SliceParts& old_key, const SliceParts& new_key) {
kvot_.UpdateK(old_key, new_key);
}
void UpdateV(const Slice& old_value, const Slice& new_value) {
kvot_.UpdateV(old_value, new_value);
}
void UpdateV(const SliceParts& old_value, const SliceParts& new_value) {
kvot_.UpdateV(old_value, new_value);
}
void UpdateO(ValueType old_op_type, ValueType new_op_type) {
kvot_.UpdateO(old_op_type, new_op_type);
}
void UpdateT(const Slice& old_timestamp, const Slice& new_timestamp) {
kvot_.UpdateT(old_timestamp, new_timestamp);
}
void UpdateS(SequenceNumber old_sequence_number,
SequenceNumber new_sequence_number);
private:
friend class ProtectionInfoKVOT<T>;
ProtectionInfoKVOTS<T>(T val) : kvot_(val) {
static_assert(sizeof(ProtectionInfoKVOTS<T>) == sizeof(T), "");
}
T GetVal() const { return kvot_.GetVal(); }
void SetVal(T val) { kvot_.SetVal(val); }
ProtectionInfoKVOT<T> kvot_;
};
template <typename T>
Status ProtectionInfo<T>::GetStatus() const {
if (val_ != 0) {
return Status::Corruption("ProtectionInfo mismatch");
}
return Status::OK();
}
template <typename T>
ProtectionInfoKVOT<T> ProtectionInfo<T>::ProtectKVOT(
const Slice& key, const Slice& value, ValueType op_type,
const Slice& timestamp) const {
T val = GetVal();
val = val ^ static_cast<T>(GetSliceNPHash64(key, ProtectionInfo<T>::kSeedK));
val =
val ^ static_cast<T>(GetSliceNPHash64(value, ProtectionInfo<T>::kSeedV));
val = val ^
static_cast<T>(NPHash64(reinterpret_cast<char*>(&op_type),
sizeof(op_type), ProtectionInfo<T>::kSeedO));
val = val ^
static_cast<T>(GetSliceNPHash64(timestamp, ProtectionInfo<T>::kSeedT));
return ProtectionInfoKVOT<T>(val);
}
template <typename T>
ProtectionInfoKVOT<T> ProtectionInfo<T>::ProtectKVOT(
const SliceParts& key, const SliceParts& value, ValueType op_type,
const Slice& timestamp) const {
T val = GetVal();
val = val ^
static_cast<T>(GetSlicePartsNPHash64(key, ProtectionInfo<T>::kSeedK));
val = val ^
static_cast<T>(GetSlicePartsNPHash64(value, ProtectionInfo<T>::kSeedV));
val = val ^
static_cast<T>(NPHash64(reinterpret_cast<char*>(&op_type),
sizeof(op_type), ProtectionInfo<T>::kSeedO));
val = val ^
static_cast<T>(GetSliceNPHash64(timestamp, ProtectionInfo<T>::kSeedT));
return ProtectionInfoKVOT<T>(val);
}
template <typename T>
void ProtectionInfoKVOT<T>::UpdateK(const Slice& old_key,
const Slice& new_key) {
T val = GetVal();
val = val ^
static_cast<T>(GetSliceNPHash64(old_key, ProtectionInfo<T>::kSeedK));
val = val ^
static_cast<T>(GetSliceNPHash64(new_key, ProtectionInfo<T>::kSeedK));
SetVal(val);
}
template <typename T>
void ProtectionInfoKVOT<T>::UpdateK(const SliceParts& old_key,
const SliceParts& new_key) {
T val = GetVal();
val = val ^ static_cast<T>(
GetSlicePartsNPHash64(old_key, ProtectionInfo<T>::kSeedK));
val = val ^ static_cast<T>(
GetSlicePartsNPHash64(new_key, ProtectionInfo<T>::kSeedK));
SetVal(val);
}
template <typename T>
void ProtectionInfoKVOT<T>::UpdateV(const Slice& old_value,
const Slice& new_value) {
T val = GetVal();
val = val ^
static_cast<T>(GetSliceNPHash64(old_value, ProtectionInfo<T>::kSeedV));
val = val ^
static_cast<T>(GetSliceNPHash64(new_value, ProtectionInfo<T>::kSeedV));
SetVal(val);
}
template <typename T>
void ProtectionInfoKVOT<T>::UpdateV(const SliceParts& old_value,
const SliceParts& new_value) {
T val = GetVal();
val = val ^ static_cast<T>(
GetSlicePartsNPHash64(old_value, ProtectionInfo<T>::kSeedV));
val = val ^ static_cast<T>(
GetSlicePartsNPHash64(new_value, ProtectionInfo<T>::kSeedV));
SetVal(val);
}
template <typename T>
void ProtectionInfoKVOT<T>::UpdateO(ValueType old_op_type,
ValueType new_op_type) {
T val = GetVal();
val = val ^ static_cast<T>(NPHash64(reinterpret_cast<char*>(&old_op_type),
sizeof(old_op_type),
ProtectionInfo<T>::kSeedO));
val = val ^ static_cast<T>(NPHash64(reinterpret_cast<char*>(&new_op_type),
sizeof(new_op_type),
ProtectionInfo<T>::kSeedO));
SetVal(val);
}
template <typename T>
void ProtectionInfoKVOT<T>::UpdateT(const Slice& old_timestamp,
const Slice& new_timestamp) {
T val = GetVal();
val = val ^ static_cast<T>(
GetSliceNPHash64(old_timestamp, ProtectionInfo<T>::kSeedT));
val = val ^ static_cast<T>(
GetSliceNPHash64(new_timestamp, ProtectionInfo<T>::kSeedT));
SetVal(val);
}
template <typename T>
ProtectionInfo<T> ProtectionInfoKVOT<T>::StripKVOT(
const Slice& key, const Slice& value, ValueType op_type,
const Slice& timestamp) const {
T val = GetVal();
val = val ^ static_cast<T>(GetSliceNPHash64(key, ProtectionInfo<T>::kSeedK));
val =
val ^ static_cast<T>(GetSliceNPHash64(value, ProtectionInfo<T>::kSeedV));
val = val ^
static_cast<T>(NPHash64(reinterpret_cast<char*>(&op_type),
sizeof(op_type), ProtectionInfo<T>::kSeedO));
val = val ^
static_cast<T>(GetSliceNPHash64(timestamp, ProtectionInfo<T>::kSeedT));
return ProtectionInfo<T>(val);
}
template <typename T>
ProtectionInfo<T> ProtectionInfoKVOT<T>::StripKVOT(
const SliceParts& key, const SliceParts& value, ValueType op_type,
const Slice& timestamp) const {
T val = GetVal();
val = val ^
static_cast<T>(GetSlicePartsNPHash64(key, ProtectionInfo<T>::kSeedK));
val = val ^
static_cast<T>(GetSlicePartsNPHash64(value, ProtectionInfo<T>::kSeedV));
val = val ^
static_cast<T>(NPHash64(reinterpret_cast<char*>(&op_type),
sizeof(op_type), ProtectionInfo<T>::kSeedO));
val = val ^
static_cast<T>(GetSliceNPHash64(timestamp, ProtectionInfo<T>::kSeedT));
return ProtectionInfo<T>(val);
}
template <typename T>
ProtectionInfoKVOTC<T> ProtectionInfoKVOT<T>::ProtectC(
ColumnFamilyId column_family_id) const {
T val = GetVal();
val = val ^ static_cast<T>(NPHash64(
reinterpret_cast<char*>(&column_family_id),
sizeof(column_family_id), ProtectionInfo<T>::kSeedC));
return ProtectionInfoKVOTC<T>(val);
}
template <typename T>
ProtectionInfoKVOT<T> ProtectionInfoKVOTC<T>::StripC(
ColumnFamilyId column_family_id) const {
T val = GetVal();
val = val ^ static_cast<T>(NPHash64(
reinterpret_cast<char*>(&column_family_id),
sizeof(column_family_id), ProtectionInfo<T>::kSeedC));
return ProtectionInfoKVOT<T>(val);
}
template <typename T>
void ProtectionInfoKVOTC<T>::UpdateC(ColumnFamilyId old_column_family_id,
ColumnFamilyId new_column_family_id) {
T val = GetVal();
val = val ^ static_cast<T>(NPHash64(
reinterpret_cast<char*>(&old_column_family_id),
sizeof(old_column_family_id), ProtectionInfo<T>::kSeedC));
val = val ^ static_cast<T>(NPHash64(
reinterpret_cast<char*>(&new_column_family_id),
sizeof(new_column_family_id), ProtectionInfo<T>::kSeedC));
SetVal(val);
}
template <typename T>
ProtectionInfoKVOTS<T> ProtectionInfoKVOT<T>::ProtectS(
SequenceNumber sequence_number) const {
T val = GetVal();
val = val ^ static_cast<T>(NPHash64(reinterpret_cast<char*>(&sequence_number),
sizeof(sequence_number),
ProtectionInfo<T>::kSeedS));
return ProtectionInfoKVOTS<T>(val);
}
template <typename T>
ProtectionInfoKVOT<T> ProtectionInfoKVOTS<T>::StripS(
SequenceNumber sequence_number) const {
T val = GetVal();
val = val ^ static_cast<T>(NPHash64(reinterpret_cast<char*>(&sequence_number),
sizeof(sequence_number),
ProtectionInfo<T>::kSeedS));
return ProtectionInfoKVOT<T>(val);
}
template <typename T>
void ProtectionInfoKVOTS<T>::UpdateS(SequenceNumber old_sequence_number,
SequenceNumber new_sequence_number) {
T val = GetVal();
val = val ^ static_cast<T>(NPHash64(
reinterpret_cast<char*>(&old_sequence_number),
sizeof(old_sequence_number), ProtectionInfo<T>::kSeedS));
val = val ^ static_cast<T>(NPHash64(
reinterpret_cast<char*>(&new_sequence_number),
sizeof(new_sequence_number), ProtectionInfo<T>::kSeedS));
SetVal(val);
}
} // namespace ROCKSDB_NAMESPACE

@ -13,7 +13,9 @@
#include <array>
#include <limits>
#include <memory>
#include "db/dbformat.h"
#include "db/kv_checksum.h"
#include "db/merge_context.h"
#include "db/merge_helper.h"
#include "db/pinned_iterators_manager.h"
@ -484,9 +486,54 @@ MemTable::MemTableStats MemTable::ApproximateStats(const Slice& start_ikey,
return {entry_count * (data_size / n), entry_count};
}
Status MemTable::VerifyEncodedEntry(Slice encoded,
const ProtectionInfoKVOTS64& kv_prot_info) {
uint32_t ikey_len = 0;
if (!GetVarint32(&encoded, &ikey_len)) {
return Status::Corruption("Unable to parse internal key length");
}
size_t ts_sz = GetInternalKeyComparator().user_comparator()->timestamp_size();
if (ikey_len < 8 + ts_sz) {
return Status::Corruption("Internal key length too short");
}
if (ikey_len > encoded.size()) {
return Status::Corruption("Internal key length too long");
}
uint32_t value_len = 0;
const size_t key_without_ts_len = ikey_len - ts_sz - 8;
Slice key(encoded.data(), key_without_ts_len);
encoded.remove_prefix(key_without_ts_len);
Slice timestamp(encoded.data(), ts_sz);
encoded.remove_prefix(ts_sz);
uint64_t packed = DecodeFixed64(encoded.data());
ValueType value_type = kMaxValue;
SequenceNumber sequence_number = kMaxSequenceNumber;
UnPackSequenceAndType(packed, &sequence_number, &value_type);
encoded.remove_prefix(8);
if (!GetVarint32(&encoded, &value_len)) {
return Status::Corruption("Unable to parse value length");
}
if (value_len < encoded.size()) {
return Status::Corruption("Value length too short");
}
if (value_len > encoded.size()) {
return Status::Corruption("Value length too long");
}
Slice value(encoded.data(), value_len);
return kv_prot_info.StripS(sequence_number)
.StripKVOT(key, value, value_type, timestamp)
.GetStatus();
}
Status MemTable::Add(SequenceNumber s, ValueType type,
const Slice& key, /* user key */
const Slice& value, bool allow_concurrent,
const Slice& value,
const ProtectionInfoKVOTS64* kv_prot_info,
bool allow_concurrent,
MemTablePostProcessInfo* post_process_info, void** hint) {
// Format of an entry is concatenation of:
// key_size : varint32 of internal_key.size()
@ -514,6 +561,15 @@ Status MemTable::Add(SequenceNumber s, ValueType type,
p = EncodeVarint32(p, val_size);
memcpy(p, value.data(), val_size);
assert((unsigned)(p + val_size - buf) == (unsigned)encoded_len);
if (kv_prot_info != nullptr) {
Slice encoded(buf, encoded_len);
TEST_SYNC_POINT_CALLBACK("MemTable::Add:Encoded", &encoded);
Status status = VerifyEncodedEntry(encoded, *kv_prot_info);
if (!status.ok()) {
return status;
}
}
size_t ts_sz = GetInternalKeyComparator().user_comparator()->timestamp_size();
Slice key_without_ts = StripTimestampFromUserKey(key, ts_sz);
@ -979,7 +1035,8 @@ void MemTable::MultiGet(const ReadOptions& read_options, MultiGetRange* range,
}
Status MemTable::Update(SequenceNumber seq, const Slice& key,
const Slice& value) {
const Slice& value,
const ProtectionInfoKVOTS64* kv_prot_info) {
LookupKey lkey(key, seq);
Slice mem_key = lkey.memtable_key();
@ -1023,6 +1080,13 @@ Status MemTable::Update(SequenceNumber seq, const Slice& key,
(unsigned)(VarintLength(key_length) + key_length +
VarintLength(value.size()) + value.size()));
RecordTick(moptions_.statistics, NUMBER_KEYS_UPDATED);
if (kv_prot_info != nullptr) {
ProtectionInfoKVOTS64 updated_kv_prot_info(*kv_prot_info);
// `seq` is swallowed and `existing_seq` prevails.
updated_kv_prot_info.UpdateS(seq, existing_seq);
Slice encoded(entry, p + value.size() - entry);
return VerifyEncodedEntry(encoded, updated_kv_prot_info);
}
return Status::OK();
}
}
@ -1030,11 +1094,12 @@ Status MemTable::Update(SequenceNumber seq, const Slice& key,
}
// The latest value is not `kTypeValue` or key doesn't exist
return Add(seq, kTypeValue, key, value);
return Add(seq, kTypeValue, key, value, kv_prot_info);
}
Status MemTable::UpdateCallback(SequenceNumber seq, const Slice& key,
const Slice& delta) {
const Slice& delta,
const ProtectionInfoKVOTS64* kv_prot_info) {
LookupKey lkey(key, seq);
Slice memkey = lkey.memtable_key();
@ -1060,8 +1125,8 @@ Status MemTable::UpdateCallback(SequenceNumber seq, const Slice& key,
// Correct user key
const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8);
ValueType type;
uint64_t unused;
UnPackSequenceAndType(tag, &unused, &type);
uint64_t existing_seq;
UnPackSequenceAndType(tag, &existing_seq, &type);
switch (type) {
case kTypeValue: {
Slice prev_value = GetLengthPrefixedSlice(key_ptr + key_length);
@ -1088,9 +1153,27 @@ Status MemTable::UpdateCallback(SequenceNumber seq, const Slice& key,
}
RecordTick(moptions_.statistics, NUMBER_KEYS_UPDATED);
UpdateFlushState();
if (kv_prot_info != nullptr) {
ProtectionInfoKVOTS64 updated_kv_prot_info(*kv_prot_info);
// `seq` is swallowed and `existing_seq` prevails.
updated_kv_prot_info.UpdateS(seq, existing_seq);
updated_kv_prot_info.UpdateV(delta,
Slice(prev_buffer, new_prev_size));
Slice encoded(entry, prev_buffer + new_prev_size - entry);
return VerifyEncodedEntry(encoded, updated_kv_prot_info);
}
return Status::OK();
} else if (status == UpdateStatus::UPDATED) {
Status s = Add(seq, kTypeValue, key, Slice(str_value));
Status s;
if (kv_prot_info != nullptr) {
ProtectionInfoKVOTS64 updated_kv_prot_info(*kv_prot_info);
updated_kv_prot_info.UpdateV(delta, str_value);
s = Add(seq, kTypeValue, key, Slice(str_value),
&updated_kv_prot_info);
} else {
s = Add(seq, kTypeValue, key, Slice(str_value),
nullptr /* kv_prot_info */);
}
RecordTick(moptions_.statistics, NUMBER_KEYS_WRITTEN);
UpdateFlushState();
return s;

@ -15,7 +15,9 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "db/dbformat.h"
#include "db/kv_checksum.h"
#include "db/range_tombstone_fragmenter.h"
#include "db/read_callback.h"
#include "db/version_edit.h"
@ -175,6 +177,9 @@ class MemTable {
FragmentedRangeTombstoneIterator* NewRangeTombstoneIterator(
const ReadOptions& read_options, SequenceNumber read_seq);
Status VerifyEncodedEntry(Slice encoded,
const ProtectionInfoKVOTS64& kv_prot_info);
// Add an entry into memtable that maps key to value at the
// specified sequence number and with the specified type.
// Typically value will be empty if type==kTypeDeletion.
@ -186,7 +191,8 @@ class MemTable {
// in the memtable and `MemTableRepFactory::CanHandleDuplicatedKey()` is true.
// The next attempt should try a larger value for `seq`.
Status Add(SequenceNumber seq, ValueType type, const Slice& key,
const Slice& value, bool allow_concurrent = false,
const Slice& value, const ProtectionInfoKVOTS64* kv_prot_info,
bool allow_concurrent = false,
MemTablePostProcessInfo* post_process_info = nullptr,
void** hint = nullptr);
@ -250,7 +256,8 @@ class MemTable {
//
// REQUIRES: external synchronization to prevent simultaneous
// operations on the same MemTable.
Status Update(SequenceNumber seq, const Slice& key, const Slice& value);
Status Update(SequenceNumber seq, const Slice& key, const Slice& value,
const ProtectionInfoKVOTS64* kv_prot_info);
// If `key` exists in current memtable with type `kTypeValue` and the existing
// value is at least as large as the new value, updates it in-place. Otherwise
@ -267,7 +274,8 @@ class MemTable {
// REQUIRES: external synchronization to prevent simultaneous
// operations on the same MemTable.
Status UpdateCallback(SequenceNumber seq, const Slice& key,
const Slice& delta);
const Slice& delta,
const ProtectionInfoKVOTS64* kv_prot_info);
// Returns the number of successive merge entries starting from the newest
// entry for the key up to the last non-merge entry or last entry for the

@ -243,10 +243,14 @@ TEST_F(MemTableListTest, GetTest) {
mem->Ref();
// Write some keys to this memtable.
ASSERT_OK(mem->Add(++seq, kTypeDeletion, "key1", ""));
ASSERT_OK(mem->Add(++seq, kTypeValue, "key2", "value2"));
ASSERT_OK(mem->Add(++seq, kTypeValue, "key1", "value1"));
ASSERT_OK(mem->Add(++seq, kTypeValue, "key2", "value2.2"));
ASSERT_OK(
mem->Add(++seq, kTypeDeletion, "key1", "", nullptr /* kv_prot_info */));
ASSERT_OK(mem->Add(++seq, kTypeValue, "key2", "value2",
nullptr /* kv_prot_info */));
ASSERT_OK(mem->Add(++seq, kTypeValue, "key1", "value1",
nullptr /* kv_prot_info */));
ASSERT_OK(mem->Add(++seq, kTypeValue, "key2", "value2.2",
nullptr /* kv_prot_info */));
// Fetch the newly written keys
merge_context.Clear();
@ -284,8 +288,10 @@ TEST_F(MemTableListTest, GetTest) {
kMaxSequenceNumber, 0 /* column_family_id */);
mem2->Ref();
ASSERT_OK(mem2->Add(++seq, kTypeDeletion, "key1", ""));
ASSERT_OK(mem2->Add(++seq, kTypeValue, "key2", "value2.3"));
ASSERT_OK(
mem2->Add(++seq, kTypeDeletion, "key1", "", nullptr /* kv_prot_info */));
ASSERT_OK(mem2->Add(++seq, kTypeValue, "key2", "value2.3",
nullptr /* kv_prot_info */));
// Add second memtable to list
list.Add(mem2, &to_delete);
@ -360,9 +366,12 @@ TEST_F(MemTableListTest, GetFromHistoryTest) {
mem->Ref();
// Write some keys to this memtable.
ASSERT_OK(mem->Add(++seq, kTypeDeletion, "key1", ""));
ASSERT_OK(mem->Add(++seq, kTypeValue, "key2", "value2"));
ASSERT_OK(mem->Add(++seq, kTypeValue, "key2", "value2.2"));
ASSERT_OK(
mem->Add(++seq, kTypeDeletion, "key1", "", nullptr /* kv_prot_info */));
ASSERT_OK(mem->Add(++seq, kTypeValue, "key2", "value2",
nullptr /* kv_prot_info */));
ASSERT_OK(mem->Add(++seq, kTypeValue, "key2", "value2.2",
nullptr /* kv_prot_info */));
// Fetch the newly written keys
merge_context.Clear();
@ -444,8 +453,10 @@ TEST_F(MemTableListTest, GetFromHistoryTest) {
kMaxSequenceNumber, 0 /* column_family_id */);
mem2->Ref();
ASSERT_OK(mem2->Add(++seq, kTypeDeletion, "key1", ""));
ASSERT_OK(mem2->Add(++seq, kTypeValue, "key3", "value3"));
ASSERT_OK(
mem2->Add(++seq, kTypeDeletion, "key1", "", nullptr /* kv_prot_info */));
ASSERT_OK(mem2->Add(++seq, kTypeValue, "key3", "value3",
nullptr /* kv_prot_info */));
// Add second memtable to list
list.Add(mem2, &to_delete);
@ -555,11 +566,16 @@ TEST_F(MemTableListTest, FlushPendingTest) {
std::string value;
MergeContext merge_context;
ASSERT_OK(mem->Add(++seq, kTypeValue, "key1", ToString(i)));
ASSERT_OK(mem->Add(++seq, kTypeValue, "keyN" + ToString(i), "valueN"));
ASSERT_OK(mem->Add(++seq, kTypeValue, "keyX" + ToString(i), "value"));
ASSERT_OK(mem->Add(++seq, kTypeValue, "keyM" + ToString(i), "valueM"));
ASSERT_OK(mem->Add(++seq, kTypeDeletion, "keyX" + ToString(i), ""));
ASSERT_OK(mem->Add(++seq, kTypeValue, "key1", ToString(i),
nullptr /* kv_prot_info */));
ASSERT_OK(mem->Add(++seq, kTypeValue, "keyN" + ToString(i), "valueN",
nullptr /* kv_prot_info */));
ASSERT_OK(mem->Add(++seq, kTypeValue, "keyX" + ToString(i), "value",
nullptr /* kv_prot_info */));
ASSERT_OK(mem->Add(++seq, kTypeValue, "keyM" + ToString(i), "valueM",
nullptr /* kv_prot_info */));
ASSERT_OK(mem->Add(++seq, kTypeDeletion, "keyX" + ToString(i), "",
nullptr /* kv_prot_info */));
tables.push_back(mem);
}
@ -824,11 +840,16 @@ TEST_F(MemTableListTest, AtomicFlusTest) {
std::string value;
ASSERT_OK(mem->Add(++seq, kTypeValue, "key1", ToString(i)));
ASSERT_OK(mem->Add(++seq, kTypeValue, "keyN" + ToString(i), "valueN"));
ASSERT_OK(mem->Add(++seq, kTypeValue, "keyX" + ToString(i), "value"));
ASSERT_OK(mem->Add(++seq, kTypeValue, "keyM" + ToString(i), "valueM"));
ASSERT_OK(mem->Add(++seq, kTypeDeletion, "keyX" + ToString(i), ""));
ASSERT_OK(mem->Add(++seq, kTypeValue, "key1", ToString(i),
nullptr /* kv_prot_info */));
ASSERT_OK(mem->Add(++seq, kTypeValue, "keyN" + ToString(i), "valueN",
nullptr /* kv_prot_info */));
ASSERT_OK(mem->Add(++seq, kTypeValue, "keyX" + ToString(i), "value",
nullptr /* kv_prot_info */));
ASSERT_OK(mem->Add(++seq, kTypeValue, "keyM" + ToString(i), "valueM",
nullptr /* kv_prot_info */));
ASSERT_OK(mem->Add(++seq, kTypeDeletion, "keyX" + ToString(i), "",
nullptr /* kv_prot_info */));
elem.push_back(mem);
}

@ -74,7 +74,7 @@ TableCache::TableCache(const ImmutableCFOptions& ioptions,
cache_(cache),
immortal_tables_(false),
block_cache_tracer_(block_cache_tracer),
loader_mutex_(kLoadConcurency, GetSliceNPHash64),
loader_mutex_(kLoadConcurency, kGetSliceNPHash64UnseededFnPtr),
io_tracer_(io_tracer) {
if (ioptions_.row_cache) {
// If the same cache is shared by multiple instances, we need to

@ -46,6 +46,7 @@
#include "db/db_impl/db_impl.h"
#include "db/dbformat.h"
#include "db/flush_scheduler.h"
#include "db/kv_checksum.h"
#include "db/memtable.h"
#include "db/merge_context.h"
#include "db/snapshot_impl.h"
@ -141,10 +142,14 @@ struct BatchContentClassifier : public WriteBatch::Handler {
class TimestampAssigner : public WriteBatch::Handler {
public:
explicit TimestampAssigner(const Slice& ts)
: timestamp_(ts), timestamps_(kEmptyTimestampList) {}
explicit TimestampAssigner(const std::vector<Slice>& ts_list)
: timestamps_(ts_list) {
explicit TimestampAssigner(const Slice& ts,
WriteBatch::ProtectionInfo* prot_info)
: timestamp_(ts),
timestamps_(kEmptyTimestampList),
prot_info_(prot_info) {}
explicit TimestampAssigner(const std::vector<Slice>& ts_list,
WriteBatch::ProtectionInfo* prot_info)
: timestamps_(ts_list), prot_info_(prot_info) {
SanityCheck();
}
~TimestampAssigner() override {}
@ -168,9 +173,8 @@ class TimestampAssigner : public WriteBatch::Handler {
}
Status DeleteRangeCF(uint32_t, const Slice& begin_key,
const Slice& end_key) override {
const Slice& /* end_key */) override {
AssignTimestamp(begin_key);
AssignTimestamp(end_key);
++idx_;
return Status::OK();
}
@ -222,12 +226,17 @@ class TimestampAssigner : public WriteBatch::Handler {
const Slice& ts = timestamps_.empty() ? timestamp_ : timestamps_[idx_];
size_t ts_sz = ts.size();
char* ptr = const_cast<char*>(key.data() + key.size() - ts_sz);
if (prot_info_ != nullptr) {
Slice old_ts(ptr, ts_sz), new_ts(ts.data(), ts_sz);
prot_info_->entries_[idx_].UpdateT(old_ts, new_ts);
}
memcpy(ptr, ts.data(), ts_sz);
}
static const std::vector<Slice> kEmptyTimestampList;
const Slice timestamp_;
const std::vector<Slice>& timestamps_;
WriteBatch::ProtectionInfo* const prot_info_;
size_t idx_ = 0;
// No copy or move.
@ -259,6 +268,21 @@ WriteBatch::WriteBatch(size_t reserved_bytes, size_t max_bytes, size_t ts_sz)
rep_.resize(WriteBatchInternal::kHeader);
}
WriteBatch::WriteBatch(size_t reserved_bytes, size_t max_bytes, size_t ts_sz,
size_t protection_bytes_per_key)
: content_flags_(0), max_bytes_(max_bytes), rep_(), timestamp_size_(ts_sz) {
// Currently `protection_bytes_per_key` can only be enabled at 8 bytes per
// entry.
assert(protection_bytes_per_key == 0 || protection_bytes_per_key == 8);
if (protection_bytes_per_key != 0) {
prot_info_.reset(new WriteBatch::ProtectionInfo());
}
rep_.reserve((reserved_bytes > WriteBatchInternal::kHeader)
? reserved_bytes
: WriteBatchInternal::kHeader);
rep_.resize(WriteBatchInternal::kHeader);
}
WriteBatch::WriteBatch(const std::string& rep)
: content_flags_(ContentFlags::DEFERRED),
max_bytes_(0),
@ -281,6 +305,10 @@ WriteBatch::WriteBatch(const WriteBatch& src)
save_points_.reset(new SavePoints());
save_points_->stack = src.save_points_->stack;
}
if (src.prot_info_ != nullptr) {
prot_info_.reset(new WriteBatch::ProtectionInfo());
prot_info_->entries_ = src.prot_info_->entries_;
}
}
WriteBatch::WriteBatch(WriteBatch&& src) noexcept
@ -288,6 +316,7 @@ WriteBatch::WriteBatch(WriteBatch&& src) noexcept
wal_term_point_(std::move(src.wal_term_point_)),
content_flags_(src.content_flags_.load(std::memory_order_relaxed)),
max_bytes_(src.max_bytes_),
prot_info_(std::move(src.prot_info_)),
rep_(std::move(src.rep_)),
timestamp_size_(src.timestamp_size_) {}
@ -332,6 +361,9 @@ void WriteBatch::Clear() {
}
}
if (prot_info_ != nullptr) {
prot_info_->entries_.clear();
}
wal_term_point_.clear();
}
@ -360,6 +392,13 @@ void WriteBatch::MarkWalTerminationPoint() {
wal_term_point_.content_flags = content_flags_;
}
size_t WriteBatch::GetProtectionBytesPerKey() const {
if (prot_info_ != nullptr) {
return prot_info_->GetBytesPerKey();
}
return 0;
}
bool WriteBatch::HasPut() const {
return (ComputeContentFlags() & ContentFlags::HAS_PUT) != 0;
}
@ -778,18 +817,31 @@ Status WriteBatchInternal::Put(WriteBatch* b, uint32_t column_family_id,
b->rep_.push_back(static_cast<char>(kTypeColumnFamilyValue));
PutVarint32(&b->rep_, column_family_id);
}
std::string timestamp(b->timestamp_size_, '\0');
if (0 == b->timestamp_size_) {
PutLengthPrefixedSlice(&b->rep_, key);
} else {
PutVarint32(&b->rep_,
static_cast<uint32_t>(key.size() + b->timestamp_size_));
b->rep_.append(key.data(), key.size());
b->rep_.append(b->timestamp_size_, '\0');
b->rep_.append(timestamp);
}
PutLengthPrefixedSlice(&b->rep_, value);
b->content_flags_.store(
b->content_flags_.load(std::memory_order_relaxed) | ContentFlags::HAS_PUT,
std::memory_order_relaxed);
if (b->prot_info_ != nullptr) {
// Technically the optype could've been `kTypeColumnFamilyValue` with the
// CF ID encoded in the `WriteBatch`. That distinction is unimportant
// however since we verify CF ID is correct, as well as all other fields
// (a missing/extra encoded CF ID would corrupt another field). It is
// convenient to consolidate on `kTypeValue` here as that is what will be
// inserted into memtable.
b->prot_info_->entries_.emplace_back(
ProtectionInfo64()
.ProtectKVOT(key, value, kTypeValue, timestamp)
.ProtectC(column_family_id));
}
return save.commit();
}
@ -834,6 +886,7 @@ Status WriteBatchInternal::Put(WriteBatch* b, uint32_t column_family_id,
b->rep_.push_back(static_cast<char>(kTypeColumnFamilyValue));
PutVarint32(&b->rep_, column_family_id);
}
std::string timestamp(b->timestamp_size_, '\0');
if (0 == b->timestamp_size_) {
PutLengthPrefixedSliceParts(&b->rep_, key);
} else {
@ -843,6 +896,14 @@ Status WriteBatchInternal::Put(WriteBatch* b, uint32_t column_family_id,
b->content_flags_.store(
b->content_flags_.load(std::memory_order_relaxed) | ContentFlags::HAS_PUT,
std::memory_order_relaxed);
if (b->prot_info_ != nullptr) {
// See comment in first `WriteBatchInternal::Put()` overload concerning the
// `ValueType` argument passed to `ProtectKVOT()`.
b->prot_info_->entries_.emplace_back(
ProtectionInfo64()
.ProtectKVOT(key, value, kTypeValue, timestamp)
.ProtectC(column_family_id));
}
return save.commit();
}
@ -917,17 +978,26 @@ Status WriteBatchInternal::Delete(WriteBatch* b, uint32_t column_family_id,
b->rep_.push_back(static_cast<char>(kTypeColumnFamilyDeletion));
PutVarint32(&b->rep_, column_family_id);
}
std::string timestamp(b->timestamp_size_, '\0');
if (0 == b->timestamp_size_) {
PutLengthPrefixedSlice(&b->rep_, key);
} else {
PutVarint32(&b->rep_,
static_cast<uint32_t>(key.size() + b->timestamp_size_));
b->rep_.append(key.data(), key.size());
b->rep_.append(b->timestamp_size_, '\0');
b->rep_.append(timestamp);
}
b->content_flags_.store(b->content_flags_.load(std::memory_order_relaxed) |
ContentFlags::HAS_DELETE,
std::memory_order_relaxed);
if (b->prot_info_ != nullptr) {
// See comment in first `WriteBatchInternal::Put()` overload concerning the
// `ValueType` argument passed to `ProtectKVOT()`.
b->prot_info_->entries_.emplace_back(
ProtectionInfo64()
.ProtectKVOT(key, "" /* value */, kTypeDeletion, timestamp)
.ProtectC(column_family_id));
}
return save.commit();
}
@ -946,6 +1016,7 @@ Status WriteBatchInternal::Delete(WriteBatch* b, uint32_t column_family_id,
b->rep_.push_back(static_cast<char>(kTypeColumnFamilyDeletion));
PutVarint32(&b->rep_, column_family_id);
}
std::string timestamp(b->timestamp_size_, '\0');
if (0 == b->timestamp_size_) {
PutLengthPrefixedSliceParts(&b->rep_, key);
} else {
@ -954,6 +1025,16 @@ Status WriteBatchInternal::Delete(WriteBatch* b, uint32_t column_family_id,
b->content_flags_.store(b->content_flags_.load(std::memory_order_relaxed) |
ContentFlags::HAS_DELETE,
std::memory_order_relaxed);
if (b->prot_info_ != nullptr) {
// See comment in first `WriteBatchInternal::Put()` overload concerning the
// `ValueType` argument passed to `ProtectKVOT()`.
b->prot_info_->entries_.emplace_back(
ProtectionInfo64()
.ProtectKVOT(key,
SliceParts(nullptr /* _parts */, 0 /* _num_parts */),
kTypeDeletion, timestamp)
.ProtectC(column_family_id));
}
return save.commit();
}
@ -978,6 +1059,15 @@ Status WriteBatchInternal::SingleDelete(WriteBatch* b,
b->content_flags_.store(b->content_flags_.load(std::memory_order_relaxed) |
ContentFlags::HAS_SINGLE_DELETE,
std::memory_order_relaxed);
if (b->prot_info_ != nullptr) {
// See comment in first `WriteBatchInternal::Put()` overload concerning the
// `ValueType` argument passed to `ProtectKVOT()`.
b->prot_info_->entries_.emplace_back(ProtectionInfo64()
.ProtectKVOT(key, "" /* value */,
kTypeSingleDeletion,
"" /* timestamp */)
.ProtectC(column_family_id));
}
return save.commit();
}
@ -1002,6 +1092,17 @@ Status WriteBatchInternal::SingleDelete(WriteBatch* b,
b->content_flags_.store(b->content_flags_.load(std::memory_order_relaxed) |
ContentFlags::HAS_SINGLE_DELETE,
std::memory_order_relaxed);
if (b->prot_info_ != nullptr) {
// See comment in first `WriteBatchInternal::Put()` overload concerning the
// `ValueType` argument passed to `ProtectKVOT()`.
b->prot_info_->entries_.emplace_back(
ProtectionInfo64()
.ProtectKVOT(key,
SliceParts(nullptr /* _parts */,
0 /* _num_parts */) /* value */,
kTypeSingleDeletion, "" /* timestamp */)
.ProtectC(column_family_id));
}
return save.commit();
}
@ -1027,6 +1128,16 @@ Status WriteBatchInternal::DeleteRange(WriteBatch* b, uint32_t column_family_id,
b->content_flags_.store(b->content_flags_.load(std::memory_order_relaxed) |
ContentFlags::HAS_DELETE_RANGE,
std::memory_order_relaxed);
if (b->prot_info_ != nullptr) {
// See comment in first `WriteBatchInternal::Put()` overload concerning the
// `ValueType` argument passed to `ProtectKVOT()`.
// In `DeleteRange()`, the end key is treated as the value.
b->prot_info_->entries_.emplace_back(ProtectionInfo64()
.ProtectKVOT(begin_key, end_key,
kTypeRangeDeletion,
"" /* timestamp */)
.ProtectC(column_family_id));
}
return save.commit();
}
@ -1052,6 +1163,16 @@ Status WriteBatchInternal::DeleteRange(WriteBatch* b, uint32_t column_family_id,
b->content_flags_.store(b->content_flags_.load(std::memory_order_relaxed) |
ContentFlags::HAS_DELETE_RANGE,
std::memory_order_relaxed);
if (b->prot_info_ != nullptr) {
// See comment in first `WriteBatchInternal::Put()` overload concerning the
// `ValueType` argument passed to `ProtectKVOT()`.
// In `DeleteRange()`, the end key is treated as the value.
b->prot_info_->entries_.emplace_back(ProtectionInfo64()
.ProtectKVOT(begin_key, end_key,
kTypeRangeDeletion,
"" /* timestamp */)
.ProtectC(column_family_id));
}
return save.commit();
}
@ -1084,6 +1205,14 @@ Status WriteBatchInternal::Merge(WriteBatch* b, uint32_t column_family_id,
b->content_flags_.store(b->content_flags_.load(std::memory_order_relaxed) |
ContentFlags::HAS_MERGE,
std::memory_order_relaxed);
if (b->prot_info_ != nullptr) {
// See comment in first `WriteBatchInternal::Put()` overload concerning the
// `ValueType` argument passed to `ProtectKVOT()`.
b->prot_info_->entries_.emplace_back(
ProtectionInfo64()
.ProtectKVOT(key, value, kTypeMerge, "" /* timestamp */)
.ProtectC(column_family_id));
}
return save.commit();
}
@ -1114,6 +1243,14 @@ Status WriteBatchInternal::Merge(WriteBatch* b, uint32_t column_family_id,
b->content_flags_.store(b->content_flags_.load(std::memory_order_relaxed) |
ContentFlags::HAS_MERGE,
std::memory_order_relaxed);
if (b->prot_info_ != nullptr) {
// See comment in first `WriteBatchInternal::Put()` overload concerning the
// `ValueType` argument passed to `ProtectKVOT()`.
b->prot_info_->entries_.emplace_back(
ProtectionInfo64()
.ProtectKVOT(key, value, kTypeMerge, "" /* timestamp */)
.ProtectC(column_family_id));
}
return save.commit();
}
@ -1139,6 +1276,14 @@ Status WriteBatchInternal::PutBlobIndex(WriteBatch* b,
b->content_flags_.store(b->content_flags_.load(std::memory_order_relaxed) |
ContentFlags::HAS_BLOB_INDEX,
std::memory_order_relaxed);
if (b->prot_info_ != nullptr) {
// See comment in first `WriteBatchInternal::Put()` overload concerning the
// `ValueType` argument passed to `ProtectKVOT()`.
b->prot_info_->entries_.emplace_back(
ProtectionInfo64()
.ProtectKVOT(key, value, kTypeBlobIndex, "" /* timestamp */)
.ProtectC(column_family_id));
}
return save.commit();
}
@ -1177,6 +1322,9 @@ Status WriteBatch::RollbackToSavePoint() {
Clear();
} else {
rep_.resize(savepoint.size);
if (prot_info_ != nullptr) {
prot_info_->entries_.resize(savepoint.count);
}
WriteBatchInternal::SetCount(this, savepoint.count);
content_flags_.store(savepoint.content_flags, std::memory_order_relaxed);
}
@ -1196,12 +1344,12 @@ Status WriteBatch::PopSavePoint() {
}
Status WriteBatch::AssignTimestamp(const Slice& ts) {
TimestampAssigner ts_assigner(ts);
TimestampAssigner ts_assigner(ts, prot_info_.get());
return Iterate(&ts_assigner);
}
Status WriteBatch::AssignTimestamps(const std::vector<Slice>& ts_list) {
TimestampAssigner ts_assigner(ts_list);
TimestampAssigner ts_assigner(ts_list, prot_info_.get());
return Iterate(&ts_assigner);
}
@ -1218,6 +1366,8 @@ class MemTableInserter : public WriteBatch::Handler {
DBImpl* db_;
const bool concurrent_memtable_writes_;
bool post_info_created_;
const WriteBatch::ProtectionInfo* prot_info_;
size_t prot_info_idx_;
bool* has_valid_writes_;
// On some (!) platforms just default creating
@ -1280,6 +1430,16 @@ class MemTableInserter : public WriteBatch::Handler {
(&duplicate_detector_)->IsDuplicateKeySeq(column_family_id, key, sequence_);
}
const ProtectionInfoKVOTC64* NextProtectionInfo() {
const ProtectionInfoKVOTC64* res = nullptr;
if (prot_info_ != nullptr) {
assert(prot_info_idx_ < prot_info_->entries_.size());
res = &prot_info_->entries_[prot_info_idx_];
++prot_info_idx_;
}
return res;
}
protected:
bool WriteBeforePrepare() const override { return write_before_prepare_; }
bool WriteAfterCommit() const override { return write_after_commit_; }
@ -1292,6 +1452,7 @@ class MemTableInserter : public WriteBatch::Handler {
bool ignore_missing_column_families,
uint64_t recovering_log_number, DB* db,
bool concurrent_memtable_writes,
const WriteBatch::ProtectionInfo* prot_info,
bool* has_valid_writes = nullptr, bool seq_per_batch = false,
bool batch_per_txn = true, bool hint_per_batch = false)
: sequence_(_sequence),
@ -1304,6 +1465,8 @@ class MemTableInserter : public WriteBatch::Handler {
db_(static_cast_with_check<DBImpl>(db)),
concurrent_memtable_writes_(concurrent_memtable_writes),
post_info_created_(false),
prot_info_(prot_info),
prot_info_idx_(0),
has_valid_writes_(has_valid_writes),
rebuilding_trx_(nullptr),
rebuilding_trx_seq_(0),
@ -1361,6 +1524,10 @@ class MemTableInserter : public WriteBatch::Handler {
}
void set_log_number_ref(uint64_t log) { log_number_ref_ = log; }
void set_prot_info(const WriteBatch::ProtectionInfo* prot_info) {
prot_info_ = prot_info;
prot_info_idx_ = 0;
}
SequenceNumber sequence() const { return sequence_; }
@ -1416,9 +1583,11 @@ class MemTableInserter : public WriteBatch::Handler {
}
Status PutCFImpl(uint32_t column_family_id, const Slice& key,
const Slice& value, ValueType value_type) {
const Slice& value, ValueType value_type,
const ProtectionInfoKVOTS64* kv_prot_info) {
// optimize for non-recovery mode
if (UNLIKELY(write_after_commit_ && rebuilding_trx_ != nullptr)) {
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
return WriteBatchInternal::Put(rebuilding_trx_, column_family_id, key,
value);
// else insert the values to the memtable right away
@ -1430,6 +1599,7 @@ class MemTableInserter : public WriteBatch::Handler {
assert(!write_after_commit_);
// The CF is probably flushed and hence no need for insert but we still
// need to keep track of the keys for upcoming rollback/commit.
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
ret_status = WriteBatchInternal::Put(rebuilding_trx_, column_family_id,
key, value);
if (ret_status.ok()) {
@ -1449,15 +1619,15 @@ class MemTableInserter : public WriteBatch::Handler {
assert(!seq_per_batch_ || !moptions->inplace_update_support);
if (!moptions->inplace_update_support) {
ret_status =
mem->Add(sequence_, value_type, key, value,
mem->Add(sequence_, value_type, key, value, kv_prot_info,
concurrent_memtable_writes_, get_post_process_info(mem),
hint_per_batch_ ? &GetHintMap()[mem] : nullptr);
} else if (moptions->inplace_callback == nullptr) {
assert(!concurrent_memtable_writes_);
ret_status = mem->Update(sequence_, key, value);
ret_status = mem->Update(sequence_, key, value, kv_prot_info);
} else {
assert(!concurrent_memtable_writes_);
ret_status = mem->UpdateCallback(sequence_, key, value);
ret_status = mem->UpdateCallback(sequence_, key, value, kv_prot_info);
if (ret_status.IsNotFound()) {
// key not found in memtable. Do sst get, update, add
SnapshotImpl read_from_snapshot;
@ -1485,7 +1655,6 @@ class MemTableInserter : public WriteBatch::Handler {
} else {
ret_status = Status::OK();
}
if (ret_status.ok()) {
UpdateStatus update_status;
char* prev_buffer = const_cast<char*>(prev_value.c_str());
@ -1500,16 +1669,35 @@ class MemTableInserter : public WriteBatch::Handler {
}
if (update_status == UpdateStatus::UPDATED_INPLACE) {
assert(get_status.ok());
// prev_value is updated in-place with final value.
ret_status = mem->Add(sequence_, value_type, key,
Slice(prev_buffer, prev_size));
if (kv_prot_info != nullptr) {
ProtectionInfoKVOTS64 updated_kv_prot_info(*kv_prot_info);
updated_kv_prot_info.UpdateV(value,
Slice(prev_buffer, prev_size));
// prev_value is updated in-place with final value.
ret_status = mem->Add(sequence_, value_type, key,
Slice(prev_buffer, prev_size),
&updated_kv_prot_info);
} else {
ret_status = mem->Add(sequence_, value_type, key,
Slice(prev_buffer, prev_size),
nullptr /* kv_prot_info */);
}
if (ret_status.ok()) {
RecordTick(moptions->statistics, NUMBER_KEYS_WRITTEN);
}
} else if (update_status == UpdateStatus::UPDATED) {
// merged_value contains the final value.
ret_status =
mem->Add(sequence_, value_type, key, Slice(merged_value));
if (kv_prot_info != nullptr) {
ProtectionInfoKVOTS64 updated_kv_prot_info(*kv_prot_info);
updated_kv_prot_info.UpdateV(value, merged_value);
// merged_value contains the final value.
ret_status = mem->Add(sequence_, value_type, key,
Slice(merged_value), &updated_kv_prot_info);
} else {
// merged_value contains the final value.
ret_status =
mem->Add(sequence_, value_type, key, Slice(merged_value),
nullptr /* kv_prot_info */);
}
if (ret_status.ok()) {
RecordTick(moptions->statistics, NUMBER_KEYS_WRITTEN);
}
@ -1532,6 +1720,7 @@ class MemTableInserter : public WriteBatch::Handler {
// away. So we only need to add to it when `ret_status.ok()`.
if (UNLIKELY(ret_status.ok() && rebuilding_trx_ != nullptr)) {
assert(!write_after_commit_);
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
ret_status = WriteBatchInternal::Put(rebuilding_trx_, column_family_id,
key, value);
}
@ -1540,15 +1729,25 @@ class MemTableInserter : public WriteBatch::Handler {
Status PutCF(uint32_t column_family_id, const Slice& key,
const Slice& value) override {
return PutCFImpl(column_family_id, key, value, kTypeValue);
const auto* kv_prot_info = NextProtectionInfo();
if (kv_prot_info != nullptr) {
// Memtable needs seqno, doesn't need CF ID
auto mem_kv_prot_info =
kv_prot_info->StripC(column_family_id).ProtectS(sequence_);
return PutCFImpl(column_family_id, key, value, kTypeValue,
&mem_kv_prot_info);
}
return PutCFImpl(column_family_id, key, value, kTypeValue,
nullptr /* kv_prot_info */);
}
Status DeleteImpl(uint32_t /*column_family_id*/, const Slice& key,
const Slice& value, ValueType delete_type) {
const Slice& value, ValueType delete_type,
const ProtectionInfoKVOTS64* kv_prot_info) {
Status ret_status;
MemTable* mem = cf_mems_->GetMemTable();
ret_status =
mem->Add(sequence_, delete_type, key, value,
mem->Add(sequence_, delete_type, key, value, kv_prot_info,
concurrent_memtable_writes_, get_post_process_info(mem),
hint_per_batch_ ? &GetHintMap()[mem] : nullptr);
if (UNLIKELY(ret_status.IsTryAgain())) {
@ -1563,8 +1762,10 @@ class MemTableInserter : public WriteBatch::Handler {
}
Status DeleteCF(uint32_t column_family_id, const Slice& key) override {
const auto* kv_prot_info = NextProtectionInfo();
// optimize for non-recovery mode
if (UNLIKELY(write_after_commit_ && rebuilding_trx_ != nullptr)) {
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
return WriteBatchInternal::Delete(rebuilding_trx_, column_family_id, key);
// else insert the values to the memtable right away
}
@ -1575,6 +1776,7 @@ class MemTableInserter : public WriteBatch::Handler {
assert(!write_after_commit_);
// The CF is probably flushed and hence no need for insert but we still
// need to keep track of the keys for upcoming rollback/commit.
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
ret_status =
WriteBatchInternal::Delete(rebuilding_trx_, column_family_id, key);
if (ret_status.ok()) {
@ -1593,7 +1795,16 @@ class MemTableInserter : public WriteBatch::Handler {
: 0;
const ValueType delete_type =
(0 == ts_sz) ? kTypeDeletion : kTypeDeletionWithTimestamp;
ret_status = DeleteImpl(column_family_id, key, Slice(), delete_type);
if (kv_prot_info != nullptr) {
auto mem_kv_prot_info =
kv_prot_info->StripC(column_family_id).ProtectS(sequence_);
mem_kv_prot_info.UpdateO(kTypeDeletion, delete_type);
ret_status = DeleteImpl(column_family_id, key, Slice(), delete_type,
&mem_kv_prot_info);
} else {
ret_status = DeleteImpl(column_family_id, key, Slice(), delete_type,
nullptr /* kv_prot_info */);
}
// optimize for non-recovery mode
// If `ret_status` is `TryAgain` then the next (successful) try will add
// the key to the rebuilding transaction object. If `ret_status` is
@ -1601,6 +1812,7 @@ class MemTableInserter : public WriteBatch::Handler {
// away. So we only need to add to it when `ret_status.ok()`.
if (UNLIKELY(ret_status.ok() && rebuilding_trx_ != nullptr)) {
assert(!write_after_commit_);
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
ret_status =
WriteBatchInternal::Delete(rebuilding_trx_, column_family_id, key);
}
@ -1608,8 +1820,10 @@ class MemTableInserter : public WriteBatch::Handler {
}
Status SingleDeleteCF(uint32_t column_family_id, const Slice& key) override {
const auto* kv_prot_info = NextProtectionInfo();
// optimize for non-recovery mode
if (UNLIKELY(write_after_commit_ && rebuilding_trx_ != nullptr)) {
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
return WriteBatchInternal::SingleDelete(rebuilding_trx_, column_family_id,
key);
// else insert the values to the memtable right away
@ -1621,6 +1835,7 @@ class MemTableInserter : public WriteBatch::Handler {
assert(!write_after_commit_);
// The CF is probably flushed and hence no need for insert but we still
// need to keep track of the keys for upcoming rollback/commit.
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
ret_status = WriteBatchInternal::SingleDelete(rebuilding_trx_,
column_family_id, key);
if (ret_status.ok()) {
@ -1633,8 +1848,15 @@ class MemTableInserter : public WriteBatch::Handler {
}
assert(ret_status.ok());
ret_status =
DeleteImpl(column_family_id, key, Slice(), kTypeSingleDeletion);
if (kv_prot_info != nullptr) {
auto mem_kv_prot_info =
kv_prot_info->StripC(column_family_id).ProtectS(sequence_);
ret_status = DeleteImpl(column_family_id, key, Slice(),
kTypeSingleDeletion, &mem_kv_prot_info);
} else {
ret_status = DeleteImpl(column_family_id, key, Slice(),
kTypeSingleDeletion, nullptr /* kv_prot_info */);
}
// optimize for non-recovery mode
// If `ret_status` is `TryAgain` then the next (successful) try will add
// the key to the rebuilding transaction object. If `ret_status` is
@ -1642,6 +1864,7 @@ class MemTableInserter : public WriteBatch::Handler {
// away. So we only need to add to it when `ret_status.ok()`.
if (UNLIKELY(ret_status.ok() && rebuilding_trx_ != nullptr)) {
assert(!write_after_commit_);
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
ret_status = WriteBatchInternal::SingleDelete(rebuilding_trx_,
column_family_id, key);
}
@ -1650,8 +1873,10 @@ class MemTableInserter : public WriteBatch::Handler {
Status DeleteRangeCF(uint32_t column_family_id, const Slice& begin_key,
const Slice& end_key) override {
const auto* kv_prot_info = NextProtectionInfo();
// optimize for non-recovery mode
if (UNLIKELY(write_after_commit_ && rebuilding_trx_ != nullptr)) {
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
return WriteBatchInternal::DeleteRange(rebuilding_trx_, column_family_id,
begin_key, end_key);
// else insert the values to the memtable right away
@ -1663,6 +1888,7 @@ class MemTableInserter : public WriteBatch::Handler {
assert(!write_after_commit_);
// The CF is probably flushed and hence no need for insert but we still
// need to keep track of the keys for upcoming rollback/commit.
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
ret_status = WriteBatchInternal::DeleteRange(
rebuilding_trx_, column_family_id, begin_key, end_key);
if (ret_status.ok()) {
@ -1705,8 +1931,15 @@ class MemTableInserter : public WriteBatch::Handler {
}
}
ret_status =
DeleteImpl(column_family_id, begin_key, end_key, kTypeRangeDeletion);
if (kv_prot_info != nullptr) {
auto mem_kv_prot_info =
kv_prot_info->StripC(column_family_id).ProtectS(sequence_);
ret_status = DeleteImpl(column_family_id, begin_key, end_key,
kTypeRangeDeletion, &mem_kv_prot_info);
} else {
ret_status = DeleteImpl(column_family_id, begin_key, end_key,
kTypeRangeDeletion, nullptr /* kv_prot_info */);
}
// optimize for non-recovery mode
// If `ret_status` is `TryAgain` then the next (successful) try will add
// the key to the rebuilding transaction object. If `ret_status` is
@ -1714,6 +1947,7 @@ class MemTableInserter : public WriteBatch::Handler {
// away. So we only need to add to it when `ret_status.ok()`.
if (UNLIKELY(!ret_status.IsTryAgain() && rebuilding_trx_ != nullptr)) {
assert(!write_after_commit_);
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
ret_status = WriteBatchInternal::DeleteRange(
rebuilding_trx_, column_family_id, begin_key, end_key);
}
@ -1722,8 +1956,10 @@ class MemTableInserter : public WriteBatch::Handler {
Status MergeCF(uint32_t column_family_id, const Slice& key,
const Slice& value) override {
const auto* kv_prot_info = NextProtectionInfo();
// optimize for non-recovery mode
if (UNLIKELY(write_after_commit_ && rebuilding_trx_ != nullptr)) {
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
return WriteBatchInternal::Merge(rebuilding_trx_, column_family_id, key,
value);
// else insert the values to the memtable right away
@ -1735,6 +1971,7 @@ class MemTableInserter : public WriteBatch::Handler {
assert(!write_after_commit_);
// The CF is probably flushed and hence no need for insert but we still
// need to keep track of the keys for upcoming rollback/commit.
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
ret_status = WriteBatchInternal::Merge(rebuilding_trx_,
column_family_id, key, value);
if (ret_status.ok()) {
@ -1802,7 +2039,6 @@ class MemTableInserter : public WriteBatch::Handler {
assert(merge_operator);
std::string new_value;
Status merge_status = MergeHelper::TimedFullMerge(
merge_operator, key, &get_value_slice, {value}, &new_value,
moptions->info_log, moptions->statistics, SystemClock::Default());
@ -1814,16 +2050,35 @@ class MemTableInserter : public WriteBatch::Handler {
} else {
// 3) Add value to memtable
assert(!concurrent_memtable_writes_);
ret_status = mem->Add(sequence_, kTypeValue, key, new_value);
if (kv_prot_info != nullptr) {
auto merged_kv_prot_info =
kv_prot_info->StripC(column_family_id).ProtectS(sequence_);
merged_kv_prot_info.UpdateV(value, new_value);
merged_kv_prot_info.UpdateO(kTypeMerge, kTypeValue);
ret_status = mem->Add(sequence_, kTypeValue, key, new_value,
&merged_kv_prot_info);
} else {
ret_status = mem->Add(sequence_, kTypeValue, key, new_value,
nullptr /* kv_prot_info */);
}
}
}
}
if (!perform_merge) {
assert(ret_status.ok());
// Add merge operand to memtable
ret_status =
mem->Add(sequence_, kTypeMerge, key, value,
concurrent_memtable_writes_, get_post_process_info(mem));
if (kv_prot_info != nullptr) {
auto mem_kv_prot_info =
kv_prot_info->StripC(column_family_id).ProtectS(sequence_);
ret_status =
mem->Add(sequence_, kTypeMerge, key, value, &mem_kv_prot_info,
concurrent_memtable_writes_, get_post_process_info(mem));
} else {
ret_status = mem->Add(
sequence_, kTypeMerge, key, value, nullptr /* kv_prot_info */,
concurrent_memtable_writes_, get_post_process_info(mem));
}
}
if (UNLIKELY(ret_status.IsTryAgain())) {
@ -1841,6 +2096,7 @@ class MemTableInserter : public WriteBatch::Handler {
// away. So we only need to add to it when `ret_status.ok()`.
if (UNLIKELY(ret_status.ok() && rebuilding_trx_ != nullptr)) {
assert(!write_after_commit_);
// TODO(ajkr): propagate `ProtectionInfoKVOTS64`.
ret_status = WriteBatchInternal::Merge(rebuilding_trx_, column_family_id,
key, value);
}
@ -1849,8 +2105,18 @@ class MemTableInserter : public WriteBatch::Handler {
Status PutBlobIndexCF(uint32_t column_family_id, const Slice& key,
const Slice& value) override {
// Same as PutCF except for value type.
return PutCFImpl(column_family_id, key, value, kTypeBlobIndex);
const auto* kv_prot_info = NextProtectionInfo();
if (kv_prot_info != nullptr) {
// Memtable needs seqno, doesn't need CF ID
auto mem_kv_prot_info =
kv_prot_info->StripC(column_family_id).ProtectS(sequence_);
// Same as PutCF except for value type.
return PutCFImpl(column_family_id, key, value, kTypeBlobIndex,
&mem_kv_prot_info);
} else {
return PutCFImpl(column_family_id, key, value, kTypeBlobIndex,
nullptr /* kv_prot_info */);
}
}
void CheckMemtableFull() {
@ -2056,8 +2322,8 @@ Status WriteBatchInternal::InsertInto(
MemTableInserter inserter(
sequence, memtables, flush_scheduler, trim_history_scheduler,
ignore_missing_column_families, recovery_log_number, db,
concurrent_memtable_writes, nullptr /*has_valid_writes*/, seq_per_batch,
batch_per_txn);
concurrent_memtable_writes, nullptr /* prot_info */,
nullptr /*has_valid_writes*/, seq_per_batch, batch_per_txn);
for (auto w : write_group) {
if (w->CallbackFailed()) {
continue;
@ -2070,6 +2336,7 @@ Status WriteBatchInternal::InsertInto(
}
SetSequence(w->batch, inserter.sequence());
inserter.set_log_number_ref(w->log_ref);
inserter.set_prot_info(w->batch->prot_info_.get());
w->status = w->batch->Iterate(&inserter);
if (!w->status.ok()) {
return w->status;
@ -2091,13 +2358,15 @@ Status WriteBatchInternal::InsertInto(
(void)batch_cnt;
#endif
assert(writer->ShouldWriteToMemtable());
MemTableInserter inserter(
sequence, memtables, flush_scheduler, trim_history_scheduler,
ignore_missing_column_families, log_number, db,
concurrent_memtable_writes, nullptr /*has_valid_writes*/, seq_per_batch,
batch_per_txn, hint_per_batch);
MemTableInserter inserter(sequence, memtables, flush_scheduler,
trim_history_scheduler,
ignore_missing_column_families, log_number, db,
concurrent_memtable_writes, nullptr /* prot_info */,
nullptr /*has_valid_writes*/, seq_per_batch,
batch_per_txn, hint_per_batch);
SetSequence(writer->batch, sequence);
inserter.set_log_number_ref(writer->log_ref);
inserter.set_prot_info(writer->batch->prot_info_.get());
Status s = writer->batch->Iterate(&inserter);
assert(!seq_per_batch || batch_cnt != 0);
assert(!seq_per_batch || inserter.sequence() - sequence == batch_cnt);
@ -2117,8 +2386,8 @@ Status WriteBatchInternal::InsertInto(
MemTableInserter inserter(Sequence(batch), memtables, flush_scheduler,
trim_history_scheduler,
ignore_missing_column_families, log_number, db,
concurrent_memtable_writes, has_valid_writes,
seq_per_batch, batch_per_txn);
concurrent_memtable_writes, batch->prot_info_.get(),
has_valid_writes, seq_per_batch, batch_per_txn);
Status s = batch->Iterate(&inserter);
if (next_seq != nullptr) {
*next_seq = inserter.sequence();
@ -2131,6 +2400,7 @@ Status WriteBatchInternal::InsertInto(
Status WriteBatchInternal::SetContents(WriteBatch* b, const Slice& contents) {
assert(contents.size() >= WriteBatchInternal::kHeader);
assert(b->prot_info_ == nullptr);
b->rep_.assign(contents.data(), contents.size());
b->content_flags_.store(ContentFlags::DEFERRED, std::memory_order_relaxed);
return Status::OK();
@ -2138,6 +2408,8 @@ Status WriteBatchInternal::SetContents(WriteBatch* b, const Slice& contents) {
Status WriteBatchInternal::Append(WriteBatch* dst, const WriteBatch* src,
const bool wal_only) {
assert(dst->Count() == 0 ||
(dst->prot_info_ == nullptr) == (src->prot_info_ == nullptr));
size_t src_len;
int src_count;
uint32_t src_flags;
@ -2154,6 +2426,13 @@ Status WriteBatchInternal::Append(WriteBatch* dst, const WriteBatch* src,
src_flags = src->content_flags_.load(std::memory_order_relaxed);
}
if (dst->prot_info_ != nullptr) {
std::copy(src->prot_info_->entries_.begin(),
src->prot_info_->entries_.begin() + src_count,
std::back_inserter(dst->prot_info_->entries_));
} else if (src->prot_info_ != nullptr) {
dst->prot_info_.reset(new WriteBatch::ProtectionInfo(*src->prot_info_));
}
SetCount(dst, Count(dst) + src_count);
assert(src->rep_.size() >= WriteBatchInternal::kHeader);
dst->rep_.append(src->rep_.data() + WriteBatchInternal::kHeader, src_len);

@ -9,7 +9,9 @@
#pragma once
#include <vector>
#include "db/flush_scheduler.h"
#include "db/kv_checksum.h"
#include "db/trim_history_scheduler.h"
#include "db/write_thread.h"
#include "rocksdb/db.h"
@ -61,6 +63,14 @@ class ColumnFamilyMemTablesDefault : public ColumnFamilyMemTables {
MemTable* mem_;
};
struct WriteBatch::ProtectionInfo {
// `WriteBatch` usually doesn't contain a huge number of keys so protecting
// with a fixed, non-configurable eight bytes per key may work well enough.
autovector<ProtectionInfoKVOTC64> entries_;
size_t GetBytesPerKey() const { return 8; }
};
// WriteBatchInternal provides static methods for manipulating a
// WriteBatch that we don't want in the public WriteBatch interface.
class WriteBatchInternal {
@ -232,6 +242,9 @@ class LocalSavePoint {
if (batch_->max_bytes_ && batch_->rep_.size() > batch_->max_bytes_) {
batch_->rep_.resize(savepoint_.size);
WriteBatchInternal::SetCount(batch_, savepoint_.count);
if (batch_->prot_info_ != nullptr) {
batch_->prot_info_->entries_.resize(savepoint_.count);
}
batch_->content_flags_.store(savepoint_.content_flags,
std::memory_order_relaxed);
return Status::MemoryLimit();

@ -465,6 +465,11 @@ size_t WriteThread::EnterAsBatchGroupLeader(Writer* leader,
break;
}
if (w->protection_bytes_per_key != leader->protection_bytes_per_key) {
// Do not mix writes with different levels of integrity protection.
break;
}
if (w->batch == nullptr) {
// Do not include those writes with nullptr batch. Those are not writes,
// those are something else. They want to be alone

@ -119,6 +119,7 @@ class WriteThread {
bool disable_wal;
bool disable_memtable;
size_t batch_cnt; // if non-zero, number of sub-batches in the write batch
size_t protection_bytes_per_key;
PreReleaseCallback* pre_release_callback;
uint64_t log_used; // log number that this batch was inserted into
uint64_t log_ref; // log number that memtable insert should reference
@ -142,6 +143,7 @@ class WriteThread {
disable_wal(false),
disable_memtable(false),
batch_cnt(0),
protection_bytes_per_key(0),
pre_release_callback(nullptr),
log_used(0),
log_ref(0),
@ -163,6 +165,7 @@ class WriteThread {
disable_wal(write_options.disableWAL),
disable_memtable(_disable_memtable),
batch_cnt(_batch_cnt),
protection_bytes_per_key(_batch->GetProtectionBytesPerKey()),
pre_release_callback(_pre_release_callback),
log_used(0),
log_ref(_log_ref),

@ -31,7 +31,8 @@ class BatchedOpsStressTest : public StressTest {
std::string keys[10] = {"9", "8", "7", "6", "5", "4", "3", "2", "1", "0"};
std::string values[10] = {"9", "8", "7", "6", "5", "4", "3", "2", "1", "0"};
Slice value_slices[10];
WriteBatch batch;
WriteBatch batch(0 /* reserved_bytes */, 0 /* max_bytes */, 0 /* ts_sz */,
FLAGS_batch_protection_bytes_per_key);
Status s;
auto cfh = column_families_[rand_column_families[0]];
std::string key_str = Key(rand_keys[0]);
@ -66,7 +67,8 @@ class BatchedOpsStressTest : public StressTest {
std::unique_ptr<MutexLock>& /* lock */) override {
std::string keys[10] = {"9", "7", "5", "3", "1", "8", "6", "4", "2", "0"};
WriteBatch batch;
WriteBatch batch(0 /* reserved_bytes */, 0 /* max_bytes */, 0 /* ts_sz */,
FLAGS_batch_protection_bytes_per_key);
Status s;
auto cfh = column_families_[rand_column_families[0]];
std::string key_str = Key(rand_keys[0]);

@ -246,6 +246,7 @@ DECLARE_bool(best_efforts_recovery);
DECLARE_bool(skip_verifydb);
DECLARE_bool(enable_compaction_filter);
DECLARE_bool(paranoid_file_checks);
DECLARE_uint64(batch_protection_bytes_per_key);
const long KB = 1024;
const int kRandomValueMaxFactor = 3;

@ -753,6 +753,11 @@ DEFINE_bool(paranoid_file_checks, true,
"After writing every SST file, reopen it and read all the keys "
"and validate checksums");
DEFINE_uint64(batch_protection_bytes_per_key, 0,
"If nonzero, enables integrity protection in `WriteBatch` at the "
"specified number of bytes per key. Currently the only supported "
"nonzero value is eight.");
DEFINE_string(file_checksum_impl, "none",
"Name of an implementation for file_checksum_gen_factory, or "
"\"none\" for null.");

@ -286,6 +286,13 @@ int db_stress_tool(int argc, char** argv) {
"test_batches_snapshots must all be 0 when using compaction filter\n");
exit(1);
}
if (FLAGS_batch_protection_bytes_per_key > 0 &&
!FLAGS_test_batches_snapshots) {
fprintf(stderr,
"Error: test_batches_snapshots must be enabled when "
"batch_protection_bytes_per_key > 0\n");
exit(1);
}
rocksdb_kill_odds = FLAGS_kill_random_test;
rocksdb_kill_exclude_prefixes = SplitString(FLAGS_kill_exclude_prefixes);

@ -62,6 +62,11 @@ class WriteBatch : public WriteBatchBase {
public:
explicit WriteBatch(size_t reserved_bytes = 0, size_t max_bytes = 0);
explicit WriteBatch(size_t reserved_bytes, size_t max_bytes, size_t ts_sz);
// `protection_bytes_per_key` is the number of bytes used to store
// protection information for each key entry. Currently supported values are
// zero (disabled) and eight.
explicit WriteBatch(size_t reserved_bytes, size_t max_bytes, size_t ts_sz,
size_t protection_bytes_per_key);
~WriteBatch() override;
using WriteBatchBase::Put;
@ -338,6 +343,9 @@ class WriteBatch : public WriteBatchBase {
void SetMaxBytes(size_t max_bytes) override { max_bytes_ = max_bytes; }
struct ProtectionInfo;
size_t GetProtectionBytesPerKey() const;
private:
friend class WriteBatchInternal;
friend class LocalSavePoint;
@ -367,11 +375,11 @@ class WriteBatch : public WriteBatchBase {
// more details.
bool is_latest_persistent_state_ = false;
std::unique_ptr<ProtectionInfo> prot_info_;
protected:
std::string rep_; // See comment in write_batch.cc for the format of rep_
const size_t timestamp_size_;
// Intentionally copyable
};
} // namespace ROCKSDB_NAMESPACE

@ -402,6 +402,7 @@ TEST_MAIN_SOURCES = \
db/db_iter_test.cc \
db/db_iter_stress_test.cc \
db/db_iterator_test.cc \
db/db_kv_checksum_test.cc \
db/db_log_iter_test.cc \
db/db_memtable_test.cc \
db/db_merge_operator_test.cc \

@ -508,7 +508,8 @@ class MemTableConstructor: public Constructor {
memtable_->Ref();
int seq = 1;
for (const auto& kv : kv_map) {
Status s = memtable_->Add(seq, kTypeValue, kv.first, kv.second);
Status s = memtable_->Add(seq, kTypeValue, kv.first, kv.second,
nullptr /* kv_prot_info */);
if (!s.ok()) {
return s;
}

@ -31,6 +31,7 @@ default_params = {
"backup_max_size": 100 * 1024 * 1024,
# Consider larger number when backups considered more stable
"backup_one_in": 100000,
"batch_protection_bytes_per_key": lambda: random.choice([0, 8]),
"block_size": 16384,
"bloom_bits": lambda: random.choice([random.randint(0,19),
random.lognormvariate(2.3, 1.3)]),
@ -330,6 +331,8 @@ def finalize_and_sanitize(src_params):
dest_params["readpercent"] += dest_params.get("iterpercent", 10)
dest_params["iterpercent"] = 0
dest_params["test_batches_snapshots"] = 0
if dest_params.get("test_batches_snapshots") == 0:
dest_params["batch_protection_bytes_per_key"] = 0
return dest_params
def gen_cmd_params(args):

@ -15,6 +15,8 @@
namespace ROCKSDB_NAMESPACE {
uint64_t (*kGetSliceNPHash64UnseededFnPtr)(const Slice&) = &GetSliceHash64;
uint32_t Hash(const char* data, size_t n, uint32_t seed) {
// MurmurHash1 - fast but mediocre quality
// https://github.com/aappleby/smhasher/wiki/MurmurHash1
@ -80,4 +82,19 @@ uint64_t Hash64(const char* data, size_t n) {
return XXH3p_64bits(data, n);
}
uint64_t GetSlicePartsNPHash64(const SliceParts& data, uint64_t seed) {
// TODO(ajkr): use XXH3 streaming APIs to avoid the copy/allocation.
size_t concat_len = 0;
for (int i = 0; i < data.num_parts; ++i) {
concat_len += data.parts[i].size();
}
std::string concat_data;
concat_data.reserve(concat_len);
for (int i = 0; i < data.num_parts; ++i) {
concat_data.append(data.parts[i].data(), data.parts[i].size());
}
assert(concat_data.size() == concat_len);
return NPHash64(concat_data.data(), concat_len, seed);
}
} // namespace ROCKSDB_NAMESPACE

@ -78,11 +78,22 @@ inline uint32_t BloomHash(const Slice& key) {
inline uint64_t GetSliceHash64(const Slice& key) {
return Hash64(key.data(), key.size());
}
// Provided for convenience for use with template argument deduction, where a
// specific overload needs to be used.
extern uint64_t (*kGetSliceNPHash64UnseededFnPtr)(const Slice&);
inline uint64_t GetSliceNPHash64(const Slice& s) {
return NPHash64(s.data(), s.size());
}
inline uint64_t GetSliceNPHash64(const Slice& s, uint64_t seed) {
return NPHash64(s.data(), s.size(), seed);
}
// Similar to `GetSliceNPHash64()` with `seed`, but input comes from
// concatenation of `Slice`s in `data`.
extern uint64_t GetSlicePartsNPHash64(const SliceParts& data, uint64_t seed);
inline size_t GetSliceRangedNPHash(const Slice& s, size_t range) {
return FastRange64(NPHash64(s.data(), s.size()), range);
}

Loading…
Cancel
Save