diff --git a/utilities/transactions/write_unprepared_transaction_test.cc b/utilities/transactions/write_unprepared_transaction_test.cc index af7b7694d..af0bb03bd 100644 --- a/utilities/transactions/write_unprepared_transaction_test.cc +++ b/utilities/transactions/write_unprepared_transaction_test.cc @@ -659,6 +659,56 @@ TEST_P(WriteUnpreparedTransactionTest, SavePoint) { delete txn; } +TEST_P(WriteUnpreparedTransactionTest, UntrackedKeys) { + WriteOptions woptions; + TransactionOptions txn_options; + txn_options.write_batch_flush_threshold = 1; + + Transaction* txn = db->BeginTransaction(woptions, txn_options); + auto wb = txn->GetWriteBatch()->GetWriteBatch(); + ASSERT_OK(txn->Put("a", "a")); + ASSERT_OK(wb->Put("a_untrack", "a_untrack")); + txn->SetSavePoint(); + ASSERT_OK(txn->Put("b", "b")); + ASSERT_OK(txn->Put("b_untrack", "b_untrack")); + + ReadOptions roptions; + std::string value; + ASSERT_OK(txn->Get(roptions, "a", &value)); + ASSERT_EQ(value, "a"); + ASSERT_OK(txn->Get(roptions, "a_untrack", &value)); + ASSERT_EQ(value, "a_untrack"); + ASSERT_OK(txn->Get(roptions, "b", &value)); + ASSERT_EQ(value, "b"); + ASSERT_OK(txn->Get(roptions, "b_untrack", &value)); + ASSERT_EQ(value, "b_untrack"); + + // b and b_untrack should be rolled back. + ASSERT_OK(txn->RollbackToSavePoint()); + ASSERT_OK(txn->Get(roptions, "a", &value)); + ASSERT_EQ(value, "a"); + ASSERT_OK(txn->Get(roptions, "a_untrack", &value)); + ASSERT_EQ(value, "a_untrack"); + auto s = txn->Get(roptions, "b", &value); + ASSERT_TRUE(s.IsNotFound()); + s = txn->Get(roptions, "b_untrack", &value); + ASSERT_TRUE(s.IsNotFound()); + + // Everything should be rolled back. + ASSERT_OK(txn->Rollback()); + s = txn->Get(roptions, "a", &value); + ASSERT_TRUE(s.IsNotFound()); + s = txn->Get(roptions, "a_untrack", &value); + ASSERT_TRUE(s.IsNotFound()); + s = txn->Get(roptions, "b", &value); + ASSERT_TRUE(s.IsNotFound()); + s = txn->Get(roptions, "b_untrack", &value); + ASSERT_TRUE(s.IsNotFound()); + + delete txn; +} + + } // namespace rocksdb int main(int argc, char** argv) { diff --git a/utilities/transactions/write_unprepared_txn.cc b/utilities/transactions/write_unprepared_txn.cc index bcaeb1eae..dd4caed8c 100644 --- a/utilities/transactions/write_unprepared_txn.cc +++ b/utilities/transactions/write_unprepared_txn.cc @@ -95,6 +95,7 @@ void WriteUnpreparedTxn::Initialize(const TransactionOptions& txn_options) { largest_validated_seq_ = 0; assert(active_iterators_.empty()); active_iterators_.clear(); + untracked_keys_.clear(); } Status WriteUnpreparedTxn::HandleWrite(std::function do_write) { @@ -286,6 +287,65 @@ Status WriteUnpreparedTxn::FlushWriteBatchToDBInternal(bool prepared) { } } + struct UntrackedKeyHandler : public WriteBatch::Handler { + WriteUnpreparedTxn* txn_; + bool rollback_merge_operands_; + + UntrackedKeyHandler(WriteUnpreparedTxn* txn, bool rollback_merge_operands) + : txn_(txn), rollback_merge_operands_(rollback_merge_operands) {} + + Status AddUntrackedKey(uint32_t cf, const Slice& key) { + auto str = key.ToString(); + if (txn_->tracked_keys_[cf].count(str) == 0) { + txn_->untracked_keys_[cf].push_back(str); + } + return Status::OK(); + } + + Status PutCF(uint32_t cf, const Slice& key, const Slice&) override { + return AddUntrackedKey(cf, key); + } + + Status DeleteCF(uint32_t cf, const Slice& key) override { + return AddUntrackedKey(cf, key); + } + + Status SingleDeleteCF(uint32_t cf, const Slice& key) override { + return AddUntrackedKey(cf, key); + } + + Status MergeCF(uint32_t cf, const Slice& key, const Slice&) override { + if (rollback_merge_operands_) { + return AddUntrackedKey(cf, key); + } + return Status::OK(); + } + + // The only expected 2PC marker is the initial Noop marker. + Status MarkNoop(bool empty_batch) override { + return empty_batch ? Status::OK() : Status::InvalidArgument(); + } + + Status MarkBeginPrepare(bool) override { return Status::InvalidArgument(); } + + Status MarkEndPrepare(const Slice&) override { + return Status::InvalidArgument(); + } + + Status MarkCommit(const Slice&) override { + return Status::InvalidArgument(); + } + + Status MarkRollback(const Slice&) override { + return Status::InvalidArgument(); + } + }; + + UntrackedKeyHandler handler( + this, wupt_db_->txn_db_options_.rollback_merge_operands); + auto s = GetWriteBatch()->GetWriteBatch()->Iterate(&handler); + assert(s.ok()); + // TODO(lth): Reduce duplicate code with WritePrepared prepare logic. WriteOptions write_options = write_options_; write_options.disableWAL = false; @@ -311,11 +371,10 @@ Status WriteUnpreparedTxn::FlushWriteBatchToDBInternal(bool prepared) { // from the current transaction. This means that if log_number_ is set, // WriteImpl should not overwrite that value, so set log_used to nullptr if // log_number_ is already set. - auto s = - db_impl_->WriteImpl(write_options, GetWriteBatch()->GetWriteBatch(), - /*callback*/ nullptr, &last_log_number_, /*log ref*/ - 0, !DISABLE_MEMTABLE, &seq_used, prepare_batch_cnt_, - &add_prepared_callback); + s = db_impl_->WriteImpl(write_options, GetWriteBatch()->GetWriteBatch(), + /*callback*/ nullptr, &last_log_number_, + /*log ref*/ 0, !DISABLE_MEMTABLE, &seq_used, + prepare_batch_cnt_, &add_prepared_callback); if (log_number_ == 0) { log_number_ = last_log_number_; } @@ -577,6 +636,59 @@ Status WriteUnpreparedTxn::CommitInternal() { return s; } +Status WriteUnpreparedTxn::WriteRollbackKeys( + const TransactionKeyMap& tracked_keys, WriteBatchWithIndex* rollback_batch, + ReadCallback* callback, const ReadOptions& roptions) { + const auto& cf_map = *wupt_db_->GetCFHandleMap(); + auto WriteRollbackKey = [&](const std::string& key, uint32_t cfid) { + const auto& cf_handle = cf_map.at(cfid); + PinnableSlice pinnable_val; + bool not_used; + DBImpl::GetImplOptions get_impl_options; + get_impl_options.column_family = cf_handle; + get_impl_options.value = &pinnable_val; + get_impl_options.value_found = ¬_used; + get_impl_options.callback = callback; + auto s = db_impl_->GetImpl(roptions, key, get_impl_options); + + if (s.ok()) { + s = rollback_batch->Put(cf_handle, key, pinnable_val); + assert(s.ok()); + } else if (s.IsNotFound()) { + s = rollback_batch->Delete(cf_handle, key); + assert(s.ok()); + } else { + return s; + } + + return Status::OK(); + }; + + for (const auto& cfkey : tracked_keys) { + const auto cfid = cfkey.first; + const auto& keys = cfkey.second; + for (const auto& pair : keys) { + auto s = WriteRollbackKey(pair.first, cfid); + if (!s.ok()) { + return s; + } + } + } + + for (const auto& cfkey : untracked_keys_) { + const auto cfid = cfkey.first; + const auto& keys = cfkey.second; + for (const auto& key : keys) { + auto s = WriteRollbackKey(key, cfid); + if (!s.ok()) { + return s; + } + } + } + + return Status::OK(); +} + Status WriteUnpreparedTxn::RollbackInternal() { // TODO(lth): Reduce duplicate code with WritePrepared rollback logic. WriteBatchWithIndex rollback_batch( @@ -584,7 +696,6 @@ Status WriteUnpreparedTxn::RollbackInternal() { assert(GetId() != kMaxSequenceNumber); assert(GetId() > 0); Status s; - const auto& cf_map = *wupt_db_->GetCFHandleMap(); auto read_at_seq = kMaxSequenceNumber; ReadOptions roptions; // to prevent callback's seq to be overrriden inside DBImpk::Get @@ -592,34 +703,8 @@ Status WriteUnpreparedTxn::RollbackInternal() { // Note that we do not use WriteUnpreparedTxnReadCallback because we do not // need to read our own writes when reading prior versions of the key for // rollback. - const auto& tracked_keys = GetTrackedKeys(); WritePreparedTxnReadCallback callback(wpt_db_, read_at_seq); - for (const auto& cfkey : tracked_keys) { - const auto cfid = cfkey.first; - const auto& keys = cfkey.second; - for (const auto& pair : keys) { - const auto& key = pair.first; - const auto& cf_handle = cf_map.at(cfid); - PinnableSlice pinnable_val; - bool not_used; - DBImpl::GetImplOptions get_impl_options; - get_impl_options.column_family = cf_handle; - get_impl_options.value = &pinnable_val; - get_impl_options.value_found = ¬_used; - get_impl_options.callback = &callback; - s = db_impl_->GetImpl(roptions, key, get_impl_options); - - if (s.ok()) { - s = rollback_batch.Put(cf_handle, key, pinnable_val); - assert(s.ok()); - } else if (s.IsNotFound()) { - s = rollback_batch.Delete(cf_handle, key); - assert(s.ok()); - } else { - return s; - } - } - } + WriteRollbackKeys(GetTrackedKeys(), &rollback_batch, &callback, roptions); // The Rollback marker will be used as a batch separator WriteBatchInternal::MarkRollback(rollback_batch.GetWriteBatch(), name_); @@ -701,6 +786,7 @@ void WriteUnpreparedTxn::Clear() { largest_validated_seq_ = 0; assert(active_iterators_.empty()); active_iterators_.clear(); + untracked_keys_.clear(); TransactionBaseImpl::Clear(); } @@ -745,7 +831,6 @@ Status WriteUnpreparedTxn::RollbackToSavePointInternal() { assert(save_points_ != nullptr && save_points_->size() > 0); const TransactionKeyMap& tracked_keys = save_points_->top().new_keys_; - // TODO(lth): Reduce duplicate code with RollbackInternal logic. ReadOptions roptions; roptions.snapshot = top.snapshot_->snapshot(); SequenceNumber min_uncommitted = @@ -756,34 +841,7 @@ Status WriteUnpreparedTxn::RollbackToSavePointInternal() { WriteUnpreparedTxnReadCallback callback(wupt_db_, snap_seq, min_uncommitted, top.unprep_seqs_, kBackedByDBSnapshot); - const auto& cf_map = *wupt_db_->GetCFHandleMap(); - for (const auto& cfkey : tracked_keys) { - const auto cfid = cfkey.first; - const auto& keys = cfkey.second; - - for (const auto& pair : keys) { - const auto& key = pair.first; - const auto& cf_handle = cf_map.at(cfid); - PinnableSlice pinnable_val; - bool not_used; - DBImpl::GetImplOptions get_impl_options; - get_impl_options.column_family = cf_handle; - get_impl_options.value = &pinnable_val; - get_impl_options.value_found = ¬_used; - get_impl_options.callback = &callback; - s = db_impl_->GetImpl(roptions, key, get_impl_options); - - if (s.ok()) { - s = write_batch_.Put(cf_handle, key, pinnable_val); - assert(s.ok()); - } else if (s.IsNotFound()) { - s = write_batch_.Delete(cf_handle, key); - assert(s.ok()); - } else { - return s; - } - } - } + WriteRollbackKeys(tracked_keys, &write_batch_, &callback, roptions); const bool kPrepared = true; s = FlushWriteBatchToDBInternal(!kPrepared); diff --git a/utilities/transactions/write_unprepared_txn.h b/utilities/transactions/write_unprepared_txn.h index f39f39891..e15ce3487 100644 --- a/utilities/transactions/write_unprepared_txn.h +++ b/utilities/transactions/write_unprepared_txn.h @@ -212,6 +212,9 @@ class WriteUnpreparedTxn : public WritePreparedTxn { friend class WriteUnpreparedTxnDB; const std::map& GetUnpreparedSequenceNumbers(); + Status WriteRollbackKeys(const TransactionKeyMap& tracked_keys, + WriteBatchWithIndex* rollback_batch, + ReadCallback* callback, const ReadOptions& roptions); Status MaybeFlushWriteBatchToDB(); Status FlushWriteBatchToDB(bool prepared); @@ -259,6 +262,7 @@ class WriteUnpreparedTxn : public WritePreparedTxn { // value when calling RollbackToSavepoint. SequenceNumber largest_validated_seq_; + using KeySet = std::unordered_map>; struct SavePoint { // Record of unprep_seqs_ at this savepoint. The set of unprep_seq is // used during RollbackToSavepoint to determine visibility when restoring @@ -315,6 +319,21 @@ class WriteUnpreparedTxn : public WritePreparedTxn { // batch, it is possible that the delta iterator on the iterator will point to // invalid memory. std::vector active_iterators_; + + // Untracked keys that we have to rollback. + // + // TODO(lth): Currently we we do not record untracked keys per-savepoint. + // This means that when rolling back to savepoints, we have to check all + // keys in the current transaction for rollback. Note that this is only + // inefficient, but still correct because we take a snapshot at every + // savepoint, and we will use that snapshot to construct the rollback batch. + // The rollback batch will then contain a reissue of the same marker. + // + // A more optimal solution would be to only check keys changed since the + // last savepoint. Also, it may make sense to merge this into tracked_keys_ + // and differentiate between tracked but not locked keys to avoid having two + // very similar data structures. + KeySet untracked_keys_; }; } // namespace rocksdb