From 9882652b0e2db21974aaa682ef7664c7ebe2f84e Mon Sep 17 00:00:00 2001 From: Changyu Bi Date: Wed, 15 Jun 2022 13:43:58 -0700 Subject: [PATCH] Verify write batch checksum before WAL (#10114) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Context: WriteBatch can have key-value checksums when it was created `with protection_bytes_per_key > 0`. This PR added checksum verification for write batches before they are written to WAL. Pull Request resolved: https://github.com/facebook/rocksdb/pull/10114 Test Plan: - Added new unit tests to db_kv_checksum_test.cc: `make check -j32` - benchmark on performance regression: `./db_bench --benchmarks=fillrandom[-X20] -db=/dev/shm/test_rocksdb -write_batch_protection_bytes_per_key=8` - Pre-PR: ` fillrandom [AVG 20 runs] : 198875 (± 3006) ops/sec; 22.0 (± 0.3) MB/sec ` - Post-PR: ` fillrandom [AVG 20 runs] : 196487 (± 2279) ops/sec; 21.7 (± 0.3) MB/sec ` Mean regressed about 1% (198875 -> 196487 ops/sec). Reviewed By: ajkr Differential Revision: D36917464 Pulled By: cbi42 fbshipit-source-id: 29beb74edf65f04b1a890b4f650d873dc7ed790d --- db/db_impl/db_impl.h | 9 +- db/db_impl/db_impl_write.cc | 86 +++--- db/db_kv_checksum_test.cc | 487 +++++++++++++++++++++++++++++----- db/write_batch.cc | 96 +++++++ db/write_batch_internal.h | 4 + db/write_thread.cc | 1 + include/rocksdb/write_batch.h | 6 + tools/db_bench_tool.cc | 15 +- 8 files changed, 604 insertions(+), 100 deletions(-) diff --git a/db/db_impl/db_impl.h b/db/db_impl/db_impl.h index 018d7904c..733a87a0c 100644 --- a/db/db_impl/db_impl.h +++ b/db/db_impl/db_impl.h @@ -1915,9 +1915,12 @@ class DBImpl : public DB { Status PreprocessWrite(const WriteOptions& write_options, bool* need_log_sync, WriteContext* write_context); - WriteBatch* MergeBatch(const WriteThread::WriteGroup& write_group, - WriteBatch* tmp_batch, size_t* write_with_wal, - WriteBatch** to_be_cached_state); + // Merge write batches in the write group into merged_batch. + // Returns OK if merge is successful. + // Returns Corruption if corruption in write batch is detected. + Status MergeBatch(const WriteThread::WriteGroup& write_group, + WriteBatch* tmp_batch, WriteBatch** merged_batch, + size_t* write_with_wal, WriteBatch** to_be_cached_state); // rate_limiter_priority is used to charge `DBOptions::rate_limiter` // for automatic WAL flush (`Options::manual_wal_flush` == false) diff --git a/db/db_impl/db_impl_write.cc b/db/db_impl/db_impl_write.cc index c6ce801ae..787006d35 100644 --- a/db/db_impl/db_impl_write.cc +++ b/db/db_impl/db_impl_write.cc @@ -533,15 +533,18 @@ Status DBImpl::WriteImpl(const WriteOptions& write_options, } PERF_TIMER_START(write_pre_and_post_process_time); + if (!io_s.ok()) { + // Check WriteToWAL status + IOStatusCheck(io_s); + } if (!w.CallbackFailed()) { if (!io_s.ok()) { assert(pre_release_cb_status.ok()); - IOStatusCheck(io_s); } else { WriteStatusCheck(pre_release_cb_status); } } else { - assert(io_s.ok() && pre_release_cb_status.ok()); + assert(pre_release_cb_status.ok()); } if (need_log_sync) { @@ -695,12 +698,11 @@ Status DBImpl::PipelinedWriteImpl(const WriteOptions& write_options, w.status = io_s; } - if (!w.CallbackFailed()) { - if (!io_s.ok()) { - IOStatusCheck(io_s); - } else { - WriteStatusCheck(w.status); - } + if (!io_s.ok()) { + // Check WriteToWAL status + IOStatusCheck(io_s); + } else if (!w.CallbackFailed()) { + WriteStatusCheck(w.status); } if (need_log_sync) { @@ -936,11 +938,18 @@ Status DBImpl::WriteImplWALOnly( seq_inc = total_batch_cnt; } Status status; - IOStatus io_s; - io_s.PermitUncheckedError(); // Allow io_s to be uninitialized if (!write_options.disableWAL) { - io_s = ConcurrentWriteToWAL(write_group, log_used, &last_sequence, seq_inc); + IOStatus io_s = + ConcurrentWriteToWAL(write_group, log_used, &last_sequence, seq_inc); status = io_s; + // last_sequence may not be set if there is an error + // This error checking and return is moved up to avoid using uninitialized + // last_sequence. + if (!io_s.ok()) { + IOStatusCheck(io_s); + write_thread->ExitAsBatchGroupLeader(write_group, status); + return status; + } } else { // Otherwise we inc seq number to do solely the seq allocation last_sequence = versions_->FetchAddLastAllocatedSequence(seq_inc); @@ -975,11 +984,7 @@ Status DBImpl::WriteImplWALOnly( PERF_TIMER_START(write_pre_and_post_process_time); if (!w.CallbackFailed()) { - if (!io_s.ok()) { - IOStatusCheck(io_s); - } else { - WriteStatusCheck(status); - } + WriteStatusCheck(status); } if (status.ok()) { size_t index = 0; @@ -1171,13 +1176,13 @@ Status DBImpl::PreprocessWrite(const WriteOptions& write_options, return status; } -WriteBatch* DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group, - WriteBatch* tmp_batch, size_t* write_with_wal, - WriteBatch** to_be_cached_state) { +Status DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group, + WriteBatch* tmp_batch, WriteBatch** merged_batch, + size_t* write_with_wal, + WriteBatch** to_be_cached_state) { assert(write_with_wal != nullptr); assert(tmp_batch != nullptr); assert(*to_be_cached_state == nullptr); - WriteBatch* merged_batch = nullptr; *write_with_wal = 0; auto* leader = write_group.leader; assert(!leader->disable_wal); // Same holds for all in the batch group @@ -1186,22 +1191,24 @@ WriteBatch* DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group, // we simply write the first WriteBatch to WAL if the group only // contains one batch, that batch should be written to the WAL, // and the batch is not wanting to be truncated - merged_batch = leader->batch; - if (WriteBatchInternal::IsLatestPersistentState(merged_batch)) { - *to_be_cached_state = merged_batch; + *merged_batch = leader->batch; + if (WriteBatchInternal::IsLatestPersistentState(*merged_batch)) { + *to_be_cached_state = *merged_batch; } *write_with_wal = 1; } else { // WAL needs all of the batches flattened into a single batch. // We could avoid copying here with an iov-like AddRecord // interface - merged_batch = tmp_batch; + *merged_batch = tmp_batch; for (auto writer : write_group) { if (!writer->CallbackFailed()) { - Status s = WriteBatchInternal::Append(merged_batch, writer->batch, + Status s = WriteBatchInternal::Append(*merged_batch, writer->batch, /*WAL_only*/ true); - // Always returns Status::OK. - assert(s.ok()); + if (!s.ok()) { + tmp_batch->Clear(); + return s; + } if (WriteBatchInternal::IsLatestPersistentState(writer->batch)) { // We only need to cache the last of such write batch *to_be_cached_state = writer->batch; @@ -1210,7 +1217,8 @@ WriteBatch* DBImpl::MergeBatch(const WriteThread::WriteGroup& write_group, } } } - return merged_batch; + // return merged_batch; + return Status::OK(); } // When two_write_queues_ is disabled, this function is called from the only @@ -1223,6 +1231,11 @@ IOStatus DBImpl::WriteToWAL(const WriteBatch& merged_batch, assert(log_size != nullptr); Slice log_entry = WriteBatchInternal::Contents(&merged_batch); + TEST_SYNC_POINT_CALLBACK("DBImpl::WriteToWAL:log_entry", &log_entry); + auto s = merged_batch.VerifyChecksum(); + if (!s.ok()) { + return status_to_io_status(std::move(s)); + } *log_size = log_entry.size(); // When two_write_queues_ WriteToWAL has to be protected from concurretn calls // from the two queues anyway and log_write_mutex_ is already held. Otherwise @@ -1260,8 +1273,13 @@ IOStatus DBImpl::WriteToWAL(const WriteThread::WriteGroup& write_group, // Same holds for all in the batch group size_t write_with_wal = 0; WriteBatch* to_be_cached_state = nullptr; - WriteBatch* merged_batch = MergeBatch(write_group, &tmp_batch_, - &write_with_wal, &to_be_cached_state); + WriteBatch* merged_batch; + io_s = status_to_io_status(MergeBatch(write_group, &tmp_batch_, &merged_batch, + &write_with_wal, &to_be_cached_state)); + if (UNLIKELY(!io_s.ok())) { + return io_s; + } + if (merged_batch == write_group.leader->batch) { write_group.leader->log_used = logfile_number_; } else if (write_with_wal > 1) { @@ -1351,8 +1369,12 @@ IOStatus DBImpl::ConcurrentWriteToWAL( WriteBatch tmp_batch; size_t write_with_wal = 0; WriteBatch* to_be_cached_state = nullptr; - WriteBatch* merged_batch = - MergeBatch(write_group, &tmp_batch, &write_with_wal, &to_be_cached_state); + WriteBatch* merged_batch; + io_s = status_to_io_status(MergeBatch(write_group, &tmp_batch, &merged_batch, + &write_with_wal, &to_be_cached_state)); + if (UNLIKELY(!io_s.ok())) { + return io_s; + } // We need to lock log_write_mutex_ since logs_ and alive_log_files might be // pushed back concurrently diff --git a/db/db_kv_checksum_test.cc b/db/db_kv_checksum_test.cc index 44ee56786..5636c9e6e 100644 --- a/db/db_kv_checksum_test.cc +++ b/db/db_kv_checksum_test.cc @@ -25,6 +25,49 @@ WriteBatchOpType operator+(WriteBatchOpType lhs, const int rhs) { return static_cast(static_cast(lhs) + rhs); } +std::pair GetWriteBatch(ColumnFamilyHandle* cf_handle, + WriteBatchOpType op_type) { + Status s; + WriteBatch wb(0 /* reserved_bytes */, 0 /* max_bytes */, + 8 /* protection_bytes_per_entry */, 0 /* default_cf_ts_sz */); + switch (op_type) { + case WriteBatchOpType::kPut: + s = wb.Put(cf_handle, "key", "val"); + break; + case WriteBatchOpType::kDelete: + s = wb.Delete(cf_handle, "key"); + break; + case WriteBatchOpType::kSingleDelete: + s = wb.SingleDelete(cf_handle, "key"); + break; + case WriteBatchOpType::kDeleteRange: + s = wb.DeleteRange(cf_handle, "begin", "end"); + break; + case WriteBatchOpType::kMerge: + s = wb.Merge(cf_handle, "key", "val"); + break; + case WriteBatchOpType::kBlobIndex: { + // TODO(ajkr): use public API once available. + uint32_t cf_id; + if (cf_handle == nullptr) { + cf_id = 0; + } else { + cf_id = cf_handle->GetID(); + } + + std::string blob_index; + BlobIndex::EncodeInlinedTTL(&blob_index, /* expiration */ 9876543210, + "val"); + + s = WriteBatchInternal::PutBlobIndex(&wb, cf_id, "key", blob_index); + break; + } + case WriteBatchOpType::kNum: + assert(false); + } + return {std::move(wb), std::move(s)}; +} + class DbKvChecksumTest : public DBTestBase, public ::testing::WithParamInterface> { @@ -35,48 +78,6 @@ class DbKvChecksumTest corrupt_byte_addend_ = std::get<1>(GetParam()); } - std::pair GetWriteBatch(ColumnFamilyHandle* cf_handle) { - Status s; - WriteBatch wb(0 /* reserved_bytes */, 0 /* max_bytes */, - 8 /* protection_bytes_per_entry */, 0 /* default_cf_ts_sz */); - switch (op_type_) { - case WriteBatchOpType::kPut: - s = wb.Put(cf_handle, "key", "val"); - break; - case WriteBatchOpType::kDelete: - s = wb.Delete(cf_handle, "key"); - break; - case WriteBatchOpType::kSingleDelete: - s = wb.SingleDelete(cf_handle, "key"); - break; - case WriteBatchOpType::kDeleteRange: - s = wb.DeleteRange(cf_handle, "begin", "end"); - break; - case WriteBatchOpType::kMerge: - s = wb.Merge(cf_handle, "key", "val"); - break; - case WriteBatchOpType::kBlobIndex: { - // TODO(ajkr): use public API once available. - uint32_t cf_id; - if (cf_handle == nullptr) { - cf_id = 0; - } else { - cf_id = cf_handle->GetID(); - } - - std::string blob_index; - BlobIndex::EncodeInlinedTTL(&blob_index, /* expiration */ 9876543210, - "val"); - - s = WriteBatchInternal::PutBlobIndex(&wb, cf_id, "key", blob_index); - break; - } - case WriteBatchOpType::kNum: - assert(false); - } - return {std::move(wb), std::move(s)}; - } - void CorruptNextByteCallBack(void* arg) { Slice encoded = *static_cast(arg); if (entry_len_ == std::numeric_limits::max()) { @@ -99,34 +100,28 @@ class DbKvChecksumTest size_t entry_len_ = std::numeric_limits::max(); }; -std::string GetTestNameSuffix( - ::testing::TestParamInfo> info) { - std::ostringstream oss; - switch (std::get<0>(info.param)) { +std::string GetOpTypeString(const WriteBatchOpType& op_type) { + switch (op_type) { case WriteBatchOpType::kPut: - oss << "Put"; - break; + return "Put"; case WriteBatchOpType::kDelete: - oss << "Delete"; - break; + return "Delete"; case WriteBatchOpType::kSingleDelete: - oss << "SingleDelete"; - break; + return "SingleDelete"; case WriteBatchOpType::kDeleteRange: - oss << "DeleteRange"; + return "DeleteRange"; break; case WriteBatchOpType::kMerge: - oss << "Merge"; + return "Merge"; break; case WriteBatchOpType::kBlobIndex: - oss << "BlobIndex"; + return "BlobIndex"; break; case WriteBatchOpType::kNum: assert(false); } - oss << "Add" - << static_cast(static_cast(std::get<1>(info.param))); - return oss.str(); + assert(false); + return ""; } INSTANTIATE_TEST_CASE_P( @@ -134,7 +129,13 @@ INSTANTIATE_TEST_CASE_P( ::testing::Combine(::testing::Range(static_cast(0), WriteBatchOpType::kNum), ::testing::Values(2, 103, 251)), - GetTestNameSuffix); + [](const testing::TestParamInfo>& args) { + std::ostringstream oss; + oss << GetOpTypeString(std::get<0>(args.param)) << "Add" + << static_cast( + static_cast(std::get<1>(args.param))); + return oss.str(); + }); TEST_P(DbKvChecksumTest, MemTableAddCorrupted) { // This test repeatedly attempts to write `WriteBatch`es containing a single @@ -157,11 +158,16 @@ TEST_P(DbKvChecksumTest, MemTableAddCorrupted) { Reopen(options); SyncPoint::GetInstance()->EnableProcessing(); - auto batch_and_status = GetWriteBatch(nullptr /* cf_handle */); + auto batch_and_status = GetWriteBatch(nullptr /* cf_handle */, op_type_); ASSERT_OK(batch_and_status.second); ASSERT_TRUE( db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption()); SyncPoint::GetInstance()->DisableProcessing(); + + // In case the above callback is not invoked, this test will run + // numeric_limits::max() times until it reports an error (or will + // exhaust disk space). Added this assert to report error early. + ASSERT_TRUE(entry_len_ < std::numeric_limits::max()); } } @@ -188,14 +194,373 @@ TEST_P(DbKvChecksumTest, MemTableAddWithColumnFamilyCorrupted) { ReopenWithColumnFamilies({kDefaultColumnFamilyName, "pikachu"}, options); SyncPoint::GetInstance()->EnableProcessing(); - auto batch_and_status = GetWriteBatch(handles_[1]); + auto batch_and_status = GetWriteBatch(handles_[1], op_type_); + ASSERT_OK(batch_and_status.second); + ASSERT_TRUE( + db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption()); + SyncPoint::GetInstance()->DisableProcessing(); + + // In case the above callback is not invoked, this test will run + // numeric_limits::max() times until it reports an error (or will + // exhaust disk space). Added this assert to report error early. + ASSERT_TRUE(entry_len_ < std::numeric_limits::max()); + } +} + +TEST_P(DbKvChecksumTest, NoCorruptionCase) { + // If this test fails, we may have found a piece of malfunctioned hardware + auto batch_and_status = GetWriteBatch(nullptr, op_type_); + ASSERT_OK(batch_and_status.second); + ASSERT_OK(batch_and_status.first.VerifyChecksum()); +} + +TEST_P(DbKvChecksumTest, WriteToWALCorrupted) { + // This test repeatedly attempts to write `WriteBatch`es containing a single + // entry of type `op_type_`. Each attempt has one byte corrupted by adding + // `corrupt_byte_addend_` to its original value. The test repeats until an + // attempt has been made on each byte in the encoded write batch. All attempts + // are expected to fail with `Status::Corruption` + Options options = CurrentOptions(); + if (op_type_ == WriteBatchOpType::kMerge) { + options.merge_operator = MergeOperators::CreateStringAppendOperator(); + } + SyncPoint::GetInstance()->SetCallBack( + "DBImpl::WriteToWAL:log_entry", + std::bind(&DbKvChecksumTest::CorruptNextByteCallBack, this, + std::placeholders::_1)); + // First 8 bytes are for sequence number which is not protected in write batch + corrupt_byte_offset_ = 8; + + while (MoreBytesToCorrupt()) { + // Corrupted write batch leads to read-only mode, so we have to + // reopen for every attempt. + Reopen(options); + auto log_size_pre_write = dbfull()->TEST_total_log_size(); + + SyncPoint::GetInstance()->EnableProcessing(); + auto batch_and_status = GetWriteBatch(nullptr /* cf_handle */, op_type_); + ASSERT_OK(batch_and_status.second); + ASSERT_TRUE( + db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption()); + // Confirm that nothing was written to WAL + ASSERT_EQ(log_size_pre_write, dbfull()->TEST_total_log_size()); + ASSERT_TRUE(dbfull()->TEST_GetBGError().IsCorruption()); + SyncPoint::GetInstance()->DisableProcessing(); + + // In case the above callback is not invoked, this test will run + // numeric_limits::max() times until it reports an error (or will + // exhaust disk space). Added this assert to report error early. + ASSERT_TRUE(entry_len_ < std::numeric_limits::max()); + } +} + +TEST_P(DbKvChecksumTest, WriteToWALWithColumnFamilyCorrupted) { + // This test repeatedly attempts to write `WriteBatch`es containing a single + // entry of type `op_type_`. Each attempt has one byte corrupted by adding + // `corrupt_byte_addend_` to its original value. The test repeats until an + // attempt has been made on each byte in the encoded write batch. All attempts + // are expected to fail with `Status::Corruption` + Options options = CurrentOptions(); + if (op_type_ == WriteBatchOpType::kMerge) { + options.merge_operator = MergeOperators::CreateStringAppendOperator(); + } + CreateAndReopenWithCF({"pikachu"}, options); + SyncPoint::GetInstance()->SetCallBack( + "DBImpl::WriteToWAL:log_entry", + std::bind(&DbKvChecksumTest::CorruptNextByteCallBack, this, + std::placeholders::_1)); + // First 8 bytes are for sequence number which is not protected in write batch + corrupt_byte_offset_ = 8; + + while (MoreBytesToCorrupt()) { + // Corrupted write batch leads to read-only mode, so we have to + // reopen for every attempt. + ReopenWithColumnFamilies({kDefaultColumnFamilyName, "pikachu"}, options); + auto log_size_pre_write = dbfull()->TEST_total_log_size(); + + SyncPoint::GetInstance()->EnableProcessing(); + auto batch_and_status = GetWriteBatch(handles_[1], op_type_); ASSERT_OK(batch_and_status.second); ASSERT_TRUE( db_->Write(WriteOptions(), &batch_and_status.first).IsCorruption()); + // Confirm that nothing was written to WAL + ASSERT_EQ(log_size_pre_write, dbfull()->TEST_total_log_size()); + ASSERT_TRUE(dbfull()->TEST_GetBGError().IsCorruption()); SyncPoint::GetInstance()->DisableProcessing(); + + // In case the above callback is not invoked, this test will run + // numeric_limits::max() times until it reports an error (or will + // exhaust disk space). Added this assert to report error early. + ASSERT_TRUE(entry_len_ < std::numeric_limits::max()); + } +} + +class DbKvChecksumTestMergedBatch + : public DBTestBase, + public ::testing::WithParamInterface< + std::tuple> { + public: + DbKvChecksumTestMergedBatch() + : DBTestBase("db_kv_checksum_test", /*env_do_fsync=*/false) { + op_type1_ = std::get<0>(GetParam()); + op_type2_ = std::get<1>(GetParam()); + corrupt_byte_addend_ = std::get<2>(GetParam()); } + + protected: + WriteBatchOpType op_type1_; + WriteBatchOpType op_type2_; + char corrupt_byte_addend_; +}; + +void CorruptWriteBatch(Slice* content, size_t offset, + char corrupt_byte_addend) { + ASSERT_TRUE(offset < content->size()); + char* buf = const_cast(content->data()); + buf[offset] += corrupt_byte_addend; +} + +TEST_P(DbKvChecksumTestMergedBatch, NoCorruptionCase) { + // Veirfy write batch checksum after write batch append + auto batch1 = GetWriteBatch(nullptr /* cf_handle */, op_type1_); + ASSERT_OK(batch1.second); + auto batch2 = GetWriteBatch(nullptr /* cf_handle */, op_type2_); + ASSERT_OK(batch2.second); + ASSERT_OK(WriteBatchInternal::Append(&batch1.first, &batch2.first)); + ASSERT_OK(batch1.first.VerifyChecksum()); } +TEST_P(DbKvChecksumTestMergedBatch, WriteToWALCorrupted) { + // This test has two writers repeatedly attempt to write `WriteBatch`es + // containing a single entry of type op_type1_ and op_type2_ respectively. The + // leader of the write group writes the batch containinng the entry of type + // op_type1_. One byte of the pre-merged write batches is corrupted by adding + // `corrupt_byte_addend_` to the batch's original value during each attempt. + // The test repeats until an attempt has been made on each byte in both + // pre-merged write batches. All attempts are expected to fail with + // `Status::Corruption`. + Options options = CurrentOptions(); + if (op_type1_ == WriteBatchOpType::kMerge || + op_type2_ == WriteBatchOpType::kMerge) { + options.merge_operator = MergeOperators::CreateStringAppendOperator(); + } + + auto leader_batch_and_status = + GetWriteBatch(nullptr /* cf_handle */, op_type1_); + ASSERT_OK(leader_batch_and_status.second); + auto follower_batch_and_status = + GetWriteBatch(nullptr /* cf_handle */, op_type2_); + size_t leader_batch_size = leader_batch_and_status.first.GetDataSize(); + size_t total_bytes = + leader_batch_size + follower_batch_and_status.first.GetDataSize(); + // First 8 bytes are for sequence number which is not protected in write batch + size_t corrupt_byte_offset = 8; + + std::atomic follower_joined{false}; + std::atomic leader_count{0}; + port::Thread follower_thread; + // This callback should only be called by the leader thread + SyncPoint::GetInstance()->SetCallBack( + "WriteThread::JoinBatchGroup:Wait2", [&](void* arg_leader) { + auto* leader = reinterpret_cast(arg_leader); + ASSERT_EQ(leader->state, WriteThread::STATE_GROUP_LEADER); + + // This callback should only be called by the follower thread + SyncPoint::GetInstance()->SetCallBack( + "WriteThread::JoinBatchGroup:Wait", [&](void* arg_follower) { + auto* follower = + reinterpret_cast(arg_follower); + // The leader thread will wait on this bool and hence wait until + // this writer joins the write group + ASSERT_NE(follower->state, WriteThread::STATE_GROUP_LEADER); + if (corrupt_byte_offset >= leader_batch_size) { + Slice batch_content = follower->batch->Data(); + CorruptWriteBatch(&batch_content, + corrupt_byte_offset - leader_batch_size, + corrupt_byte_addend_); + } + // Leader busy waits on this flag + follower_joined = true; + // So the follower does not enter the outer callback at + // WriteThread::JoinBatchGroup:Wait2 + SyncPoint::GetInstance()->DisableProcessing(); + }); + + // Start the other writer thread which will join the write group as + // follower + follower_thread = port::Thread([&]() { + follower_batch_and_status = + GetWriteBatch(nullptr /* cf_handle */, op_type2_); + ASSERT_OK(follower_batch_and_status.second); + ASSERT_TRUE( + db_->Write(WriteOptions(), &follower_batch_and_status.first) + .IsCorruption()); + }); + + ASSERT_EQ(leader->batch->GetDataSize(), leader_batch_size); + if (corrupt_byte_offset < leader_batch_size) { + Slice batch_content = leader->batch->Data(); + CorruptWriteBatch(&batch_content, corrupt_byte_offset, + corrupt_byte_addend_); + } + leader_count++; + while (!follower_joined) { + // busy waiting + } + }); + while (corrupt_byte_offset < total_bytes) { + // Reopen DB since it failed WAL write which lead to read-only mode + Reopen(options); + SyncPoint::GetInstance()->EnableProcessing(); + auto log_size_pre_write = dbfull()->TEST_total_log_size(); + leader_batch_and_status = GetWriteBatch(nullptr /* cf_handle */, op_type1_); + ASSERT_OK(leader_batch_and_status.second); + ASSERT_TRUE(db_->Write(WriteOptions(), &leader_batch_and_status.first) + .IsCorruption()); + follower_thread.join(); + // Prevent leader thread from entering this callback + SyncPoint::GetInstance()->ClearCallBack("WriteThread::JoinBatchGroup:Wait"); + ASSERT_EQ(1, leader_count); + // Nothing should have been written to WAL + ASSERT_EQ(log_size_pre_write, dbfull()->TEST_total_log_size()); + ASSERT_TRUE(dbfull()->TEST_GetBGError().IsCorruption()); + + corrupt_byte_offset++; + if (corrupt_byte_offset == leader_batch_size) { + // skip over the sequence number part of follower's write batch + corrupt_byte_offset += 8; + } + follower_joined = false; + leader_count = 0; + } + SyncPoint::GetInstance()->DisableProcessing(); +} + +TEST_P(DbKvChecksumTestMergedBatch, WriteToWALWithColumnFamilyCorrupted) { + // This test has two writers repeatedly attempt to write `WriteBatch`es + // containing a single entry of type op_type1_ and op_type2_ respectively. The + // leader of the write group writes the batch containinng the entry of type + // op_type1_. One byte of the pre-merged write batches is corrupted by adding + // `corrupt_byte_addend_` to the batch's original value during each attempt. + // The test repeats until an attempt has been made on each byte in both + // pre-merged write batches. All attempts are expected to fail with + // `Status::Corruption`. + Options options = CurrentOptions(); + if (op_type1_ == WriteBatchOpType::kMerge || + op_type2_ == WriteBatchOpType::kMerge) { + options.merge_operator = MergeOperators::CreateStringAppendOperator(); + } + CreateAndReopenWithCF({"ramen"}, options); + + auto leader_batch_and_status = GetWriteBatch(handles_[1], op_type1_); + ASSERT_OK(leader_batch_and_status.second); + auto follower_batch_and_status = GetWriteBatch(handles_[1], op_type2_); + size_t leader_batch_size = leader_batch_and_status.first.GetDataSize(); + size_t total_bytes = + leader_batch_size + follower_batch_and_status.first.GetDataSize(); + // First 8 bytes are for sequence number which is not protected in write batch + size_t corrupt_byte_offset = 8; + + std::atomic follower_joined{false}; + std::atomic leader_count{0}; + port::Thread follower_thread; + // This callback should only be called by the leader thread + SyncPoint::GetInstance()->SetCallBack( + "WriteThread::JoinBatchGroup:Wait2", [&](void* arg_leader) { + auto* leader = reinterpret_cast(arg_leader); + ASSERT_EQ(leader->state, WriteThread::STATE_GROUP_LEADER); + + // This callback should only be called by the follower thread + SyncPoint::GetInstance()->SetCallBack( + "WriteThread::JoinBatchGroup:Wait", [&](void* arg_follower) { + auto* follower = + reinterpret_cast(arg_follower); + // The leader thread will wait on this bool and hence wait until + // this writer joins the write group + ASSERT_NE(follower->state, WriteThread::STATE_GROUP_LEADER); + if (corrupt_byte_offset >= leader_batch_size) { + Slice batch_content = + WriteBatchInternal::Contents(follower->batch); + CorruptWriteBatch(&batch_content, + corrupt_byte_offset - leader_batch_size, + corrupt_byte_addend_); + } + follower_joined = true; + // So the follower does not enter the outer callback at + // WriteThread::JoinBatchGroup:Wait2 + SyncPoint::GetInstance()->DisableProcessing(); + }); + + // Start the other writer thread which will join the write group as + // follower + follower_thread = port::Thread([&]() { + follower_batch_and_status = GetWriteBatch(handles_[1], op_type2_); + ASSERT_OK(follower_batch_and_status.second); + ASSERT_TRUE( + db_->Write(WriteOptions(), &follower_batch_and_status.first) + .IsCorruption()); + }); + + ASSERT_EQ(leader->batch->GetDataSize(), leader_batch_size); + if (corrupt_byte_offset < leader_batch_size) { + Slice batch_content = WriteBatchInternal::Contents(leader->batch); + CorruptWriteBatch(&batch_content, corrupt_byte_offset, + corrupt_byte_addend_); + } + leader_count++; + while (!follower_joined) { + // busy waiting + } + }); + SyncPoint::GetInstance()->EnableProcessing(); + while (corrupt_byte_offset < total_bytes) { + // Reopen DB since it failed WAL write which lead to read-only mode + ReopenWithColumnFamilies({kDefaultColumnFamilyName, "ramen"}, options); + SyncPoint::GetInstance()->EnableProcessing(); + auto log_size_pre_write = dbfull()->TEST_total_log_size(); + leader_batch_and_status = GetWriteBatch(handles_[1], op_type1_); + ASSERT_OK(leader_batch_and_status.second); + ASSERT_TRUE(db_->Write(WriteOptions(), &leader_batch_and_status.first) + .IsCorruption()); + follower_thread.join(); + // Prevent leader thread from entering this callback + SyncPoint::GetInstance()->ClearCallBack("WriteThread::JoinBatchGroup:Wait"); + + ASSERT_EQ(1, leader_count); + // Nothing should have been written to WAL + ASSERT_EQ(log_size_pre_write, dbfull()->TEST_total_log_size()); + ASSERT_TRUE(dbfull()->TEST_GetBGError().IsCorruption()); + + corrupt_byte_offset++; + if (corrupt_byte_offset == leader_batch_size) { + // skip over the sequence number part of follower's write batch + corrupt_byte_offset += 8; + } + follower_joined = false; + leader_count = 0; + } + SyncPoint::GetInstance()->DisableProcessing(); +} + +INSTANTIATE_TEST_CASE_P( + DbKvChecksumTestMergedBatch, DbKvChecksumTestMergedBatch, + ::testing::Combine(::testing::Range(static_cast(0), + WriteBatchOpType::kNum), + ::testing::Range(static_cast(0), + WriteBatchOpType::kNum), + ::testing::Values(2, 103, 251)), + [](const testing::TestParamInfo< + std::tuple>& args) { + std::ostringstream oss; + oss << GetOpTypeString(std::get<0>(args.param)) + << GetOpTypeString(std::get<1>(args.param)) << "Add" + << static_cast( + static_cast(std::get<2>(args.param))); + return oss.str(); + }); + +// TODO: add test for transactions +// TODO: add test for corrupted write batch with WAL disabled } // namespace ROCKSDB_NAMESPACE int main(int argc, char** argv) { diff --git a/db/write_batch.cc b/db/write_batch.cc index 788b9bae4..b919ea056 100644 --- a/db/write_batch.cc +++ b/db/write_batch.cc @@ -1491,6 +1491,94 @@ Status WriteBatch::UpdateTimestamps( return s; } +Status WriteBatch::VerifyChecksum() const { + if (prot_info_ == nullptr) { + return Status::OK(); + } + Slice input(rep_.data() + WriteBatchInternal::kHeader, + rep_.size() - WriteBatchInternal::kHeader); + Slice key, value, blob, xid; + char tag = 0; + uint32_t column_family = 0; // default + Status s; + size_t prot_info_idx = 0; + bool checksum_protected = true; + while (!input.empty() && prot_info_idx < prot_info_->entries_.size()) { + // In case key/value/column_family are not updated by + // ReadRecordFromWriteBatch + key.clear(); + value.clear(); + column_family = 0; + s = ReadRecordFromWriteBatch(&input, &tag, &column_family, &key, &value, + &blob, &xid); + if (!s.ok()) { + return s; + } + checksum_protected = true; + // Write batch checksum uses op_type without ColumnFamily (e.g., if op_type + // in the write batch is kTypeColumnFamilyValue, kTypeValue is used to + // compute the checksum), and encodes column family id separately. See + // comment in first `WriteBatchInternal::Put()` for more detail. + switch (tag) { + case kTypeColumnFamilyValue: + case kTypeValue: + tag = kTypeValue; + break; + case kTypeColumnFamilyDeletion: + case kTypeDeletion: + tag = kTypeDeletion; + break; + case kTypeColumnFamilySingleDeletion: + case kTypeSingleDeletion: + tag = kTypeSingleDeletion; + break; + case kTypeColumnFamilyRangeDeletion: + case kTypeRangeDeletion: + tag = kTypeRangeDeletion; + break; + case kTypeColumnFamilyMerge: + case kTypeMerge: + tag = kTypeMerge; + break; + case kTypeColumnFamilyBlobIndex: + case kTypeBlobIndex: + tag = kTypeBlobIndex; + break; + case kTypeLogData: + case kTypeBeginPrepareXID: + case kTypeEndPrepareXID: + case kTypeCommitXID: + case kTypeRollbackXID: + case kTypeNoop: + case kTypeBeginPersistedPrepareXID: + case kTypeBeginUnprepareXID: + case kTypeDeletionWithTimestamp: + case kTypeCommitXIDAndTimestamp: + checksum_protected = false; + break; + default: + return Status::Corruption( + "unknown WriteBatch tag", + std::to_string(static_cast(tag))); + } + if (checksum_protected) { + s = prot_info_->entries_[prot_info_idx++] + .StripC(column_family) + .StripKVO(key, value, static_cast(tag)) + .GetStatus(); + if (!s.ok()) { + return s; + } + } + } + + if (prot_info_idx != WriteBatchInternal::Count(this)) { + return Status::Corruption("WriteBatch has wrong count"); + } + assert(WriteBatchInternal::Count(this) == prot_info_->entries_.size()); + return Status::OK(); +} + namespace { class MemTableInserter : public WriteBatch::Handler { @@ -2773,6 +2861,14 @@ Status WriteBatchInternal::Append(WriteBatch* dst, const WriteBatch* src, const bool wal_only) { assert(dst->Count() == 0 || (dst->prot_info_ == nullptr) == (src->prot_info_ == nullptr)); + if ((src->prot_info_ != nullptr && + src->prot_info_->entries_.size() != src->Count()) || + (dst->prot_info_ != nullptr && + dst->prot_info_->entries_.size() != dst->Count())) { + return Status::Corruption( + "Write batch has inconsistent count and number of checksums"); + } + size_t src_len; int src_count; uint32_t src_flags; diff --git a/db/write_batch_internal.h b/db/write_batch_internal.h index 49abed74e..926acc63a 100644 --- a/db/write_batch_internal.h +++ b/db/write_batch_internal.h @@ -206,6 +206,10 @@ class WriteBatchInternal { bool batch_per_txn = true, bool hint_per_batch = false); + // Appends src write batch to dst write batch and updates count in dst + // write batch. Returns OK if the append is successful. Checks number of + // checksum against count in dst and src write batches, and returns Corruption + // if the count is inconsistent. static Status Append(WriteBatch* dst, const WriteBatch* src, const bool WAL_only = false); diff --git a/db/write_thread.cc b/db/write_thread.cc index d59eba263..06d7f4500 100644 --- a/db/write_thread.cc +++ b/db/write_thread.cc @@ -389,6 +389,7 @@ void WriteThread::JoinBatchGroup(Writer* w) { } TEST_SYNC_POINT_CALLBACK("WriteThread::JoinBatchGroup:Wait", w); + TEST_SYNC_POINT_CALLBACK("WriteThread::JoinBatchGroup:Wait2", w); if (!linked_as_leader) { /** diff --git a/include/rocksdb/write_batch.h b/include/rocksdb/write_batch.h index 618e4734c..d8bd108ea 100644 --- a/include/rocksdb/write_batch.h +++ b/include/rocksdb/write_batch.h @@ -391,6 +391,12 @@ class WriteBatch : public WriteBatchBase { Status UpdateTimestamps(const Slice& ts, std::function ts_sz_func); + // Verify the per-key-value checksums of this write batch. + // Corruption status will be returned if the verification fails. + // If this write batch does not have per-key-value checksum, + // OK status will be returned. + Status VerifyChecksum() const; + using WriteBatchBase::GetWriteBatch; WriteBatch* GetWriteBatch() override { return this; } diff --git a/tools/db_bench_tool.cc b/tools/db_bench_tool.cc index a163d8667..46d8a9af1 100644 --- a/tools/db_bench_tool.cc +++ b/tools/db_bench_tool.cc @@ -1656,6 +1656,10 @@ static const bool FLAGS_table_cache_numshardbits_dummy __attribute__((__unused__ RegisterFlagValidator(&FLAGS_table_cache_numshardbits, &ValidateTableCacheNumshardbits); +DEFINE_uint32(write_batch_protection_bytes_per_key, 0, + "Size of per-key-value checksum in each write batch. Currently " + "only value 0 and 8 are supported."); + namespace ROCKSDB_NAMESPACE { namespace { static Status CreateMemTableRepFactory( @@ -4910,7 +4914,8 @@ class Benchmark { RandomGenerator gen; WriteBatch batch(/*reserved_bytes=*/0, /*max_bytes=*/0, - /*protection_bytes_per_key=*/0, user_timestamp_size_); + FLAGS_write_batch_protection_bytes_per_key, + user_timestamp_size_); Status s; int64_t bytes = 0; @@ -6699,7 +6704,8 @@ class Benchmark { void DoDelete(ThreadState* thread, bool seq) { WriteBatch batch(/*reserved_bytes=*/0, /*max_bytes=*/0, - /*protection_bytes_per_key=*/0, user_timestamp_size_); + FLAGS_write_batch_protection_bytes_per_key, + user_timestamp_size_); Duration duration(seq ? 0 : FLAGS_duration, deletes_); int64_t i = 0; std::unique_ptr key_guard; @@ -6899,7 +6905,8 @@ class Benchmark { std::string keys[3]; WriteBatch batch(/*reserved_bytes=*/0, /*max_bytes=*/0, - /*protection_bytes_per_key=*/0, user_timestamp_size_); + FLAGS_write_batch_protection_bytes_per_key, + user_timestamp_size_); Status s; for (int i = 0; i < 3; i++) { keys[i] = key.ToString() + suffixes[i]; @@ -6931,7 +6938,7 @@ class Benchmark { std::string suffixes[3] = {"1", "2", "0"}; std::string keys[3]; - WriteBatch batch(0, 0, /*protection_bytes_per_key=*/0, + WriteBatch batch(0, 0, FLAGS_write_batch_protection_bytes_per_key, user_timestamp_size_); Status s; for (int i = 0; i < 3; i++) {