thread local pointer storage

Summary:
This is not a generic thread local implementation in the sense that it
only takes pointer. But it does support multiple instances per thread
and lets user plugin function to perform cleanup when thread exits or an
instance gets destroyed.

Test Plan: unit test for now

Reviewers: haobo, igor, sdong, dhruba

Reviewed By: igor

CC: leveldb, kailiu

Differential Revision: https://reviews.facebook.net/D16131
main
Lei Jin 11 years ago
parent 4209516359
commit b2795b799e
  1. 1
      HISTORY.md
  2. 6
      Makefile
  3. 12
      hdfs/env_hdfs.h
  4. 4
      include/rocksdb/env.h
  5. 9
      util/env_posix.cc
  6. 236
      util/thread_local.cc
  7. 158
      util/thread_local.h
  8. 456
      util/thread_local_test.cc

@ -7,6 +7,7 @@
* Removed arena.h from public header files. * Removed arena.h from public header files.
* By default, checksums are verified on every read from database * By default, checksums are verified on every read from database
* Added is_manual_compaction to CompactionFilter::Context * Added is_manual_compaction to CompactionFilter::Context
* Added "virtual void WaitForJoin() = 0" in class Env
## 2.7.0 (01/28/2014) ## 2.7.0 (01/28/2014)

@ -87,7 +87,8 @@ TESTS = \
version_set_test \ version_set_test \
write_batch_test\ write_batch_test\
deletefile_test \ deletefile_test \
table_test table_test \
thread_local_test
TOOLS = \ TOOLS = \
sst_dump \ sst_dump \
@ -276,6 +277,9 @@ redis_test: utilities/redis/redis_lists_test.o $(LIBOBJECTS) $(TESTHARNESS)
histogram_test: util/histogram_test.o $(LIBOBJECTS) $(TESTHARNESS) histogram_test: util/histogram_test.o $(LIBOBJECTS) $(TESTHARNESS)
$(CXX) util/histogram_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o$@ $(LDFLAGS) $(COVERAGEFLAGS) $(CXX) util/histogram_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o$@ $(LDFLAGS) $(COVERAGEFLAGS)
thread_local_test: util/thread_local_test.o $(LIBOBJECTS) $(TESTHARNESS)
$(CXX) util/thread_local_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o $@ $(LDFLAGS) $(COVERAGEFLAGS)
corruption_test: db/corruption_test.o $(LIBOBJECTS) $(TESTHARNESS) corruption_test: db/corruption_test.o $(LIBOBJECTS) $(TESTHARNESS)
$(CXX) db/corruption_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o $@ $(LDFLAGS) $(COVERAGEFLAGS) $(CXX) db/corruption_test.o $(LIBOBJECTS) $(TESTHARNESS) $(EXEC_LDFLAGS) -o $@ $(LDFLAGS) $(COVERAGEFLAGS)

