Verify write batch checksum before WAL (#10114)

Summary:
Context: WriteBatch can have key-value checksums when it was created `with protection_bytes_per_key > 0`.
This PR added checksum verification for write batches before they are written to WAL.

Pull Request resolved: https://github.com/facebook/rocksdb/pull/10114

Test Plan:
- Added new unit tests to db_kv_checksum_test.cc: `make check -j32`
- benchmark on performance regression: `./db_bench --benchmarks=fillrandom[-X20] -db=/dev/shm/test_rocksdb -write_batch_protection_bytes_per_key=8`
  - Pre-PR:
`
fillrandom [AVG    20 runs] : 198875 (± 3006) ops/sec;   22.0 (± 0.3) MB/sec
`
  - Post-PR:
`
fillrandom [AVG    20 runs] : 196487 (± 2279) ops/sec;   21.7 (± 0.3) MB/sec
`
  Mean regressed about 1% (198875 -> 196487 ops/sec).

Reviewed By: ajkr

Differential Revision: D36917464

Pulled By: cbi42

fbshipit-source-id: 29beb74edf65f04b1a890b4f650d873dc7ed790d
main
Changyu Bi 2 years ago committed by Facebook GitHub Bot
parent 2e5a323dbd
commit 9882652b0e
  1. 9
      db/db_impl/db_impl.h
  2. 86
      db/db_impl/db_impl_write.cc
  3. 487
      db/db_kv_checksum_test.cc
  4. 96
      db/write_batch.cc
  5. 4
      db/write_batch_internal.h
  6. 1
      db/write_thread.cc
  7. 6
      include/rocksdb/write_batch.h
  8. 15
      tools/db_bench_tool.cc

@ -1915,9 +1915,12 @@ class DBImpl : public DB {
Status PreprocessWrite(const WriteOptions& write_options, bool* need_log_sync, Status PreprocessWrite(const WriteOptions& write_options, bool* need_log_sync,
WriteContext* write_context); WriteContext* write_context);
WriteBatch* MergeBatch(const WriteThread::WriteGroup& write_group, // Merge write batches in the write group into merged_batch.
WriteBatch* tmp_batch, size_t* write_with_wal, // Returns OK if merge is successful.
WriteBatch** to_be_cached_state); // Returns Corruption if corruption in write batch is detected.
Status MergeBatch(const WriteThread::WriteGroup& write_group,
WriteBatch* tmp_batch, WriteBatch** merged_batch,
size_t* write_with_wal, WriteBatch** to_be_cached_state);
// rate_limiter_priority is used to charge `DBOptions::rate_limiter` // rate_limiter_priority is used to charge `DBOptions::rate_limiter`
// for automatic WAL flush (`Options::manual_wal_flush` == false) // for automatic WAL flush (`Options::manual_wal_flush` == false)

@ -533,15 +533,18 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options,
} }
PERF_TIMER_START(write_pre_and_post_process_time); PERF_TIMER_START(write_pre_and_post_process_time);
if (!io_s.ok()) {
// Check WriteToWAL status
IOStatusCheck(io_s);
}
if (!w.CallbackFailed()) { if (!w.CallbackFailed()) {
if (!io_s.ok()) { if (!io_s.ok()) {
assert(pre_release_cb_status.ok()); assert(pre_release_cb_status.ok());
IOStatusCheck(io_s);
} else { } else {
WriteStatusCheck(pre_release_cb_status); WriteStatusCheck(pre_release_cb_status);
} }
} else { } else {
assert(io_s.ok() && pre_release_cb_status.ok()); assert(pre_release_cb_status.ok());
} }
if (need_log_sync) { if (need_log_sync) {
@ -695,12 +698,11 @@ Status DBImpl::PipelinedWriteImpl(const WriteOptions& write_options,
w.status = io_s; w.status = io_s;
} }
if (!w.CallbackFailed()) { if (!io_s.ok()) {
if (!io_s.ok()) { // Check WriteToWAL status
IOStatusCheck(io_s); IOStatusCheck(io_s);
} else { } else if (!w.CallbackFailed()) {
WriteStatusCheck(w.status); WriteStatusCheck(w.status);
}
} }
if (need_log_sync) { if (need_log_sync) {
@ -936,11 +938,18 @@ Status DBImpl::WriteImplWALOnly(
seq_inc = total_batch_cnt; seq_inc = total_batch_cnt;
} }
Status status; Status status;
IOStatus io_s;
io_s.PermitUncheckedError(); // Allow io_s to be uninitialized
if (!write_options.disableWAL) { if (!write_options.disableWAL) {
io_s = ConcurrentWriteToWAL(write_group, log_used, &last_sequence, seq_inc); IOStatus io_s =
ConcurrentWriteToWAL(write_group, log_used, &last_sequence, seq_inc);
status = io_s; status = io_s;
// last_sequence may not be set if there is an error
// This error checking and return is moved up to avoid using uninitialized
// last_sequence.
if (!io_s.ok()) {
IOStatusCheck(io_s);
write_thread->ExitAsBatchGroupLeader(write_group, status);
return status;
}
} else { } else {
// Otherwise we inc seq number to do solely the seq allocation // Otherwise we inc seq number to do solely the seq allocation
last_sequence = versions_->FetchAddLastAllocatedSequence(seq_inc); last_sequence = versions_->FetchAddLastAllocatedSequence(seq_inc);
@ -975,11 +984,7 @@ Status DBImpl::WriteImplWALOnly(
PERF_TIMER_START(write_pre_and_post_process_time); PERF_TIMER_START(write_pre_and_post_process_time);
if (!w.CallbackFailed()) { if (!w.CallbackFailed()) {
if (!io_s.ok()) { WriteStatusCheck(status);
IOStatusCheck(io_s);
} else {
WriteStatusCheck(status);
}
} }
if (status.ok()) { if (status.ok()) {
size_t index = 0; size_t index = 0;
@ -1171,13 +1176,13 @@ Status DBImpl::PreprocessWrite(const WriteOptions& write_options,
return status; return status;
} }
WriteBatch* DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group, Status DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group,
WriteBatch* tmp_batch, size_t* write_with_wal, WriteBatch* tmp_batch, WriteBatch** merged_batch,
WriteBatch** to_be_cached_state) { size_t* write_with_wal,
WriteBatch** to_be_cached_state) {
assert(write_with_wal != nullptr); assert(write_with_wal != nullptr);
assert(tmp_batch != nullptr); assert(tmp_batch != nullptr);
assert(*to_be_cached_state == nullptr); assert(*to_be_cached_state == nullptr);
WriteBatch* merged_batch = nullptr;
*write_with_wal = 0; *write_with_wal = 0;
auto* leader = write_group.leader; auto* leader = write_group.leader;
assert(!leader->disable_wal); // Same holds for all in the batch group assert(!leader->disable_wal); // Same holds for all in the batch group
@ -1186,22 +1191,24 @@ WriteBatch* DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group,
// we simply write the first WriteBatch to WAL if the group only // we simply write the first WriteBatch to WAL if the group only
// contains one batch, that batch should be written to the WAL, // contains one batch, that batch should be written to the WAL,
// and the batch is not wanting to be truncated // and the batch is not wanting to be truncated
merged_batch = leader->batch; *merged_batch = leader->batch;
if (WriteBatchInternal::IsLatestPersistentState(merged_batch)) { if (WriteBatchInternal::IsLatestPersistentState(*merged_batch)) {
*to_be_cached_state = merged_batch; *to_be_cached_state = *merged_batch;
} }
*write_with_wal = 1; *write_with_wal = 1;
} else { } else {
// WAL needs all of the batches flattened into a single batch. // WAL needs all of the batches flattened into a single batch.
// We could avoid copying here with an iov-like AddRecord // We could avoid copying here with an iov-like AddRecord
// interface // interface
merged_batch = tmp_batch; *merged_batch = tmp_batch;
for (auto writer : write_group) { for (auto writer : write_group) {
if (!writer->CallbackFailed()) { if (!writer->CallbackFailed()) {
Status s = WriteBatchInternal::Append(merged_batch, writer->batch, Status s = WriteBatchInternal::Append(*merged_batch, writer->batch,
/*WAL_only*/ true); /*WAL_only*/ true);
// Always returns Status::OK. if (!s.ok()) {
assert(s.ok()); tmp_batch->Clear();
return s;
}
if (WriteBatchInternal::IsLatestPersistentState(writer->batch)) { if (WriteBatchInternal::IsLatestPersistentState(writer->batch)) {
// We only need to cache the last of such write batch // We only need to cache the last of such write batch
*to_be_cached_state = writer->batch; *to_be_cached_state = writer->batch;
@ -1210,7 +1217,8 @@ WriteBatch* DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group,
} }
} }
} }
return merged_batch; // return merged_batch;
return Status::OK();
} }
// When two_write_queues_ is disabled, this function is called from the only // When two_write_queues_ is disabled, this function is called from the only
@ -1223,6 +1231,11 @@ IOStatus DBImpl::WriteToWAL(const WriteBatch& merged_batch,
assert(log_size != nullptr); assert(log_size != nullptr);
Slice log_entry = WriteBatchInternal::Contents(&merged_batch); Slice log_entry = WriteBatchInternal::Contents(&merged_batch);
TEST_SYNC_POINT_CALLBACK("DBImpl::WriteToWAL:log_entry", &log_entry);
auto s = merged_batch.VerifyChecksum();
if (!s.ok()) {
return status_to_io_status(std::move(s));
}
*log_size = log_entry.size(); *log_size = log_entry.size();
// When two_write_queues_ WriteToWAL has to be protected from concurretn calls // When two_write_queues_ WriteToWAL has to be protected from concurretn calls
// from the two queues anyway and log_write_mutex_ is already held. Otherwise // from the two queues anyway and log_write_mutex_ is already held. Otherwise
@ -1260,8 +1273,13 @@ IOStatus DBImpl::WriteToWAL(const WriteThread::WriteGroup& write_group,
// Same holds for all in the batch group // Same holds for all in the batch group
size_t write_with_wal = 0; size_t write_with_wal = 0;
WriteBatch* to_be_cached_state = nullptr; WriteBatch* to_be_cached_state = nullptr;
WriteBatch* merged_batch = MergeBatch(write_group, &tmp_batch_, WriteBatch* merged_batch;
&write_with_wal, &to_be_cached_state); io_s = status_to_io_status(MergeBatch(write_group, &tmp_batch_, &merged_batch,
&write_with_wal, &to_be_cached_state));
if (UNLIKELY(!io_s.ok())) {
return io_s;
}
if (merged_batch == write_group.leader->batch) { if (merged_batch == write_group.leader->batch) {
write_group.leader->log_used = logfile_number_; write_group.leader->log_used = logfile_number_;
} else if (write_with_wal > 1) { } else if (write_with_wal > 1) {
@ -1351,8 +1369,12 @@ IOStatus DBImpl::ConcurrentWriteToWAL(
WriteBatch tmp_batch; WriteBatch tmp_batch;
size_t write_with_wal = 0; size_t write_with_wal = 0;
WriteBatch* to_be_cached_state = nullptr; WriteBatch* to_be_cached_state = nullptr;
WriteBatch* merged_batch = WriteBatch* merged_batch;
MergeBatch(write_group, &tmp_batch, &write_with_wal, &to_be_cached_state); io_s = status_to_io_status(MergeBatch(write_group, &tmp_batch, &merged_batch,
&write_with_wal, &to_be_cached_state));
if (UNLIKELY(!io_s.ok())) {
return io_s;
}
// We need to lock log_write_mutex_ since logs_ and alive_log_files might be // We need to lock log_write_mutex_ since logs_ and alive_log_files might be
// pushed back concurrently // pushed back concurrently

@ -25,6 +25,49 @@ WriteBatchOpType operator+(WriteBatchOpType lhs, const int rhs) {
return static_cast<WriteBatchOpType>(static_cast<T>(lhs) + rhs); return static_cast<WriteBatchOpType>(static_cast<T>(lhs) + rhs);
} }
std::pair<WriteBatch, Status> GetWriteBatch(ColumnFamilyHandle* cf_handle,
WriteBatchOpType op_type) {
Status s;
WriteBatch wb(0 /* reserved_bytes */, 0 /* max_bytes */,
8 /* protection_bytes_per_entry */, 0 /* default_cf_ts_sz */);
switch (op_type) {
case WriteBatchOpType::kPut:
s = wb.Put(cf_handle, "key", "val");
break;
case WriteBatchOpType::kDelete:
s = wb.Delete(cf_handle, "key");
break;
case WriteBatchOpType::kSingleDelete:
s = wb.SingleDelete(cf_handle, "key");
break;
case WriteBatchOpType::kDeleteRange:
s = wb.DeleteRange(cf_handle, "begin", "end");
break;
case WriteBatchOpType::kMerge:
s = wb.Merge(cf_handle, "key", "val");
break;
case WriteBatchOpType::kBlobIndex: {
// TODO(ajkr): use public API once available.
uint32_t cf_id;
if (cf_handle == nullptr) {
cf_id = 0;
} else {
cf_id = cf_handle->GetID();
}
std::string blob_index;
BlobIndex::EncodeInlinedTTL(&blob_index, /* expiration */ 9876543210,
"val");
s = WriteBatchInternal::PutBlobIndex(&wb, cf_id, "key", blob_index);
break;
}
case WriteBatchOpType::kNum:
assert(false);
}
return {std::move(wb), std::move(s)};
}
class DbKvChecksumTest class DbKvChecksumTest
: public DBTestBase, : public DBTestBase,
public ::testing::WithParamInterface<std::tuple<WriteBatchOpType, char>> { public ::testing::WithParamInterface<std::tuple<WriteBatchOpType, char>> {
@ -35,48 +78,6 @@ class DbKvChecksumTest
corrupt_byte_addend_ = std::get<1>(GetParam()); corrupt_byte_addend_ = std::get<1>(GetParam());
} }
std::pair<WriteBatch, Status> GetWriteBatch(ColumnFamilyHandle* cf_handle) {
Status s;
WriteBatch wb(0 /* reserved_bytes */, 0 /* max_bytes */,
8 /* protection_bytes_per_entry */, 0 /* default_cf_ts_sz */);
switch (op_type_) {
case WriteBatchOpType::kPut:
s = wb.Put(cf_handle, "key", "val");
break;
case WriteBatchOpType::kDelete:
s = wb.Delete(cf_handle, "key");
break;
case WriteBatchOpType::kSingleDelete:
s = wb.SingleDelete(cf_handle, "key");
break;
case WriteBatchOpType::kDeleteRange:
s = wb.DeleteRange(cf_handle, "begin", "end");
break;
case WriteBatchOpType::kMerge:
s = wb.Merge(cf_handle, "key", "val");
break;
case WriteBatchOpType::kBlobIndex: {
// TODO(ajkr): use public API once available.
uint32_t cf_id;
if (cf_handle == nullptr) {
cf_id = 0;
} else {
cf_id = cf_handle->GetID();
}
std::string blob_index;
BlobIndex::EncodeInlinedTTL(&blob_index, /* expiration */ 9876543210,
"val");
s = WriteBatchInternal::PutBlobIndex(&wb, cf_id, "key", blob_index);
break;
}
case WriteBatchOpType::kNum:
assert(false);
}
return {std::move(wb), std::move(s)};
}
void CorruptNextByteCallBack(void* arg) { void CorruptNextByteCallBack(void* arg) {
Slice encoded = *static_cast<Slice*>(arg); Slice encoded = *static_cast<Slice*>(arg);
if (entry_len_ == std::numeric_limits<size_t>::max()) { if (entry_len_ == std::numeric_limits<size_t>::max()) {
@ -99,34 +100,28 @@ class DbKvChecksumTest
size_t entry_len_ = std::numeric_limits<size_t>::max(); size_t entry_len_ = std::numeric_limits<size_t>::max();
}; };
std::string GetTestNameSuffix( std::string GetOpTypeString(const WriteBatchOpType& op_type) {
::testing::TestParamInfo<std::tuple<WriteBatchOpType, char>> info) { switch (op_type) {
std::ostringstream oss;
switch (std::get<0>(info.param)) {
case WriteBatchOpType::kPut: case WriteBatchOpType::kPut:
oss << "Put"; return "Put";
break;
case WriteBatchOpType::kDelete: case WriteBatchOpType::kDelete:
oss << "Delete"; return "Delete";
break;
case WriteBatchOpType::kSingleDelete: case WriteBatchOpType::kSingleDelete:
oss << "SingleDelete"; return "SingleDelete";
break;
case WriteBatchOpType::kDeleteRange: case WriteBatchOpType::kDeleteRange:
oss << "DeleteRange"; return "DeleteRange";
break; break;
case WriteBatchOpType::kMerge: case WriteBatchOpType::kMerge:
oss << "Merge"; return "Merge";
break; break;
case WriteBatchOpType::kBlobIndex: case WriteBatchOpType::kBlobIndex:
oss << "BlobIndex"; return "BlobIndex";
break; break;
case WriteBatchOpType::kNum: case WriteBatchOpType::kNum:
assert(false); assert(false);
} }
oss << "Add" assert(false);
<< static_cast<int>(static_cast<unsigned char>(std::get<1>(info.param))); return "";
return oss.str();
} }
INSTANTIATE_TEST_CASE_P( INSTANTIATE_TEST_CASE_P(
@ -134,7 +129,13 @@ INSTANTIATE_TEST_CASE_P(
::testing::Combine(::testing::Range(static_cast<WriteBatchOpType>(0), ::testing::Combine(::testing::Range(static_cast<WriteBatchOpType>(0),
WriteBatchOpType::kNum), WriteBatchOpType::kNum),
::testing::Values(2, 103, 251)), ::testing::Values(2, 103, 251)),
GetTestNameSuffix); [](const testing::TestParamInfo<std::tuple<WriteBatchOpType, char>>& args) {
std::ostringstream oss;
oss << GetOpTypeString(std::get<0>(args.param)) << "Add"
<< static_cast<int>(
static_cast<unsigned char>(std::get<1>(args.param)));
return oss.str();
});
TEST_P(DbKvChecksumTest, MemTableAddCorrupted) { TEST_P(DbKvChecksumTest, MemTableAddCorrupted) {
// This test repeatedly attempts to write `WriteBatch`es containing a single // This test repeatedly attempts to write `WriteBatch`es containing a single
@ -157,11 +158,16 @@ TEST_P(DbKvChecksumTest, MemTableAddCorrupted) {
Reopen(options); Reopen(options);
SyncPoint::GetInstance()->EnableProcessing(); SyncPoint::GetInstance()->EnableProcessing();
auto batch_and_status = GetWriteBatch(nullptr /* cf_handle */); auto batch_and_status = GetWriteBatch(nullptr /* cf_handle */, op_type_);
ASSERT_OK(batch_and_status.second); ASSERT_OK(batch_and_status.second);
ASSERT_TRUE( ASSERT_TRUE(
db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption()); db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption());
SyncPoint::GetInstance()->DisableProcessing(); SyncPoint::GetInstance()->DisableProcessing();
// In case the above callback is not invoked, this test will run
// numeric_limits<size_t>::max() times until it reports an error (or will
// exhaust disk space). Added this assert to report error early.
ASSERT_TRUE(entry_len_ < std::numeric_limits<size_t>::max());
} }
} }
@ -188,14 +194,373 @@ TEST_P(DbKvChecksumTest, MemTableAddWithColumnFamilyCorrupted) {
ReopenWithColumnFamilies({kDefaultColumnFamilyName, "pikachu"}, options); ReopenWithColumnFamilies({kDefaultColumnFamilyName, "pikachu"}, options);
SyncPoint::GetInstance()->EnableProcessing(); SyncPoint::GetInstance()->EnableProcessing();
auto batch_and_status = GetWriteBatch(handles_[1]); auto batch_and_status = GetWriteBatch(handles_[1], op_type_);
ASSERT_OK(batch_and_status.second);
ASSERT_TRUE(
db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption());
SyncPoint::GetInstance()->DisableProcessing();
// In case the above callback is not invoked, this test will run
// numeric_limits<size_t>::max() times until it reports an error (or will
// exhaust disk space). Added this assert to report error early.
ASSERT_TRUE(entry_len_ < std::numeric_limits<size_t>::max());
}
}
TEST_P(DbKvChecksumTest, NoCorruptionCase) {
// If this test fails, we may have found a piece of malfunctioned hardware
auto batch_and_status = GetWriteBatch(nullptr, op_type_);
ASSERT_OK(batch_and_status.second);
ASSERT_OK(batch_and_status.first.VerifyChecksum());
}
TEST_P(DbKvChecksumTest, WriteToWALCorrupted) {
// This test repeatedly attempts to write `WriteBatch`es containing a single
// entry of type `op_type_`. Each attempt has one byte corrupted by adding
// `corrupt_byte_addend_` to its original value. The test repeats until an
// attempt has been made on each byte in the encoded write batch. All attempts
// are expected to fail with `Status::Corruption`
Options options = CurrentOptions();
if (op_type_ == WriteBatchOpType::kMerge) {
options.merge_operator = MergeOperators::CreateStringAppendOperator();
}
SyncPoint::GetInstance()->SetCallBack(
"DBImpl::WriteToWAL:log_entry",
std::bind(&DbKvChecksumTest::CorruptNextByteCallBack, this,
std::placeholders::_1));
// First 8 bytes are for sequence number which is not protected in write batch
corrupt_byte_offset_ = 8;
while (MoreBytesToCorrupt()) {
// Corrupted write batch leads to read-only mode, so we have to
// reopen for every attempt.
Reopen(options);
auto log_size_pre_write = dbfull()->TEST_total_log_size();
SyncPoint::GetInstance()->EnableProcessing();
auto batch_and_status = GetWriteBatch(nullptr /* cf_handle */, op_type_);
ASSERT_OK(batch_and_status.second);
ASSERT_TRUE(
db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption());
// Confirm that nothing was written to WAL
ASSERT_EQ(log_size_pre_write, dbfull()->TEST_total_log_size());
ASSERT_TRUE(dbfull()->TEST_GetBGError().IsCorruption());
SyncPoint::GetInstance()->DisableProcessing();
// In case the above callback is not invoked, this test will run
// numeric_limits<size_t>::max() times until it reports an error (or will
// exhaust disk space). Added this assert to report error early.
ASSERT_TRUE(entry_len_ < std::numeric_limits<size_t>::max());
}
}
TEST_P(DbKvChecksumTest, WriteToWALWithColumnFamilyCorrupted) {
// This test repeatedly attempts to write `WriteBatch`es containing a single
// entry of type `op_type_`. Each attempt has one byte corrupted by adding
// `corrupt_byte_addend_` to its original value. The test repeats until an
// attempt has been made on each byte in the encoded write batch. All attempts
// are expected to fail with `Status::Corruption`
Options options = CurrentOptions();
if (op_type_ == WriteBatchOpType::kMerge) {
options.merge_operator = MergeOperators::CreateStringAppendOperator();
}
CreateAndReopenWithCF({"pikachu"}, options);
SyncPoint::GetInstance()->SetCallBack(
"DBImpl::WriteToWAL:log_entry",
std::bind(&DbKvChecksumTest::CorruptNextByteCallBack, this,
std::placeholders::_1));
// First 8 bytes are for sequence number which is not protected in write batch
corrupt_byte_offset_ = 8;
while (MoreBytesToCorrupt()) {
// Corrupted write batch leads to read-only mode, so we have to
// reopen for every attempt.
ReopenWithColumnFamilies({kDefaultColumnFamilyName, "pikachu"}, options);
auto log_size_pre_write = dbfull()->TEST_total_log_size();
SyncPoint::GetInstance()->EnableProcessing();
auto batch_and_status = GetWriteBatch(handles_[1], op_type_);
ASSERT_OK(batch_and_status.second); ASSERT_OK(batch_and_status.second);
ASSERT_TRUE( ASSERT_TRUE(
db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption()); db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption());
// Confirm that nothing was written to WAL
ASSERT_EQ(log_size_pre_write, dbfull()->TEST_total_log_size());
ASSERT_TRUE(dbfull()->TEST_GetBGError().IsCorruption());
SyncPoint::GetInstance()->DisableProcessing(); SyncPoint::GetInstance()->DisableProcessing();
// In case the above callback is not invoked, this test will run
// numeric_limits<size_t>::max() times until it reports an error (or will
// exhaust disk space). Added this assert to report error early.
ASSERT_TRUE(entry_len_ < std::numeric_limits<size_t>::max());
}
}
class DbKvChecksumTestMergedBatch
: public DBTestBase,
public ::testing::WithParamInterface<
std::tuple<WriteBatchOpType, WriteBatchOpType, char>> {
public:
DbKvChecksumTestMergedBatch()
: DBTestBase("db_kv_checksum_test", /*env_do_fsync=*/false) {
op_type1_ = std::get<0>(GetParam());
op_type2_ = std::get<1>(GetParam());
corrupt_byte_addend_ = std::get<2>(GetParam());
} }
protected:
WriteBatchOpType op_type1_;
WriteBatchOpType op_type2_;
char corrupt_byte_addend_;
};
void CorruptWriteBatch(Slice* content, size_t offset,
char corrupt_byte_addend) {
ASSERT_TRUE(offset < content->size());
char* buf = const_cast<char*>(content->data());
buf[offset] += corrupt_byte_addend;
}
TEST_P(DbKvChecksumTestMergedBatch, NoCorruptionCase) {
// Veirfy write batch checksum after write batch append
auto batch1 = GetWriteBatch(nullptr /* cf_handle */, op_type1_);
ASSERT_OK(batch1.second);
auto batch2 = GetWriteBatch(nullptr /* cf_handle */, op_type2_);
ASSERT_OK(batch2.second);
ASSERT_OK(WriteBatchInternal::Append(&batch1.first, &batch2.first));
ASSERT_OK(batch1.first.VerifyChecksum());
} }
TEST_P(DbKvChecksumTestMergedBatch, WriteToWALCorrupted) {
// This test has two writers repeatedly attempt to write `WriteBatch`es
// containing a single entry of type op_type1_ and op_type2_ respectively. The
// leader of the write group writes the batch containinng the entry of type
// op_type1_. One byte of the pre-merged write batches is corrupted by adding
// `corrupt_byte_addend_` to the batch's original value during each attempt.
// The test repeats until an attempt has been made on each byte in both
// pre-merged write batches. All attempts are expected to fail with
// `Status::Corruption`.
Options options = CurrentOptions();
if (op_type1_ == WriteBatchOpType::kMerge ||
op_type2_ == WriteBatchOpType::kMerge) {
options.merge_operator = MergeOperators::CreateStringAppendOperator();
}
auto leader_batch_and_status =
GetWriteBatch(nullptr /* cf_handle */, op_type1_);
ASSERT_OK(leader_batch_and_status.second);
auto follower_batch_and_status =
GetWriteBatch(nullptr /* cf_handle */, op_type2_);
size_t leader_batch_size = leader_batch_and_status.first.GetDataSize();
size_t total_bytes =
leader_batch_size + follower_batch_and_status.first.GetDataSize();
// First 8 bytes are for sequence number which is not protected in write batch
size_t corrupt_byte_offset = 8;
std::atomic<bool> follower_joined{false};
std::atomic<int> leader_count{0};
port::Thread follower_thread;
// This callback should only be called by the leader thread
SyncPoint::GetInstance()->SetCallBack(
"WriteThread::JoinBatchGroup:Wait2", [&](void* arg_leader) {
auto* leader = reinterpret_cast<WriteThread::Writer*>(arg_leader);
ASSERT_EQ(leader->state, WriteThread::STATE_GROUP_LEADER);
// This callback should only be called by the follower thread
SyncPoint::GetInstance()->SetCallBack(
"WriteThread::JoinBatchGroup:Wait", [&](void* arg_follower) {
auto* follower =
reinterpret_cast<WriteThread::Writer*>(arg_follower);
// The leader thread will wait on this bool and hence wait until
// this writer joins the write group
ASSERT_NE(follower->state, WriteThread::STATE_GROUP_LEADER);
if (corrupt_byte_offset >= leader_batch_size) {
Slice batch_content = follower->batch->Data();
CorruptWriteBatch(&batch_content,
corrupt_byte_offset - leader_batch_size,
corrupt_byte_addend_);
}
// Leader busy waits on this flag
follower_joined = true;
// So the follower does not enter the outer callback at
// WriteThread::JoinBatchGroup:Wait2
SyncPoint::GetInstance()->DisableProcessing();
});
// Start the other writer thread which will join the write group as
// follower
follower_thread = port::Thread([&]() {
follower_batch_and_status =
GetWriteBatch(nullptr /* cf_handle */, op_type2_);
ASSERT_OK(follower_batch_and_status.second);
ASSERT_TRUE(
db_->Write(WriteOptions(), &follower_batch_and_status.first)
.IsCorruption());
});
ASSERT_EQ(leader->batch->GetDataSize(), leader_batch_size);
if (corrupt_byte_offset < leader_batch_size) {
Slice batch_content = leader->batch->Data();
CorruptWriteBatch(&batch_content, corrupt_byte_offset,
corrupt_byte_addend_);
}
leader_count++;
while (!follower_joined) {
// busy waiting
}
});
while (corrupt_byte_offset < total_bytes) {
// Reopen DB since it failed WAL write which lead to read-only mode
Reopen(options);
SyncPoint::GetInstance()->EnableProcessing();
auto log_size_pre_write = dbfull()->TEST_total_log_size();
leader_batch_and_status = GetWriteBatch(nullptr /* cf_handle */, op_type1_);
ASSERT_OK(leader_batch_and_status.second);
ASSERT_TRUE(db_->Write(WriteOptions(), &leader_batch_and_status.first)
.IsCorruption());
follower_thread.join();
// Prevent leader thread from entering this callback
SyncPoint::GetInstance()->ClearCallBack("WriteThread::JoinBatchGroup:Wait");
ASSERT_EQ(1, leader_count);
// Nothing should have been written to WAL
ASSERT_EQ(log_size_pre_write, dbfull()->TEST_total_log_size());
ASSERT_TRUE(dbfull()->TEST_GetBGError().IsCorruption());
corrupt_byte_offset++;
if (corrupt_byte_offset == leader_batch_size) {
// skip over the sequence number part of follower's write batch
corrupt_byte_offset += 8;
}
follower_joined = false;
leader_count = 0;
}
SyncPoint::GetInstance()->DisableProcessing();
}
TEST_P(DbKvChecksumTestMergedBatch, WriteToWALWithColumnFamilyCorrupted) {
// This test has two writers repeatedly attempt to write `WriteBatch`es
// containing a single entry of type op_type1_ and op_type2_ respectively. The
// leader of the write group writes the batch containinng the entry of type
// op_type1_. One byte of the pre-merged write batches is corrupted by adding
// `corrupt_byte_addend_` to the batch's original value during each attempt.
// The test repeats until an attempt has been made on each byte in both
// pre-merged write batches. All attempts are expected to fail with
// `Status::Corruption`.
Options options = CurrentOptions();
if (op_type1_ == WriteBatchOpType::kMerge ||
op_type2_ == WriteBatchOpType::kMerge) {
options.merge_operator = MergeOperators::CreateStringAppendOperator();
}
CreateAndReopenWithCF({"ramen"}, options);
auto leader_batch_and_status = GetWriteBatch(handles_[1], op_type1_);
ASSERT_OK(leader_batch_and_status.second);
auto follower_batch_and_status = GetWriteBatch(handles_[1], op_type2_);
size_t leader_batch_size = leader_batch_and_status.first.GetDataSize();
size_t total_bytes =
leader_batch_size + follower_batch_and_status.first.GetDataSize();
// First 8 bytes are for sequence number which is not protected in write batch
size_t corrupt_byte_offset = 8;
std::atomic<bool> follower_joined{false};
std::atomic<int> leader_count{0};
port::Thread follower_thread;
// This callback should only be called by the leader thread
SyncPoint::GetInstance()->SetCallBack(
"WriteThread::JoinBatchGroup:Wait2", [&](void* arg_leader) {
auto* leader = reinterpret_cast<WriteThread::Writer*>(arg_leader);
ASSERT_EQ(leader->state, WriteThread::STATE_GROUP_LEADER);
// This callback should only be called by the follower thread
SyncPoint::GetInstance()->SetCallBack(
"WriteThread::JoinBatchGroup:Wait", [&](void* arg_follower) {
auto* follower =
reinterpret_cast<WriteThread::Writer*>(arg_follower);
// The leader thread will wait on this bool and hence wait until
// this writer joins the write group
ASSERT_NE(follower->state, WriteThread::STATE_GROUP_LEADER);
if (corrupt_byte_offset >= leader_batch_size) {
Slice batch_content =
WriteBatchInternal::Contents(follower->batch);
CorruptWriteBatch(&batch_content,
corrupt_byte_offset - leader_batch_size,
corrupt_byte_addend_);
}
follower_joined = true;
// So the follower does not enter the outer callback at
// WriteThread::JoinBatchGroup:Wait2
SyncPoint::GetInstance()->DisableProcessing();
});
// Start the other writer thread which will join the write group as
// follower
follower_thread = port::Thread([&]() {
follower_batch_and_status = GetWriteBatch(handles_[1], op_type2_);
ASSERT_OK(follower_batch_and_status.second);
ASSERT_TRUE(
db_->Write(WriteOptions(), &follower_batch_and_status.first)
.IsCorruption());
});
ASSERT_EQ(leader->batch->GetDataSize(), leader_batch_size);
if (corrupt_byte_offset < leader_batch_size) {
Slice batch_content = WriteBatchInternal::Contents(leader->batch);
CorruptWriteBatch(&batch_content, corrupt_byte_offset,
corrupt_byte_addend_);
}
leader_count++;
while (!follower_joined) {
// busy waiting
}
});
SyncPoint::GetInstance()->EnableProcessing();
while (corrupt_byte_offset < total_bytes) {
// Reopen DB since it failed WAL write which lead to read-only mode
ReopenWithColumnFamilies({kDefaultColumnFamilyName, "ramen"}, options);
SyncPoint::GetInstance()->EnableProcessing();
auto log_size_pre_write = dbfull()->TEST_total_log_size();
leader_batch_and_status = GetWriteBatch(handles_[1], op_type1_);
ASSERT_OK(leader_batch_and_status.second);
ASSERT_TRUE(db_->Write(WriteOptions(), &leader_batch_and_status.first)
.IsCorruption());
follower_thread.join();
// Prevent leader thread from entering this callback
SyncPoint::GetInstance()->ClearCallBack("WriteThread::JoinBatchGroup:Wait");
ASSERT_EQ(1, leader_count);
// Nothing should have been written to WAL
ASSERT_EQ(log_size_pre_write, dbfull()->TEST_total_log_size());
ASSERT_TRUE(dbfull()->TEST_GetBGError().IsCorruption());
corrupt_byte_offset++;
if (corrupt_byte_offset == leader_batch_size) {
// skip over the sequence number part of follower's write batch
corrupt_byte_offset += 8;
}
follower_joined = false;
leader_count = 0;
}
SyncPoint::GetInstance()->DisableProcessing();
}
INSTANTIATE_TEST_CASE_P(
DbKvChecksumTestMergedBatch, DbKvChecksumTestMergedBatch,
::testing::Combine(::testing::Range(static_cast<WriteBatchOpType>(0),
WriteBatchOpType::kNum),
::testing::Range(static_cast<WriteBatchOpType>(0),
WriteBatchOpType::kNum),
::testing::Values(2, 103, 251)),
[](const testing::TestParamInfo<
std::tuple<WriteBatchOpType, WriteBatchOpType, char>>& args) {
std::ostringstream oss;
oss << GetOpTypeString(std::get<0>(args.param))
<< GetOpTypeString(std::get<1>(args.param)) << "Add"
<< static_cast<int>(
static_cast<unsigned char>(std::get<2>(args.param)));
return oss.str();
});
// TODO: add test for transactions
// TODO: add test for corrupted write batch with WAL disabled
} // namespace ROCKSDB_NAMESPACE } // namespace ROCKSDB_NAMESPACE
int main(int argc, char** argv) { int main(int argc, char** argv) {

@ -1491,6 +1491,94 @@ Status WriteBatch::UpdateTimestamps(
return s; return s;
} }
Status WriteBatch::VerifyChecksum() const {
if (prot_info_ == nullptr) {
return Status::OK();
}
Slice input(rep_.data() + WriteBatchInternal::kHeader,
rep_.size() - WriteBatchInternal::kHeader);
Slice key, value, blob, xid;
char tag = 0;
uint32_t column_family = 0; // default
Status s;
size_t prot_info_idx = 0;
bool checksum_protected = true;
while (!input.empty() && prot_info_idx < prot_info_->entries_.size()) {
// In case key/value/column_family are not updated by
// ReadRecordFromWriteBatch
key.clear();
value.clear();
column_family = 0;
s = ReadRecordFromWriteBatch(&input, &tag, &column_family, &key, &value,
&blob, &xid);
if (!s.ok()) {
return s;
}
checksum_protected = true;
// Write batch checksum uses op_type without ColumnFamily (e.g., if op_type
// in the write batch is kTypeColumnFamilyValue, kTypeValue is used to
// compute the checksum), and encodes column family id separately. See
// comment in first `WriteBatchInternal::Put()` for more detail.
switch (tag) {
case kTypeColumnFamilyValue:
case kTypeValue:
tag = kTypeValue;
break;
case kTypeColumnFamilyDeletion:
case kTypeDeletion:
tag = kTypeDeletion;
break;
case kTypeColumnFamilySingleDeletion:
case kTypeSingleDeletion:
tag = kTypeSingleDeletion;
break;
case kTypeColumnFamilyRangeDeletion:
case kTypeRangeDeletion:
tag = kTypeRangeDeletion;
break;
case kTypeColumnFamilyMerge:
case kTypeMerge:
tag = kTypeMerge;
break;
case kTypeColumnFamilyBlobIndex:
case kTypeBlobIndex:
tag = kTypeBlobIndex;
break;
case kTypeLogData:
case kTypeBeginPrepareXID:
case kTypeEndPrepareXID:
case kTypeCommitXID:
case kTypeRollbackXID:
case kTypeNoop:
case kTypeBeginPersistedPrepareXID:
case kTypeBeginUnprepareXID:
case kTypeDeletionWithTimestamp:
case kTypeCommitXIDAndTimestamp:
checksum_protected = false;
break;
default:
return Status::Corruption(
"unknown WriteBatch tag",
std::to_string(static_cast<unsigned int>(tag)));
}
if (checksum_protected) {
s = prot_info_->entries_[prot_info_idx++]
.StripC(column_family)
.StripKVO(key, value, static_cast<ValueType>(tag))
.GetStatus();
if (!s.ok()) {
return s;
}
}
}
if (prot_info_idx != WriteBatchInternal::Count(this)) {
return Status::Corruption("WriteBatch has wrong count");
}
assert(WriteBatchInternal::Count(this) == prot_info_->entries_.size());
return Status::OK();
}
namespace { namespace {
class MemTableInserter : public WriteBatch::Handler { class MemTableInserter : public WriteBatch::Handler {
@ -2773,6 +2861,14 @@ Status WriteBatchInternal::Append(WriteBatch* dst, const WriteBatch* src,
const bool wal_only) { const bool wal_only) {
assert(dst->Count() == 0 || assert(dst->Count() == 0 ||
(dst->prot_info_ == nullptr) == (src->prot_info_ == nullptr)); (dst->prot_info_ == nullptr) == (src->prot_info_ == nullptr));
if ((src->prot_info_ != nullptr &&
src->prot_info_->entries_.size() != src->Count()) ||
(dst->prot_info_ != nullptr &&
dst->prot_info_->entries_.size() != dst->Count())) {
return Status::Corruption(
"Write batch has inconsistent count and number of checksums");
}
size_t src_len; size_t src_len;
int src_count; int src_count;
uint32_t src_flags; uint32_t src_flags;

@ -206,6 +206,10 @@ class WriteBatchInternal {
bool batch_per_txn = true, bool batch_per_txn = true,
bool hint_per_batch = false); bool hint_per_batch = false);
// Appends src write batch to dst write batch and updates count in dst
// write batch. Returns OK if the append is successful. Checks number of
// checksum against count in dst and src write batches, and returns Corruption
// if the count is inconsistent.
static Status Append(WriteBatch* dst, const WriteBatch* src, static Status Append(WriteBatch* dst, const WriteBatch* src,
const bool WAL_only = false); const bool WAL_only = false);

@ -389,6 +389,7 @@ void WriteThread::JoinBatchGroup(Writer* w) {
} }
TEST_SYNC_POINT_CALLBACK("WriteThread::JoinBatchGroup:Wait", w); TEST_SYNC_POINT_CALLBACK("WriteThread::JoinBatchGroup:Wait", w);
TEST_SYNC_POINT_CALLBACK("WriteThread::JoinBatchGroup:Wait2", w);
if (!linked_as_leader) { if (!linked_as_leader) {
/** /**

@ -391,6 +391,12 @@ class WriteBatch : public WriteBatchBase {
Status UpdateTimestamps(const Slice& ts, Status UpdateTimestamps(const Slice& ts,
std::function<size_t(uint32_t /*cf*/)> ts_sz_func); std::function<size_t(uint32_t /*cf*/)> ts_sz_func);
// Verify the per-key-value checksums of this write batch.
// Corruption status will be returned if the verification fails.
// If this write batch does not have per-key-value checksum,
// OK status will be returned.
Status VerifyChecksum() const;
using WriteBatchBase::GetWriteBatch; using WriteBatchBase::GetWriteBatch;
WriteBatch* GetWriteBatch() override { return this; } WriteBatch* GetWriteBatch() override { return this; }

@ -1656,6 +1656,10 @@ static const bool FLAGS_table_cache_numshardbits_dummy __attribute__((__unused__
RegisterFlagValidator(&FLAGS_table_cache_numshardbits, RegisterFlagValidator(&FLAGS_table_cache_numshardbits,
&ValidateTableCacheNumshardbits); &ValidateTableCacheNumshardbits);
DEFINE_uint32(write_batch_protection_bytes_per_key, 0,
"Size of per-key-value checksum in each write batch. Currently "
"only value 0 and 8 are supported.");
namespace ROCKSDB_NAMESPACE { namespace ROCKSDB_NAMESPACE {
namespace { namespace {
static Status CreateMemTableRepFactory( static Status CreateMemTableRepFactory(
@ -4910,7 +4914,8 @@ class Benchmark {
RandomGenerator gen; RandomGenerator gen;
WriteBatch batch(/*reserved_bytes=*/0, /*max_bytes=*/0, WriteBatch batch(/*reserved_bytes=*/0, /*max_bytes=*/0,
/*protection_bytes_per_key=*/0, user_timestamp_size_); FLAGS_write_batch_protection_bytes_per_key,
user_timestamp_size_);
Status s; Status s;
int64_t bytes = 0; int64_t bytes = 0;
@ -6699,7 +6704,8 @@ class Benchmark {
void DoDelete(ThreadState* thread, bool seq) { void DoDelete(ThreadState* thread, bool seq) {
WriteBatch batch(/*reserved_bytes=*/0, /*max_bytes=*/0, WriteBatch batch(/*reserved_bytes=*/0, /*max_bytes=*/0,
/*protection_bytes_per_key=*/0, user_timestamp_size_); FLAGS_write_batch_protection_bytes_per_key,
user_timestamp_size_);
Duration duration(seq ? 0 : FLAGS_duration, deletes_); Duration duration(seq ? 0 : FLAGS_duration, deletes_);
int64_t i = 0; int64_t i = 0;
std::unique_ptr<const char[]> key_guard; std::unique_ptr<const char[]> key_guard;
@ -6899,7 +6905,8 @@ class Benchmark {
std::string keys[3]; std::string keys[3];
WriteBatch batch(/*reserved_bytes=*/0, /*max_bytes=*/0, WriteBatch batch(/*reserved_bytes=*/0, /*max_bytes=*/0,
/*protection_bytes_per_key=*/0, user_timestamp_size_); FLAGS_write_batch_protection_bytes_per_key,
user_timestamp_size_);
Status s; Status s;
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
keys[i] = key.ToString() + suffixes[i]; keys[i] = key.ToString() + suffixes[i];
@ -6931,7 +6938,7 @@ class Benchmark {
std::string suffixes[3] = {"1", "2", "0"}; std::string suffixes[3] = {"1", "2", "0"};
std::string keys[3]; std::string keys[3];
WriteBatch batch(0, 0, /*protection_bytes_per_key=*/0, WriteBatch batch(0, 0, FLAGS_write_batch_protection_bytes_per_key,
user_timestamp_size_); user_timestamp_size_);
Status s; Status s;
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {

Loading…
Cancel
Save