diff --git a/db/column_family.cc b/db/column_family.cc index ff6b8fe6c..8b4e007ed 100644 --- a/db/column_family.cc +++ b/db/column_family.cc @@ -86,6 +86,10 @@ ColumnFamilyHandleImpl::~ColumnFamilyHandleImpl() { uint32_t ColumnFamilyHandleImpl::GetID() const { return cfd()->GetID(); } +const Comparator* ColumnFamilyHandleImpl::user_comparator() const { + return cfd()->user_comparator(); +} + ColumnFamilyOptions SanitizeOptions(const InternalKeyComparator* icmp, const ColumnFamilyOptions& src) { ColumnFamilyOptions result = src; @@ -726,4 +730,13 @@ uint32_t GetColumnFamilyID(ColumnFamilyHandle* column_family) { return column_family_id; } +const Comparator* GetColumnFamilyUserComparator( + ColumnFamilyHandle* column_family) { + if (column_family != nullptr) { + auto cfh = reinterpret_cast(column_family); + return cfh->user_comparator(); + } + return nullptr; +} + } // namespace rocksdb diff --git a/db/column_family.h b/db/column_family.h index f1ef13cf1..65b4b53ba 100644 --- a/db/column_family.h +++ b/db/column_family.h @@ -49,6 +49,7 @@ class ColumnFamilyHandleImpl : public ColumnFamilyHandle { // destroy without mutex virtual ~ColumnFamilyHandleImpl(); virtual ColumnFamilyData* cfd() const { return cfd_; } + virtual const Comparator* user_comparator() const; virtual uint32_t GetID() const; @@ -448,4 +449,7 @@ class ColumnFamilyMemTablesImpl : public ColumnFamilyMemTables { extern uint32_t GetColumnFamilyID(ColumnFamilyHandle* column_family); +extern const Comparator* GetColumnFamilyUserComparator( + ColumnFamilyHandle* column_family); + } // namespace rocksdb diff --git a/db/write_batch_test.cc b/db/write_batch_test.cc index d8fa52d40..ba7451078 100644 --- a/db/write_batch_test.cc +++ b/db/write_batch_test.cc @@ -289,6 +289,9 @@ class ColumnFamilyHandleImplDummy : public ColumnFamilyHandleImpl { explicit ColumnFamilyHandleImplDummy(int id) : ColumnFamilyHandleImpl(nullptr, nullptr, nullptr), id_(id) {} uint32_t GetID() const override { return id_; } + const Comparator* user_comparator() const override { + return BytewiseComparator(); + } private: uint32_t id_; @@ -320,7 +323,7 @@ TEST(WriteBatchTest, ColumnFamiliesBatchTest) { } TEST(WriteBatchTest, ColumnFamiliesBatchWithIndexTest) { - WriteBatchWithIndex batch(BytewiseComparator(), 20); + WriteBatchWithIndex batch; ColumnFamilyHandleImplDummy zero(0), two(2), three(3), eight(8); batch.Put(&zero, Slice("foo"), Slice("bar")); batch.Put(&two, Slice("twofoo"), Slice("bar2")); diff --git a/include/rocksdb/utilities/write_batch_with_index.h b/include/rocksdb/utilities/write_batch_with_index.h index c09f53d11..85c80850f 100644 --- a/include/rocksdb/utilities/write_batch_with_index.h +++ b/include/rocksdb/utilities/write_batch_with_index.h @@ -11,8 +11,9 @@ #pragma once -#include "rocksdb/status.h" +#include "rocksdb/comparator.h" #include "rocksdb/slice.h" +#include "rocksdb/status.h" #include "rocksdb/write_batch.h" namespace rocksdb { @@ -56,12 +57,14 @@ class WBWIIterator { // A user can call NewIterator() to create an iterator. class WriteBatchWithIndex { public: - // index_comparator indicates the order when iterating data in the write - // batch. Technically, it doesn't have to be the same as the one used in - // the DB. + // backup_index_comparator: the backup comparator used to compare keys + // within the same column family, if column family is not given in the + // interface, or we can't find a column family from the column family handle + // passed in, backup_index_comparator will be used for the column family. // reserved_bytes: reserved bytes in underlying WriteBatch - explicit WriteBatchWithIndex(const Comparator* index_comparator, - size_t reserved_bytes = 0); + explicit WriteBatchWithIndex( + const Comparator* backup_index_comparator = BytewiseComparator(), + size_t reserved_bytes = 0); virtual ~WriteBatchWithIndex(); WriteBatch* GetWriteBatch(); diff --git a/utilities/write_batch_with_index/write_batch_with_index.cc b/utilities/write_batch_with_index/write_batch_with_index.cc index 68b3d3970..2caa2e4cc 100644 --- a/utilities/write_batch_with_index/write_batch_with_index.cc +++ b/utilities/write_batch_with_index/write_batch_with_index.cc @@ -20,7 +20,6 @@ class ReadableWriteBatch : public WriteBatch { Status GetEntryFromDataOffset(size_t data_offset, WriteType* type, Slice* Key, Slice* value, Slice* blob) const; }; -} // namespace // Key used by skip list, as the binary searchable index of WriteBatchWithIndex. struct WriteBatchIndexEntry { @@ -38,44 +37,28 @@ struct WriteBatchIndexEntry { class WriteBatchEntryComparator { public: - WriteBatchEntryComparator(const Comparator* comparator, + WriteBatchEntryComparator(const Comparator* default_comparator, const ReadableWriteBatch* write_batch) - : comparator_(comparator), write_batch_(write_batch) {} + : default_comparator_(default_comparator), write_batch_(write_batch) {} // Compare a and b. Return a negative value if a is less than b, 0 if they // are equal, and a positive value if a is greater than b int operator()(const WriteBatchIndexEntry* entry1, const WriteBatchIndexEntry* entry2) const; + void SetComparatorForCF(uint32_t column_family_id, + const Comparator* comparator) { + cf_comparator_map_[column_family_id] = comparator; + } + private: - const Comparator* comparator_; + const Comparator* default_comparator_; + std::unordered_map cf_comparator_map_; const ReadableWriteBatch* write_batch_; }; typedef SkipList WriteBatchEntrySkipList; -struct WriteBatchWithIndex::Rep { - Rep(const Comparator* index_comparator, size_t reserved_bytes = 0) - : write_batch(reserved_bytes), - comparator(index_comparator, &write_batch), - skip_list(comparator, &arena) {} - ReadableWriteBatch write_batch; - WriteBatchEntryComparator comparator; - Arena arena; - WriteBatchEntrySkipList skip_list; - - WriteBatchIndexEntry* GetEntry(ColumnFamilyHandle* column_family) { - return GetEntryWithCfId(GetColumnFamilyID(column_family)); - } - - WriteBatchIndexEntry* GetEntryWithCfId(uint32_t column_family_id) { - auto* mem = arena.Allocate(sizeof(WriteBatchIndexEntry)); - auto* index_entry = new (mem) - WriteBatchIndexEntry(write_batch.GetDataSize(), column_family_id); - return index_entry; - } -}; - class WBWIIteratorImpl : public WBWIIterator { public: WBWIIteratorImpl(uint32_t column_family_id, @@ -138,6 +121,35 @@ class WBWIIteratorImpl : public WBWIIterator { } } }; +} // namespace + +struct WriteBatchWithIndex::Rep { + Rep(const Comparator* index_comparator, size_t reserved_bytes = 0) + : write_batch(reserved_bytes), + comparator(index_comparator, &write_batch), + skip_list(comparator, &arena) {} + ReadableWriteBatch write_batch; + WriteBatchEntryComparator comparator; + Arena arena; + WriteBatchEntrySkipList skip_list; + + WriteBatchIndexEntry* GetEntry(ColumnFamilyHandle* column_family) { + uint32_t cf_id = GetColumnFamilyID(column_family); + const auto* cf_cmp = GetColumnFamilyUserComparator(column_family); + if (cf_cmp != nullptr) { + comparator.SetComparatorForCF(cf_id, cf_cmp); + } + + return GetEntryWithCfId(cf_id); + } + + WriteBatchIndexEntry* GetEntryWithCfId(uint32_t column_family_id) { + auto* mem = arena.Allocate(sizeof(WriteBatchIndexEntry)); + auto* index_entry = new (mem) + WriteBatchIndexEntry(write_batch.GetDataSize(), column_family_id); + return index_entry; + } +}; Status ReadableWriteBatch::GetEntryFromDataOffset(size_t data_offset, WriteType* type, Slice* Key, @@ -179,9 +191,9 @@ Status ReadableWriteBatch::GetEntryFromDataOffset(size_t data_offset, return Status::OK(); } -WriteBatchWithIndex::WriteBatchWithIndex(const Comparator* index_comparator, - size_t reserved_bytes) - : rep(new Rep(index_comparator, reserved_bytes)) {} +WriteBatchWithIndex::WriteBatchWithIndex( + const Comparator* default_index_comparator, size_t reserved_bytes) + : rep(new Rep(default_index_comparator, reserved_bytes)) {} WriteBatchWithIndex::~WriteBatchWithIndex() { delete rep; } @@ -287,7 +299,14 @@ int WriteBatchEntryComparator::operator()( key2 = *(entry2->search_key); } - int cmp = comparator_->Compare(key1, key2); + int cmp; + auto comparator_for_cf = cf_comparator_map_.find(entry1->column_family); + if (comparator_for_cf != cf_comparator_map_.end()) { + cmp = comparator_for_cf->second->Compare(key1, key2); + } else { + cmp = default_comparator_->Compare(key1, key2); + } + if (cmp != 0) { return cmp; } else if (entry1->offset > entry2->offset) { diff --git a/utilities/write_batch_with_index/write_batch_with_index_test.cc b/utilities/write_batch_with_index/write_batch_with_index_test.cc index fdceed4c4..ad8c110c1 100644 --- a/utilities/write_batch_with_index/write_batch_with_index_test.cc +++ b/utilities/write_batch_with_index/write_batch_with_index_test.cc @@ -19,12 +19,16 @@ namespace rocksdb { namespace { class ColumnFamilyHandleImplDummy : public ColumnFamilyHandleImpl { public: - explicit ColumnFamilyHandleImplDummy(int id) - : ColumnFamilyHandleImpl(nullptr, nullptr, nullptr), id_(id) {} + explicit ColumnFamilyHandleImplDummy(int id, const Comparator* comparator) + : ColumnFamilyHandleImpl(nullptr, nullptr, nullptr), + id_(id), + comparator_(comparator) {} uint32_t GetID() const override { return id_; } + const Comparator* user_comparator() const override { return comparator_; } private: uint32_t id_; + const Comparator* comparator_; }; struct Entry { @@ -90,8 +94,9 @@ TEST(WriteBatchWithIndexTest, TestValueAsSecondaryIndex) { index_map[e.value].push_back(&e); } - WriteBatchWithIndex batch(BytewiseComparator(), 20); - ColumnFamilyHandleImplDummy data(6), index(8); + WriteBatchWithIndex batch(nullptr, 20); + ColumnFamilyHandleImplDummy data(6, BytewiseComparator()); + ColumnFamilyHandleImplDummy index(8, BytewiseComparator()); for (auto& e : entries) { if (e.type == kPutRecord) { batch.Put(&data, e.key, e.value); @@ -230,6 +235,107 @@ TEST(WriteBatchWithIndexTest, TestValueAsSecondaryIndex) { } } +class ReverseComparator : public Comparator { + public: + ReverseComparator() {} + + virtual const char* Name() const override { + return "rocksdb.ReverseComparator"; + } + + virtual int Compare(const Slice& a, const Slice& b) const override { + return 0 - BytewiseComparator()->Compare(a, b); + } + + virtual void FindShortestSeparator(std::string* start, + const Slice& limit) const {} + virtual void FindShortSuccessor(std::string* key) const {} +}; + +TEST(WriteBatchWithIndexTest, TestComparatorForCF) { + ReverseComparator reverse_cmp; + ColumnFamilyHandleImplDummy cf1(6, nullptr); + ColumnFamilyHandleImplDummy reverse_cf(66, &reverse_cmp); + ColumnFamilyHandleImplDummy cf2(88, BytewiseComparator()); + WriteBatchWithIndex batch(BytewiseComparator(), 20); + + batch.Put(&cf1, "ddd", ""); + batch.Put(&cf2, "aaa", ""); + batch.Put(&cf2, "eee", ""); + batch.Put(&cf1, "ccc", ""); + batch.Put(&reverse_cf, "a11", ""); + batch.Put(&cf1, "bbb", ""); + batch.Put(&reverse_cf, "a33", ""); + batch.Put(&reverse_cf, "a22", ""); + + { + std::unique_ptr iter(batch.NewIterator(&cf1)); + iter->Seek(""); + ASSERT_OK(iter->status()); + ASSERT_TRUE(iter->Valid()); + ASSERT_EQ("bbb", iter->Entry().key.ToString()); + iter->Next(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(iter->Valid()); + ASSERT_EQ("ccc", iter->Entry().key.ToString()); + iter->Next(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(iter->Valid()); + ASSERT_EQ("ddd", iter->Entry().key.ToString()); + iter->Next(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(!iter->Valid()); + } + + { + std::unique_ptr iter(batch.NewIterator(&cf2)); + iter->Seek(""); + ASSERT_OK(iter->status()); + ASSERT_TRUE(iter->Valid()); + ASSERT_EQ("aaa", iter->Entry().key.ToString()); + iter->Next(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(iter->Valid()); + ASSERT_EQ("eee", iter->Entry().key.ToString()); + iter->Next(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(!iter->Valid()); + } + + { + std::unique_ptr iter(batch.NewIterator(&reverse_cf)); + iter->Seek(""); + ASSERT_OK(iter->status()); + ASSERT_TRUE(!iter->Valid()); + + iter->Seek("z"); + ASSERT_OK(iter->status()); + ASSERT_TRUE(iter->Valid()); + ASSERT_EQ("a33", iter->Entry().key.ToString()); + iter->Next(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(iter->Valid()); + ASSERT_EQ("a22", iter->Entry().key.ToString()); + iter->Next(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(iter->Valid()); + ASSERT_EQ("a11", iter->Entry().key.ToString()); + iter->Next(); + ASSERT_OK(iter->status()); + ASSERT_TRUE(!iter->Valid()); + + iter->Seek("a22"); + ASSERT_OK(iter->status()); + ASSERT_TRUE(iter->Valid()); + ASSERT_EQ("a22", iter->Entry().key.ToString()); + + iter->Seek("a13"); + ASSERT_OK(iter->status()); + ASSERT_TRUE(iter->Valid()); + ASSERT_EQ("a11", iter->Entry().key.ToString()); + } +} + } // namespace int main(int argc, char** argv) { return rocksdb::test::RunAllTests(); }