diff --git a/HISTORY.md b/HISTORY.md index d48591ac5..933c43e4a 100644 --- a/HISTORY.md +++ b/HISTORY.md @@ -7,6 +7,7 @@ * Removed arena.h from public header files. * By default, checksums are verified on every read from database * Added is_manual_compaction to CompactionFilter::Context +* Added "virtual void WaitForJoin() = 0" in class Env ## 2.7.0 (01/28/2014) diff --git a/Makefile b/Makefile index 3be3c3d08..6eef8bec0 100644 --- a/Makefile +++ b/Makefile @@ -87,7 +87,8 @@ TESTS = \ version_set_test \ write_batch_test\ deletefile_test \ - table_test + table_test \ + thread_local_test TOOLS = \ sst_dump \ @@ -147,7 +148,7 @@ all: $(LIBRARY) $(PROGRAMS) dbg: $(LIBRARY) $(PROGRAMS) -# Will also generate shared libraries. +# Will also generate shared libraries. release: $(MAKE) clean OPT="-DNDEBUG -O2" $(MAKE) all -j32 @@ -276,6 +277,9 @@ redis_test: utilities/redis/redis_lists_test.o $(LIBOBJECTS) $(TESTHARNESS) histogram_test: util/histogram_test.o $(LIBOBJECTS) $(TESTHARNESS) $(CXX) util/histogram_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o$@ $(LDFLAGS) $(COVERAGEFLAGS) +thread_local_test: util/thread_local_test.o $(LIBOBJECTS) $(TESTHARNESS) + $(CXX) util/thread_local_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o $@ $(LDFLAGS) $(COVERAGEFLAGS) + corruption_test: db/corruption_test.o $(LIBOBJECTS) $(TESTHARNESS) $(CXX) db/corruption_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o $@ $(LDFLAGS) $(COVERAGEFLAGS) diff --git a/hdfs/env_hdfs.h b/hdfs/env_hdfs.h index 886ccdac3..17d8fcb2b 100644 --- a/hdfs/env_hdfs.h +++ b/hdfs/env_hdfs.h @@ -47,7 +47,7 @@ private: class HdfsEnv : public Env { public: - HdfsEnv(const std::string& fsname) : fsname_(fsname) { + explicit HdfsEnv(const std::string& fsname) : fsname_(fsname) { posixEnv = Env::Default(); fileSys_ = connectToPath(fsname_); } @@ -108,6 +108,8 @@ class HdfsEnv : public Env { posixEnv->StartThread(function, arg); } + virtual void WaitForJoin() { posixEnv->WaitForJoin(); } + virtual Status GetTestDirectory(std::string* path) { return posixEnv->GetTestDirectory(path); } @@ -161,7 +163,7 @@ class HdfsEnv : public Env { */ hdfsFS connectToPath(const std::string& uri) { if (uri.empty()) { - return NULL; + return nullptr; } if (uri.find(kProto) != 0) { // uri doesn't start with hdfs:// -> use default:0, which is special @@ -218,10 +220,10 @@ static const Status notsup; class HdfsEnv : public Env { public: - HdfsEnv(const std::string& fsname) { + explicit HdfsEnv(const std::string& fsname) { fprintf(stderr, "You have not build rocksdb with HDFS support\n"); fprintf(stderr, "Please see hdfs/README for details\n"); - throw new std::exception(); + throw std::exception(); } virtual ~HdfsEnv() { @@ -288,6 +290,8 @@ class HdfsEnv : public Env { virtual void StartThread(void (*function)(void* arg), void* arg) {} + virtual void WaitForJoin() {} + virtual Status GetTestDirectory(std::string* path) {return notsup;} virtual uint64_t NowMicros() {return 0;} diff --git a/include/rocksdb/env.h b/include/rocksdb/env.h index 06e9b94aa..932425027 100644 --- a/include/rocksdb/env.h +++ b/include/rocksdb/env.h @@ -205,6 +205,9 @@ class Env { // When "function(arg)" returns, the thread will be destroyed. virtual void StartThread(void (*function)(void* arg), void* arg) = 0; + // Wait for all threads started by StartThread to terminate. + virtual void WaitForJoin() = 0; + // *path is set to a temporary directory that can be used for testing. It may // or many not have just been created. The directory may or may not differ // between runs of the same process, but subsequent calls will return the @@ -634,6 +637,7 @@ class EnvWrapper : public Env { void StartThread(void (*f)(void*), void* a) { return target_->StartThread(f, a); } + void WaitForJoin() { return target_->WaitForJoin(); } virtual Status GetTestDirectory(std::string* path) { return target_->GetTestDirectory(path); } diff --git a/util/env_posix.cc b/util/env_posix.cc index 1ccb32084..fcfea28ab 100644 --- a/util/env_posix.cc +++ b/util/env_posix.cc @@ -1194,6 +1194,8 @@ class PosixEnv : public Env { virtual void StartThread(void (*function)(void* arg), void* arg); + virtual void WaitForJoin(); + virtual Status GetTestDirectory(std::string* result) { const char* env = getenv("TEST_TMPDIR"); if (env && env[0] != '\0') { @@ -1511,6 +1513,13 @@ void PosixEnv::StartThread(void (*function)(void* arg), void* arg) { PthreadCall("unlock", pthread_mutex_unlock(&mu_)); } +void PosixEnv::WaitForJoin() { + for (const auto tid : threads_to_join_) { + pthread_join(tid, nullptr); + } + threads_to_join_.clear(); +} + } // namespace std::string Env::GenerateUniqueId() { diff --git a/util/thread_local.cc b/util/thread_local.cc new file mode 100644 index 000000000..90571b97e --- /dev/null +++ b/util/thread_local.cc @@ -0,0 +1,236 @@ +// Copyright (c) 2013, Facebook, Inc. All rights reserved. +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. An additional grant +// of patent rights can be found in the PATENTS file in the same directory. +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#include "util/thread_local.h" +#include "util/mutexlock.h" + +#if defined(__GNUC__) && __GNUC__ >= 4 +#define UNLIKELY(x) (__builtin_expect((x), 0)) +#else +#define UNLIKELY(x) (x) +#endif + +namespace rocksdb { + +std::unique_ptr ThreadLocalPtr::StaticMeta::inst_; +port::Mutex ThreadLocalPtr::StaticMeta::mutex_; +#if !defined(OS_MACOSX) +__thread ThreadLocalPtr::ThreadData* ThreadLocalPtr::StaticMeta::tls_ = nullptr; +#endif + +ThreadLocalPtr::StaticMeta* ThreadLocalPtr::StaticMeta::Instance() { + if (UNLIKELY(inst_ == nullptr)) { + MutexLock l(&mutex_); + if (inst_ == nullptr) { + inst_.reset(new StaticMeta()); + } + } + return inst_.get(); +} + +void ThreadLocalPtr::StaticMeta::OnThreadExit(void* ptr) { + auto* tls = static_cast(ptr); + assert(tls != nullptr); + + auto* inst = Instance(); + pthread_setspecific(inst->pthread_key_, nullptr); + + MutexLock l(&mutex_); + inst->RemoveThreadData(tls); + // Unref stored pointers of current thread from all instances + uint32_t id = 0; + for (auto& e : tls->entries) { + void* raw = e.ptr.load(std::memory_order_relaxed); + if (raw != nullptr) { + auto unref = inst->GetHandler(id); + if (unref != nullptr) { + unref(raw); + } + } + ++id; + } + // Delete thread local structure no matter if it is Mac platform + delete tls; +} + +ThreadLocalPtr::StaticMeta::StaticMeta() : next_instance_id_(0) { + if (pthread_key_create(&pthread_key_, &OnThreadExit) != 0) { + throw std::runtime_error("pthread_key_create failed"); + } + head_.next = &head_; + head_.prev = &head_; +} + +void ThreadLocalPtr::StaticMeta::AddThreadData(ThreadLocalPtr::ThreadData* d) { + mutex_.AssertHeld(); + d->next = &head_; + d->prev = head_.prev; + head_.prev->next = d; + head_.prev = d; +} + +void ThreadLocalPtr::StaticMeta::RemoveThreadData( + ThreadLocalPtr::ThreadData* d) { + mutex_.AssertHeld(); + d->next->prev = d->prev; + d->prev->next = d->next; + d->next = d->prev = d; +} + +ThreadLocalPtr::ThreadData* ThreadLocalPtr::StaticMeta::GetThreadLocal() { +#if defined(OS_MACOSX) + // Make this local variable name look like a member variable so that we + // can share all the code below + ThreadData* tls_ = + static_cast(pthread_getspecific(Instance()->pthread_key_)); +#endif + + if (UNLIKELY(tls_ == nullptr)) { + auto* inst = Instance(); + tls_ = new ThreadData(); + { + // Register it in the global chain, needs to be done before thread exit + // handler registration + MutexLock l(&mutex_); + inst->AddThreadData(tls_); + } + // Even it is not OS_MACOSX, need to register value for pthread_key_ so that + // its exit handler will be triggered. + if (pthread_setspecific(inst->pthread_key_, tls_) != 0) { + { + MutexLock l(&mutex_); + inst->RemoveThreadData(tls_); + } + delete tls_; + throw std::runtime_error("pthread_setspecific failed"); + } + } + return tls_; +} + +void* ThreadLocalPtr::StaticMeta::Get(uint32_t id) const { + auto* tls = GetThreadLocal(); + if (UNLIKELY(id >= tls->entries.size())) { + return nullptr; + } + return tls->entries[id].ptr.load(std::memory_order_relaxed); +} + +void ThreadLocalPtr::StaticMeta::Reset(uint32_t id, void* ptr) { + auto* tls = GetThreadLocal(); + if (UNLIKELY(id >= tls->entries.size())) { + // Need mutex to protect entries access within ReclaimId + MutexLock l(&mutex_); + tls->entries.resize(id + 1); + } + tls->entries[id].ptr.store(ptr, std::memory_order_relaxed); +} + +void* ThreadLocalPtr::StaticMeta::Swap(uint32_t id, void* ptr) { + auto* tls = GetThreadLocal(); + if (UNLIKELY(id >= tls->entries.size())) { + // Need mutex to protect entries access within ReclaimId + MutexLock l(&mutex_); + tls->entries.resize(id + 1); + } + return tls->entries[id].ptr.exchange(ptr, std::memory_order_relaxed); +} + +void ThreadLocalPtr::StaticMeta::Scrape(uint32_t id, autovector* ptrs) { + MutexLock l(&mutex_); + for (ThreadData* t = head_.next; t != &head_; t = t->next) { + if (id < t->entries.size()) { + void* ptr = + t->entries[id].ptr.exchange(nullptr, std::memory_order_relaxed); + if (ptr != nullptr) { + ptrs->push_back(ptr); + } + } + } +} + +void ThreadLocalPtr::StaticMeta::SetHandler(uint32_t id, UnrefHandler handler) { + MutexLock l(&mutex_); + handler_map_[id] = handler; +} + +UnrefHandler ThreadLocalPtr::StaticMeta::GetHandler(uint32_t id) { + mutex_.AssertHeld(); + auto iter = handler_map_.find(id); + if (iter == handler_map_.end()) { + return nullptr; + } + return iter->second; +} + +uint32_t ThreadLocalPtr::StaticMeta::GetId() { + MutexLock l(&mutex_); + if (free_instance_ids_.empty()) { + return next_instance_id_++; + } + + uint32_t id = free_instance_ids_.back(); + free_instance_ids_.pop_back(); + return id; +} + +uint32_t ThreadLocalPtr::StaticMeta::PeekId() const { + MutexLock l(&mutex_); + if (!free_instance_ids_.empty()) { + return free_instance_ids_.back(); + } + return next_instance_id_; +} + +void ThreadLocalPtr::StaticMeta::ReclaimId(uint32_t id) { + // This id is not used, go through all thread local data and release + // corresponding value + MutexLock l(&mutex_); + auto unref = GetHandler(id); + for (ThreadData* t = head_.next; t != &head_; t = t->next) { + if (id < t->entries.size()) { + void* ptr = + t->entries[id].ptr.exchange(nullptr, std::memory_order_relaxed); + if (ptr != nullptr && unref != nullptr) { + unref(ptr); + } + } + } + handler_map_[id] = nullptr; + free_instance_ids_.push_back(id); +} + +ThreadLocalPtr::ThreadLocalPtr(UnrefHandler handler) + : id_(StaticMeta::Instance()->GetId()) { + if (handler != nullptr) { + StaticMeta::Instance()->SetHandler(id_, handler); + } +} + +ThreadLocalPtr::~ThreadLocalPtr() { + StaticMeta::Instance()->ReclaimId(id_); +} + +void* ThreadLocalPtr::Get() const { + return StaticMeta::Instance()->Get(id_); +} + +void ThreadLocalPtr::Reset(void* ptr) { + StaticMeta::Instance()->Reset(id_, ptr); +} + +void* ThreadLocalPtr::Swap(void* ptr) { + return StaticMeta::Instance()->Swap(id_, ptr); +} + +void ThreadLocalPtr::Scrape(autovector* ptrs) { + StaticMeta::Instance()->Scrape(id_, ptrs); +} + +} // namespace rocksdb diff --git a/util/thread_local.h b/util/thread_local.h new file mode 100644 index 000000000..d6fc5f085 --- /dev/null +++ b/util/thread_local.h @@ -0,0 +1,158 @@ +// Copyright (c) 2013, Facebook, Inc. All rights reserved. +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. An additional grant +// of patent rights can be found in the PATENTS file in the same directory. +// +// Copyright (c) 2011 The LevelDB Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. See the AUTHORS file for names of contributors. + +#pragma once + +#include +#include +#include +#include + +#include "util/autovector.h" +#include "port/port_posix.h" + +namespace rocksdb { + +// Cleanup function that will be called for a stored thread local +// pointer (if not NULL) when one of the following happens: +// (1) a thread terminates +// (2) a ThreadLocalPtr is destroyed +typedef void (*UnrefHandler)(void* ptr); + +// Thread local storage that only stores value of pointer type. The storage +// distinguish data coming from different thread and different ThreadLocalPtr +// instances. For example, if a regular thread_local variable A is declared +// in DBImpl, two DBImpl objects would share the same A. ThreadLocalPtr avoids +// the confliction. The total storage size equals to # of threads * # of +// ThreadLocalPtr instances. It is not efficient in terms of space, but it +// should serve most of our use cases well and keep code simple. +class ThreadLocalPtr { + public: + explicit ThreadLocalPtr(UnrefHandler handler = nullptr); + + ~ThreadLocalPtr(); + + // Return the current pointer stored in thread local + void* Get() const; + + // Set a new pointer value to the thread local storage. + void Reset(void* ptr); + + // Atomically swap the supplied ptr and return the previous value + void* Swap(void* ptr); + + // Return non-nullptr data for all existing threads and reset them + // to nullptr + void Scrape(autovector* ptrs); + + protected: + struct Entry { + Entry() : ptr(nullptr) {} + Entry(const Entry& e) : ptr(e.ptr.load(std::memory_order_relaxed)) {} + std::atomic ptr; + }; + + // This is the structure that is declared as "thread_local" storage. + // The vector keep list of atomic pointer for all instances for "current" + // thread. The vector is indexed by an Id that is unique in process and + // associated with one ThreadLocalPtr instance. The Id is assigned by a + // global StaticMeta singleton. So if we instantiated 3 ThreadLocalPtr + // instances, each thread will have a ThreadData with a vector of size 3: + // --------------------------------------------------- + // | | instance 1 | instance 2 | instnace 3 | + // --------------------------------------------------- + // | thread 1 | void* | void* | void* | <- ThreadData + // --------------------------------------------------- + // | thread 2 | void* | void* | void* | <- ThreadData + // --------------------------------------------------- + // | thread 3 | void* | void* | void* | <- ThreadData + // --------------------------------------------------- + struct ThreadData { + ThreadData() : entries() {} + std::vector entries; + ThreadData* next; + ThreadData* prev; + }; + + class StaticMeta { + public: + static StaticMeta* Instance(); + + // Return the next available Id + uint32_t GetId(); + // Return the next availabe Id without claiming it + uint32_t PeekId() const; + // Return the given Id back to the free pool. This also triggers + // UnrefHandler for associated pointer value (if not NULL) for all threads. + void ReclaimId(uint32_t id); + + // Return the pointer value for the given id for the current thread. + void* Get(uint32_t id) const; + // Reset the pointer value for the given id for the current thread. + // It triggers UnrefHanlder if the id has existing pointer value. + void Reset(uint32_t id, void* ptr); + // Atomically swap the supplied ptr and return the previous value + void* Swap(uint32_t id, void* ptr); + // Return data for all existing threads and return them to nullptr + void Scrape(uint32_t id, autovector* ptrs); + + // Register the UnrefHandler for id + void SetHandler(uint32_t id, UnrefHandler handler); + + private: + StaticMeta(); + + // Get UnrefHandler for id with acquiring mutex + // REQUIRES: mutex locked + UnrefHandler GetHandler(uint32_t id); + + // Triggered before a thread terminates + static void OnThreadExit(void* ptr); + + // Add current thread's ThreadData to the global chain + // REQUIRES: mutex locked + void AddThreadData(ThreadData* d); + + // Remove current thread's ThreadData from the global chain + // REQUIRES: mutex locked + void RemoveThreadData(ThreadData* d); + + static ThreadData* GetThreadLocal(); + + // Singleton instance + static std::unique_ptr inst_; + + uint32_t next_instance_id_; + // Used to recycle Ids in case ThreadLocalPtr is instantiated and destroyed + // frequently. This also prevents it from blowing up the vector space. + autovector free_instance_ids_; + // Chain all thread local structure together. This is necessary since + // when one ThreadLocalPtr gets destroyed, we need to loop over each + // thread's version of pointer corresponding to that instance and + // call UnrefHandler for it. + ThreadData head_; + + std::unordered_map handler_map_; + + // protect inst, next_instance_id_, free_instance_ids_, head_, + // ThreadData.entries + static port::Mutex mutex_; +#if !defined(OS_MACOSX) + // Thread local storage + static __thread ThreadData* tls_; +#endif + // Used to make thread exit trigger possible if !defined(OS_MACOSX). + // Otherwise, used to retrieve thread data. + pthread_key_t pthread_key_; + }; + + const uint32_t id_; +}; + +} // namespace rocksdb diff --git a/util/thread_local_test.cc b/util/thread_local_test.cc new file mode 100644 index 000000000..bc7aa5b52 --- /dev/null +++ b/util/thread_local_test.cc @@ -0,0 +1,456 @@ +// Copyright (c) 2013, Facebook, Inc. All rights reserved. +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. An additional grant +// of patent rights can be found in the PATENTS file in the same directory. + +#include + +#include "rocksdb/env.h" +#include "port/port_posix.h" +#include "util/autovector.h" +#include "util/thread_local.h" +#include "util/testharness.h" +#include "util/testutil.h" + +namespace rocksdb { + +class ThreadLocalTest { + public: + ThreadLocalTest() : env_(Env::Default()) {} + + Env* env_; +}; + +namespace { + +struct Params { + Params(port::Mutex* m, port::CondVar* c, int* unref, int n, + UnrefHandler handler = nullptr) + : mu(m), + cv(c), + unref(unref), + total(n), + started(0), + completed(0), + doWrite(false), + tls1(handler), + tls2(nullptr) {} + + port::Mutex* mu; + port::CondVar* cv; + int* unref; + int total; + int started; + int completed; + bool doWrite; + ThreadLocalPtr tls1; + ThreadLocalPtr* tls2; +}; + +class IDChecker : public ThreadLocalPtr { + public: + static uint32_t PeekId() { return StaticMeta::Instance()->PeekId(); } +}; + +} // anonymous namespace + +TEST(ThreadLocalTest, UniqueIdTest) { + port::Mutex mu; + port::CondVar cv(&mu); + + ASSERT_EQ(IDChecker::PeekId(), 0); + // New ThreadLocal instance bumps id by 1 + { + // Id used 0 + Params p1(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 1); + // Id used 1 + Params p2(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 2); + // Id used 2 + Params p3(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 3); + // Id used 3 + Params p4(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 4); + } + // id 3, 2, 1, 0 are in the free queue in order + ASSERT_EQ(IDChecker::PeekId(), 0); + + // pick up 0 + Params p1(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 1); + // pick up 1 + Params* p2 = new Params(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 2); + // pick up 2 + Params p3(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 3); + // return up 1 + delete p2; + ASSERT_EQ(IDChecker::PeekId(), 1); + // Now we have 3, 1 in queue + // pick up 1 + Params p4(&mu, &cv, nullptr, 1); + ASSERT_EQ(IDChecker::PeekId(), 3); + // pick up 3 + Params p5(&mu, &cv, nullptr, 1); + // next new id + ASSERT_EQ(IDChecker::PeekId(), 4); + // After exit, id sequence in queue: + // 3, 1, 2, 0 +} + +TEST(ThreadLocalTest, SequentialReadWriteTest) { + // global id list carries over 3, 1, 2, 0 + ASSERT_EQ(IDChecker::PeekId(), 0); + + port::Mutex mu; + port::CondVar cv(&mu); + Params p(&mu, &cv, nullptr, 1); + ThreadLocalPtr tls2; + p.tls2 = &tls2; + + auto func = [](void* ptr) { + auto& p = *static_cast(ptr); + + ASSERT_TRUE(p.tls1.Get() == nullptr); + p.tls1.Reset(reinterpret_cast(1)); + ASSERT_TRUE(p.tls1.Get() == reinterpret_cast(1)); + p.tls1.Reset(reinterpret_cast(2)); + ASSERT_TRUE(p.tls1.Get() == reinterpret_cast(2)); + + ASSERT_TRUE(p.tls2->Get() == nullptr); + p.tls2->Reset(reinterpret_cast(1)); + ASSERT_TRUE(p.tls2->Get() == reinterpret_cast(1)); + p.tls2->Reset(reinterpret_cast(2)); + ASSERT_TRUE(p.tls2->Get() == reinterpret_cast(2)); + + p.mu->Lock(); + ++(p.completed); + p.cv->SignalAll(); + p.mu->Unlock(); + }; + + for (int iter = 0; iter < 1024; ++iter) { + ASSERT_EQ(IDChecker::PeekId(), 1); + // Another new thread, read/write should not see value from previous thread + env_->StartThread(func, static_cast(&p)); + mu.Lock(); + while (p.completed != iter + 1) { + cv.Wait(); + } + mu.Unlock(); + ASSERT_EQ(IDChecker::PeekId(), 1); + } +} + +TEST(ThreadLocalTest, ConcurrentReadWriteTest) { + // global id list carries over 3, 1, 2, 0 + ASSERT_EQ(IDChecker::PeekId(), 0); + + ThreadLocalPtr tls2; + port::Mutex mu1; + port::CondVar cv1(&mu1); + Params p1(&mu1, &cv1, nullptr, 128); + p1.tls2 = &tls2; + + port::Mutex mu2; + port::CondVar cv2(&mu2); + Params p2(&mu2, &cv2, nullptr, 128); + p2.doWrite = true; + p2.tls2 = &tls2; + + auto func = [](void* ptr) { + auto& p = *static_cast(ptr); + + p.mu->Lock(); + int own = ++(p.started); + p.cv->SignalAll(); + while (p.started != p.total) { + p.cv->Wait(); + } + p.mu->Unlock(); + + // Let write threads write a different value from the read threads + if (p.doWrite) { + own += 8192; + } + + ASSERT_TRUE(p.tls1.Get() == nullptr); + ASSERT_TRUE(p.tls2->Get() == nullptr); + + auto* env = Env::Default(); + auto start = env->NowMicros(); + + p.tls1.Reset(reinterpret_cast(own)); + p.tls2->Reset(reinterpret_cast(own + 1)); + // Loop for 1 second + while (env->NowMicros() - start < 1000 * 1000) { + for (int iter = 0; iter < 100000; ++iter) { + ASSERT_TRUE(p.tls1.Get() == reinterpret_cast(own)); + ASSERT_TRUE(p.tls2->Get() == reinterpret_cast(own + 1)); + if (p.doWrite) { + p.tls1.Reset(reinterpret_cast(own)); + p.tls2->Reset(reinterpret_cast(own + 1)); + } + } + } + + p.mu->Lock(); + ++(p.completed); + p.cv->SignalAll(); + p.mu->Unlock(); + }; + + // Initiate 2 instnaces: one keeps writing and one keeps reading. + // The read instance should not see data from the write instance. + // Each thread local copy of the value are also different from each + // other. + for (int th = 0; th < p1.total; ++th) { + env_->StartThread(func, static_cast(&p1)); + } + for (int th = 0; th < p2.total; ++th) { + env_->StartThread(func, static_cast(&p2)); + } + + mu1.Lock(); + while (p1.completed != p1.total) { + cv1.Wait(); + } + mu1.Unlock(); + + mu2.Lock(); + while (p2.completed != p2.total) { + cv2.Wait(); + } + mu2.Unlock(); + + ASSERT_EQ(IDChecker::PeekId(), 3); +} + +TEST(ThreadLocalTest, Unref) { + ASSERT_EQ(IDChecker::PeekId(), 0); + + auto unref = [](void* ptr) { + auto& p = *static_cast(ptr); + p.mu->Lock(); + ++(*p.unref); + p.mu->Unlock(); + }; + + // Case 0: no unref triggered if ThreadLocalPtr is never accessed + auto func0 = [](void* ptr) { + auto& p = *static_cast(ptr); + + p.mu->Lock(); + ++(p.started); + p.cv->SignalAll(); + while (p.started != p.total) { + p.cv->Wait(); + } + p.mu->Unlock(); + }; + + for (int th = 1; th <= 128; th += th) { + port::Mutex mu; + port::CondVar cv(&mu); + int unref_count = 0; + Params p(&mu, &cv, &unref_count, th, unref); + + for (int i = 0; i < p.total; ++i) { + env_->StartThread(func0, static_cast(&p)); + } + env_->WaitForJoin(); + ASSERT_EQ(unref_count, 0); + } + + // Case 1: unref triggered by thread exit + auto func1 = [](void* ptr) { + auto& p = *static_cast(ptr); + + p.mu->Lock(); + ++(p.started); + p.cv->SignalAll(); + while (p.started != p.total) { + p.cv->Wait(); + } + p.mu->Unlock(); + + ASSERT_TRUE(p.tls1.Get() == nullptr); + ASSERT_TRUE(p.tls2->Get() == nullptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + }; + + for (int th = 1; th <= 128; th += th) { + port::Mutex mu; + port::CondVar cv(&mu); + int unref_count = 0; + ThreadLocalPtr tls2(unref); + Params p(&mu, &cv, &unref_count, th, unref); + p.tls2 = &tls2; + + for (int i = 0; i < p.total; ++i) { + env_->StartThread(func1, static_cast(&p)); + } + + env_->WaitForJoin(); + + // N threads x 2 ThreadLocal instance cleanup on thread exit + ASSERT_EQ(unref_count, 2 * p.total); + } + + // Case 2: unref triggered by ThreadLocal instance destruction + auto func2 = [](void* ptr) { + auto& p = *static_cast(ptr); + + p.mu->Lock(); + ++(p.started); + p.cv->SignalAll(); + while (p.started != p.total) { + p.cv->Wait(); + } + p.mu->Unlock(); + + ASSERT_TRUE(p.tls1.Get() == nullptr); + ASSERT_TRUE(p.tls2->Get() == nullptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.mu->Lock(); + ++(p.completed); + p.cv->SignalAll(); + + // Waiting for instruction to exit thread + while (p.completed != 0) { + p.cv->Wait(); + } + p.mu->Unlock(); + }; + + for (int th = 1; th <= 128; th += th) { + port::Mutex mu; + port::CondVar cv(&mu); + int unref_count = 0; + Params p(&mu, &cv, &unref_count, th, unref); + p.tls2 = new ThreadLocalPtr(unref); + + for (int i = 0; i < p.total; ++i) { + env_->StartThread(func2, static_cast(&p)); + } + + // Wait for all threads to finish using Params + mu.Lock(); + while (p.completed != p.total) { + cv.Wait(); + } + mu.Unlock(); + + // Now destroy one ThreadLocal instance + delete p.tls2; + p.tls2 = nullptr; + // instance destroy for N threads + ASSERT_EQ(unref_count, p.total); + + // Signal to exit + mu.Lock(); + p.completed = 0; + cv.SignalAll(); + mu.Unlock(); + env_->WaitForJoin(); + // additional N threads exit unref for the left instance + ASSERT_EQ(unref_count, 2 * p.total); + } +} + +TEST(ThreadLocalTest, Swap) { + ThreadLocalPtr tls; + tls.Reset(reinterpret_cast(1)); + ASSERT_EQ(reinterpret_cast(tls.Swap(nullptr)), 1); + ASSERT_TRUE(tls.Swap(reinterpret_cast(2)) == nullptr); + ASSERT_EQ(reinterpret_cast(tls.Get()), 2); + ASSERT_EQ(reinterpret_cast(tls.Swap(reinterpret_cast(3))), 2); +} + +TEST(ThreadLocalTest, Scrape) { + auto unref = [](void* ptr) { + auto& p = *static_cast(ptr); + p.mu->Lock(); + ++(*p.unref); + p.mu->Unlock(); + }; + + auto func = [](void* ptr) { + auto& p = *static_cast(ptr); + + ASSERT_TRUE(p.tls1.Get() == nullptr); + ASSERT_TRUE(p.tls2->Get() == nullptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.tls1.Reset(ptr); + p.tls2->Reset(ptr); + + p.mu->Lock(); + ++(p.completed); + p.cv->SignalAll(); + + // Waiting for instruction to exit thread + while (p.completed != 0) { + p.cv->Wait(); + } + p.mu->Unlock(); + }; + + for (int th = 1; th <= 128; th += th) { + port::Mutex mu; + port::CondVar cv(&mu); + int unref_count = 0; + Params p(&mu, &cv, &unref_count, th, unref); + p.tls2 = new ThreadLocalPtr(unref); + + for (int i = 0; i < p.total; ++i) { + env_->StartThread(func, static_cast(&p)); + } + + // Wait for all threads to finish using Params + mu.Lock(); + while (p.completed != p.total) { + cv.Wait(); + } + mu.Unlock(); + + ASSERT_EQ(unref_count, 0); + + // Scrape all thread local data. No unref at thread + // exit or ThreadLocalPtr destruction + autovector ptrs; + p.tls1.Scrape(&ptrs); + p.tls2->Scrape(&ptrs); + delete p.tls2; + // Signal to exit + mu.Lock(); + p.completed = 0; + cv.SignalAll(); + mu.Unlock(); + env_->WaitForJoin(); + + ASSERT_EQ(unref_count, 0); + } +} + +} // namespace rocksdb + +int main(int argc, char** argv) { + return rocksdb::test::RunAllTests(); +}