diff --git a/HISTORY.md b/HISTORY.md index e755d0e31..87813faf6 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -10,6 +10,7 @@ ### Public API Changes * Removed class Env::RandomRWFile and Env::NewRandomRWFile(). * Renamed DBOptions.num_subcompactions to DBOptions.max_subcompactions to make the name better match the actual functionality of the option. +* Added Equal() method to the Comparator interface that can optionally be overwritten in cases where equality comparisons can be done more efficiently than three-way comparisons. ## 3.13.0 (8/6/2015) ### New Features diff --git a/db/builder.cc b/db/builder.cc index 47a1fb40f..d7c9f8c85 100644 --- a/db/builder.cc +++ b/db/builder.cc @@ -149,8 +149,8 @@ Status BuildTable( // first key), then we skip it, since it is an older version. // Otherwise we output the key and mark it as the "new" previous key. if (!has_current_user_key || - internal_comparator.user_comparator()->Compare( - ikey.user_key, current_user_key.GetKey()) != 0) { + !internal_comparator.user_comparator()->Equal( + ikey.user_key, current_user_key.GetKey())) { // First occurrence of this user key current_user_key.SetKey(ikey.user_key); has_current_user_key = true; diff --git a/db/compaction_job.cc b/db/compaction_job.cc index 9bd333905..ffb294edb 100644 --- a/db/compaction_job.cc +++ b/db/compaction_job.cc @@ -634,8 +634,8 @@ void CompactionJob::ProcessKeyValueCompaction(SubCompactionState* sub_compact) { } if (!has_current_user_key || - cfd->user_comparator()->Compare(ikey.user_key, - current_user_key.GetKey()) != 0) { + !cfd->user_comparator()->Equal(ikey.user_key, + current_user_key.GetKey())) { // First occurrence of this user key current_user_key.SetKey(ikey.user_key); has_current_user_key = true; diff --git a/db/db_iter.cc b/db/db_iter.cc index 471cadf9d..587da72ad 100644 --- a/db/db_iter.cc +++ b/db/db_iter.cc @@ -298,7 +298,7 @@ void DBIter::MergeValuesNewToOld() { continue; } - if (user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) != 0) { + if (!user_comparator_->Equal(ikey.user_key, saved_key_.GetKey())) { // hit the next user key, stop right here break; } @@ -400,7 +400,7 @@ void DBIter::PrevInternal() { return; } FindParseableKey(&ikey, kReverse); - if (user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) == 0) { + if (user_comparator_->Equal(ikey.user_key, saved_key_.GetKey())) { FindPrevUserKey(); } return; @@ -409,8 +409,7 @@ void DBIter::PrevInternal() { break; } FindParseableKey(&ikey, kReverse); - if (user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) == 0) { - + if (user_comparator_->Equal(ikey.user_key, saved_key_.GetKey())) { FindPrevUserKey(); } } @@ -434,7 +433,7 @@ bool DBIter::FindValueForCurrentKey() { size_t num_skipped = 0; while (iter_->Valid() && ikey.sequence <= sequence_ && - (user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) == 0)) { + user_comparator_->Equal(ikey.user_key, saved_key_.GetKey())) { // We iterate too much: let's use Seek() to avoid too much key comparisons if (num_skipped >= max_skip_) { return FindValueForCurrentKeyUsingSeek(); @@ -461,7 +460,7 @@ bool DBIter::FindValueForCurrentKey() { } PERF_COUNTER_ADD(internal_key_skipped_count, 1); - assert(user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) == 0); + assert(user_comparator_->Equal(ikey.user_key, saved_key_.GetKey())); iter_->Prev(); ++num_skipped; FindParseableKey(&ikey, kReverse); @@ -531,7 +530,7 @@ bool DBIter::FindValueForCurrentKeyUsingSeek() { // in operands std::deque operands; while (iter_->Valid() && - (user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) == 0) && + user_comparator_->Equal(ikey.user_key, saved_key_.GetKey()) && ikey.type == kTypeMerge) { operands.push_front(iter_->value().ToString()); iter_->Next(); @@ -539,7 +538,7 @@ bool DBIter::FindValueForCurrentKeyUsingSeek() { } if (!iter_->Valid() || - (user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) != 0) || + !user_comparator_->Equal(ikey.user_key, saved_key_.GetKey()) || ikey.type == kTypeDeletion) { { StopWatchNano timer(env_, statistics_ != nullptr); @@ -550,7 +549,7 @@ bool DBIter::FindValueForCurrentKeyUsingSeek() { } // Make iter_ valid and point to saved_key_ if (!iter_->Valid() || - (user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) != 0)) { + !user_comparator_->Equal(ikey.user_key, saved_key_.GetKey())) { iter_->Seek(last_key); RecordTick(statistics_, NUMBER_OF_RESEEKS_IN_ITERATION); } @@ -581,7 +580,7 @@ void DBIter::FindNextUserKey() { ParsedInternalKey ikey; FindParseableKey(&ikey, kForward); while (iter_->Valid() && - user_comparator_->Compare(ikey.user_key, saved_key_.GetKey()) != 0) { + !user_comparator_->Equal(ikey.user_key, saved_key_.GetKey())) { iter_->Next(); FindParseableKey(&ikey, kForward); } diff --git a/db/memtable.cc b/db/memtable.cc index d1bbd3960..e712a7b9c 100644 --- a/db/memtable.cc +++ b/db/memtable.cc @@ -404,8 +404,8 @@ static bool SaveValue(void* arg, const char* entry) { // all entries with overly large sequence numbers. uint32_t key_length; const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); - if (s->mem->GetInternalKeyComparator().user_comparator()->Compare( - Slice(key_ptr, key_length - 8), s->key->user_key()) == 0) { + if (s->mem->GetInternalKeyComparator().user_comparator()->Equal( + Slice(key_ptr, key_length - 8), s->key->user_key())) { // Correct user key const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8); ValueType type; @@ -563,8 +563,8 @@ void MemTable::Update(SequenceNumber seq, const char* entry = iter->key(); uint32_t key_length = 0; const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); - if (comparator_.comparator.user_comparator()->Compare( - Slice(key_ptr, key_length - 8), lkey.user_key()) == 0) { + if (comparator_.comparator.user_comparator()->Equal( + Slice(key_ptr, key_length - 8), lkey.user_key())) { // Correct user key const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8); ValueType type; @@ -624,8 +624,8 @@ bool MemTable::UpdateCallback(SequenceNumber seq, const char* entry = iter->key(); uint32_t key_length = 0; const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); - if (comparator_.comparator.user_comparator()->Compare( - Slice(key_ptr, key_length - 8), lkey.user_key()) == 0) { + if (comparator_.comparator.user_comparator()->Equal( + Slice(key_ptr, key_length - 8), lkey.user_key())) { // Correct user key const uint64_t tag = DecodeFixed64(key_ptr + key_length - 8); ValueType type; @@ -695,8 +695,8 @@ size_t MemTable::CountSuccessiveMergeEntries(const LookupKey& key) { const char* entry = iter->key(); uint32_t key_length = 0; const char* iter_key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); - if (comparator_.comparator.user_comparator()->Compare( - Slice(iter_key_ptr, key_length - 8), key.user_key()) != 0) { + if (!comparator_.comparator.user_comparator()->Equal( + Slice(iter_key_ptr, key_length - 8), key.user_key())) { break; } diff --git a/db/memtablerep_bench.cc b/db/memtablerep_bench.cc index 5bdfa836d..a2a872226 100644 --- a/db/memtablerep_bench.cc +++ b/db/memtablerep_bench.cc @@ -312,9 +312,10 @@ class ReadBenchmarkThread : public BenchmarkThread { assert(callback_args != nullptr); uint32_t key_length; const char* key_ptr = GetVarint32Ptr(entry, entry + 5, &key_length); - if ((callback_args->comparator)->user_comparator()->Compare( - Slice(key_ptr, key_length - 8), callback_args->key->user_key()) == - 0) { + if ((callback_args->comparator) + ->user_comparator() + ->Equal(Slice(key_ptr, key_length - 8), + callback_args->key->user_key())) { callback_args->found = true; } return false; diff --git a/db/merge_helper.cc b/db/merge_helper.cc index 427013806..7d5f4c1e9 100644 --- a/db/merge_helper.cc +++ b/db/merge_helper.cc @@ -91,8 +91,7 @@ Status MergeHelper::MergeUntil(Iterator* iter, const SequenceNumber stop_before, assert(!"corrupted internal key is not expected"); } break; - } else if (user_comparator_->Compare(ikey.user_key, orig_ikey.user_key) != - 0) { + } else if (!user_comparator_->Equal(ikey.user_key, orig_ikey.user_key)) { // hit a different user key, stop right here hit_the_next_user_key = true; break; diff --git a/db/version_set.cc b/db/version_set.cc index 692e6f6da..1a649e24f 100644 --- a/db/version_set.cc +++ b/db/version_set.cc @@ -1581,7 +1581,7 @@ bool VersionStorageInfo::HasOverlappingUserKey( files[last_file].largest_key); const Slice first_key_after = ExtractUserKey( files[last_file+1].smallest_key); - if (user_cmp->Compare(last_key_in_input, first_key_after) == 0) { + if (user_cmp->Equal(last_key_in_input, first_key_after)) { // The last user key in input overlaps with the next file's first key return true; } @@ -1596,7 +1596,7 @@ bool VersionStorageInfo::HasOverlappingUserKey( files[first_file].smallest_key); const Slice& last_key_before = ExtractUserKey( files[first_file-1].largest_key); - if (user_cmp->Compare(first_key_in_input, last_key_before) == 0) { + if (user_cmp->Equal(first_key_in_input, last_key_before)) { // The first user key in input overlaps with the previous file's last key return true; } diff --git a/include/rocksdb/comparator.h b/include/rocksdb/comparator.h index 5b7dc1021..8fc2710aa 100644 --- a/include/rocksdb/comparator.h +++ b/include/rocksdb/comparator.h @@ -29,6 +29,15 @@ class Comparator { // > 0 iff "a" > "b" virtual int Compare(const Slice& a, const Slice& b) const = 0; + // Compares two slices for equality. The following invariant should always + // hold (and is the default implementation): + // Equal(a, b) iff Compare(a, b) == 0 + // Overwrite only if equality comparisons can be done more efficiently than + // three-way comparisons. + virtual bool Equal(const Slice& a, const Slice& b) const { + return Compare(a, b) == 0; + } + // The name of the comparator. Used to check for comparator // mismatches (i.e., a DB created with one comparator is // accessed using a different comparator. diff --git a/table/cuckoo_table_reader.cc b/table/cuckoo_table_reader.cc index 51ffc6ffa..8c0329c66 100644 --- a/table/cuckoo_table_reader.cc +++ b/table/cuckoo_table_reader.cc @@ -137,13 +137,13 @@ Status CuckooTableReader::Get(const ReadOptions& readOptions, const Slice& key, const char* bucket = &file_data_.data()[offset]; for (uint32_t block_idx = 0; block_idx < cuckoo_block_size_; ++block_idx, bucket += bucket_length_) { - if (ucomp_->Compare(Slice(unused_key_.data(), user_key.size()), - Slice(bucket, user_key.size())) == 0) { + if (ucomp_->Equal(Slice(unused_key_.data(), user_key.size()), + Slice(bucket, user_key.size()))) { return Status::OK(); } // Here, we compare only the user key part as we support only one entry // per user key and we don't support sanpshot. - if (ucomp_->Compare(user_key, Slice(bucket, user_key.size())) == 0) { + if (ucomp_->Equal(user_key, Slice(bucket, user_key.size()))) { Slice value(bucket + key_length_, value_length_); if (is_last_level_) { get_context->SaveValue(value); diff --git a/table/get_context.cc b/table/get_context.cc index 5ac3525cd..77ac8f4f7 100644 --- a/table/get_context.cc +++ b/table/get_context.cc @@ -71,7 +71,7 @@ bool GetContext::SaveValue(const ParsedInternalKey& parsed_key, const Slice& value) { assert((state_ != kMerge && parsed_key.type != kTypeMerge) || merge_context_ != nullptr); - if (ucmp_->Compare(parsed_key.user_key, user_key_) == 0) { + if (ucmp_->Equal(parsed_key.user_key, user_key_)) { appendToReplayLog(replay_log_, parsed_key.type, value); // Key matches. Process it diff --git a/table/merger.cc b/table/merger.cc index bc615602f..242587ea8 100644 --- a/table/merger.cc +++ b/table/merger.cc @@ -131,8 +131,7 @@ class MergingIterator : public Iterator { for (auto& child : children_) { if (&child != current_) { child.Seek(key()); - if (child.Valid() && - comparator_->Compare(key(), child.key()) == 0) { + if (child.Valid() && comparator_->Equal(key(), child.key())) { child.Next(); } } diff --git a/util/comparator.cc b/util/comparator.cc index 821a2540b..6d7709db5 100644 --- a/util/comparator.cc +++ b/util/comparator.cc @@ -32,6 +32,10 @@ class BytewiseComparatorImpl : public Comparator { return a.compare(b); } + virtual bool Equal(const Slice& a, const Slice& b) const override { + return a == b; + } + virtual void FindShortestSeparator(std::string* start, const Slice& limit) const override { // Find length of common prefix diff --git a/util/db_test_util.cc b/util/db_test_util.cc index 9af097d02..186528a2c 100644 --- a/util/db_test_util.cc +++ b/util/db_test_util.cc @@ -553,7 +553,7 @@ std::string DBTestBase::AllEntriesFor(const Slice& user_key, int cf) { if (!ParseInternalKey(iter->key(), &ikey)) { result += "CORRUPTED"; } else { - if (last_options_.comparator->Compare(ikey.user_key, user_key) != 0) { + if (!last_options_.comparator->Equal(ikey.user_key, user_key)) { break; } if (!first) { diff --git a/util/hash_cuckoo_rep.cc b/util/hash_cuckoo_rep.cc index ea2b7bca2..6e5057a73 100644 --- a/util/hash_cuckoo_rep.cc +++ b/util/hash_cuckoo_rep.cc @@ -299,8 +299,8 @@ void HashCuckooRep::Get(const LookupKey& key, void* callback_args, const char* bucket = cuckoo_array_[GetHash(user_key, hid)].load(std::memory_order_acquire); if (bucket != nullptr) { - auto bucket_user_key = UserKey(bucket); - if (user_key.compare(bucket_user_key) == 0) { + Slice bucket_user_key = UserKey(bucket); + if (user_key == bucket_user_key) { callback_func(callback_args, bucket); break; } @@ -466,10 +466,10 @@ bool HashCuckooRep::FindCuckooPath(const char* internal_key, } // again, we can perform no barrier load safely here as the current // thread is the only writer. - auto bucket_user_key = + Slice bucket_user_key = UserKey(cuckoo_array_[step.bucket_id_].load(std::memory_order_relaxed)); if (step.prev_step_id_ != CuckooStep::kNullStep) { - if (bucket_user_key.compare(user_key) == 0) { + if (bucket_user_key == user_key) { // then there is a loop in the current path, stop discovering this path. continue; } diff --git a/utilities/write_batch_with_index/write_batch_with_index.cc b/utilities/write_batch_with_index/write_batch_with_index.cc index c4a4a1500..e3caa7932 100644 --- a/utilities/write_batch_with_index/write_batch_with_index.cc +++ b/utilities/write_batch_with_index/write_batch_with_index.cc @@ -92,8 +92,8 @@ class BaseDeltaIterator : public Iterator { AdvanceBase(); } if (DeltaValid() && BaseValid()) { - if (comparator_->Compare(delta_iterator_->Entry().key, - base_iterator_->key()) == 0) { + if (comparator_->Equal(delta_iterator_->Entry().key, + base_iterator_->key())) { equal_keys_ = true; } } @@ -127,8 +127,8 @@ class BaseDeltaIterator : public Iterator { AdvanceBase(); } if (DeltaValid() && BaseValid()) { - if (comparator_->Compare(delta_iterator_->Entry().key, - base_iterator_->key()) == 0) { + if (comparator_->Equal(delta_iterator_->Entry().key, + base_iterator_->key())) { equal_keys_ = true; } }