@ -47,7 +47,7 @@ private:
class HdfsEnv : public Env { class HdfsEnv : public Env {
public: public:
HdfsEnv(const std::string& fsname) : fsname_(fsname) { explicit HdfsEnv(const std::string& fsname) : fsname_(fsname) {
posixEnv = Env::Default(); posixEnv = Env::Default();
fileSys_ = connectToPath(fsname_); fileSys_ = connectToPath(fsname_);
} }
@ -108,6 +108,8 @@ class HdfsEnv : public Env {
posixEnv->StartThread(function, arg); posixEnv->StartThread(function, arg);
} }
virtual void WaitForJoin() { posixEnv->WaitForJoin(); }
virtual Status GetTestDirectory(std::string* path) { virtual Status GetTestDirectory(std::string* path) {
return posixEnv->GetTestDirectory(path); return posixEnv->GetTestDirectory(path);
} }
@ -161,7 +163,7 @@ class HdfsEnv : public Env {
*/ */
hdfsFS connectToPath(const std::string& uri) { hdfsFS connectToPath(const std::string& uri) {
if (uri.empty()) { if (uri.empty()) {
return NULL; return nullptr;
} }
if (uri.find(kProto) != 0) { if (uri.find(kProto) != 0) {
// uri doesn't start with hdfs:// -> use default:0, which is special // uri doesn't start with hdfs:// -> use default:0, which is special
@ -218,10 +220,10 @@ static const Status notsup;
class HdfsEnv : public Env { class HdfsEnv : public Env {
public: public:
HdfsEnv(const std::string& fsname) { explicit HdfsEnv(const std::string& fsname) {
fprintf(stderr, "You have not build rocksdb with HDFS support\n"); fprintf(stderr, "You have not build rocksdb with HDFS support\n");
fprintf(stderr, "Please see hdfs/README for details\n"); fprintf(stderr, "Please see hdfs/README for details\n");
throw new std::exception(); throw std::exception();
} }
virtual ~HdfsEnv() { virtual ~HdfsEnv() {
@ -288,6 +290,8 @@ class HdfsEnv : public Env {
virtual void StartThread(void (*function)(void* arg), void* arg) {} virtual void StartThread(void (*function)(void* arg), void* arg) {}
virtual void WaitForJoin() {}
virtual Status GetTestDirectory(std::string* path) {return notsup;} virtual Status GetTestDirectory(std::string* path) {return notsup;}
virtual uint64_t NowMicros() {return 0;} virtual uint64_t NowMicros() {return 0;}

@ -205,6 +205,9 @@ class Env {
// When "function(arg)" returns, the thread will be destroyed. // When "function(arg)" returns, the thread will be destroyed.
virtual void StartThread(void (*function)(void* arg), void* arg) = 0; virtual void StartThread(void (*function)(void* arg), void* arg) = 0;
// Wait for all threads started by StartThread to terminate.
virtual void WaitForJoin() = 0;
// *path is set to a temporary directory that can be used for testing. It may // *path is set to a temporary directory that can be used for testing. It may
// or many not have just been created. The directory may or may not differ // or many not have just been created. The directory may or may not differ
// between runs of the same process, but subsequent calls will return the // between runs of the same process, but subsequent calls will return the
@ -634,6 +637,7 @@ class EnvWrapper : public Env {
void StartThread(void (*f)(void*), void* a) { void StartThread(void (*f)(void*), void* a) {
return target_->StartThread(f, a); return target_->StartThread(f, a);
} }
void WaitForJoin() { return target_->WaitForJoin(); }
virtual Status GetTestDirectory(std::string* path) { virtual Status GetTestDirectory(std::string* path) {
return target_->GetTestDirectory(path); return target_->GetTestDirectory(path);
} }

@ -1194,6 +1194,8 @@ class PosixEnv : public Env {
virtual void StartThread(void (*function)(void* arg), void* arg); virtual void StartThread(void (*function)(void* arg), void* arg);
virtual void WaitForJoin();
virtual Status GetTestDirectory(std::string* result) { virtual Status GetTestDirectory(std::string* result) {
const char* env = getenv("TEST_TMPDIR"); const char* env = getenv("TEST_TMPDIR");
if (env && env[0] != '\0') { if (env && env[0] != '\0') {
@ -1511,6 +1513,13 @@ void PosixEnv::StartThread(void (*function)(void* arg), void* arg) {
PthreadCall("unlock", pthread_mutex_unlock(&mu_)); PthreadCall("unlock", pthread_mutex_unlock(&mu_));
} }
void PosixEnv::WaitForJoin() {
for (const auto tid : threads_to_join_) {
pthread_join(tid, nullptr);
}
threads_to_join_.clear();
}
} // namespace } // namespace
std::string Env::GenerateUniqueId() { std::string Env::GenerateUniqueId() {

@ -0,0 +1,236 @@
// Copyright (c) 2013, Facebook, Inc. All rights reserved.
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree. An additional grant
// of patent rights can be found in the PATENTS file in the same directory.
//
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. See the AUTHORS file for names of contributors.
#include "util/thread_local.h"
#include "util/mutexlock.h"
#if defined(__GNUC__) && __GNUC__ >= 4
#define UNLIKELY(x) (__builtin_expect((x), 0))
#else
#define UNLIKELY(x) (x)
#endif
namespace rocksdb {
std::unique_ptr<ThreadLocalPtr::StaticMeta> ThreadLocalPtr::StaticMeta::inst_;
port::Mutex ThreadLocalPtr::StaticMeta::mutex_;
#if !defined(OS_MACOSX)
__thread ThreadLocalPtr::ThreadData* ThreadLocalPtr::StaticMeta::tls_ = nullptr;
#endif
ThreadLocalPtr::StaticMeta* ThreadLocalPtr::StaticMeta::Instance() {
if (UNLIKELY(inst_ == nullptr)) {
MutexLock l(&mutex_);
if (inst_ == nullptr) {
inst_.reset(new StaticMeta());
}
}
return inst_.get();
}
void ThreadLocalPtr::StaticMeta::OnThreadExit(void* ptr) {
auto* tls = static_cast<ThreadData*>(ptr);
assert(tls != nullptr);
auto* inst = Instance();
pthread_setspecific(inst->pthread_key_, nullptr);
MutexLock l(&mutex_);
inst->RemoveThreadData(tls);
// Unref stored pointers of current thread from all instances
uint32_t id = 0;
for (auto& e : tls->entries) {
void* raw = e.ptr.load(std::memory_order_relaxed);
if (raw != nullptr) {
auto unref = inst->GetHandler(id);
if (unref != nullptr) {
unref(raw);
}
}
++id;
}
// Delete thread local structure no matter if it is Mac platform
delete tls;
}
ThreadLocalPtr::StaticMeta::StaticMeta() : next_instance_id_(0) {
if (pthread_key_create(&pthread_key_, &OnThreadExit) != 0) {
throw std::runtime_error("pthread_key_create failed");
}
head_.next = &head_;
head_.prev = &head_;
}
void ThreadLocalPtr::StaticMeta::AddThreadData(ThreadLocalPtr::ThreadData* d) {
mutex_.AssertHeld();
d->next = &head_;
d->prev = head_.prev;
head_.prev->next = d;
head_.prev = d;
}
void ThreadLocalPtr::StaticMeta::RemoveThreadData(
ThreadLocalPtr::ThreadData* d) {
mutex_.AssertHeld();
d->next->prev = d->prev;
d->prev->next = d->next;
d->next = d->prev = d;
}
ThreadLocalPtr::ThreadData* ThreadLocalPtr::StaticMeta::GetThreadLocal() {
#if defined(OS_MACOSX)
// Make this local variable name look like a member variable so that we
// can share all the code below
ThreadData* tls_ =
static_cast<ThreadData*>(pthread_getspecific(Instance()->pthread_key_));
#endif
if (UNLIKELY(tls_ == nullptr)) {
auto* inst = Instance();
tls_ = new ThreadData();
{
// Register it in the global chain, needs to be done before thread exit
// handler registration
MutexLock l(&mutex_);
inst->AddThreadData(tls_);
}
// Even it is not OS_MACOSX, need to register value for pthread_key_ so that
// its exit handler will be triggered.
if (pthread_setspecific(inst->pthread_key_, tls_) != 0) {
{
MutexLock l(&mutex_);
inst->RemoveThreadData(tls_);
}
delete tls_;
throw std::runtime_error("pthread_setspecific failed");
}
}
return tls_;
}
void* ThreadLocalPtr::StaticMeta::Get(uint32_t id) const {
auto* tls = GetThreadLocal();
if (UNLIKELY(id >= tls->entries.size())) {
return nullptr;
}
return tls->entries[id].ptr.load(std::memory_order_relaxed);
}
void ThreadLocalPtr::StaticMeta::Reset(uint32_t id, void* ptr) {
auto* tls = GetThreadLocal();
if (UNLIKELY(id >= tls->entries.size())) {
// Need mutex to protect entries access within ReclaimId
MutexLock l(&mutex_);
tls->entries.resize(id + 1);
}
tls->entries[id].ptr.store(ptr, std::memory_order_relaxed);
}
void* ThreadLocalPtr::StaticMeta::Swap(uint32_t id, void* ptr) {
auto* tls = GetThreadLocal();
if (UNLIKELY(id >= tls->entries.size())) {
// Need mutex to protect entries access within ReclaimId
MutexLock l(&mutex_);
tls->entries.resize(id + 1);
}
return tls->entries[id].ptr.exchange(ptr, std::memory_order_relaxed);
}
void ThreadLocalPtr::StaticMeta::Scrape(uint32_t id, autovector<void*>* ptrs) {
MutexLock l(&mutex_);
for (ThreadData* t = head_.next; t != &head_; t = t->next) {
if (id < t->entries.size()) {
void* ptr =
t->entries[id].ptr.exchange(nullptr, std::memory_order_relaxed);
if (ptr != nullptr) {
ptrs->push_back(ptr);
}
}
}
}
void ThreadLocalPtr::StaticMeta::SetHandler(uint32_t id, UnrefHandler handler) {
MutexLock l(&mutex_);
handler_map_[id] = handler;
}
UnrefHandler ThreadLocalPtr::StaticMeta::GetHandler(uint32_t id) {
mutex_.AssertHeld();
auto iter = handler_map_.find(id);
if (iter == handler_map_.end()) {
return nullptr;
}
return iter->second;
}
uint32_t ThreadLocalPtr::StaticMeta::GetId() {
MutexLock l(&mutex_);
if (free_instance_ids_.empty()) {
return next_instance_id_++;
}
uint32_t id = free_instance_ids_.back();
free_instance_ids_.pop_back();
return id;
}
uint32_t ThreadLocalPtr::StaticMeta::PeekId() const {
MutexLock l(&mutex_);
if (!free_instance_ids_.empty()) {
return free_instance_ids_.back();
}
return next_instance_id_;
}
void ThreadLocalPtr::StaticMeta::ReclaimId(uint32_t id) {
// This id is not used, go through all thread local data and release
// corresponding value
MutexLock l(&mutex_);
auto unref = GetHandler(id);
for (ThreadData* t = head_.next; t != &head_; t = t->next) {
if (id < t->entries.size()) {
void* ptr =
t->entries[id].ptr.exchange(nullptr, std::memory_order_relaxed);
if (ptr != nullptr && unref != nullptr) {
unref(ptr);
}
}
}
handler_map_[id] = nullptr;
free_instance_ids_.push_back(id);
}
ThreadLocalPtr::ThreadLocalPtr(UnrefHandler handler)
: id_(StaticMeta::Instance()->GetId()) {
if (handler != nullptr) {
StaticMeta::Instance()->SetHandler(id_, handler);
}
}
ThreadLocalPtr::~ThreadLocalPtr() {
StaticMeta::Instance()->ReclaimId(id_);
}
void* ThreadLocalPtr::Get() const {
return StaticMeta::Instance()->Get(id_);
}
void ThreadLocalPtr::Reset(void* ptr) {
StaticMeta::Instance()->Reset(id_, ptr);
}
void* ThreadLocalPtr::Swap(void* ptr) {
return StaticMeta::Instance()->Swap(id_, ptr);
}
void ThreadLocalPtr::Scrape(autovector<void*>* ptrs) {
StaticMeta::Instance()->Scrape(id_, ptrs);
}
} // namespace rocksdb

@ -0,0 +1,158 @@
// Copyright (c) 2013, Facebook, Inc. All rights reserved.
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree. An additional grant
// of patent rights can be found in the PATENTS file in the same directory.
//
// Copyright (c) 2011 The LevelDB Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file. See the AUTHORS file for names of contributors.
#pragma once
#include <atomic>
#include <memory>
#include <unordered_map>
#include <vector>
#include "util/autovector.h"
#include "port/port_posix.h"
namespace rocksdb {
// Cleanup function that will be called for a stored thread local
// pointer (if not NULL) when one of the following happens:
// (1) a thread terminates
// (2) a ThreadLocalPtr is destroyed
typedef void (*UnrefHandler)(void* ptr);
// Thread local storage that only stores value of pointer type. The storage
// distinguish data coming from different thread and different ThreadLocalPtr
// instances. For example, if a regular thread_local variable A is declared
// in DBImpl, two DBImpl objects would share the same A. ThreadLocalPtr avoids
// the confliction. The total storage size equals to # of threads * # of
// ThreadLocalPtr instances. It is not efficient in terms of space, but it
// should serve most of our use cases well and keep code simple.
class ThreadLocalPtr {
public:
explicit ThreadLocalPtr(UnrefHandler handler = nullptr);
~ThreadLocalPtr();
// Return the current pointer stored in thread local
void* Get() const;
// Set a new pointer value to the thread local storage.
void Reset(void* ptr);
// Atomically swap the supplied ptr and return the previous value
void* Swap(void* ptr);
// Return non-nullptr data for all existing threads and reset them
// to nullptr
void Scrape(autovector<void*>* ptrs);
protected:
struct Entry {
Entry() : ptr(nullptr) {}
Entry(const Entry& e) : ptr(e.ptr.load(std::memory_order_relaxed)) {}
std::atomic<void*> ptr;
};
// This is the structure that is declared as "thread_local" storage.
// The vector keep list of atomic pointer for all instances for "current"
// thread. The vector is indexed by an Id that is unique in process and
// associated with one ThreadLocalPtr instance. The Id is assigned by a
// global StaticMeta singleton. So if we instantiated 3 ThreadLocalPtr
// instances, each thread will have a ThreadData with a vector of size 3:
// ---------------------------------------------------
// | | instance 1 | instance 2 | instnace 3 |
// ---------------------------------------------------
// | thread 1 | void* | void* | void* | <- ThreadData
// ---------------------------------------------------
// | thread 2 | void* | void* | void* | <- ThreadData
// ---------------------------------------------------
// | thread 3 | void* | void* | void* | <- ThreadData
// ---------------------------------------------------
struct ThreadData {
ThreadData() : entries() {}
std::vector<Entry> entries;
ThreadData* next;
ThreadData* prev;
};
class StaticMeta {
public:
static StaticMeta* Instance();
// Return the next available Id
uint32_t GetId();
// Return the next availabe Id without claiming it
uint32_t PeekId() const;
// Return the given Id back to the free pool. This also triggers
// UnrefHandler for associated pointer value (if not NULL) for all threads.
void ReclaimId(uint32_t id);
// Return the pointer value for the given id for the current thread.
void* Get(uint32_t id) const;
// Reset the pointer value for the given id for the current thread.
// It triggers UnrefHanlder if the id has existing pointer value.
void Reset(uint32_t id, void* ptr);
// Atomically swap the supplied ptr and return the previous value
void* Swap(uint32_t id, void* ptr);
// Return data for all existing threads and return them to nullptr
void Scrape(uint32_t id, autovector<void*>* ptrs);
// Register the UnrefHandler for id
void SetHandler(uint32_t id, UnrefHandler handler);
private:
StaticMeta();
// Get UnrefHandler for id with acquiring mutex
// REQUIRES: mutex locked
UnrefHandler GetHandler(uint32_t id);
// Triggered before a thread terminates
static void OnThreadExit(void* ptr);
// Add current thread's ThreadData to the global chain
// REQUIRES: mutex locked
void AddThreadData(ThreadData* d);
// Remove current thread's ThreadData from the global chain
// REQUIRES: mutex locked
void RemoveThreadData(ThreadData* d);
static ThreadData* GetThreadLocal();
// Singleton instance
static std::unique_ptr<StaticMeta> inst_;
uint32_t next_instance_id_;
// Used to recycle Ids in case ThreadLocalPtr is instantiated and destroyed
// frequently. This also prevents it from blowing up the vector space.
autovector<uint32_t> free_instance_ids_;
// Chain all thread local structure together. This is necessary since
// when one ThreadLocalPtr gets destroyed, we need to loop over each
// thread's version of pointer corresponding to that instance and
// call UnrefHandler for it.
ThreadData head_;
std::unordered_map<uint32_t, UnrefHandler> handler_map_;
// protect inst, next_instance_id_, free_instance_ids_, head_,
// ThreadData.entries
static port::Mutex mutex_;
#if !defined(OS_MACOSX)
// Thread local storage
static __thread ThreadData* tls_;
#endif
// Used to make thread exit trigger possible if !defined(OS_MACOSX).
// Otherwise, used to retrieve thread data.
pthread_key_t pthread_key_;
};
const uint32_t id_;
};
} // namespace rocksdb

@ -0,0 +1,456 @@
// Copyright (c) 2013, Facebook, Inc. All rights reserved.
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree. An additional grant
// of patent rights can be found in the PATENTS file in the same directory.
#include <atomic>
#include "rocksdb/env.h"
#include "port/port_posix.h"
#include "util/autovector.h"
#include "util/thread_local.h"
#include "util/testharness.h"
#include "util/testutil.h"
namespace rocksdb {
class ThreadLocalTest {
public:
ThreadLocalTest() : env_(Env::Default()) {}
Env* env_;
};
namespace {
struct Params {
Params(port::Mutex* m, port::CondVar* c, int* unref, int n,
UnrefHandler handler = nullptr)
: mu(m),
cv(c),
unref(unref),
total(n),
started(0),
completed(0),
doWrite(false),
tls1(handler),
tls2(nullptr) {}
port::Mutex* mu;
port::CondVar* cv;
int* unref;
int total;
int started;
int completed;
bool doWrite;
ThreadLocalPtr tls1;
ThreadLocalPtr* tls2;
};
class IDChecker : public ThreadLocalPtr {
public:
static uint32_t PeekId() { return StaticMeta::Instance()->PeekId(); }
};
} // anonymous namespace
TEST(ThreadLocalTest, UniqueIdTest) {
port::Mutex mu;
port::CondVar cv(&mu);
ASSERT_EQ(IDChecker::PeekId(), 0);
// New ThreadLocal instance bumps id by 1
{
// Id used 0
Params p1(&mu, &cv, nullptr, 1);
ASSERT_EQ(IDChecker::PeekId(), 1);
// Id used 1
Params p2(&mu, &cv, nullptr, 1);
ASSERT_EQ(IDChecker::PeekId(), 2);
// Id used 2
Params p3(&mu, &cv, nullptr, 1);
ASSERT_EQ(IDChecker::PeekId(), 3);
// Id used 3
Params p4(&mu, &cv, nullptr, 1);
ASSERT_EQ(IDChecker::PeekId(), 4);
}
// id 3, 2, 1, 0 are in the free queue in order
ASSERT_EQ(IDChecker::PeekId(), 0);
// pick up 0
Params p1(&mu, &cv, nullptr, 1);
ASSERT_EQ(IDChecker::PeekId(), 1);
// pick up 1
Params* p2 = new Params(&mu, &cv, nullptr, 1);
ASSERT_EQ(IDChecker::PeekId(), 2);
// pick up 2
Params p3(&mu, &cv, nullptr, 1);
ASSERT_EQ(IDChecker::PeekId(), 3);
// return up 1
delete p2;
ASSERT_EQ(IDChecker::PeekId(), 1);
// Now we have 3, 1 in queue
// pick up 1
Params p4(&mu, &cv, nullptr, 1);
ASSERT_EQ(IDChecker::PeekId(), 3);
// pick up 3
Params p5(&mu, &cv, nullptr, 1);
// next new id
ASSERT_EQ(IDChecker::PeekId(), 4);
// After exit, id sequence in queue:
// 3, 1, 2, 0
}
TEST(ThreadLocalTest, SequentialReadWriteTest) {
// global id list carries over 3, 1, 2, 0
ASSERT_EQ(IDChecker::PeekId(), 0);
port::Mutex mu;
port::CondVar cv(&mu);
Params p(&mu, &cv, nullptr, 1);
ThreadLocalPtr tls2;
p.tls2 = &tls2;
auto func = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
ASSERT_TRUE(p.tls1.Get() == nullptr);
p.tls1.Reset(reinterpret_cast<int*>(1));
ASSERT_TRUE(p.tls1.Get() == reinterpret_cast<int*>(1));
p.tls1.Reset(reinterpret_cast<int*>(2));
ASSERT_TRUE(p.tls1.Get() == reinterpret_cast<int*>(2));
ASSERT_TRUE(p.tls2->Get() == nullptr);
p.tls2->Reset(reinterpret_cast<int*>(1));
ASSERT_TRUE(p.tls2->Get() == reinterpret_cast<int*>(1));
p.tls2->Reset(reinterpret_cast<int*>(2));
ASSERT_TRUE(p.tls2->Get() == reinterpret_cast<int*>(2));
p.mu->Lock();
++(p.completed);
p.cv->SignalAll();
p.mu->Unlock();
};
for (int iter = 0; iter < 1024; ++iter) {
ASSERT_EQ(IDChecker::PeekId(), 1);
// Another new thread, read/write should not see value from previous thread
env_->StartThread(func, static_cast<void*>(&p));
mu.Lock();
while (p.completed != iter + 1) {
cv.Wait();
}
mu.Unlock();
ASSERT_EQ(IDChecker::PeekId(), 1);
}
}
TEST(ThreadLocalTest, ConcurrentReadWriteTest) {
// global id list carries over 3, 1, 2, 0
ASSERT_EQ(IDChecker::PeekId(), 0);
ThreadLocalPtr tls2;
port::Mutex mu1;
port::CondVar cv1(&mu1);
Params p1(&mu1, &cv1, nullptr, 128);
p1.tls2 = &tls2;
port::Mutex mu2;
port::CondVar cv2(&mu2);
Params p2(&mu2, &cv2, nullptr, 128);
p2.doWrite = true;
p2.tls2 = &tls2;
auto func = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
p.mu->Lock();
int own = ++(p.started);
p.cv->SignalAll();
while (p.started != p.total) {
p.cv->Wait();
}
p.mu->Unlock();
// Let write threads write a different value from the read threads
if (p.doWrite) {
own += 8192;
}
ASSERT_TRUE(p.tls1.Get() == nullptr);
ASSERT_TRUE(p.tls2->Get() == nullptr);
auto* env = Env::Default();
auto start = env->NowMicros();
p.tls1.Reset(reinterpret_cast<int*>(own));
p.tls2->Reset(reinterpret_cast<int*>(own + 1));
// Loop for 1 second
while (env->NowMicros() - start < 1000 * 1000) {
for (int iter = 0; iter < 100000; ++iter) {
ASSERT_TRUE(p.tls1.Get() == reinterpret_cast<int*>(own));
ASSERT_TRUE(p.tls2->Get() == reinterpret_cast<int*>(own + 1));
if (p.doWrite) {
p.tls1.Reset(reinterpret_cast<int*>(own));
p.tls2->Reset(reinterpret_cast<int*>(own + 1));
}
}
}
p.mu->Lock();
++(p.completed);
p.cv->SignalAll();
p.mu->Unlock();
};
// Initiate 2 instnaces: one keeps writing and one keeps reading.
// The read instance should not see data from the write instance.
// Each thread local copy of the value are also different from each
// other.
for (int th = 0; th < p1.total; ++th) {
env_->StartThread(func, static_cast<void*>(&p1));
}
for (int th = 0; th < p2.total; ++th) {
env_->StartThread(func, static_cast<void*>(&p2));
}
mu1.Lock();
while (p1.completed != p1.total) {
cv1.Wait();
}
mu1.Unlock();
mu2.Lock();
while (p2.completed != p2.total) {
cv2.Wait();
}
mu2.Unlock();
ASSERT_EQ(IDChecker::PeekId(), 3);
}
TEST(ThreadLocalTest, Unref) {
ASSERT_EQ(IDChecker::PeekId(), 0);
auto unref = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
p.mu->Lock();
++(*p.unref);
p.mu->Unlock();
};
// Case 0: no unref triggered if ThreadLocalPtr is never accessed
auto func0 = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
p.mu->Lock();
++(p.started);
p.cv->SignalAll();
while (p.started != p.total) {
p.cv->Wait();
}
p.mu->Unlock();
};
for (int th = 1; th <= 128; th += th) {
port::Mutex mu;
port::CondVar cv(&mu);
int unref_count = 0;
Params p(&mu, &cv, &unref_count, th, unref);
for (int i = 0; i < p.total; ++i) {
env_->StartThread(func0, static_cast<void*>(&p));
}
env_->WaitForJoin();
ASSERT_EQ(unref_count, 0);
}
// Case 1: unref triggered by thread exit
auto func1 = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
p.mu->Lock();
++(p.started);
p.cv->SignalAll();
while (p.started != p.total) {
p.cv->Wait();
}
p.mu->Unlock();
ASSERT_TRUE(p.tls1.Get() == nullptr);
ASSERT_TRUE(p.tls2->Get() == nullptr);
p.tls1.Reset(ptr);
p.tls2->Reset(ptr);
p.tls1.Reset(ptr);
p.tls2->Reset(ptr);
};
for (int th = 1; th <= 128; th += th) {
port::Mutex mu;
port::CondVar cv(&mu);
int unref_count = 0;
ThreadLocalPtr tls2(unref);
Params p(&mu, &cv, &unref_count, th, unref);
p.tls2 = &tls2;
for (int i = 0; i < p.total; ++i) {
env_->StartThread(func1, static_cast<void*>(&p));
}
env_->WaitForJoin();
// N threads x 2 ThreadLocal instance cleanup on thread exit
ASSERT_EQ(unref_count, 2 * p.total);
}
// Case 2: unref triggered by ThreadLocal instance destruction
auto func2 = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
p.mu->Lock();
++(p.started);
p.cv->SignalAll();
while (p.started != p.total) {
p.cv->Wait();
}
p.mu->Unlock();
ASSERT_TRUE(p.tls1.Get() == nullptr);
ASSERT_TRUE(p.tls2->Get() == nullptr);
p.tls1.Reset(ptr);
p.tls2->Reset(ptr);
p.tls1.Reset(ptr);
p.tls2->Reset(ptr);
p.mu->Lock();
++(p.completed);
p.cv->SignalAll();
// Waiting for instruction to exit thread
while (p.completed != 0) {
p.cv->Wait();
}
p.mu->Unlock();
};
for (int th = 1; th <= 128; th += th) {
port::Mutex mu;
port::CondVar cv(&mu);
int unref_count = 0;
Params p(&mu, &cv, &unref_count, th, unref);
p.tls2 = new ThreadLocalPtr(unref);
for (int i = 0; i < p.total; ++i) {
env_->StartThread(func2, static_cast<void*>(&p));
}
// Wait for all threads to finish using Params
mu.Lock();
while (p.completed != p.total) {
cv.Wait();
}
mu.Unlock();
// Now destroy one ThreadLocal instance
delete p.tls2;
p.tls2 = nullptr;
// instance destroy for N threads
ASSERT_EQ(unref_count, p.total);
// Signal to exit
mu.Lock();
p.completed = 0;
cv.SignalAll();
mu.Unlock();
env_->WaitForJoin();
// additional N threads exit unref for the left instance
ASSERT_EQ(unref_count, 2 * p.total);
}
}
TEST(ThreadLocalTest, Swap) {
ThreadLocalPtr tls;
tls.Reset(reinterpret_cast<void*>(1));
ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(nullptr)), 1);
ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(2)) == nullptr);
ASSERT_EQ(reinterpret_cast<int64_t>(tls.Get()), 2);
ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(reinterpret_cast<void*>(3))), 2);
}
TEST(ThreadLocalTest, Scrape) {
auto unref = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
p.mu->Lock();
++(*p.unref);
p.mu->Unlock();
};
auto func = [](void* ptr) {
auto& p = *static_cast<Params*>(ptr);
ASSERT_TRUE(p.tls1.Get() == nullptr);
ASSERT_TRUE(p.tls2->Get() == nullptr);
p.tls1.Reset(ptr);
p.tls2->Reset(ptr);
p.tls1.Reset(ptr);
p.tls2->Reset(ptr);
p.mu->Lock();
++(p.completed);
p.cv->SignalAll();
// Waiting for instruction to exit thread
while (p.completed != 0) {
p.cv->Wait();
}
p.mu->Unlock();
};
for (int th = 1; th <= 128; th += th) {
port::Mutex mu;
port::CondVar cv(&mu);
int unref_count = 0;
Params p(&mu, &cv, &unref_count, th, unref);
p.tls2 = new ThreadLocalPtr(unref);
for (int i = 0; i < p.total; ++i) {
env_->StartThread(func, static_cast<void*>(&p));
}
// Wait for all threads to finish using Params
mu.Lock();
while (p.completed != p.total) {
cv.Wait();
}
mu.Unlock();
ASSERT_EQ(unref_count, 0);
// Scrape all thread local data. No unref at thread
// exit or ThreadLocalPtr destruction
autovector<void*> ptrs;
p.tls1.Scrape(&ptrs);
p.tls2->Scrape(&ptrs);
delete p.tls2;
// Signal to exit
mu.Lock();
p.completed = 0;
cv.SignalAll();
mu.Unlock();
env_->WaitForJoin();
ASSERT_EQ(unref_count, 0);
}
}
} // namespace rocksdb
int main(int argc, char** argv) {
return rocksdb::test::RunAllTests();
}
Loading…
Cancel
Save