diff --git a/db/db_impl/db_impl_open.cc b/db/db_impl/db_impl_open.cc index 0bd773fbe..b8b78a8c8 100644 --- a/db/db_impl/db_impl_open.cc +++ b/db/db_impl/db_impl_open.cc @@ -1108,9 +1108,11 @@ Status DBImpl::RecoverLogFiles(const std::vector& wal_numbers, TEST_SYNC_POINT_CALLBACK("DBImpl::RecoverLogFiles:BeforeReadWal", /*arg=*/nullptr); + uint64_t record_checksum; while (!stop_replay_by_wal_filter && reader.ReadRecord(&record, &scratch, - immutable_db_options_.wal_recovery_mode) && + immutable_db_options_.wal_recovery_mode, + &record_checksum) && status.ok()) { if (record.size() < WriteBatchInternal::kHeader) { reporter.Corruption(record.size(), @@ -1126,8 +1128,13 @@ Status DBImpl::RecoverLogFiles(const std::vector& wal_numbers, if (!status.ok()) { return status; } - status = WriteBatchInternal::UpdateProtectionInfo(&batch, - 8 /* bytes_per_key */); + TEST_SYNC_POINT_CALLBACK( + "DBImpl::RecoverLogFiles:BeforeUpdateProtectionInfo:batch", &batch); + TEST_SYNC_POINT_CALLBACK( + "DBImpl::RecoverLogFiles:BeforeUpdateProtectionInfo:checksum", + &record_checksum); + status = WriteBatchInternal::UpdateProtectionInfo( + &batch, 8 /* bytes_per_key */, &record_checksum); if (!status.ok()) { return status; } diff --git a/db/db_kv_checksum_test.cc b/db/db_kv_checksum_test.cc index bf5c5b464..7c8c0190e 100644 --- a/db/db_kv_checksum_test.cc +++ b/db/db_kv_checksum_test.cc @@ -627,6 +627,39 @@ INSTANTIATE_TEST_CASE_P( // TODO: add test for transactions // TODO: add test for corrupted write batch with WAL disabled + +class DbKVChecksumWALToWriteBatchTest : public DBTestBase { + public: + DbKVChecksumWALToWriteBatchTest() + : DBTestBase("db_kv_checksum_test", /*env_do_fsync=*/false) {} +}; + +TEST_F(DbKVChecksumWALToWriteBatchTest, WriteBatchChecksumHandoff) { + Options options = CurrentOptions(); + Reopen(options); + ASSERT_OK(db_->Put(WriteOptions(), "key", "val")); + std::string content = ""; + SyncPoint::GetInstance()->SetCallBack( + "DBImpl::RecoverLogFiles:BeforeUpdateProtectionInfo:batch", + [&](void* batch_ptr) { + WriteBatch* batch = reinterpret_cast(batch_ptr); + content.assign(batch->Data().data(), batch->GetDataSize()); + Slice batch_content = batch->Data(); + // Corrupt first bit + CorruptWriteBatch(&batch_content, 0, 1); + }); + SyncPoint::GetInstance()->SetCallBack( + "DBImpl::RecoverLogFiles:BeforeUpdateProtectionInfo:checksum", + [&](void* checksum_ptr) { + // Verify that checksum is produced on the batch content + uint64_t checksum = *reinterpret_cast(checksum_ptr); + ASSERT_EQ(checksum, XXH3_64bits(content.data(), content.size())); + }); + SyncPoint::GetInstance()->EnableProcessing(); + ASSERT_TRUE(TryReopen(options).IsCorruption()); + SyncPoint::GetInstance()->DisableProcessing(); +}; + } // namespace ROCKSDB_NAMESPACE int main(int argc, char** argv) { diff --git a/db/log_reader.cc b/db/log_reader.cc index 0c1852e82..a9f43ae80 100644 --- a/db/log_reader.cc +++ b/db/log_reader.cc @@ -43,13 +43,17 @@ Reader::Reader(std::shared_ptr info_log, first_record_read_(false), compression_type_(kNoCompression), compression_type_record_read_(false), - uncompress_(nullptr) {} + uncompress_(nullptr), + hash_state_(nullptr) {} Reader::~Reader() { delete[] backing_store_; if (uncompress_) { delete uncompress_; } + if (hash_state_) { + XXH3_freeState(hash_state_); + } } // For kAbsoluteConsistency, on clean shutdown we don't expect any error @@ -60,9 +64,15 @@ Reader::~Reader() { // TODO krad: Evaluate if we need to move to a more strict mode where we // restrict the inconsistency to only the last log bool Reader::ReadRecord(Slice* record, std::string* scratch, - WALRecoveryMode wal_recovery_mode) { + WALRecoveryMode wal_recovery_mode, uint64_t* checksum) { scratch->clear(); record->clear(); + if (checksum != nullptr) { + if (hash_state_ == nullptr) { + hash_state_ = XXH3_createState(); + } + XXH3_64bits_reset(hash_state_); + } if (uncompress_) { uncompress_->Reset(); } @@ -86,6 +96,10 @@ bool Reader::ReadRecord(Slice* record, std::string* scratch, // at the beginning of the next block. ReportCorruption(scratch->size(), "partial record without end(1)"); } + if (checksum != nullptr) { + // No need to stream since the record is a single fragment + *checksum = XXH3_64bits(fragment.data(), fragment.size()); + } prospective_record_offset = physical_record_offset; scratch->clear(); *record = fragment; @@ -101,6 +115,10 @@ bool Reader::ReadRecord(Slice* record, std::string* scratch, // of a block followed by a kFullType or kFirstType record // at the beginning of the next block. ReportCorruption(scratch->size(), "partial record without end(2)"); + XXH3_64bits_reset(hash_state_); + } + if (checksum != nullptr) { + XXH3_64bits_update(hash_state_, fragment.data(), fragment.size()); } prospective_record_offset = physical_record_offset; scratch->assign(fragment.data(), fragment.size()); @@ -113,6 +131,9 @@ bool Reader::ReadRecord(Slice* record, std::string* scratch, ReportCorruption(fragment.size(), "missing start of fragmented record(1)"); } else { + if (checksum != nullptr) { + XXH3_64bits_update(hash_state_, fragment.data(), fragment.size()); + } scratch->append(fragment.data(), fragment.size()); } break; @@ -123,6 +144,10 @@ bool Reader::ReadRecord(Slice* record, std::string* scratch, ReportCorruption(fragment.size(), "missing start of fragmented record(2)"); } else { + if (checksum != nullptr) { + XXH3_64bits_update(hash_state_, fragment.data(), fragment.size()); + *checksum = XXH3_64bits_digest(hash_state_); + } scratch->append(fragment.data(), fragment.size()); *record = Slice(*scratch); last_record_offset_ = prospective_record_offset; @@ -509,7 +534,8 @@ void Reader::InitCompression(const CompressionTypeRecord& compression_record) { } bool FragmentBufferedReader::ReadRecord(Slice* record, std::string* scratch, - WALRecoveryMode /*unused*/) { + WALRecoveryMode /*unused*/, + uint64_t* /* checksum */) { assert(record != nullptr); assert(scratch != nullptr); record->clear(); diff --git a/db/log_reader.h b/db/log_reader.h index dbea81728..677939099 100644 --- a/db/log_reader.h +++ b/db/log_reader.h @@ -18,6 +18,7 @@ #include "rocksdb/slice.h" #include "rocksdb/status.h" #include "util/compression.h" +#include "util/xxhash.h" namespace ROCKSDB_NAMESPACE { class Logger; @@ -61,12 +62,17 @@ class Reader { // Read the next record into *record. Returns true if read // successfully, false if we hit end of the input. May use - // "*scratch" as temporary storage. The contents filled in *record + // "*scratch" as temporary storage. The contents filled in *record // will only be valid until the next mutating operation on this // reader or the next mutation to *scratch. + // If record_checksum is not nullptr, then this function will calculate the + // checksum of the record read and set record_checksum to it. The checksum is + // calculated from the original buffers that contain the contents of the + // record. virtual bool ReadRecord(Slice* record, std::string* scratch, WALRecoveryMode wal_recovery_mode = - WALRecoveryMode::kTolerateCorruptedTailRecords); + WALRecoveryMode::kTolerateCorruptedTailRecords, + uint64_t* record_checksum = nullptr); // Returns the physical offset of the last record returned by ReadRecord. // @@ -145,6 +151,8 @@ class Reader { std::unique_ptr uncompressed_buffer_; // Reusable uncompressed record std::string uncompressed_record_; + // Used for stream hashing log record + XXH3_state_t* hash_state_; // Extend record types with the following special values enum { @@ -191,7 +199,8 @@ class FragmentBufferedReader : public Reader { ~FragmentBufferedReader() override {} bool ReadRecord(Slice* record, std::string* scratch, WALRecoveryMode wal_recovery_mode = - WALRecoveryMode::kTolerateCorruptedTailRecords) override; + WALRecoveryMode::kTolerateCorruptedTailRecords, + uint64_t* record_checksum = nullptr) override; void UnmarkEOF() override; private: diff --git a/db/write_batch.cc b/db/write_batch.cc index 90e72b751..4301800d0 100644 --- a/db/write_batch.cc +++ b/db/write_batch.cc @@ -3063,7 +3063,8 @@ size_t WriteBatchInternal::AppendedByteSize(size_t leftByteSize, } Status WriteBatchInternal::UpdateProtectionInfo(WriteBatch* wb, - size_t bytes_per_key) { + size_t bytes_per_key, + uint64_t* checksum) { if (bytes_per_key == 0) { if (wb->prot_info_ != nullptr) { wb->prot_info_.reset(); @@ -3076,7 +3077,14 @@ Status WriteBatchInternal::UpdateProtectionInfo(WriteBatch* wb, if (wb->prot_info_ == nullptr) { wb->prot_info_.reset(new WriteBatch::ProtectionInfo()); ProtectionInfoUpdater prot_info_updater(wb->prot_info_.get()); - return wb->Iterate(&prot_info_updater); + Status s = wb->Iterate(&prot_info_updater); + if (s.ok() && checksum != nullptr) { + uint64_t expected_hash = XXH3_64bits(wb->rep_.data(), wb->rep_.size()); + if (expected_hash != *checksum) { + return Status::Corruption("Write batch content corrupted."); + } + } + return s; } else { // Already protected. return Status::OK(); diff --git a/db/write_batch_internal.h b/db/write_batch_internal.h index 9bc5ab98a..ee8690c28 100644 --- a/db/write_batch_internal.h +++ b/db/write_batch_internal.h @@ -240,7 +240,10 @@ class WriteBatchInternal { return wb.has_key_with_ts_; } - static Status UpdateProtectionInfo(WriteBatch* wb, size_t bytes_per_key); + // Update per-key value protection information on this write batch. + // If checksum is provided, the batch content is verfied against the checksum. + static Status UpdateProtectionInfo(WriteBatch* wb, size_t bytes_per_key, + uint64_t* checksum = nullptr); }; // LocalSavePoint is similar to a scope guard