From 748e3acc11b65f0703b1f991f2eabc48322305cb Mon Sep 17 00:00:00 2001 From: Glebanister Date: Wed, 19 May 2021 16:45:40 -0700 Subject: [PATCH] Add StartThread type checking wrapper (#8303) Summary: - Add class `FunctorWrapper` to invoke the function with given parameters - Implement `StartThreadTyped` which wraps `StartThread` with type checking cover - Demonstrate `StartThreadTyped` in test `util/thread_local_test.cc` https://github.com/facebook/rocksdb/issues/8285 Pull Request resolved: https://github.com/facebook/rocksdb/pull/8303 Reviewed By: ajkr Differential Revision: D28539318 Pulled By: pdillinger fbshipit-source-id: 624789c236bde31163deda95c1e1471aee68933e --- include/rocksdb/env.h | 20 ++++++++++- include/rocksdb/functor_wrapper.h | 55 +++++++++++++++++++++++++++++++ util/thread_local_test.cc | 46 +++++++++++++------------- 3 files changed, 96 insertions(+), 25 deletions(-) create mode 100644 include/rocksdb/functor_wrapper.h diff --git a/include/rocksdb/env.h b/include/rocksdb/env.h index fc776bc71..919a41ed7 100644 --- a/include/rocksdb/env.h +++ b/include/rocksdb/env.h @@ -17,12 +17,15 @@ #pragma once #include + #include #include #include #include #include #include + +#include "rocksdb/functor_wrapper.h" #include "rocksdb/status.h" #include "rocksdb/thread_status.h" @@ -35,7 +38,7 @@ #if defined(__GNUC__) || defined(__clang__) #define ROCKSDB_PRINTF_FORMAT_ATTR(format_param, dots_param) \ - __attribute__((__format__(__printf__, format_param, dots_param))) + __attribute__((__format__(__printf__, format_param, dots_param))) #else #define ROCKSDB_PRINTF_FORMAT_ATTR(format_param, dots_param) #endif @@ -422,6 +425,21 @@ class Env { // When "function(arg)" returns, the thread will be destroyed. virtual void StartThread(void (*function)(void* arg), void* arg) = 0; + // Start a new thread, invoking "function(args...)" within the new thread. + // When "function(args...)" returns, the thread will be destroyed. + template + void StartThreadTyped(FunctionT function, Args&&... args) { + using FWType = FunctorWrapper; + StartThread( + [](void* arg) { + auto* functor = static_cast(arg); + functor->invoke(); + delete functor; + }, + new FWType(std::function(function), + std::forward(args)...)); + } + // Wait for all threads started by StartThread to terminate. virtual void WaitForJoin() {} diff --git a/include/rocksdb/functor_wrapper.h b/include/rocksdb/functor_wrapper.h new file mode 100644 index 000000000..c5f7414b1 --- /dev/null +++ b/include/rocksdb/functor_wrapper.h @@ -0,0 +1,55 @@ +// Copyright (c) 2011-present, Facebook, Inc. All rights reserved. +// This source code is licensed under both the GPLv2 (found in the +// COPYING file in the root directory) and Apache 2.0 License +// (found in the LICENSE.Apache file in the root directory). + +#pragma once + +#include +#include + +#include "rocksdb/rocksdb_namespace.h" + +namespace ROCKSDB_NAMESPACE { + +namespace detail { +template +struct IndexSequence {}; + +template +struct IndexSequenceHelper + : public IndexSequenceHelper {}; + +template +struct IndexSequenceHelper<0U, Next...> { + using type = IndexSequence; +}; + +template +using make_index_sequence = typename IndexSequenceHelper::type; + +template +void call(Function f, Tuple t, IndexSequence) { + f(std::get(t)...); +} + +template +void call(Function f, Tuple t) { + static constexpr auto size = std::tuple_size::value; + call(f, t, make_index_sequence{}); +} +} // namespace detail + +template +class FunctorWrapper { + public: + explicit FunctorWrapper(std::function functor, Args &&...args) + : functor_(std::move(functor)), args_(std::forward(args)...) {} + + void invoke() { detail::call(functor_, args_); } + + private: + std::function functor_; + std::tuple args_; +}; +} // namespace ROCKSDB_NAMESPACE diff --git a/util/thread_local_test.cc b/util/thread_local_test.cc index e719c7daa..7baab2fde 100644 --- a/util/thread_local_test.cc +++ b/util/thread_local_test.cc @@ -3,9 +3,11 @@ // COPYING file in the root directory) and Apache 2.0 License // (found in the LICENSE.Apache file in the root directory). -#include +#include "util/thread_local.h" + #include #include +#include #include "port/port.h" #include "rocksdb/env.h" @@ -13,7 +15,6 @@ #include "test_util/testharness.h" #include "test_util/testutil.h" #include "util/autovector.h" -#include "util/thread_local.h" namespace ROCKSDB_NAMESPACE { @@ -51,10 +52,8 @@ struct Params { }; class IDChecker : public ThreadLocalPtr { -public: - static uint32_t PeekId() { - return TEST_PeekId(); - } + public: + static uint32_t PeekId() { return TEST_PeekId(); } }; } // anonymous namespace @@ -122,9 +121,8 @@ TEST_F(ThreadLocalTest, SequentialReadWriteTest) { ASSERT_GT(IDChecker::PeekId(), base_id); base_id = IDChecker::PeekId(); - auto func = [](void* ptr) { - auto& params = *static_cast(ptr); - + auto func = [](Params* ptr) { + Params& params = *ptr; ASSERT_TRUE(params.tls1.Get() == nullptr); params.tls1.Reset(reinterpret_cast(1)); ASSERT_TRUE(params.tls1.Get() == reinterpret_cast(1)); @@ -146,7 +144,8 @@ TEST_F(ThreadLocalTest, SequentialReadWriteTest) { for (int iter = 0; iter < 1024; ++iter) { ASSERT_EQ(IDChecker::PeekId(), base_id); // Another new thread, read/write should not see value from previous thread - env_->StartThread(func, static_cast(&p)); + env_->StartThreadTyped(func, &p); + mu.Lock(); while (p.completed != iter + 1) { cv.Wait(); @@ -221,10 +220,10 @@ TEST_F(ThreadLocalTest, ConcurrentReadWriteTest) { // Each thread local copy of the value are also different from each // other. for (int th = 0; th < p1.total; ++th) { - env_->StartThread(func, static_cast(&p1)); + env_->StartThreadTyped(func, &p1); } for (int th = 0; th < p2.total; ++th) { - env_->StartThread(func, static_cast(&p2)); + env_->StartThreadTyped(func, &p2); } mu1.Lock(); @@ -251,9 +250,8 @@ TEST_F(ThreadLocalTest, Unref) { }; // Case 0: no unref triggered if ThreadLocalPtr is never accessed - auto func0 = [](void* ptr) { - auto& p = *static_cast(ptr); - + auto func0 = [](Params* ptr) { + auto& p = *ptr; p.mu->Lock(); ++(p.started); p.cv->SignalAll(); @@ -270,15 +268,15 @@ TEST_F(ThreadLocalTest, Unref) { Params p(&mu, &cv, &unref_count, th, unref); for (int i = 0; i < p.total; ++i) { - env_->StartThread(func0, static_cast(&p)); + env_->StartThreadTyped(func0, &p); } env_->WaitForJoin(); ASSERT_EQ(unref_count, 0); } // Case 1: unref triggered by thread exit - auto func1 = [](void* ptr) { - auto& p = *static_cast(ptr); + auto func1 = [](Params* ptr) { + auto& p = *ptr; p.mu->Lock(); ++(p.started); @@ -307,7 +305,7 @@ TEST_F(ThreadLocalTest, Unref) { p.tls2 = &tls2; for (int i = 0; i < p.total; ++i) { - env_->StartThread(func1, static_cast(&p)); + env_->StartThreadTyped(func1, &p); } env_->WaitForJoin(); @@ -317,8 +315,8 @@ TEST_F(ThreadLocalTest, Unref) { } // Case 2: unref triggered by ThreadLocal instance destruction - auto func2 = [](void* ptr) { - auto& p = *static_cast(ptr); + auto func2 = [](Params* ptr) { + auto& p = *ptr; p.mu->Lock(); ++(p.started); @@ -356,7 +354,7 @@ TEST_F(ThreadLocalTest, Unref) { p.tls2 = new ThreadLocalPtr(unref); for (int i = 0; i < p.total; ++i) { - env_->StartThread(func2, static_cast(&p)); + env_->StartThreadTyped(func2, &p); } // Wait for all threads to finish using Params @@ -431,7 +429,7 @@ TEST_F(ThreadLocalTest, Scrape) { p.tls2 = new ThreadLocalPtr(unref); for (int i = 0; i < p.total; ++i) { - env_->StartThread(func, static_cast(&p)); + env_->StartThreadTyped(func, &p); } // Wait for all threads to finish using Params @@ -490,7 +488,7 @@ TEST_F(ThreadLocalTest, Fold) { }; for (int th = 0; th < params.total; ++th) { - env_->StartThread(func, static_cast(¶ms)); + env_->StartThread(func, ¶ms); } // Wait for all threads to finish using Params