@ -16,7 +16,6 @@
# include "rocksdb/db.h"
# include "rocksdb/write_batch.h"
# include "port/port.h"
# include "util/logging.h"
# include "util/random.h"
# include "util/sync_point.h"
# include "util/testharness.h"
@ -107,6 +106,10 @@ TEST_F(WriteCallbackTest, WriteWithCallbackTest) {
std : : vector < std : : pair < string , string > > kvs_ ;
} ;
// In each scenario we'll launch multiple threads to write.
// The size of each array equals to number of threads, and
// each boolean in it denote whether callback of corresponding
// thread should succeed or fail.
std : : vector < std : : vector < WriteOP > > write_scenarios = {
{ true } ,
{ false } ,
@ -145,23 +148,37 @@ TEST_F(WriteCallbackTest, WriteWithCallbackTest) {
db_impl = dynamic_cast < DBImpl * > ( db ) ;
ASSERT_TRUE ( db_impl ) ;
std : : atomic < uint64_t > threads_waiting ( 0 ) ;
// Writers that have called JoinBatchGroup.
std : : atomic < uint64_t > threads_joining ( 0 ) ;
// Writers that have linked to the queue
std : : atomic < uint64_t > threads_linked ( 0 ) ;
// Writers that pass WriteThread::JoinBatchGroup:Wait sync-point.
std : : atomic < uint64_t > threads_verified ( 0 ) ;
std : : atomic < uint64_t > seq ( db_impl - > GetLatestSequenceNumber ( ) ) ;
ASSERT_EQ ( db_impl - > GetLatestSequenceNumber ( ) , 0 ) ;
rocksdb : : SyncPoint : : GetInstance ( ) - > SetCallBack (
" WriteThread::JoinBatchGroup:Start " , [ & ] ( void * ) {
uint64_t cur_threads_joining = threads_joining . fetch_add ( 1 ) ;
// Wait for the last joined writer to link to the queue.
// In this way the writers link to the queue one by one.
// This allows us to confidently detect the first writer
// who increases threads_linked as the leader.
while ( threads_linked . load ( ) < cur_threads_joining ) {
}
} ) ;
// Verification once writers call JoinBatchGroup.
rocksdb : : SyncPoint : : GetInstance ( ) - > SetCallBack (
" WriteThread::JoinBatchGroup:Wait " , [ & ] ( void * arg ) {
uint64_t cur_threads_waiting = 0 ;
uint64_t cur_threads_linked = threads_linked . fetch_add ( 1 ) ;
bool is_leader = false ;
bool is_last = false ;
// who am i
do {
cur_threads_waiting = threads_waiting . load ( ) ;
is_leader = ( cur_threads_waiting = = 0 ) ;
is_last = ( cur_threads_waiting = = write_group . size ( ) - 1 ) ;
} while ( ! threads_waiting . compare_exchange_strong (
cur_threads_waiting , cur_threads_waiting + 1 ) ) ;
is_leader = ( cur_threads_linked = = 0 ) ;
is_last = ( cur_threads_linked = = write_group . size ( ) - 1 ) ;
// check my state
auto * writer = reinterpret_cast < WriteThread : : Writer * > ( arg ) ;
@ -185,8 +202,10 @@ TEST_F(WriteCallbackTest, WriteWithCallbackTest) {
! write_group . back ( ) . callback_ . should_fail_ ) ;
}
// wait for friends
while ( threads_waiting . load ( ) < write_group . size ( ) ) {
threads_verified . fetch_add ( 1 ) ;
// Wait here until all verification in this sync-point
// callback finish for all writers.
while ( threads_verified . load ( ) < write_group . size ( ) ) {
}
} ) ;
@ -211,17 +230,20 @@ TEST_F(WriteCallbackTest, WriteWithCallbackTest) {
std : : atomic < uint32_t > thread_num ( 0 ) ;
std : : atomic < char > dummy_key ( 0 ) ;
// Each write thread create a random write batch and write to DB
// with a write callback.
std : : function < void ( ) > write_with_callback_func = [ & ] ( ) {
uint32_t i = thread_num . fetch_add ( 1 ) ;
Random rnd ( i ) ;
// leaders gotta lead
while ( i > 0 & & threads_waiting . load ( ) < 1 ) {
while ( i > 0 & & threads_verified . load ( ) < 1 ) {
}
// loser has to lose
while ( i = = write_group . size ( ) - 1 & &
threads_waiting . load ( ) < write_group . size ( ) - 1 ) {
threads_verified . load ( ) < write_group . size ( ) - 1 ) {
}
auto & write_op = write_group . at ( i ) ;
@ -231,11 +253,7 @@ TEST_F(WriteCallbackTest, WriteWithCallbackTest) {
// insert some keys
for ( uint32_t j = 0 ; j < rnd . Next ( ) % 50 ; j + + ) {
// grab unique key
char my_key = 0 ;
do {
my_key = dummy_key . load ( ) ;
} while (
! dummy_key . compare_exchange_strong ( my_key , my_key + 1 ) ) ;
char my_key = dummy_key . fetch_add ( 1 ) ;
string skey ( 5 , my_key ) ;
string sval ( 10 , my_key ) ;