Add StartThread type checking wrapper (#8303)

Summary:
- Add class `FunctorWrapper` to invoke the function with given parameters
- Implement `StartThreadTyped` which wraps `StartThread` with type checking cover
- Demonstrate `StartThreadTyped` in test `util/thread_local_test.cc`

https://github.com/facebook/rocksdb/issues/8285

Pull Request resolved: https://github.com/facebook/rocksdb/pull/8303

Reviewed By: ajkr

Differential Revision: D28539318

Pulled By: pdillinger

fbshipit-source-id: 624789c236bde31163deda95c1e1471aee68933e
main
Glebanister 4 years ago committed by Facebook GitHub Bot
parent 13232e11d4
commit 748e3acc11
  1. 20
      include/rocksdb/env.h
  2. 55
      include/rocksdb/functor_wrapper.h
  3. 46
      util/thread_local_test.cc

@ -17,12 +17,15 @@
#pragma once #pragma once
#include <stdint.h> #include <stdint.h>
#include <cstdarg> #include <cstdarg>
#include <functional> #include <functional>
#include <limits> #include <limits>
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "rocksdb/functor_wrapper.h"
#include "rocksdb/status.h" #include "rocksdb/status.h"
#include "rocksdb/thread_status.h" #include "rocksdb/thread_status.h"
@ -35,7 +38,7 @@
#if defined(__GNUC__) || defined(__clang__) #if defined(__GNUC__) || defined(__clang__)
#define ROCKSDB_PRINTF_FORMAT_ATTR(format_param, dots_param) \ #define ROCKSDB_PRINTF_FORMAT_ATTR(format_param, dots_param) \
__attribute__((__format__(__printf__, format_param, dots_param))) __attribute__((__format__(__printf__, format_param, dots_param)))
#else #else
#define ROCKSDB_PRINTF_FORMAT_ATTR(format_param, dots_param) #define ROCKSDB_PRINTF_FORMAT_ATTR(format_param, dots_param)
#endif #endif
@ -422,6 +425,21 @@ class Env {
// When "function(arg)" returns, the thread will be destroyed. // When "function(arg)" returns, the thread will be destroyed.
virtual void StartThread(void (*function)(void* arg), void* arg) = 0; virtual void StartThread(void (*function)(void* arg), void* arg) = 0;
// Start a new thread, invoking "function(args...)" within the new thread.
// When "function(args...)" returns, the thread will be destroyed.
template <typename FunctionT, typename... Args>
void StartThreadTyped(FunctionT function, Args&&... args) {
using FWType = FunctorWrapper<Args...>;
StartThread(
[](void* arg) {
auto* functor = static_cast<FWType*>(arg);
functor->invoke();
delete functor;
},
new FWType(std::function<void(Args...)>(function),
std::forward<Args>(args)...));
}
// Wait for all threads started by StartThread to terminate. // Wait for all threads started by StartThread to terminate.
virtual void WaitForJoin() {} virtual void WaitForJoin() {}

@ -0,0 +1,55 @@
// Copyright (c) 2011-present, Facebook, Inc. All rights reserved.
// This source code is licensed under both the GPLv2 (found in the
// COPYING file in the root directory) and Apache 2.0 License
// (found in the LICENSE.Apache file in the root directory).
#pragma once
#include <memory>
#include <utility>
#include "rocksdb/rocksdb_namespace.h"
namespace ROCKSDB_NAMESPACE {
namespace detail {
template <std::size_t...>
struct IndexSequence {};
template <std::size_t N, std::size_t... Next>
struct IndexSequenceHelper
: public IndexSequenceHelper<N - 1U, N - 1U, Next...> {};
template <std::size_t... Next>
struct IndexSequenceHelper<0U, Next...> {
using type = IndexSequence<Next...>;
};
template <std::size_t N>
using make_index_sequence = typename IndexSequenceHelper<N>::type;
template <typename Function, typename Tuple, size_t... I>
void call(Function f, Tuple t, IndexSequence<I...>) {
f(std::get<I>(t)...);
}
template <typename Function, typename Tuple>
void call(Function f, Tuple t) {
static constexpr auto size = std::tuple_size<Tuple>::value;
call(f, t, make_index_sequence<size>{});
}
} // namespace detail
template <typename... Args>
class FunctorWrapper {
public:
explicit FunctorWrapper(std::function<void(Args...)> functor, Args &&...args)
: functor_(std::move(functor)), args_(std::forward<Args>(args)...) {}
void invoke() { detail::call(functor_, args_); }
private:
std::function<void(Args...)> functor_;
std::tuple<Args...> args_;
};
} // namespace ROCKSDB_NAMESPACE

@ -3,9 +3,11 @@
// COPYING file in the root directory) and Apache 2.0 License // COPYING file in the root directory) and Apache 2.0 License
// (found in the LICENSE.Apache file in the root directory). // (found in the LICENSE.Apache file in the root directory).
#include <thread> #include "util/thread_local.h"
#include <atomic> #include <atomic>
#include <string> #include <string>
#include <thread>
#include "port/port.h" #include "port/port.h"
#include "rocksdb/env.h" #include "rocksdb/env.h"
@ -13,7 +15,6 @@
#include "test_util/testharness.h" #include "test_util/testharness.h"
#include "test_util/testutil.h" #include "test_util/testutil.h"
#include "util/autovector.h" #include "util/autovector.h"
#include "util/thread_local.h"
namespace ROCKSDB_NAMESPACE { namespace ROCKSDB_NAMESPACE {
@ -51,10 +52,8 @@ struct Params {
}; };
class IDChecker : public ThreadLocalPtr { class IDChecker : public ThreadLocalPtr {
public: public:
static uint32_t PeekId() { static uint32_t PeekId() { return TEST_PeekId(); }
return TEST_PeekId();
}
}; };
} // anonymous namespace } // anonymous namespace
@ -122,9 +121,8 @@ TEST_F(ThreadLocalTest, SequentialReadWriteTest) {
ASSERT_GT(IDChecker::PeekId(), base_id); ASSERT_GT(IDChecker::PeekId(), base_id);
base_id = IDChecker::PeekId(); base_id = IDChecker::PeekId();
auto func = [](void* ptr) { auto func = [](Params* ptr) {
auto& params = *static_cast<Params*>(ptr); Params& params = *ptr;
ASSERT_TRUE(params.tls1.Get() == nullptr); ASSERT_TRUE(params.tls1.Get() == nullptr);
params.tls1.Reset(reinterpret_cast<int*>(1)); params.tls1.Reset(reinterpret_cast<int*>(1));
ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(1)); ASSERT_TRUE(params.tls1.Get() == reinterpret_cast<int*>(1));
@ -146,7 +144,8 @@ TEST_F(ThreadLocalTest, SequentialReadWriteTest) {
for (int iter = 0; iter < 1024; ++iter) { for (int iter = 0; iter < 1024; ++iter) {
ASSERT_EQ(IDChecker::PeekId(), base_id); ASSERT_EQ(IDChecker::PeekId(), base_id);
// Another new thread, read/write should not see value from previous thread // Another new thread, read/write should not see value from previous thread
env_->StartThread(func, static_cast<void*>(&p)); env_->StartThreadTyped(func, &p);
mu.Lock(); mu.Lock();
while (p.completed != iter + 1) { while (p.completed != iter + 1) {
cv.Wait(); cv.Wait();
@ -221,10 +220,10 @@ TEST_F(ThreadLocalTest, ConcurrentReadWriteTest) {
// Each thread local copy of the value are also different from each // Each thread local copy of the value are also different from each
// other. // other.
for (int th = 0; th < p1.total; ++th) { for (int th = 0; th < p1.total; ++th) {
env_->StartThread(func, static_cast<void*>(&p1)); env_->StartThreadTyped(func, &p1);
} }
for (int th = 0; th < p2.total; ++th) { for (int th = 0; th < p2.total; ++th) {
env_->StartThread(func, static_cast<void*>(&p2)); env_->StartThreadTyped(func, &p2);
} }
mu1.Lock(); mu1.Lock();
@ -251,9 +250,8 @@ TEST_F(ThreadLocalTest, Unref) {
}; };
// Case 0: no unref triggered if ThreadLocalPtr is never accessed // Case 0: no unref triggered if ThreadLocalPtr is never accessed
auto func0 = [](void* ptr) { auto func0 = [](Params* ptr) {
auto& p = *static_cast<Params*>(ptr); auto& p = *ptr;
p.mu->Lock(); p.mu->Lock();
++(p.started); ++(p.started);
p.cv->SignalAll(); p.cv->SignalAll();
@ -270,15 +268,15 @@ TEST_F(ThreadLocalTest, Unref) {
Params p(&mu, &cv, &unref_count, th, unref); Params p(&mu, &cv, &unref_count, th, unref);
for (int i = 0; i < p.total; ++i) { for (int i = 0; i < p.total; ++i) {
env_->StartThread(func0, static_cast<void*>(&p)); env_->StartThreadTyped(func0, &p);
} }
env_->WaitForJoin(); env_->WaitForJoin();
ASSERT_EQ(unref_count, 0); ASSERT_EQ(unref_count, 0);
} }
// Case 1: unref triggered by thread exit // Case 1: unref triggered by thread exit
auto func1 = [](void* ptr) { auto func1 = [](Params* ptr) {
auto& p = *static_cast<Params*>(ptr); auto& p = *ptr;
p.mu->Lock(); p.mu->Lock();
++(p.started); ++(p.started);
@ -307,7 +305,7 @@ TEST_F(ThreadLocalTest, Unref) {
p.tls2 = &tls2; p.tls2 = &tls2;
for (int i = 0; i < p.total; ++i) { for (int i = 0; i < p.total; ++i) {
env_->StartThread(func1, static_cast<void*>(&p)); env_->StartThreadTyped(func1, &p);
} }
env_->WaitForJoin(); env_->WaitForJoin();
@ -317,8 +315,8 @@ TEST_F(ThreadLocalTest, Unref) {
} }
// Case 2: unref triggered by ThreadLocal instance destruction // Case 2: unref triggered by ThreadLocal instance destruction
auto func2 = [](void* ptr) { auto func2 = [](Params* ptr) {
auto& p = *static_cast<Params*>(ptr); auto& p = *ptr;
p.mu->Lock(); p.mu->Lock();
++(p.started); ++(p.started);
@ -356,7 +354,7 @@ TEST_F(ThreadLocalTest, Unref) {
p.tls2 = new ThreadLocalPtr(unref); p.tls2 = new ThreadLocalPtr(unref);
for (int i = 0; i < p.total; ++i) { for (int i = 0; i < p.total; ++i) {
env_->StartThread(func2, static_cast<void*>(&p)); env_->StartThreadTyped(func2, &p);
} }
// Wait for all threads to finish using Params // Wait for all threads to finish using Params
@ -431,7 +429,7 @@ TEST_F(ThreadLocalTest, Scrape) {
p.tls2 = new ThreadLocalPtr(unref); p.tls2 = new ThreadLocalPtr(unref);
for (int i = 0; i < p.total; ++i) { for (int i = 0; i < p.total; ++i) {
env_->StartThread(func, static_cast<void*>(&p)); env_->StartThreadTyped(func, &p);
} }
// Wait for all threads to finish using Params // Wait for all threads to finish using Params
@ -490,7 +488,7 @@ TEST_F(ThreadLocalTest, Fold) {
}; };
for (int th = 0; th < params.total; ++th) { for (int th = 0; th < params.total; ++th) {
env_->StartThread(func, static_cast<void*>(&params)); env_->StartThread(func, &params);
} }
// Wait for all threads to finish using Params // Wait for all threads to finish using Params

Loading…
Cancel
Save