diff --git a/util/timer.h b/util/timer.h index 75479d43d..7a38c3260 100644 --- a/util/timer.h +++ b/util/timer.h @@ -35,12 +35,12 @@ namespace ROCKSDB_NAMESPACE { // A map from a function name to the function keeps track of all the functions. class Timer { public: - Timer(Env* env) + explicit Timer(Env* env) : env_(env), mutex_(env), cond_var_(&mutex_), - running_(false) { - } + running_(false), + executing_task_(false) {} ~Timer() {} @@ -64,10 +64,22 @@ class Timer { void Cancel(const std::string& fn_name) { InstrumentedMutexLock l(&mutex_); + // Mark the function with fn_name as invalid so that it will not be + // requeued. auto it = map_.find(fn_name); - if (it != map_.end()) { - if (it->second) { - it->second->Cancel(); + if (it != map_.end() && it->second) { + it->second->Cancel(); + } + + // If the currently running function is fn_name, then we need to wait + // until it finishes before returning to caller. + while (!heap_.empty() && executing_task_) { + FunctionInfo* func_info = heap_.top(); + assert(func_info); + if (func_info->name == fn_name) { + WaitForTaskCompleteIfNecessary(); + } else { + break; } } } @@ -84,8 +96,8 @@ class Timer { return false; } - thread_.reset(new port::Thread(&Timer::Run, this)); running_ = true; + thread_.reset(new port::Thread(&Timer::Run, this)); return true; } @@ -96,8 +108,8 @@ class Timer { if (!running_) { return false; } - CancelAllWithLock(); running_ = false; + CancelAllWithLock(); cond_var_.SignalAll(); } @@ -121,6 +133,7 @@ class Timer { } FunctionInfo* current_fn = heap_.top(); + assert(current_fn); if (!current_fn->IsValid()) { heap_.pop(); @@ -129,8 +142,13 @@ class Timer { } if (current_fn->next_run_time_us <= env_->NowMicros()) { + executing_task_ = true; + mutex_.Unlock(); // Execute the work current_fn->fn(); + mutex_.Lock(); + executing_task_ = false; + cond_var_.SignalAll(); // Remove the work from the heap once it is done executing. // Note that we are just removing the pointer from the heap. Its @@ -138,7 +156,9 @@ class Timer { // So current_fn is still a valid ptr. heap_.pop(); - if (current_fn->repeat_every_us > 0) { + // current_fn may be cancelled already. + if (current_fn->IsValid() && current_fn->repeat_every_us > 0) { + assert(running_); current_fn->next_run_time_us = env_->NowMicros() + current_fn->repeat_every_us; @@ -152,14 +172,25 @@ class Timer { } void CancelAllWithLock() { + mutex_.AssertHeld(); if (map_.empty() && heap_.empty()) { return; } + // With mutex_ held, set all tasks to invalid so that they will not be + // re-queued. + for (auto& elem : map_) { + auto& func_info = elem.second; + assert(func_info); + func_info->Cancel(); + } + + // WaitForTaskCompleteIfNecessary() may release mutex_ + WaitForTaskCompleteIfNecessary(); + while (!heap_.empty()) { heap_.pop(); } - map_.clear(); } @@ -179,25 +210,29 @@ class Timer { // calls `Cancel()`. bool valid; - FunctionInfo(std::function&& _fn, - const std::string& _name, - const uint64_t _next_run_time_us, - uint64_t _repeat_every_us) - : fn(std::move(_fn)), - name(_name), - next_run_time_us(_next_run_time_us), - repeat_every_us(_repeat_every_us), - valid(true) {} + FunctionInfo(std::function&& _fn, const std::string& _name, + const uint64_t _next_run_time_us, uint64_t _repeat_every_us) + : fn(std::move(_fn)), + name(_name), + next_run_time_us(_next_run_time_us), + repeat_every_us(_repeat_every_us), + valid(true) {} void Cancel() { valid = false; } - bool IsValid() { - return valid; - } + bool IsValid() const { return valid; } }; + void WaitForTaskCompleteIfNecessary() { + mutex_.AssertHeld(); + while (executing_task_) { + TEST_SYNC_POINT("Timer::WaitForTaskCompleteIfNecessary:TaskExecuting"); + cond_var_.Wait(); + } + } + struct RunTimeOrder { bool operator()(const FunctionInfo* f1, const FunctionInfo* f2) { @@ -212,7 +247,7 @@ class Timer { InstrumentedCondVar cond_var_; std::unique_ptr thread_; bool running_; - + bool executing_task_; std::priority_queue, diff --git a/util/timer_test.cc b/util/timer_test.cc index 56135935d..0200cf69b 100644 --- a/util/timer_test.cc +++ b/util/timer_test.cc @@ -283,6 +283,81 @@ TEST_F(TimerTest, AddAfterStartTest) { ASSERT_EQ(kIterations, count); } +TEST_F(TimerTest, CancelRunningTask) { + constexpr char kTestFuncName[] = "test_func"; + mock_env_->set_current_time(0); + Timer timer(mock_env_.get()); + ASSERT_TRUE(timer.Start()); + int* value = new int; + ASSERT_NE(nullptr, value); // make linter happy + *value = 0; + SyncPoint::GetInstance()->DisableProcessing(); + SyncPoint::GetInstance()->LoadDependency({ + {"TimerTest::CancelRunningTask:test_func:0", + "TimerTest::CancelRunningTask:BeforeCancel"}, + {"Timer::WaitForTaskCompleteIfNecessary:TaskExecuting", + "TimerTest::CancelRunningTask:test_func:1"}, + }); + SyncPoint::GetInstance()->EnableProcessing(); + timer.Add( + [&]() { + *value = 1; + TEST_SYNC_POINT("TimerTest::CancelRunningTask:test_func:0"); + TEST_SYNC_POINT("TimerTest::CancelRunningTask:test_func:1"); + }, + kTestFuncName, 0, 1 * kSecond); + port::Thread control_thr([&]() { + TEST_SYNC_POINT("TimerTest::CancelRunningTask:BeforeCancel"); + timer.Cancel(kTestFuncName); + // Verify that *value has been set to 1. + ASSERT_EQ(1, *value); + delete value; + value = nullptr; + }); + mock_env_->set_current_time(1); + control_thr.join(); + ASSERT_TRUE(timer.Shutdown()); +} + +TEST_F(TimerTest, ShutdownRunningTask) { + constexpr char kTestFunc1Name[] = "test_func1"; + constexpr char kTestFunc2Name[] = "test_func2"; + mock_env_->set_current_time(0); + Timer timer(mock_env_.get()); + + SyncPoint::GetInstance()->DisableProcessing(); + SyncPoint::GetInstance()->LoadDependency({ + {"TimerTest::ShutdownRunningTest:test_func:0", + "TimerTest::ShutdownRunningTest:BeforeShutdown"}, + {"Timer::WaitForTaskCompleteIfNecessary:TaskExecuting", + "TimerTest::ShutdownRunningTest:test_func:1"}, + }); + SyncPoint::GetInstance()->EnableProcessing(); + + ASSERT_TRUE(timer.Start()); + + int* value = new int; + ASSERT_NE(nullptr, value); + *value = 0; + timer.Add( + [&]() { + TEST_SYNC_POINT("TimerTest::ShutdownRunningTest:test_func:0"); + *value = 1; + TEST_SYNC_POINT("TimerTest::ShutdownRunningTest:test_func:1"); + }, + kTestFunc1Name, 0, 1 * kSecond); + + timer.Add([&]() { ++(*value); }, kTestFunc2Name, 0, 1 * kSecond); + + port::Thread control_thr([&]() { + TEST_SYNC_POINT("TimerTest::ShutdownRunningTest:BeforeShutdown"); + timer.Shutdown(); + }); + mock_env_->set_current_time(1); + control_thr.join(); + delete value; +} + } // namespace ROCKSDB_NAMESPACE int main(int argc, char** argv) {