// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. An additional grant // of patent rights can be found in the PATENTS file in the same directory. #include <thread> #include <atomic> #include <string> #include "rocksdb/env.h" #include "port/port.h" #include "util/autovector.h" #include "util/sync_point.h" #include "util/testharness.h" #include "util/testutil.h" #include "util/thread_local.h" namespace rocksdb { class ThreadLocalTest : public testing::Test { public: ThreadLocalTest() : env_(Env::Default()) {} Env* env_; }; namespace { struct Params { Params(port::Mutex* m, port::CondVar* c, int* u, int n, UnrefHandler handler = nullptr) : mu(m), cv(c), unref(u), total(n), started(0), completed(0), doWrite(false), tls1(handler), tls2(nullptr) {} port::Mutex* mu; port::CondVar* cv; int* unref; int total; int started; int completed; bool doWrite; ThreadLocalPtr tls1; ThreadLocalPtr* tls2; }; class IDChecker : public ThreadLocalPtr { public: static uint32_t PeekId() { return Instance()->PeekId(); } }; } // anonymous namespace TEST_F(ThreadLocalTest, UniqueIdTest) { port::Mutex mu; port::CondVar cv(&mu); ASSERT_EQ(IDChecker::PeekId(), 0u); // New ThreadLocal instance bumps id by 1 { // Id used 0 Params p1(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), 1u); // Id used 1 Params p2(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), 2u); // Id used 2 Params p3(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), 3u); // Id used 3 Params p4(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), 4u); } // id 3, 2, 1, 0 are in the free queue in order ASSERT_EQ(IDChecker::PeekId(), 0u); // pick up 0 Params p1(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), 1u); // pick up 1 Params* p2 = new Params(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), 2u); // pick up 2 Params p3(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), 3u); // return up 1 delete p2; ASSERT_EQ(IDChecker::PeekId(), 1u); // Now we have 3, 1 in queue // pick up 1 Params p4(&mu, &cv, nullptr, 1u); ASSERT_EQ(IDChecker::PeekId(), 3u); // pick up 3 Params p5(&mu, &cv, nullptr, 1u); // next new id ASSERT_EQ(IDChecker::PeekId(), 4u); // After exit, id sequence in queue: // 3, 1, 2, 0 } TEST_F(ThreadLocalTest, SequentialReadWriteTest) { // global id list carries over 3, 1, 2, 0 ASSERT_EQ(IDChecker::PeekId(), 0u); port::Mutex mu; port::CondVar cv(&mu); Params p(&mu, &cv, nullptr, 1); ThreadLocalPtr tls2; p.tls2 = &tls2; auto func = [](void* ptr) { auto& params = *static_cast<Params*>(ptr); ASSERT_TRUE(params.tls1.Get() == nullptr); params.tls1.Reset(reinterpret_cast<int*>(1)); ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(1)); params.tls1.Reset(reinterpret_cast<int*>(2)); ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(2)); ASSERT_TRUE(params.tls2->Get() == nullptr); params.tls2->Reset(reinterpret_cast<int*>(1)); ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(1)); params.tls2->Reset(reinterpret_cast<int*>(2)); ASSERT_TRUE(params.tls2->Get() == reinterpret_cast<int*>(2)); params.mu->Lock(); ++(params.completed); params.cv->SignalAll(); params.mu->Unlock(); }; for (int iter = 0; iter < 1024; ++iter) { ASSERT_EQ(IDChecker::PeekId(), 1u); // Another new thread, read/write should not see value from previous thread env_->StartThread(func, static_cast<void*>(&p)); mu.Lock(); while (p.completed != iter + 1) { cv.Wait(); } mu.Unlock(); ASSERT_EQ(IDChecker::PeekId(), 1u); } } TEST_F(ThreadLocalTest, ConcurrentReadWriteTest) { // global id list carries over 3, 1, 2, 0 ASSERT_EQ(IDChecker::PeekId(), 0u); ThreadLocalPtr tls2; port::Mutex mu1; port::CondVar cv1(&mu1); Params p1(&mu1, &cv1, nullptr, 16); p1.tls2 = &tls2; port::Mutex mu2; port::CondVar cv2(&mu2); Params p2(&mu2, &cv2, nullptr, 16); p2.doWrite = true; p2.tls2 = &tls2; auto func = [](void* ptr) { auto& p = *static_cast<Params*>(ptr); p.mu->Lock(); // Size_T switches size along with the ptr size // we want to cast to. size_t own = ++(p.started); p.cv->SignalAll(); while (p.started != p.total) { p.cv->Wait(); } p.mu->Unlock(); // Let write threads write a different value from the read threads if (p.doWrite) { own += 8192; } ASSERT_TRUE(p.tls1.Get() == nullptr); ASSERT_TRUE(p.tls2->Get() == nullptr); auto* env = Env::Default(); auto start = env->NowMicros(); p.tls1.Reset(reinterpret_cast<size_t*>(own)); p.tls2->Reset(reinterpret_cast<size_t*>(own + 1)); // Loop for 1 second while (env->NowMicros() - start < 1000 * 1000) { for (int iter = 0; iter < 100000; ++iter) { ASSERT_TRUE(p.tls1.Get() == reinterpret_cast<size_t*>(own)); ASSERT_TRUE(p.tls2->Get() == reinterpret_cast<size_t*>(own + 1)); if (p.doWrite) { p.tls1.Reset(reinterpret_cast<size_t*>(own)); p.tls2->Reset(reinterpret_cast<size_t*>(own + 1)); } } } p.mu->Lock(); ++(p.completed); p.cv->SignalAll(); p.mu->Unlock(); }; // Initiate 2 instnaces: one keeps writing and one keeps reading. // The read instance should not see data from the write instance. // Each thread local copy of the value are also different from each // other. for (int th = 0; th < p1.total; ++th) { env_->StartThread(func, static_cast<void*>(&p1)); } for (int th = 0; th < p2.total; ++th) { env_->StartThread(func, static_cast<void*>(&p2)); } mu1.Lock(); while (p1.completed != p1.total) { cv1.Wait(); } mu1.Unlock(); mu2.Lock(); while (p2.completed != p2.total) { cv2.Wait(); } mu2.Unlock(); ASSERT_EQ(IDChecker::PeekId(), 3u); } TEST_F(ThreadLocalTest, Unref) { ASSERT_EQ(IDChecker::PeekId(), 0u); auto unref = [](void* ptr) { auto& p = *static_cast<Params*>(ptr); p.mu->Lock(); ++(*p.unref); p.mu->Unlock(); }; // Case 0: no unref triggered if ThreadLocalPtr is never accessed auto func0 = [](void* ptr) { auto& p = *static_cast<Params*>(ptr); p.mu->Lock(); ++(p.started); p.cv->SignalAll(); while (p.started != p.total) { p.cv->Wait(); } p.mu->Unlock(); }; for (int th = 1; th <= 128; th += th) { port::Mutex mu; port::CondVar cv(&mu); int unref_count = 0; Params p(&mu, &cv, &unref_count, th, unref); for (int i = 0; i < p.total; ++i) { env_->StartThread(func0, static_cast<void*>(&p)); } env_->WaitForJoin(); ASSERT_EQ(unref_count, 0); } // Case 1: unref triggered by thread exit auto func1 = [](void* ptr) { auto& p = *static_cast<Params*>(ptr); p.mu->Lock(); ++(p.started); p.cv->SignalAll(); while (p.started != p.total) { p.cv->Wait(); } p.mu->Unlock(); ASSERT_TRUE(p.tls1.Get() == nullptr); ASSERT_TRUE(p.tls2->Get() == nullptr); p.tls1.Reset(ptr); p.tls2->Reset(ptr); p.tls1.Reset(ptr); p.tls2->Reset(ptr); }; for (int th = 1; th <= 128; th += th) { port::Mutex mu; port::CondVar cv(&mu); int unref_count = 0; ThreadLocalPtr tls2(unref); Params p(&mu, &cv, &unref_count, th, unref); p.tls2 = &tls2; for (int i = 0; i < p.total; ++i) { env_->StartThread(func1, static_cast<void*>(&p)); } env_->WaitForJoin(); // N threads x 2 ThreadLocal instance cleanup on thread exit ASSERT_EQ(unref_count, 2 * p.total); } // Case 2: unref triggered by ThreadLocal instance destruction auto func2 = [](void* ptr) { auto& p = *static_cast<Params*>(ptr); p.mu->Lock(); ++(p.started); p.cv->SignalAll(); while (p.started != p.total) { p.cv->Wait(); } p.mu->Unlock(); ASSERT_TRUE(p.tls1.Get() == nullptr); ASSERT_TRUE(p.tls2->Get() == nullptr); p.tls1.Reset(ptr); p.tls2->Reset(ptr); p.tls1.Reset(ptr); p.tls2->Reset(ptr); p.mu->Lock(); ++(p.completed); p.cv->SignalAll(); // Waiting for instruction to exit thread while (p.completed != 0) { p.cv->Wait(); } p.mu->Unlock(); }; for (int th = 1; th <= 128; th += th) { port::Mutex mu; port::CondVar cv(&mu); int unref_count = 0; Params p(&mu, &cv, &unref_count, th, unref); p.tls2 = new ThreadLocalPtr(unref); for (int i = 0; i < p.total; ++i) { env_->StartThread(func2, static_cast<void*>(&p)); } // Wait for all threads to finish using Params mu.Lock(); while (p.completed != p.total) { cv.Wait(); } mu.Unlock(); // Now destroy one ThreadLocal instance delete p.tls2; p.tls2 = nullptr; // instance destroy for N threads ASSERT_EQ(unref_count, p.total); // Signal to exit mu.Lock(); p.completed = 0; cv.SignalAll(); mu.Unlock(); env_->WaitForJoin(); // additional N threads exit unref for the left instance ASSERT_EQ(unref_count, 2 * p.total); } } TEST_F(ThreadLocalTest, Swap) { ThreadLocalPtr tls; tls.Reset(reinterpret_cast<void*>(1)); ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(nullptr)), 1); ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(2)) == nullptr); ASSERT_EQ(reinterpret_cast<int64_t>(tls.Get()), 2); ASSERT_EQ(reinterpret_cast<int64_t>(tls.Swap(reinterpret_cast<void*>(3))), 2); } TEST_F(ThreadLocalTest, Scrape) { auto unref = [](void* ptr) { auto& p = *static_cast<Params*>(ptr); p.mu->Lock(); ++(*p.unref); p.mu->Unlock(); }; auto func = [](void* ptr) { auto& p = *static_cast<Params*>(ptr); ASSERT_TRUE(p.tls1.Get() == nullptr); ASSERT_TRUE(p.tls2->Get() == nullptr); p.tls1.Reset(ptr); p.tls2->Reset(ptr); p.tls1.Reset(ptr); p.tls2->Reset(ptr); p.mu->Lock(); ++(p.completed); p.cv->SignalAll(); // Waiting for instruction to exit thread while (p.completed != 0) { p.cv->Wait(); } p.mu->Unlock(); }; for (int th = 1; th <= 128; th += th) { port::Mutex mu; port::CondVar cv(&mu); int unref_count = 0; Params p(&mu, &cv, &unref_count, th, unref); p.tls2 = new ThreadLocalPtr(unref); for (int i = 0; i < p.total; ++i) { env_->StartThread(func, static_cast<void*>(&p)); } // Wait for all threads to finish using Params mu.Lock(); while (p.completed != p.total) { cv.Wait(); } mu.Unlock(); ASSERT_EQ(unref_count, 0); // Scrape all thread local data. No unref at thread // exit or ThreadLocalPtr destruction autovector<void*> ptrs; p.tls1.Scrape(&ptrs, nullptr); p.tls2->Scrape(&ptrs, nullptr); delete p.tls2; // Signal to exit mu.Lock(); p.completed = 0; cv.SignalAll(); mu.Unlock(); env_->WaitForJoin(); ASSERT_EQ(unref_count, 0); } } TEST_F(ThreadLocalTest, CompareAndSwap) { ThreadLocalPtr tls; ASSERT_TRUE(tls.Swap(reinterpret_cast<void*>(1)) == nullptr); void* expected = reinterpret_cast<void*>(1); // Swap in 2 ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(2), expected)); expected = reinterpret_cast<void*>(100); // Fail Swap, still 2 ASSERT_TRUE(!tls.CompareAndSwap(reinterpret_cast<void*>(2), expected)); ASSERT_EQ(expected, reinterpret_cast<void*>(2)); // Swap in 3 expected = reinterpret_cast<void*>(2); ASSERT_TRUE(tls.CompareAndSwap(reinterpret_cast<void*>(3), expected)); ASSERT_EQ(tls.Get(), reinterpret_cast<void*>(3)); } namespace { void* AccessThreadLocal(void* arg) { TEST_SYNC_POINT("AccessThreadLocal:Start"); ThreadLocalPtr tlp; tlp.Reset(new std::string("hello RocksDB")); TEST_SYNC_POINT("AccessThreadLocal:End"); return nullptr; } } // namespace // The following test is disabled as it requires manual steps to run it // correctly. // // Currently we have no way to acess SyncPoint w/o ASAN error when the // child thread dies after the main thread dies. So if you manually enable // this test and only see an ASAN error on SyncPoint, it means you pass the // test. TEST_F(ThreadLocalTest, DISABLED_MainThreadDiesFirst) { rocksdb::SyncPoint::GetInstance()->LoadDependency( {{"AccessThreadLocal:Start", "MainThreadDiesFirst:End"}, {"PosixEnv::~PosixEnv():End", "AccessThreadLocal:End"}}); // Triggers the initialization of singletons. Env::Default(); #ifndef ROCKSDB_LITE try { #endif // ROCKSDB_LITE std::thread th(&AccessThreadLocal, nullptr); th.detach(); TEST_SYNC_POINT("MainThreadDiesFirst:End"); #ifndef ROCKSDB_LITE } catch (const std::system_error& ex) { std::cerr << "Start thread: " << ex.code() << std::endl; ASSERT_TRUE(false); } #endif // ROCKSDB_LITE } } // namespace rocksdb int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); }