Replace tracked_keys with a new LockTracker interface in TransactionDB (#7013)

Summary:
We're going to support more locking protocols such as range lock in transaction.

However, in current design, `TransactionBase` has a member `tracked_keys` which assumes that point lock (lock a single key) is used, and is used in snapshot checking (isolation protocol). When using range lock, we may use read committed instead of snapshot checking as the isolation protocol.

The most significant usage scenarios of `tracked_keys` are:
1. pessimistic transaction uses it to track the locked keys, and unlock these keys when commit or rollback.
2. optimistic transaction does not lock keys upfront, it only tracks the lock intentions in tracked_keys, and do write conflict checking when commit.
3. each `SavePoint` tracks the keys that are locked since the `SavePoint`, `RollbackToSavePoint` or `PopSavePoint` relies on both the tracked keys in `SavePoint`s and `tracked_keys`.

Based on these scenarios, if we can abstract out a `LockTracker` interface to hold a set of tracked locks (can be keys or key ranges), and have methods that can be composed together to implement the scenarios, then `tracked_keys` can be an internal data structure of one implementation of `LockTracker`. See `utilities/transactions/lock/lock_tracker.h` for the detailed interface design, and `utilities/transactions/lock/point_lock_tracker.cc` for the implementation.

In the future, a `RangeLockTracker` can be implemented to track range locks without affecting other components.

After this PR, a clean interface for lock manager should be possible, and then ideally, we can have pluggable locking protocols.

Pull Request resolved: https://github.com/facebook/rocksdb/pull/7013

Test Plan: Run `transaction_test` and `optimistic_transaction_test`.

Reviewed By: ajkr

Differential Revision: D22163706

Pulled By: cheng-chang

fbshipit-source-id: f2860577b5334e31dd2994f5bc6d7c40d502b1b4
main
Cheng Chang 4 years ago committed by Facebook GitHub Bot
parent cd48ecaa1a
commit 71c7e4935e
  1. 2
      CMakeLists.txt
  2. 2
      TARGETS
  3. 2
      src.mk
  4. 17
      utilities/transactions/lock/lock_tracker.cc
  5. 199
      utilities/transactions/lock/lock_tracker.h
  6. 266
      utilities/transactions/lock/point_lock_tracker.cc
  7. 84
      utilities/transactions/lock/point_lock_tracker.h
  8. 18
      utilities/transactions/optimistic_transaction.cc
  9. 89
      utilities/transactions/pessimistic_transaction.cc
  10. 2
      utilities/transactions/pessimistic_transaction.h
  11. 2
      utilities/transactions/pessimistic_transaction_db.cc
  12. 2
      utilities/transactions/pessimistic_transaction_db.h
  13. 230
      utilities/transactions/transaction_base.cc
  14. 37
      utilities/transactions/transaction_base.h
  15. 27
      utilities/transactions/transaction_lock_mgr.cc
  16. 2
      utilities/transactions/transaction_lock_mgr.h
  17. 27
      utilities/transactions/transaction_util.cc
  18. 35
      utilities/transactions/transaction_util.h
  19. 34
      utilities/transactions/write_unprepared_txn.cc
  20. 2
      utilities/transactions/write_unprepared_txn.h

@ -778,6 +778,8 @@ set(SOURCES
utilities/simulator_cache/sim_cache.cc utilities/simulator_cache/sim_cache.cc
utilities/table_properties_collectors/compact_on_deletion_collector.cc utilities/table_properties_collectors/compact_on_deletion_collector.cc
utilities/trace/file_trace_reader_writer.cc utilities/trace/file_trace_reader_writer.cc
utilities/transactions/lock/lock_tracker.cc
utilities/transactions/lock/point_lock_tracker.cc
utilities/transactions/optimistic_transaction_db_impl.cc utilities/transactions/optimistic_transaction_db_impl.cc
utilities/transactions/optimistic_transaction.cc utilities/transactions/optimistic_transaction.cc
utilities/transactions/pessimistic_transaction.cc utilities/transactions/pessimistic_transaction.cc

@ -358,6 +358,8 @@ cpp_library(
"utilities/simulator_cache/sim_cache.cc", "utilities/simulator_cache/sim_cache.cc",
"utilities/table_properties_collectors/compact_on_deletion_collector.cc", "utilities/table_properties_collectors/compact_on_deletion_collector.cc",
"utilities/trace/file_trace_reader_writer.cc", "utilities/trace/file_trace_reader_writer.cc",
"utilities/transactions/lock/lock_tracker.cc",
"utilities/transactions/lock/point_lock_tracker.cc",
"utilities/transactions/optimistic_transaction.cc", "utilities/transactions/optimistic_transaction.cc",
"utilities/transactions/optimistic_transaction_db_impl.cc", "utilities/transactions/optimistic_transaction_db_impl.cc",
"utilities/transactions/pessimistic_transaction.cc", "utilities/transactions/pessimistic_transaction.cc",

@ -238,6 +238,8 @@ LIB_SOURCES = \
utilities/simulator_cache/sim_cache.cc \ utilities/simulator_cache/sim_cache.cc \
utilities/table_properties_collectors/compact_on_deletion_collector.cc \ utilities/table_properties_collectors/compact_on_deletion_collector.cc \
utilities/trace/file_trace_reader_writer.cc \ utilities/trace/file_trace_reader_writer.cc \
utilities/transactions/lock/lock_tracker.cc \
utilities/transactions/lock/point_lock_tracker.cc \
utilities/transactions/optimistic_transaction.cc \ utilities/transactions/optimistic_transaction.cc \
utilities/transactions/optimistic_transaction_db_impl.cc \ utilities/transactions/optimistic_transaction_db_impl.cc \
utilities/transactions/pessimistic_transaction.cc \ utilities/transactions/pessimistic_transaction.cc \

@ -0,0 +1,17 @@
// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
// This source code is licensed under both the GPLv2 (found in the
// COPYING file in the root directory) and Apache 2.0 License
// (found in the LICENSE.Apache file in the root directory).
#include "utilities/transactions/lock/lock_tracker.h"
#include "utilities/transactions/lock/point_lock_tracker.h"
namespace ROCKSDB_NAMESPACE {
LockTracker* NewLockTracker() {
// TODO: determine the lock tracker implementation based on configuration.
return new PointLockTracker();
}
} // namespace ROCKSDB_NAMESPACE

@ -0,0 +1,199 @@
// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
// This source code is licensed under both the GPLv2 (found in the
// COPYING file in the root directory) and Apache 2.0 License
// (found in the LICENSE.Apache file in the root directory).
#pragma once
#include <memory>
#include "rocksdb/rocksdb_namespace.h"
#include "rocksdb/status.h"
#include "rocksdb/types.h"
namespace ROCKSDB_NAMESPACE {
using ColumnFamilyId = uint32_t;
// Request for locking a single key.
struct PointLockRequest {
// The id of the key's column family.
ColumnFamilyId column_family_id = 0;
// The key to lock.
std::string key;
// The sequence number from which there is no concurrent update to key.
SequenceNumber seq = 0;
// Whether the lock is acquired only for read.
bool read_only = false;
// Whether the lock is in exclusive mode.
bool exclusive = true;
};
// Request for locking a range of keys.
struct RangeLockRequest {
// TODO
};
struct PointLockStatus {
// Whether the key is locked.
bool locked = false;
// Whether the key is locked in exclusive mode.
bool exclusive = true;
// The sequence number in the tracked PointLockRequest.
SequenceNumber seq = 0;
};
// Return status when calling LockTracker::Untrack.
enum class UntrackStatus {
// The lock is not tracked at all, so no lock to untrack.
NOT_TRACKED,
// The lock is untracked but not removed from the tracker.
UNTRACKED,
// The lock is removed from the tracker.
REMOVED,
};
// Tracks the lock requests.
// In PessimisticTransaction, it tracks the locks acquired through LockMgr;
// In OptimisticTransaction, since there is no LockMgr, it tracks the lock
// intention. Not thread-safe.
class LockTracker {
public:
virtual ~LockTracker() {}
// Whether supports locking a specific key.
virtual bool IsPointLockSupported() const = 0;
// Whether supports locking a range of keys.
virtual bool IsRangeLockSupported() const = 0;
// Tracks the acquirement of a lock on key.
//
// If this method is not supported, leave it as a no-op.
virtual void Track(const PointLockRequest& /*lock_request*/) = 0;
// Untracks the lock on a key.
// seq and exclusive in lock_request are not used.
//
// If this method is not supported, leave it as a no-op and
// returns NOT_TRACKED.
virtual UntrackStatus Untrack(const PointLockRequest& /*lock_request*/) = 0;
// Counterpart of Track(const PointLockRequest&) for RangeLockRequest.
virtual void Track(const RangeLockRequest& /*lock_request*/) = 0;
// Counterpart of Untrack(const PointLockRequest&) for RangeLockRequest.
virtual UntrackStatus Untrack(const RangeLockRequest& /*lock_request*/) = 0;
// Merges lock requests tracked in the specified tracker into the current
// tracker.
//
// E.g. for point lock, if a key in tracker is not yet tracked,
// track this new key; otherwise, merge the tracked information of the key
// such as lock's exclusiveness, read/write statistics.
//
// If this method is not supported, leave it as a no-op.
//
// REQUIRED: the specified tracker must be of the same concrete class type as
// the current tracker.
virtual void Merge(const LockTracker& /*tracker*/) = 0;
// This is a reverse operation of Merge.
//
// E.g. for point lock, if a key exists in both current and the sepcified
// tracker, then subtract the information (such as read/write statistics) of
// the key in the specified tracker from the current tracker.
//
// If this method is not supported, leave it as a no-op.
//
// REQUIRED:
// The specified tracker must be of the same concrete class type as
// the current tracker.
// The tracked locks in the specified tracker must be a subset of those
// tracked by the current tracker.
virtual void Subtract(const LockTracker& /*tracker*/) = 0;
// Clears all tracked locks.
virtual void Clear() = 0;
// Gets the new locks (excluding the locks that have been tracked before the
// save point) tracked since the specified save point, the result is stored
// in an internally constructed LockTracker and returned.
//
// save_point_tracker is the tracker used by a SavePoint to track locks
// tracked after creating the SavePoint.
//
// The implementation should document whether point lock, or range lock, or
// both are considered in this method.
// If this method is not supported, returns nullptr.
//
// REQUIRED:
// The save_point_tracker must be of the same concrete class type as the
// current tracker.
// The tracked locks in the specified tracker must be a subset of those
// tracked by the current tracker.
virtual LockTracker* GetTrackedLocksSinceSavePoint(
const LockTracker& /*save_point_tracker*/) const = 0;
// Gets lock related information of the key.
//
// If point lock is not supported, always returns LockStatus with
// locked=false.
virtual PointLockStatus GetPointLockStatus(
ColumnFamilyId /*column_family_id*/,
const std::string& /*key*/) const = 0;
// Gets number of tracked point locks.
//
// If point lock is not supported, always returns 0.
virtual uint64_t GetNumPointLocks() const = 0;
class ColumnFamilyIterator {
public:
virtual ~ColumnFamilyIterator() {}
// Whether there are remaining column families.
virtual bool HasNext() const = 0;
// Gets next column family id.
//
// If HasNext is false, calling this method has undefined behavior.
virtual ColumnFamilyId Next() = 0;
};
// Gets an iterator for column families.
//
// Returned iterator must not be nullptr.
// If there is no column family to iterate,
// returns an empty non-null iterator.
// Caller owns the returned pointer.
virtual ColumnFamilyIterator* GetColumnFamilyIterator() const = 0;
class KeyIterator {
public:
virtual ~KeyIterator() {}
// Whether there are remaining keys.
virtual bool HasNext() const = 0;
// Gets the next key.
//
// If HasNext is false, calling this method has undefined behavior.
virtual const std::string& Next() = 0;
};
// Gets an iterator for keys with tracked point locks in the column family.
//
// The column family must exist.
// Returned iterator must not be nullptr.
// Caller owns the returned pointer.
virtual KeyIterator* GetKeyIterator(
ColumnFamilyId /*column_family_id*/) const = 0;
};
// LockTracker should always be constructed through this factory method,
// instead of constructing through concrete implementations' constructor.
// Caller owns the returned pointer.
LockTracker* NewLockTracker();
} // namespace ROCKSDB_NAMESPACE

@ -0,0 +1,266 @@
// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
// This source code is licensed under both the GPLv2 (found in the
// COPYING file in the root directory) and Apache 2.0 License
// (found in the LICENSE.Apache file in the root directory).
#include "utilities/transactions/lock/point_lock_tracker.h"
namespace ROCKSDB_NAMESPACE {
namespace {
class TrackedKeysColumnFamilyIterator
: public LockTracker::ColumnFamilyIterator {
public:
explicit TrackedKeysColumnFamilyIterator(const TrackedKeys& keys)
: tracked_keys_(keys), it_(keys.begin()) {}
bool HasNext() const override { return it_ != tracked_keys_.end(); }
ColumnFamilyId Next() override { return (it_++)->first; }
private:
const TrackedKeys& tracked_keys_;
TrackedKeys::const_iterator it_;
};
class TrackedKeysIterator : public LockTracker::KeyIterator {
public:
TrackedKeysIterator(const TrackedKeys& keys, ColumnFamilyId id)
: key_infos_(keys.at(id)), it_(key_infos_.begin()) {}
bool HasNext() const override { return it_ != key_infos_.end(); }
const std::string& Next() override { return (it_++)->first; }
private:
const TrackedKeyInfos& key_infos_;
TrackedKeyInfos::const_iterator it_;
};
} // namespace
void PointLockTracker::Track(const PointLockRequest& r) {
auto& keys = tracked_keys_[r.column_family_id];
#ifdef __cpp_lib_unordered_map_try_emplace
// use c++17's try_emplace if available, to avoid rehashing the key
// in case it is not already in the map
auto result = keys.try_emplace(r.key, r.seq);
auto it = result.first;
if (!result.second && r.seq < it->second.seq) {
// Now tracking this key with an earlier sequence number
it->second.seq = r.seq;
}
#else
auto it = keys.find(r.key);
if (it == keys.end()) {
auto result = keys.emplace(r.key, TrackedKeyInfo(r.seq));
it = result.first;
} else if (r.seq < it->second.seq) {
// Now tracking this key with an earlier sequence number
it->second.seq = r.seq;
}
#endif
// else we do not update the seq. The smaller the tracked seq, the stronger it
// the guarantee since it implies from the seq onward there has not been a
// concurrent update to the key. So we update the seq if it implies stronger
// guarantees, i.e., if it is smaller than the existing tracked seq.
if (r.read_only) {
it->second.num_reads++;
} else {
it->second.num_writes++;
}
it->second.exclusive = it->second.exclusive || r.exclusive;
}
UntrackStatus PointLockTracker::Untrack(const PointLockRequest& r) {
auto cf_keys = tracked_keys_.find(r.column_family_id);
if (cf_keys == tracked_keys_.end()) {
return UntrackStatus::NOT_TRACKED;
}
auto& keys = cf_keys->second;
auto it = keys.find(r.key);
if (it == keys.end()) {
return UntrackStatus::NOT_TRACKED;
}
bool untracked = false;
auto& info = it->second;
if (r.read_only) {
if (info.num_reads > 0) {
info.num_reads--;
untracked = true;
}
} else {
if (info.num_writes > 0) {
info.num_writes--;
untracked = true;
}
}
bool removed = false;
if (info.num_reads == 0 && info.num_writes == 0) {
keys.erase(it);
if (keys.empty()) {
tracked_keys_.erase(cf_keys);
}
removed = true;
}
if (removed) {
return UntrackStatus::REMOVED;
}
if (untracked) {
return UntrackStatus::UNTRACKED;
}
return UntrackStatus::NOT_TRACKED;
}
void PointLockTracker::Merge(const LockTracker& tracker) {
const PointLockTracker& t = static_cast<const PointLockTracker&>(tracker);
for (const auto& cf_keys : t.tracked_keys_) {
ColumnFamilyId cf = cf_keys.first;
const auto& keys = cf_keys.second;
auto current_cf_keys = tracked_keys_.find(cf);
if (current_cf_keys == tracked_keys_.end()) {
tracked_keys_.emplace(cf_keys);
} else {
auto& current_keys = current_cf_keys->second;
for (const auto& key_info : keys) {
const std::string& key = key_info.first;
const TrackedKeyInfo& info = key_info.second;
// If key was not previously tracked, just copy the whole struct over.
// Otherwise, some merging needs to occur.
auto current_info = current_keys.find(key);
if (current_info == current_keys.end()) {
current_keys.emplace(key_info);
} else {
current_info->second.Merge(info);
}
}
}
}
}
void PointLockTracker::Subtract(const LockTracker& tracker) {
const PointLockTracker& t = static_cast<const PointLockTracker&>(tracker);
for (const auto& cf_keys : t.tracked_keys_) {
ColumnFamilyId cf = cf_keys.first;
const auto& keys = cf_keys.second;
auto& current_keys = tracked_keys_.at(cf);
for (const auto& key_info : keys) {
const std::string& key = key_info.first;
const TrackedKeyInfo& info = key_info.second;
uint32_t num_reads = info.num_reads;
uint32_t num_writes = info.num_writes;
auto current_key_info = current_keys.find(key);
assert(current_key_info != current_keys.end());
// Decrement the total reads/writes of this key by the number of
// reads/writes done since the last SavePoint.
if (num_reads > 0) {
assert(current_key_info->second.num_reads >= num_reads);
current_key_info->second.num_reads -= num_reads;
}
if (num_writes > 0) {
assert(current_key_info->second.num_writes >= num_writes);
current_key_info->second.num_writes -= num_writes;
}
if (current_key_info->second.num_reads == 0 &&
current_key_info->second.num_writes == 0) {
current_keys.erase(current_key_info);
}
}
}
}
LockTracker* PointLockTracker::GetTrackedLocksSinceSavePoint(
const LockTracker& save_point_tracker) const {
// Examine the number of reads/writes performed on all keys written
// since the last SavePoint and compare to the total number of reads/writes
// for each key.
LockTracker* t = new PointLockTracker();
const PointLockTracker& save_point_t =
static_cast<const PointLockTracker&>(save_point_tracker);
for (const auto& cf_keys : save_point_t.tracked_keys_) {
ColumnFamilyId cf = cf_keys.first;
const auto& keys = cf_keys.second;
auto& current_keys = tracked_keys_.at(cf);
for (const auto& key_info : keys) {
const std::string& key = key_info.first;
const TrackedKeyInfo& info = key_info.second;
uint32_t num_reads = info.num_reads;
uint32_t num_writes = info.num_writes;
auto current_key_info = current_keys.find(key);
assert(current_key_info != current_keys.end());
assert(current_key_info->second.num_reads >= num_reads);
assert(current_key_info->second.num_writes >= num_writes);
if (current_key_info->second.num_reads == num_reads &&
current_key_info->second.num_writes == num_writes) {
// All the reads/writes to this key were done in the last savepoint.
PointLockRequest r;
r.column_family_id = cf;
r.key = key;
r.seq = info.seq;
r.read_only = (num_writes == 0);
r.exclusive = info.exclusive;
t->Track(r);
}
}
}
return t;
}
PointLockStatus PointLockTracker::GetPointLockStatus(
ColumnFamilyId column_family_id, const std::string& key) const {
assert(IsPointLockSupported());
PointLockStatus status;
auto it = tracked_keys_.find(column_family_id);
if (it == tracked_keys_.end()) {
return status;
}
const auto& keys = it->second;
auto key_it = keys.find(key);
if (key_it == keys.end()) {
return status;
}
const TrackedKeyInfo& key_info = key_it->second;
status.locked = true;
status.exclusive = key_info.exclusive;
status.seq = key_info.seq;
return status;
}
uint64_t PointLockTracker::GetNumPointLocks() const {
uint64_t num_keys = 0;
for (const auto& cf_keys : tracked_keys_) {
num_keys += cf_keys.second.size();
}
return num_keys;
}
LockTracker::ColumnFamilyIterator* PointLockTracker::GetColumnFamilyIterator()
const {
return new TrackedKeysColumnFamilyIterator(tracked_keys_);
}
LockTracker::KeyIterator* PointLockTracker::GetKeyIterator(
ColumnFamilyId column_family_id) const {
assert(tracked_keys_.find(column_family_id) != tracked_keys_.end());
return new TrackedKeysIterator(tracked_keys_, column_family_id);
}
void PointLockTracker::Clear() { tracked_keys_.clear(); }
} // namespace ROCKSDB_NAMESPACE

@ -0,0 +1,84 @@
// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
// This source code is licensed under both the GPLv2 (found in the
// COPYING file in the root directory) and Apache 2.0 License
// (found in the LICENSE.Apache file in the root directory).
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include "utilities/transactions/lock/lock_tracker.h"
namespace ROCKSDB_NAMESPACE {
struct TrackedKeyInfo {
// Earliest sequence number that is relevant to this transaction for this key
SequenceNumber seq;
uint32_t num_writes;
uint32_t num_reads;
bool exclusive;
explicit TrackedKeyInfo(SequenceNumber seq_no)
: seq(seq_no), num_writes(0), num_reads(0), exclusive(false) {}
void Merge(const TrackedKeyInfo& info) {
assert(seq <= info.seq);
num_reads += info.num_reads;
num_writes += info.num_writes;
exclusive = exclusive || info.exclusive;
}
};
using TrackedKeyInfos = std::unordered_map<std::string, TrackedKeyInfo>;
using TrackedKeys = std::unordered_map<ColumnFamilyId, TrackedKeyInfos>;
// Tracks point locks on single keys.
class PointLockTracker : public LockTracker {
public:
PointLockTracker() = default;
PointLockTracker(const PointLockTracker&) = delete;
PointLockTracker& operator=(const PointLockTracker&) = delete;
bool IsPointLockSupported() const override { return true; }
bool IsRangeLockSupported() const override { return false; }
void Track(const PointLockRequest& lock_request) override;
UntrackStatus Untrack(const PointLockRequest& lock_request) override;
void Track(const RangeLockRequest& /*lock_request*/) override {}
UntrackStatus Untrack(const RangeLockRequest& /*lock_request*/) override {
return UntrackStatus::NOT_TRACKED;
}
void Merge(const LockTracker& tracker) override;
void Subtract(const LockTracker& tracker) override;
void Clear() override;
virtual LockTracker* GetTrackedLocksSinceSavePoint(
const LockTracker& save_point_tracker) const override;
PointLockStatus GetPointLockStatus(ColumnFamilyId column_family_id,
const std::string& key) const override;
uint64_t GetNumPointLocks() const override;
ColumnFamilyIterator* GetColumnFamilyIterator() const override;
KeyIterator* GetKeyIterator(ColumnFamilyId column_family_id) const override;
private:
TrackedKeys tracked_keys_;
};
} // namespace ROCKSDB_NAMESPACE

@ -97,9 +97,17 @@ Status OptimisticTransaction::CommitWithParallelValidate() {
const size_t space = txn_db_impl->GetLockBucketsSize(); const size_t space = txn_db_impl->GetLockBucketsSize();
std::set<size_t> lk_idxes; std::set<size_t> lk_idxes;
std::vector<std::unique_lock<std::mutex>> lks; std::vector<std::unique_lock<std::mutex>> lks;
for (auto& cfit : GetTrackedKeys()) { std::unique_ptr<LockTracker::ColumnFamilyIterator> cf_it(
for (auto& keyit : cfit.second) { tracked_locks_->GetColumnFamilyIterator());
lk_idxes.insert(fastrange64(GetSliceNPHash64(keyit.first), space)); assert(cf_it != nullptr);
while (cf_it->HasNext()) {
ColumnFamilyId cf = cf_it->Next();
std::unique_ptr<LockTracker::KeyIterator> key_it(
tracked_locks_->GetKeyIterator(cf));
assert(key_it != nullptr);
while (key_it->HasNext()) {
const std::string& key = key_it->Next();
lk_idxes.insert(fastrange64(GetSliceNPHash64(key), space));
} }
} }
// NOTE: in a single txn, all bucket-locks are taken in ascending order. // NOTE: in a single txn, all bucket-locks are taken in ascending order.
@ -109,7 +117,7 @@ Status OptimisticTransaction::CommitWithParallelValidate() {
lks.emplace_back(txn_db_impl->LockBucket(v)); lks.emplace_back(txn_db_impl->LockBucket(v));
} }
Status s = TransactionUtil::CheckKeysForConflicts(db_impl, GetTrackedKeys(), Status s = TransactionUtil::CheckKeysForConflicts(db_impl, *tracked_locks_,
true /* cache_only */); true /* cache_only */);
if (!s.ok()) { if (!s.ok()) {
return s; return s;
@ -174,7 +182,7 @@ Status OptimisticTransaction::CheckTransactionForConflicts(DB* db) {
// we will do a cache-only conflict check. This can result in TryAgain // we will do a cache-only conflict check. This can result in TryAgain
// getting returned if there is not sufficient memtable history to check // getting returned if there is not sufficient memtable history to check
// for conflicts. // for conflicts.
return TransactionUtil::CheckKeysForConflicts(db_impl, GetTrackedKeys(), return TransactionUtil::CheckKeysForConflicts(db_impl, *tracked_locks_,
true /* cache_only */); true /* cache_only */);
} }

@ -91,7 +91,7 @@ void PessimisticTransaction::Initialize(const TransactionOptions& txn_options) {
} }
PessimisticTransaction::~PessimisticTransaction() { PessimisticTransaction::~PessimisticTransaction() {
txn_db_impl_->UnLock(this, &GetTrackedKeys()); txn_db_impl_->UnLock(this, *tracked_locks_);
if (expiration_time_ > 0) { if (expiration_time_ > 0) {
txn_db_impl_->RemoveExpirableTransaction(txn_id_); txn_db_impl_->RemoveExpirableTransaction(txn_id_);
} }
@ -101,7 +101,7 @@ PessimisticTransaction::~PessimisticTransaction() {
} }
void PessimisticTransaction::Clear() { void PessimisticTransaction::Clear() {
txn_db_impl_->UnLock(this, &GetTrackedKeys()); txn_db_impl_->UnLock(this, *tracked_locks_);
TransactionBaseImpl::Clear(); TransactionBaseImpl::Clear();
} }
@ -132,8 +132,8 @@ WriteCommittedTxn::WriteCommittedTxn(TransactionDB* txn_db,
: PessimisticTransaction(txn_db, write_options, txn_options){}; : PessimisticTransaction(txn_db, write_options, txn_options){};
Status PessimisticTransaction::CommitBatch(WriteBatch* batch) { Status PessimisticTransaction::CommitBatch(WriteBatch* batch) {
TransactionKeyMap keys_to_unlock; std::unique_ptr<LockTracker> keys_to_unlock(NewLockTracker());
Status s = LockBatch(batch, &keys_to_unlock); Status s = LockBatch(batch, keys_to_unlock.get());
if (!s.ok()) { if (!s.ok()) {
return s; return s;
@ -164,7 +164,7 @@ Status PessimisticTransaction::CommitBatch(WriteBatch* batch) {
s = Status::InvalidArgument("Transaction is not in state for commit."); s = Status::InvalidArgument("Transaction is not in state for commit.");
} }
txn_db_impl_->UnLock(this, &keys_to_unlock); txn_db_impl_->UnLock(this, *keys_to_unlock);
return s; return s;
} }
@ -446,12 +446,14 @@ Status PessimisticTransaction::RollbackToSavePoint() {
return Status::InvalidArgument("Transaction is beyond state for rollback."); return Status::InvalidArgument("Transaction is beyond state for rollback.");
} }
// Unlock any keys locked since last transaction if (save_points_ != nullptr && !save_points_->empty()) {
const std::unique_ptr<TransactionKeyMap>& keys = // Unlock any keys locked since last transaction
GetTrackedKeysSinceSavePoint(); auto& save_point_tracker = *save_points_->top().new_locks_;
std::unique_ptr<LockTracker> t(
if (keys) { tracked_locks_->GetTrackedLocksSinceSavePoint(save_point_tracker));
txn_db_impl_->UnLock(this, keys.get()); if (t) {
txn_db_impl_->UnLock(this, *t);
}
} }
return TransactionBaseImpl::RollbackToSavePoint(); return TransactionBaseImpl::RollbackToSavePoint();
@ -460,7 +462,7 @@ Status PessimisticTransaction::RollbackToSavePoint() {
// Lock all keys in this batch. // Lock all keys in this batch.
// On success, caller should unlock keys_to_unlock // On success, caller should unlock keys_to_unlock
Status PessimisticTransaction::LockBatch(WriteBatch* batch, Status PessimisticTransaction::LockBatch(WriteBatch* batch,
TransactionKeyMap* keys_to_unlock) { LockTracker* keys_to_unlock) {
class Handler : public WriteBatch::Handler { class Handler : public WriteBatch::Handler {
public: public:
// Sorted map of column_family_id to sorted set of keys. // Sorted map of column_family_id to sorted set of keys.
@ -516,8 +518,13 @@ Status PessimisticTransaction::LockBatch(WriteBatch* batch,
if (!s.ok()) { if (!s.ok()) {
break; break;
} }
TrackKey(keys_to_unlock, cfh_id, std::move(key), kMaxSequenceNumber, PointLockRequest r;
false, true /* exclusive */); r.column_family_id = cfh_id;
r.key = key;
r.seq = kMaxSequenceNumber;
r.read_only = false;
r.exclusive = true;
keys_to_unlock->Track(r);
} }
if (!s.ok()) { if (!s.ok()) {
@ -526,7 +533,7 @@ Status PessimisticTransaction::LockBatch(WriteBatch* batch,
} }
if (!s.ok()) { if (!s.ok()) {
txn_db_impl_->UnLock(this, keys_to_unlock); txn_db_impl_->UnLock(this, *keys_to_unlock);
} }
return s; return s;
@ -548,28 +555,9 @@ Status PessimisticTransaction::TryLock(ColumnFamilyHandle* column_family,
} }
uint32_t cfh_id = GetColumnFamilyID(column_family); uint32_t cfh_id = GetColumnFamilyID(column_family);
std::string key_str = key.ToString(); std::string key_str = key.ToString();
bool previously_locked; PointLockStatus status = tracked_locks_->GetPointLockStatus(cfh_id, key_str);
bool lock_upgrade = false; bool previously_locked = status.locked;
bool lock_upgrade = previously_locked && exclusive && !status.exclusive;
// lock this key if this transactions hasn't already locked it
SequenceNumber tracked_at_seq = kMaxSequenceNumber;
const auto& tracked_keys = GetTrackedKeys();
const auto tracked_keys_cf = tracked_keys.find(cfh_id);
if (tracked_keys_cf == tracked_keys.end()) {
previously_locked = false;
} else {
auto iter = tracked_keys_cf->second.find(key_str);
if (iter == tracked_keys_cf->second.end()) {
previously_locked = false;
} else {
if (!iter->second.exclusive && exclusive) {
lock_upgrade = true;
}
previously_locked = true;
tracked_at_seq = iter->second.seq;
}
}
// Lock this key if this transactions hasn't already locked it or we require // Lock this key if this transactions hasn't already locked it or we require
// an upgrade. // an upgrade.
@ -585,6 +573,8 @@ Status PessimisticTransaction::TryLock(ColumnFamilyHandle* column_family,
// any writes since this transaction's snapshot. // any writes since this transaction's snapshot.
// TODO(agiardullo): could optimize by supporting shared txn locks in the // TODO(agiardullo): could optimize by supporting shared txn locks in the
// future // future
SequenceNumber tracked_at_seq =
status.locked ? status.seq : kMaxSequenceNumber;
if (!do_validate || snapshot_ == nullptr) { if (!do_validate || snapshot_ == nullptr) {
if (assume_tracked && !previously_locked) { if (assume_tracked && !previously_locked) {
s = Status::InvalidArgument( s = Status::InvalidArgument(
@ -614,15 +604,13 @@ Status PessimisticTransaction::TryLock(ColumnFamilyHandle* column_family,
if (!s.ok()) { if (!s.ok()) {
// Failed to validate key // Failed to validate key
if (!previously_locked) { // Unlock key we just locked
// Unlock key we just locked if (lock_upgrade) {
if (lock_upgrade) { s = txn_db_impl_->TryLock(this, cfh_id, key_str,
s = txn_db_impl_->TryLock(this, cfh_id, key_str, false /* exclusive */);
false /* exclusive */); assert(s.ok());
assert(s.ok()); } else if (!previously_locked) {
} else { txn_db_impl_->UnLock(this, cfh_id, key.ToString());
txn_db_impl_->UnLock(this, cfh_id, key.ToString());
}
} }
} }
} }
@ -645,10 +633,11 @@ Status PessimisticTransaction::TryLock(ColumnFamilyHandle* column_family,
TrackKey(cfh_id, key_str, tracked_at_seq, read_only, exclusive); TrackKey(cfh_id, key_str, tracked_at_seq, read_only, exclusive);
} else { } else {
#ifndef NDEBUG #ifndef NDEBUG
assert(tracked_keys_cf->second.count(key_str) > 0); PointLockStatus lock_status =
const auto& info = tracked_keys_cf->second.find(key_str)->second; tracked_locks_->GetPointLockStatus(cfh_id, key_str);
assert(info.seq <= tracked_at_seq); assert(lock_status.locked);
assert(info.exclusive == exclusive); assert(lock_status.seq <= tracked_at_seq);
assert(lock_status.exclusive == exclusive);
#endif #endif
} }
} }

@ -139,7 +139,7 @@ class PessimisticTransaction : public TransactionBaseImpl {
virtual void Initialize(const TransactionOptions& txn_options); virtual void Initialize(const TransactionOptions& txn_options);
Status LockBatch(WriteBatch* batch, TransactionKeyMap* keys_to_unlock); Status LockBatch(WriteBatch* batch, LockTracker* keys_to_unlock);
Status TryLock(ColumnFamilyHandle* column_family, const Slice& key, Status TryLock(ColumnFamilyHandle* column_family, const Slice& key,
bool read_only, bool exclusive, const bool do_validate = true, bool read_only, bool exclusive, const bool do_validate = true,

@ -402,7 +402,7 @@ Status PessimisticTransactionDB::TryLock(PessimisticTransaction* txn,
} }
void PessimisticTransactionDB::UnLock(PessimisticTransaction* txn, void PessimisticTransactionDB::UnLock(PessimisticTransaction* txn,
const TransactionKeyMap* keys) { const LockTracker& keys) {
lock_mgr_.UnLock(txn, keys, GetEnv()); lock_mgr_.UnLock(txn, keys, GetEnv());
} }

@ -99,7 +99,7 @@ class PessimisticTransactionDB : public TransactionDB {
Status TryLock(PessimisticTransaction* txn, uint32_t cfh_id, Status TryLock(PessimisticTransaction* txn, uint32_t cfh_id,
const std::string& key, bool exclusive); const std::string& key, bool exclusive);
void UnLock(PessimisticTransaction* txn, const TransactionKeyMap* keys); void UnLock(PessimisticTransaction* txn, const LockTracker& keys);
void UnLock(PessimisticTransaction* txn, uint32_t cfh_id, void UnLock(PessimisticTransaction* txn, uint32_t cfh_id,
const std::string& key); const std::string& key);

@ -16,6 +16,7 @@
#include "rocksdb/status.h" #include "rocksdb/status.h"
#include "util/cast_util.h" #include "util/cast_util.h"
#include "util/string_util.h" #include "util/string_util.h"
#include "utilities/transactions/lock/lock_tracker.h"
namespace ROCKSDB_NAMESPACE { namespace ROCKSDB_NAMESPACE {
@ -27,6 +28,7 @@ TransactionBaseImpl::TransactionBaseImpl(DB* db,
cmp_(GetColumnFamilyUserComparator(db->DefaultColumnFamily())), cmp_(GetColumnFamilyUserComparator(db->DefaultColumnFamily())),
start_time_(db_->GetEnv()->NowMicros()), start_time_(db_->GetEnv()->NowMicros()),
write_batch_(cmp_, 0, true, 0), write_batch_(cmp_, 0, true, 0),
tracked_locks_(NewLockTracker()),
indexing_enabled_(true) { indexing_enabled_(true) {
assert(dynamic_cast<DBImpl*>(db_) != nullptr); assert(dynamic_cast<DBImpl*>(db_) != nullptr);
log_number_ = 0; log_number_ = 0;
@ -44,7 +46,7 @@ void TransactionBaseImpl::Clear() {
save_points_.reset(nullptr); save_points_.reset(nullptr);
write_batch_.Clear(); write_batch_.Clear();
commit_time_batch_.Clear(); commit_time_batch_.Clear();
tracked_keys_.clear(); tracked_locks_->Clear();
num_puts_ = 0; num_puts_ = 0;
num_deletes_ = 0; num_deletes_ = 0;
num_merges_ = 0; num_merges_ = 0;
@ -143,37 +145,7 @@ Status TransactionBaseImpl::RollbackToSavePoint() {
assert(s.ok()); assert(s.ok());
// Rollback any keys that were tracked since the last savepoint // Rollback any keys that were tracked since the last savepoint
const TransactionKeyMap& key_map = save_point.new_keys_; tracked_locks_->Subtract(*save_point.new_locks_);
for (const auto& key_map_iter : key_map) {
uint32_t column_family_id = key_map_iter.first;
auto& keys = key_map_iter.second;
auto& cf_tracked_keys = tracked_keys_[column_family_id];
for (const auto& key_iter : keys) {
const std::string& key = key_iter.first;
uint32_t num_reads = key_iter.second.num_reads;
uint32_t num_writes = key_iter.second.num_writes;
auto tracked_keys_iter = cf_tracked_keys.find(key);
assert(tracked_keys_iter != cf_tracked_keys.end());
// Decrement the total reads/writes of this key by the number of
// reads/writes done since the last SavePoint.
if (num_reads > 0) {
assert(tracked_keys_iter->second.num_reads >= num_reads);
tracked_keys_iter->second.num_reads -= num_reads;
}
if (num_writes > 0) {
assert(tracked_keys_iter->second.num_writes >= num_writes);
tracked_keys_iter->second.num_writes -= num_writes;
}
if (tracked_keys_iter->second.num_reads == 0 &&
tracked_keys_iter->second.num_writes == 0) {
cf_tracked_keys.erase(tracked_keys_iter);
}
}
}
save_points_->pop(); save_points_->pop();
@ -204,35 +176,7 @@ Status TransactionBaseImpl::PopSavePoint() {
std::swap(top, save_points_->top()); std::swap(top, save_points_->top());
save_points_->pop(); save_points_->pop();
const TransactionKeyMap& curr_cf_key_map = top.new_keys_; save_points_->top().new_locks_->Merge(*top.new_locks_);
TransactionKeyMap& prev_cf_key_map = save_points_->top().new_keys_;
for (const auto& curr_cf_key_iter : curr_cf_key_map) {
uint32_t column_family_id = curr_cf_key_iter.first;
const std::unordered_map<std::string, TransactionKeyMapInfo>& curr_keys =
curr_cf_key_iter.second;
// If cfid was not previously tracked, just copy everything over.
auto prev_keys_iter = prev_cf_key_map.find(column_family_id);
if (prev_keys_iter == prev_cf_key_map.end()) {
prev_cf_key_map.emplace(curr_cf_key_iter);
} else {
std::unordered_map<std::string, TransactionKeyMapInfo>& prev_keys =
prev_keys_iter->second;
for (const auto& key_iter : curr_keys) {
const std::string& key = key_iter.first;
const TransactionKeyMapInfo& info = key_iter.second;
// If key was not previously tracked, just copy the whole struct over.
// Otherwise, some merging needs to occur.
auto prev_info = prev_keys.find(key);
if (prev_info == prev_keys.end()) {
prev_keys.emplace(key_iter);
} else {
prev_info->second.Merge(info);
}
}
}
}
} }
return write_batch_.PopSavePoint(); return write_batch_.PopSavePoint();
@ -601,106 +545,26 @@ uint64_t TransactionBaseImpl::GetNumDeletes() const { return num_deletes_; }
uint64_t TransactionBaseImpl::GetNumMerges() const { return num_merges_; } uint64_t TransactionBaseImpl::GetNumMerges() const { return num_merges_; }
uint64_t TransactionBaseImpl::GetNumKeys() const { uint64_t TransactionBaseImpl::GetNumKeys() const {
uint64_t count = 0; return tracked_locks_->GetNumPointLocks();
// sum up locked keys in all column families
for (const auto& key_map_iter : tracked_keys_) {
const auto& keys = key_map_iter.second;
count += keys.size();
}
return count;
} }
void TransactionBaseImpl::TrackKey(uint32_t cfh_id, const std::string& key, void TransactionBaseImpl::TrackKey(uint32_t cfh_id, const std::string& key,
SequenceNumber seq, bool read_only, SequenceNumber seq, bool read_only,
bool exclusive) { bool exclusive) {
PointLockRequest r;
r.column_family_id = cfh_id;
r.key = key;
r.seq = seq;
r.read_only = read_only;
r.exclusive = exclusive;
// Update map of all tracked keys for this transaction // Update map of all tracked keys for this transaction
TrackKey(&tracked_keys_, cfh_id, key, seq, read_only, exclusive); tracked_locks_->Track(r);
if (save_points_ != nullptr && !save_points_->empty()) { if (save_points_ != nullptr && !save_points_->empty()) {
// Update map of tracked keys in this SavePoint // Update map of tracked keys in this SavePoint
TrackKey(&save_points_->top().new_keys_, cfh_id, key, seq, read_only, save_points_->top().new_locks_->Track(r);
exclusive);
}
}
// Add a key to the given TransactionKeyMap
// seq for pessimistic transactions is the sequence number from which we know
// there has not been a concurrent update to the key.
void TransactionBaseImpl::TrackKey(TransactionKeyMap* key_map, uint32_t cfh_id,
const std::string& key, SequenceNumber seq,
bool read_only, bool exclusive) {
auto& cf_key_map = (*key_map)[cfh_id];
#ifdef __cpp_lib_unordered_map_try_emplace
// use c++17's try_emplace if available, to avoid rehashing the key
// in case it is not already in the map
auto result = cf_key_map.try_emplace(key, seq);
auto iter = result.first;
if (!result.second && seq < iter->second.seq) {
// Now tracking this key with an earlier sequence number
iter->second.seq = seq;
}
#else
auto iter = cf_key_map.find(key);
if (iter == cf_key_map.end()) {
auto result = cf_key_map.emplace(key, TransactionKeyMapInfo(seq));
iter = result.first;
} else if (seq < iter->second.seq) {
// Now tracking this key with an earlier sequence number
iter->second.seq = seq;
}
#endif
// else we do not update the seq. The smaller the tracked seq, the stronger it
// the guarantee since it implies from the seq onward there has not been a
// concurrent update to the key. So we update the seq if it implies stronger
// guarantees, i.e., if it is smaller than the existing tracked seq.
if (read_only) {
iter->second.num_reads++;
} else {
iter->second.num_writes++;
}
iter->second.exclusive |= exclusive;
}
std::unique_ptr<TransactionKeyMap>
TransactionBaseImpl::GetTrackedKeysSinceSavePoint() {
if (save_points_ != nullptr && !save_points_->empty()) {
// Examine the number of reads/writes performed on all keys written
// since the last SavePoint and compare to the total number of reads/writes
// for each key.
TransactionKeyMap* result = new TransactionKeyMap();
for (const auto& key_map_iter : save_points_->top().new_keys_) {
uint32_t column_family_id = key_map_iter.first;
auto& keys = key_map_iter.second;
auto& cf_tracked_keys = tracked_keys_[column_family_id];
for (const auto& key_iter : keys) {
const std::string& key = key_iter.first;
uint32_t num_reads = key_iter.second.num_reads;
uint32_t num_writes = key_iter.second.num_writes;
auto total_key_info = cf_tracked_keys.find(key);
assert(total_key_info != cf_tracked_keys.end());
assert(total_key_info->second.num_reads >= num_reads);
assert(total_key_info->second.num_writes >= num_writes);
if (total_key_info->second.num_reads == num_reads &&
total_key_info->second.num_writes == num_writes) {
// All the reads/writes to this key were done in the last savepoint.
bool read_only = (num_writes == 0);
TrackKey(result, column_family_id, key, key_iter.second.seq,
read_only, key_iter.second.exclusive);
}
}
}
return std::unique_ptr<TransactionKeyMap>(result);
} }
// No SavePoint
return nullptr;
} }
// Gets the write batch that should be used for Put/Merge/Deletes. // Gets the write batch that should be used for Put/Merge/Deletes.
@ -728,54 +592,28 @@ void TransactionBaseImpl::ReleaseSnapshot(const Snapshot* snapshot, DB* db) {
void TransactionBaseImpl::UndoGetForUpdate(ColumnFamilyHandle* column_family, void TransactionBaseImpl::UndoGetForUpdate(ColumnFamilyHandle* column_family,
const Slice& key) { const Slice& key) {
uint32_t column_family_id = GetColumnFamilyID(column_family); PointLockRequest r;
auto& cf_tracked_keys = tracked_keys_[column_family_id]; r.column_family_id = GetColumnFamilyID(column_family);
std::string key_str = key.ToString(); r.key = key.ToString();
bool can_decrement = false; r.read_only = true;
bool can_unlock __attribute__((__unused__)) = false;
bool can_untrack = false;
if (save_points_ != nullptr && !save_points_->empty()) { if (save_points_ != nullptr && !save_points_->empty()) {
// Check if this key was fetched ForUpdate in this SavePoint // If there is no GetForUpdate of the key in this save point,
auto& cf_savepoint_keys = save_points_->top().new_keys_[column_family_id]; // then cannot untrack from the global lock tracker.
UntrackStatus s = save_points_->top().new_locks_->Untrack(r);
auto savepoint_iter = cf_savepoint_keys.find(key_str); can_untrack = (s != UntrackStatus::NOT_TRACKED);
if (savepoint_iter != cf_savepoint_keys.end()) {
if (savepoint_iter->second.num_reads > 0) {
savepoint_iter->second.num_reads--;
can_decrement = true;
if (savepoint_iter->second.num_reads == 0 &&
savepoint_iter->second.num_writes == 0) {
// No other GetForUpdates or write on this key in this SavePoint
cf_savepoint_keys.erase(savepoint_iter);
can_unlock = true;
}
}
}
} else { } else {
// No SavePoint set // No save point, so can untrack from the global lock tracker.
can_decrement = true; can_untrack = true;
can_unlock = true; }
}
if (can_untrack) {
// We can only decrement the read count for this key if we were able to // If erased from the global tracker, then can unlock the key.
// decrement the read count in the current SavePoint, OR if there is no UntrackStatus s = tracked_locks_->Untrack(r);
// SavePoint set. bool can_unlock = (s == UntrackStatus::REMOVED);
if (can_decrement) { if (can_unlock) {
auto key_iter = cf_tracked_keys.find(key_str); UnlockGetForUpdate(column_family, key);
if (key_iter != cf_tracked_keys.end()) {
if (key_iter->second.num_reads > 0) {
key_iter->second.num_reads--;
if (key_iter->second.num_reads == 0 &&
key_iter->second.num_writes == 0) {
// No other GetForUpdates or writes on this key
assert(can_unlock);
cf_tracked_keys.erase(key_iter);
UnlockGetForUpdate(column_family, key);
}
}
} }
} }
} }

@ -1,7 +1,7 @@
// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. // Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
// This source code is licensed under both the GPLv2 (found in the // This source code is licensed under both the GPLv2 (found in the
// COPYING file in the root directory) and Apache 2.0 License // COPYING file in the root directory) and Apache 2.0 License
// (found in the LICENSE.Apache file in the root directory). // (found in the LICENSE.Apache file in the root directory).
#pragma once #pragma once
@ -21,6 +21,7 @@
#include "rocksdb/utilities/transaction_db.h" #include "rocksdb/utilities/transaction_db.h"
#include "rocksdb/utilities/write_batch_with_index.h" #include "rocksdb/utilities/write_batch_with_index.h"
#include "util/autovector.h" #include "util/autovector.h"
#include "utilities/transactions/lock/lock_tracker.h"
#include "utilities/transactions/transaction_util.h" #include "utilities/transactions/transaction_util.h"
namespace ROCKSDB_NAMESPACE { namespace ROCKSDB_NAMESPACE {
@ -233,10 +234,6 @@ class TransactionBaseImpl : public Transaction {
return UndoGetForUpdate(nullptr, key); return UndoGetForUpdate(nullptr, key);
}; };
// Get list of keys in this transaction that must not have any conflicts
// with writes in other transactions.
const TransactionKeyMap& GetTrackedKeys() const { return tracked_keys_; }
WriteOptions* GetWriteOptions() override { return &write_options_; } WriteOptions* GetWriteOptions() override { return &write_options_; }
void SetWriteOptions(const WriteOptions& write_options) override { void SetWriteOptions(const WriteOptions& write_options) override {
@ -260,17 +257,10 @@ class TransactionBaseImpl : public Transaction {
void TrackKey(uint32_t cfh_id, const std::string& key, SequenceNumber seqno, void TrackKey(uint32_t cfh_id, const std::string& key, SequenceNumber seqno,
bool readonly, bool exclusive); bool readonly, bool exclusive);
// Helper function to add a key to the given TransactionKeyMap
static void TrackKey(TransactionKeyMap* key_map, uint32_t cfh_id,
const std::string& key, SequenceNumber seqno,
bool readonly, bool exclusive);
// Called when UndoGetForUpdate determines that this key can be unlocked. // Called when UndoGetForUpdate determines that this key can be unlocked.
virtual void UnlockGetForUpdate(ColumnFamilyHandle* column_family, virtual void UnlockGetForUpdate(ColumnFamilyHandle* column_family,
const Slice& key) = 0; const Slice& key) = 0;
std::unique_ptr<TransactionKeyMap> GetTrackedKeysSinceSavePoint();
// Sets a snapshot if SetSnapshotOnNextOperation() has been called. // Sets a snapshot if SetSnapshotOnNextOperation() has been called.
void SetSnapshotIfNeeded(); void SetSnapshotIfNeeded();
@ -310,8 +300,8 @@ class TransactionBaseImpl : public Transaction {
uint64_t num_deletes_ = 0; uint64_t num_deletes_ = 0;
uint64_t num_merges_ = 0; uint64_t num_merges_ = 0;
// Record all keys tracked since the last savepoint // Record all locks tracked since the last savepoint
TransactionKeyMap new_keys_; std::shared_ptr<LockTracker> new_locks_;
SavePoint(std::shared_ptr<const Snapshot> snapshot, bool snapshot_needed, SavePoint(std::shared_ptr<const Snapshot> snapshot, bool snapshot_needed,
std::shared_ptr<TransactionNotifier> snapshot_notifier, std::shared_ptr<TransactionNotifier> snapshot_notifier,
@ -321,19 +311,20 @@ class TransactionBaseImpl : public Transaction {
snapshot_notifier_(snapshot_notifier), snapshot_notifier_(snapshot_notifier),
num_puts_(num_puts), num_puts_(num_puts),
num_deletes_(num_deletes), num_deletes_(num_deletes),
num_merges_(num_merges) {} num_merges_(num_merges),
new_locks_(NewLockTracker()) {}
SavePoint() = default; SavePoint() : new_locks_(NewLockTracker()) {}
}; };
// Records writes pending in this transaction // Records writes pending in this transaction
WriteBatchWithIndex write_batch_; WriteBatchWithIndex write_batch_;
// Map from column_family_id to map of keys that are involved in this // For Pessimistic Transactions this is the set of acquired locks.
// transaction. // Optimistic Transactions will keep note the requested locks (not actually
// For Pessimistic Transactions this is the list of locked keys. // locked), and do conflict checking until commit time based on the tracked
// Optimistic Transactions will wait till commit time to do conflict checking. // lock requests.
TransactionKeyMap tracked_keys_; std::unique_ptr<LockTracker> tracked_locks_;
// Stack of the Snapshot saved at each save point. Saved snapshots may be // Stack of the Snapshot saved at each save point. Saved snapshots may be
// nullptr if there was no snapshot at the time SetSavePoint() was called. // nullptr if there was no snapshot at the time SetSavePoint() was called.

@ -643,26 +643,27 @@ void TransactionLockMgr::UnLock(PessimisticTransaction* txn,
} }
void TransactionLockMgr::UnLock(const PessimisticTransaction* txn, void TransactionLockMgr::UnLock(const PessimisticTransaction* txn,
const TransactionKeyMap* key_map, Env* env) { const LockTracker& tracker, Env* env) {
for (auto& key_map_iter : *key_map) { std::unique_ptr<LockTracker::ColumnFamilyIterator> cf_it(
uint32_t column_family_id = key_map_iter.first; tracker.GetColumnFamilyIterator());
auto& keys = key_map_iter.second; assert(cf_it != nullptr);
while (cf_it->HasNext()) {
std::shared_ptr<LockMap> lock_map_ptr = GetLockMap(column_family_id); ColumnFamilyId cf = cf_it->Next();
std::shared_ptr<LockMap> lock_map_ptr = GetLockMap(cf);
LockMap* lock_map = lock_map_ptr.get(); LockMap* lock_map = lock_map_ptr.get();
if (!lock_map) {
if (lock_map == nullptr) {
// Column Family must have been dropped. // Column Family must have been dropped.
return; return;
} }
// Bucket keys by lock_map_ stripe // Bucket keys by lock_map_ stripe
std::unordered_map<size_t, std::vector<const std::string*>> keys_by_stripe( std::unordered_map<size_t, std::vector<const std::string*>> keys_by_stripe(
std::max(keys.size(), lock_map->num_stripes_)); lock_map->num_stripes_);
std::unique_ptr<LockTracker::KeyIterator> key_it(
for (auto& key_iter : keys) { tracker.GetKeyIterator(cf));
const std::string& key = key_iter.first; assert(key_it != nullptr);
while (key_it->HasNext()) {
const std::string& key = key_it->Next();
size_t stripe_num = lock_map->GetStripe(key); size_t stripe_num = lock_map->GetStripe(key);
keys_by_stripe[stripe_num].push_back(&key); keys_by_stripe[stripe_num].push_back(&key);
} }

@ -77,7 +77,7 @@ class TransactionLockMgr {
// Unlock a key locked by TryLock(). txn must be the same Transaction that // Unlock a key locked by TryLock(). txn must be the same Transaction that
// locked this key. // locked this key.
void UnLock(const PessimisticTransaction* txn, const TransactionKeyMap* keys, void UnLock(const PessimisticTransaction* txn, const LockTracker& tracker,
Env* env); Env* env);
void UnLock(PessimisticTransaction* txn, uint32_t column_family_id, void UnLock(PessimisticTransaction* txn, uint32_t column_family_id,
const std::string& key, Env* env); const std::string& key, Env* env);

@ -137,18 +137,20 @@ Status TransactionUtil::CheckKey(DBImpl* db_impl, SuperVersion* sv,
} }
Status TransactionUtil::CheckKeysForConflicts(DBImpl* db_impl, Status TransactionUtil::CheckKeysForConflicts(DBImpl* db_impl,
const TransactionKeyMap& key_map, const LockTracker& tracker,
bool cache_only) { bool cache_only) {
Status result; Status result;
for (auto& key_map_iter : key_map) { std::unique_ptr<LockTracker::ColumnFamilyIterator> cf_it(
uint32_t cf_id = key_map_iter.first; tracker.GetColumnFamilyIterator());
const auto& keys = key_map_iter.second; assert(cf_it != nullptr);
while (cf_it->HasNext()) {
ColumnFamilyId cf = cf_it->Next();
SuperVersion* sv = db_impl->GetAndRefSuperVersion(cf_id); SuperVersion* sv = db_impl->GetAndRefSuperVersion(cf);
if (sv == nullptr) { if (sv == nullptr) {
result = Status::InvalidArgument("Could not access column family " + result = Status::InvalidArgument("Could not access column family " +
ToString(cf_id)); ToString(cf));
break; break;
} }
@ -157,18 +159,21 @@ Status TransactionUtil::CheckKeysForConflicts(DBImpl* db_impl,
// For each of the keys in this transaction, check to see if someone has // For each of the keys in this transaction, check to see if someone has
// written to this key since the start of the transaction. // written to this key since the start of the transaction.
for (const auto& key_iter : keys) { std::unique_ptr<LockTracker::KeyIterator> key_it(
const auto& key = key_iter.first; tracker.GetKeyIterator(cf));
const SequenceNumber key_seq = key_iter.second.seq; assert(key_it != nullptr);
while (key_it->HasNext()) {
const std::string& key = key_it->Next();
PointLockStatus status = tracker.GetPointLockStatus(cf, key);
const SequenceNumber key_seq = status.seq;
result = CheckKey(db_impl, sv, earliest_seq, key_seq, key, cache_only); result = CheckKey(db_impl, sv, earliest_seq, key_seq, key, cache_only);
if (!result.ok()) { if (!result.ok()) {
break; break;
} }
} }
db_impl->ReturnAndCleanupSuperVersion(cf_id, sv); db_impl->ReturnAndCleanupSuperVersion(cf, sv);
if (!result.ok()) { if (!result.ok()) {
break; break;

@ -12,39 +12,14 @@
#include "db/dbformat.h" #include "db/dbformat.h"
#include "db/read_callback.h" #include "db/read_callback.h"
#include "rocksdb/db.h" #include "rocksdb/db.h"
#include "rocksdb/slice.h" #include "rocksdb/slice.h"
#include "rocksdb/status.h" #include "rocksdb/status.h"
#include "rocksdb/types.h" #include "rocksdb/types.h"
#include "utilities/transactions/lock/lock_tracker.h"
namespace ROCKSDB_NAMESPACE { namespace ROCKSDB_NAMESPACE {
struct TransactionKeyMapInfo {
// Earliest sequence number that is relevant to this transaction for this key
SequenceNumber seq;
uint32_t num_writes;
uint32_t num_reads;
bool exclusive;
explicit TransactionKeyMapInfo(SequenceNumber seq_no)
: seq(seq_no), num_writes(0), num_reads(0), exclusive(false) {}
// Used in PopSavePoint to collapse two savepoints together.
void Merge(const TransactionKeyMapInfo& info) {
assert(seq <= info.seq);
num_reads += info.num_reads;
num_writes += info.num_writes;
exclusive |= info.exclusive;
}
};
using TransactionKeyMap =
std::unordered_map<uint32_t,
std::unordered_map<std::string, TransactionKeyMapInfo>>;
class DBImpl; class DBImpl;
struct SuperVersion; struct SuperVersion;
class WriteBatchWithIndex; class WriteBatchWithIndex;
@ -69,17 +44,19 @@ class TransactionUtil {
ReadCallback* snap_checker = nullptr, ReadCallback* snap_checker = nullptr,
SequenceNumber min_uncommitted = kMaxSequenceNumber); SequenceNumber min_uncommitted = kMaxSequenceNumber);
// For each key,SequenceNumber pair in the TransactionKeyMap, this function // For each key,SequenceNumber pair tracked by the LockTracker, this function
// will verify there have been no writes to the key in the db since that // will verify there have been no writes to the key in the db since that
// sequence number. // sequence number.
// //
// Returns OK on success, BUSY if there is a conflicting write, or other error // Returns OK on success, BUSY if there is a conflicting write, or other error
// status for any unexpected errors. // status for any unexpected errors.
// //
// REQUIRED: this function should only be called on the write thread or if the // REQUIRED:
// This function should only be called on the write thread or if the
// mutex is held. // mutex is held.
// tracker must support point lock.
static Status CheckKeysForConflicts(DBImpl* db_impl, static Status CheckKeysForConflicts(DBImpl* db_impl,
const TransactionKeyMap& keys, const LockTracker& tracker,
bool cache_only); bool cache_only);
private: private:

@ -72,10 +72,10 @@ WriteUnpreparedTxn::~WriteUnpreparedTxn() {
} }
} }
// Call tracked_keys_.clear() so that ~PessimisticTransaction does not // Clear the tracked locks so that ~PessimisticTransaction does not
// try to unlock keys for recovered transactions. // try to unlock keys for recovered transactions.
if (recovered_txn_) { if (recovered_txn_) {
tracked_keys_.clear(); tracked_locks_->Clear();
} }
} }
@ -296,7 +296,9 @@ Status WriteUnpreparedTxn::FlushWriteBatchToDBInternal(bool prepared) {
Status AddUntrackedKey(uint32_t cf, const Slice& key) { Status AddUntrackedKey(uint32_t cf, const Slice& key) {
auto str = key.ToString(); auto str = key.ToString();
if (txn_->tracked_keys_[cf].count(str) == 0) { PointLockStatus lock_status =
txn_->tracked_locks_->GetPointLockStatus(cf, str);
if (!lock_status.locked) {
txn_->untracked_keys_[cf].push_back(str); txn_->untracked_keys_[cf].push_back(str);
} }
return Status::OK(); return Status::OK();
@ -639,8 +641,10 @@ Status WriteUnpreparedTxn::CommitInternal() {
} }
Status WriteUnpreparedTxn::WriteRollbackKeys( Status WriteUnpreparedTxn::WriteRollbackKeys(
const TransactionKeyMap& tracked_keys, WriteBatchWithIndex* rollback_batch, const LockTracker& lock_tracker, WriteBatchWithIndex* rollback_batch,
ReadCallback* callback, const ReadOptions& roptions) { ReadCallback* callback, const ReadOptions& roptions) {
// This assertion can be removed when range lock is supported.
assert(lock_tracker.IsPointLockSupported());
const auto& cf_map = *wupt_db_->GetCFHandleMap(); const auto& cf_map = *wupt_db_->GetCFHandleMap();
auto WriteRollbackKey = [&](const std::string& key, uint32_t cfid) { auto WriteRollbackKey = [&](const std::string& key, uint32_t cfid) {
const auto& cf_handle = cf_map.at(cfid); const auto& cf_handle = cf_map.at(cfid);
@ -666,11 +670,17 @@ Status WriteUnpreparedTxn::WriteRollbackKeys(
return Status::OK(); return Status::OK();
}; };
for (const auto& cfkey : tracked_keys) { std::unique_ptr<LockTracker::ColumnFamilyIterator> cf_it(
const auto cfid = cfkey.first; lock_tracker.GetColumnFamilyIterator());
const auto& keys = cfkey.second; assert(cf_it != nullptr);
for (const auto& pair : keys) { while (cf_it->HasNext()) {
auto s = WriteRollbackKey(pair.first, cfid); ColumnFamilyId cf = cf_it->Next();
std::unique_ptr<LockTracker::KeyIterator> key_it(
lock_tracker.GetKeyIterator(cf));
assert(key_it != nullptr);
while (key_it->HasNext()) {
const std::string& key = key_it->Next();
auto s = WriteRollbackKey(key, cf);
if (!s.ok()) { if (!s.ok()) {
return s; return s;
} }
@ -709,7 +719,7 @@ Status WriteUnpreparedTxn::RollbackInternal() {
// TODO(lth): We write rollback batch all in a single batch here, but this // TODO(lth): We write rollback batch all in a single batch here, but this
// should be subdivded into multiple batches as well. In phase 2, when key // should be subdivded into multiple batches as well. In phase 2, when key
// sets are read from WAL, this will happen naturally. // sets are read from WAL, this will happen naturally.
WriteRollbackKeys(GetTrackedKeys(), &rollback_batch, &callback, roptions); WriteRollbackKeys(*tracked_locks_, &rollback_batch, &callback, roptions);
// The Rollback marker will be used as a batch separator // The Rollback marker will be used as a batch separator
WriteBatchInternal::MarkRollback(rollback_batch.GetWriteBatch(), name_); WriteBatchInternal::MarkRollback(rollback_batch.GetWriteBatch(), name_);
@ -790,7 +800,7 @@ Status WriteUnpreparedTxn::RollbackInternal() {
void WriteUnpreparedTxn::Clear() { void WriteUnpreparedTxn::Clear() {
if (!recovered_txn_) { if (!recovered_txn_) {
txn_db_impl_->UnLock(this, &GetTrackedKeys()); txn_db_impl_->UnLock(this, *tracked_locks_);
} }
unprep_seqs_.clear(); unprep_seqs_.clear();
flushed_save_points_.reset(nullptr); flushed_save_points_.reset(nullptr);
@ -842,7 +852,7 @@ Status WriteUnpreparedTxn::RollbackToSavePointInternal() {
WriteUnpreparedTxn::SavePoint& top = flushed_save_points_->back(); WriteUnpreparedTxn::SavePoint& top = flushed_save_points_->back();
assert(save_points_ != nullptr && save_points_->size() > 0); assert(save_points_ != nullptr && save_points_->size() > 0);
const TransactionKeyMap& tracked_keys = save_points_->top().new_keys_; const LockTracker& tracked_keys = *save_points_->top().new_locks_;
ReadOptions roptions; ReadOptions roptions;
roptions.snapshot = top.snapshot_->snapshot(); roptions.snapshot = top.snapshot_->snapshot();

@ -212,7 +212,7 @@ class WriteUnpreparedTxn : public WritePreparedTxn {
friend class WriteUnpreparedTxnDB; friend class WriteUnpreparedTxnDB;
const std::map<SequenceNumber, size_t>& GetUnpreparedSequenceNumbers(); const std::map<SequenceNumber, size_t>& GetUnpreparedSequenceNumbers();
Status WriteRollbackKeys(const TransactionKeyMap& tracked_keys, Status WriteRollbackKeys(const LockTracker& tracked_keys,
WriteBatchWithIndex* rollback_batch, WriteBatchWithIndex* rollback_batch,
ReadCallback* callback, const ReadOptions& roptions); ReadCallback* callback, const ReadOptions& roptions);

Loading…
Cancel
Save