#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import gc import random import sys import time from os import path import numpy as np kSampleSize = 16 # The sample size used when performing eviction. kMicrosInSecond = 1000000 kSecondsInMinute = 60 kSecondsInHour = 3600 class TraceRecord: """ A trace record represents a block access. It holds the same struct as BlockCacheTraceRecord in trace_replay/block_cache_tracer.h """ def __init__( self, access_time, block_id, block_type, block_size, cf_id, cf_name, level, fd, caller, no_insert, get_id, key_id, kv_size, is_hit, ): self.access_time = access_time self.block_id = block_id self.block_type = block_type self.block_size = block_size self.cf_id = cf_id self.cf_name = cf_name self.level = level self.fd = fd self.caller = caller if no_insert == 1: self.no_insert = True else: self.no_insert = False self.get_id = get_id self.key_id = key_id self.kv_size = kv_size if is_hit == 1: self.is_hit = True else: self.is_hit = False class CacheEntry: """A cache entry stored in the cache.""" def __init__(self, value_size, cf_id, level, block_type, access_number): self.value_size = value_size self.last_access_number = access_number self.num_hits = 0 self.cf_id = 0 self.level = level self.block_type = block_type def __repr__(self): """Debug string.""" return "s={},last={},hits={},cf={},l={},bt={}".format( self.value_size, self.last_access_number, self.num_hits, self.cf_id, self.level, self.block_type, ) class HashEntry: """A hash entry stored in a hash table.""" def __init__(self, key, hash, value): self.key = key self.hash = hash self.value = value def __repr__(self): return "k={},h={},v=[{}]".format(self.key, self.hash, self.value) class HashTable: """ A custom implementation of hash table to support fast random sampling. It is closed hashing and uses chaining to resolve hash conflicts. It grows/shrinks the hash table upon insertion/deletion to support fast lookups and random samplings. """ def __init__(self): self.table = [None] * 32 self.elements = 0 def random_sample(self, sample_size): """Randomly sample 'sample_size' hash entries from the table.""" samples = [] index = random.randint(0, len(self.table)) pos = (index + 1) % len(self.table) searches = 0 # Starting from index, adding hash entries to the sample list until # sample_size is met or we ran out of entries. while pos != index and len(samples) < sample_size: if self.table[pos] is not None: for i in range(len(self.table[pos])): if self.table[pos][i] is None: continue samples.append(self.table[pos][i]) if len(samples) > sample_size: break pos += 1 pos = pos % len(self.table) searches += 1 return samples def insert(self, key, hash, value): """ Insert a hash entry in the table. Replace the old entry if it already exists. """ self.grow() inserted = False index = hash % len(self.table) if self.table[index] is None: self.table[index] = [] for i in range(len(self.table[index])): if self.table[index][i] is not None: if ( self.table[index][i].hash == hash and self.table[index][i].key == key ): # The entry already exists in the table. self.table[index][i] = HashEntry(key, hash, value) return continue self.table[index][i] = HashEntry(key, hash, value) inserted = True break if not inserted: self.table[index].append(HashEntry(key, hash, value)) self.elements += 1 def resize(self, new_size): if new_size == len(self.table): return if new_size == 0: return if self.elements < 100: return new_table = [None] * new_size # Copy 'self.table' to new_table. for i in range(len(self.table)): entries = self.table[i] if entries is None: continue for j in range(len(entries)): if entries[j] is None: continue index = entries[j].hash % new_size if new_table[index] is None: new_table[index] = [] new_table[index].append(entries[j]) self.table = new_table del new_table # Manually call python gc here to free the memory as 'self.table' # might be very large. gc.collect() def grow(self): if self.elements < len(self.table): return new_size = int(len(self.table) * 1.2) self.resize(new_size) def delete(self, key, hash): index = hash % len(self.table) entries = self.table[index] deleted = False if entries is None: return for i in range(len(entries)): if ( entries[i] is not None and entries[i].hash == hash and entries[i].key == key ): entries[i] = None self.elements -= 1 deleted = True break if deleted: self.shrink() def shrink(self): if self.elements * 2 >= len(self.table): return new_size = int(len(self.table) * 0.7) self.resize(new_size) def lookup(self, key, hash): index = hash % len(self.table) entries = self.table[index] if entries is None: return None for entry in entries: if entry is not None and entry.hash == hash and entry.key == key: return entry.value return None class MissRatioStats: def __init__(self, time_unit): self.num_misses = 0 self.num_accesses = 0 self.time_unit = time_unit self.time_misses = {} self.time_accesses = {} def update_metrics(self, access_time, is_hit): access_time /= kMicrosInSecond * self.time_unit self.num_accesses += 1 if access_time not in self.time_accesses: self.time_accesses[access_time] = 0 self.time_accesses[access_time] += 1 if not is_hit: self.num_misses += 1 if access_time not in self.time_misses: self.time_misses[access_time] = 0 self.time_misses[access_time] += 1 def reset_counter(self): self.num_misses = 0 self.num_accesses = 0 def miss_ratio(self): return float(self.num_misses) * 100.0 / float(self.num_accesses) def write_miss_timeline(self, cache_type, cache_size, result_dir, start, end): start /= kMicrosInSecond * self.time_unit end /= kMicrosInSecond * self.time_unit header_file_path = "{}/header-ml-miss-timeline-{}-{}-{}".format( result_dir, self.time_unit, cache_type, cache_size ) if not path.exists(header_file_path): with open(header_file_path, "w+") as header_file: header = "time" for trace_time in range(start, end): header += ",{}".format(trace_time) header_file.write(header + "\n") file_path = "{}/data-ml-miss-timeline-{}-{}-{}".format( result_dir, self.time_unit, cache_type, cache_size ) with open(file_path, "w+") as file: row = "{}".format(cache_type) for trace_time in range(start, end): row += ",{}".format(self.time_misses.get(trace_time, 0)) file.write(row + "\n") def write_miss_ratio_timeline(self, cache_type, cache_size, result_dir, start, end): start /= kMicrosInSecond * self.time_unit end /= kMicrosInSecond * self.time_unit header_file_path = "{}/header-ml-miss-ratio-timeline-{}-{}-{}".format( result_dir, self.time_unit, cache_type, cache_size ) if not path.exists(header_file_path): with open(header_file_path, "w+") as header_file: header = "time" for trace_time in range(start, end): header += ",{}".format(trace_time) header_file.write(header + "\n") file_path = "{}/data-ml-miss-ratio-timeline-{}-{}-{}".format( result_dir, self.time_unit, cache_type, cache_size ) with open(file_path, "w+") as file: row = "{}".format(cache_type) for trace_time in range(start, end): naccesses = self.time_accesses.get(trace_time, 0) miss_ratio = 0 if naccesses > 0: miss_ratio = float( self.time_misses.get(trace_time, 0) * 100.0 ) / float(naccesses) row += ",{0:.2f}".format(miss_ratio) file.write(row + "\n") class PolicyStats: def __init__(self, time_unit, policies): self.time_selected_polices = {} self.time_accesses = {} self.policy_names = {} self.time_unit = time_unit for i in range(len(policies)): self.policy_names[i] = policies[i].policy_name() def update_metrics(self, access_time, selected_policy): access_time /= kMicrosInSecond * self.time_unit if access_time not in self.time_accesses: self.time_accesses[access_time] = 0 self.time_accesses[access_time] += 1 if access_time not in self.time_selected_polices: self.time_selected_polices[access_time] = {} policy_name = self.policy_names[selected_policy] if policy_name not in self.time_selected_polices[access_time]: self.time_selected_polices[access_time][policy_name] = 0 self.time_selected_polices[access_time][policy_name] += 1 def write_policy_timeline(self, cache_type, cache_size, result_dir, start, end): start /= kMicrosInSecond * self.time_unit end /= kMicrosInSecond * self.time_unit header_file_path = "{}/header-ml-policy-timeline-{}-{}-{}".format( result_dir, self.time_unit, cache_type, cache_size ) if not path.exists(header_file_path): with open(header_file_path, "w+") as header_file: header = "time" for trace_time in range(start, end): header += ",{}".format(trace_time) header_file.write(header + "\n") file_path = "{}/data-ml-policy-timeline-{}-{}-{}".format( result_dir, self.time_unit, cache_type, cache_size ) with open(file_path, "w+") as file: for policy in self.policy_names: policy_name = self.policy_names[policy] row = "{}-{}".format(cache_type, policy_name) for trace_time in range(start, end): row += ",{}".format( self.time_selected_polices.get(trace_time, {}).get( policy_name, 0 ) ) file.write(row + "\n") def write_policy_ratio_timeline( self, cache_type, cache_size, file_path, start, end ): start /= kMicrosInSecond * self.time_unit end /= kMicrosInSecond * self.time_unit header_file_path = "{}/header-ml-policy-ratio-timeline-{}-{}-{}".format( result_dir, self.time_unit, cache_type, cache_size ) if not path.exists(header_file_path): with open(header_file_path, "w+") as header_file: header = "time" for trace_time in range(start, end): header += ",{}".format(trace_time) header_file.write(header + "\n") file_path = "{}/data-ml-policy-ratio-timeline-{}-{}-{}".format( result_dir, self.time_unit, cache_type, cache_size ) with open(file_path, "w+") as file: for policy in self.policy_names: policy_name = self.policy_names[policy] row = "{}-{}".format(cache_type, policy_name) for trace_time in range(start, end): naccesses = self.time_accesses.get(trace_time, 0) ratio = 0 if naccesses > 0: ratio = float( self.time_selected_polices.get(trace_time, {}).get( policy_name, 0 ) * 100.0 ) / float(naccesses) row += ",{0:.2f}".format(ratio) file.write(row + "\n") class Policy(object): """ A policy maintains a set of evicted keys. It returns a reward of one to itself if it has not evicted a missing key. Otherwise, it gives itself 0 reward. """ def __init__(self): self.evicted_keys = {} def evict(self, key, max_size): self.evicted_keys[key] = 0 def delete(self, key): self.evicted_keys.pop(key, None) def prioritize_samples(self, samples): raise NotImplementedError def policy_name(self): raise NotImplementedError def generate_reward(self, key): if key in self.evicted_keys: return 0 return 1 class LRUPolicy(Policy): def prioritize_samples(self, samples): return sorted( samples, cmp=lambda e1, e2: e1.value.last_access_number - e2.value.last_access_number, ) def policy_name(self): return "lru" class MRUPolicy(Policy): def prioritize_samples(self, samples): return sorted( samples, cmp=lambda e1, e2: e2.value.last_access_number - e1.value.last_access_number, ) def policy_name(self): return "mru" class LFUPolicy(Policy): def prioritize_samples(self, samples): return sorted(samples, cmp=lambda e1, e2: e1.value.num_hits - e2.value.num_hits) def policy_name(self): return "lfu" class MLCache(object): def __init__(self, cache_size, enable_cache_row_key, policies): self.cache_size = cache_size self.used_size = 0 self.miss_ratio_stats = MissRatioStats(kSecondsInMinute) self.policy_stats = PolicyStats(kSecondsInMinute, policies) self.per_hour_miss_ratio_stats = MissRatioStats(kSecondsInHour) self.per_hour_policy_stats = PolicyStats(kSecondsInHour, policies) self.table = HashTable() self.enable_cache_row_key = enable_cache_row_key self.get_id_row_key_map = {} self.policies = policies def _lookup(self, key, hash): value = self.table.lookup(key, hash) if value is not None: value.last_access_number = self.miss_ratio_stats.num_accesses value.num_hits += 1 return True return False def _select_policy(self, trace_record, key): raise NotImplementedError def cache_name(self): raise NotImplementedError def _evict(self, policy_index, value_size): # Randomly sample n entries. samples = self.table.random_sample(kSampleSize) samples = self.policies[policy_index].prioritize_samples(samples) for hash_entry in samples: self.used_size -= hash_entry.value.value_size self.table.delete(hash_entry.key, hash_entry.hash) self.policies[policy_index].evict( key=hash_entry.key, max_size=self.table.elements ) if self.used_size + value_size <= self.cache_size: break def _insert(self, trace_record, key, hash, value_size): if value_size > self.cache_size: return policy_index = self._select_policy(trace_record, key) self.policies[policy_index].delete(key) self.policy_stats.update_metrics(trace_record.access_time, policy_index) self.per_hour_policy_stats.update_metrics( trace_record.access_time, policy_index ) while self.used_size + value_size > self.cache_size: self._evict(policy_index, value_size) self.table.insert( key, hash, CacheEntry( value_size, trace_record.cf_id, trace_record.level, trace_record.block_type, self.miss_ratio_stats.num_accesses, ), ) self.used_size += value_size def _access_kv(self, trace_record, key, hash, value_size, no_insert): if self._lookup(key, hash): return True if not no_insert and value_size > 0: self._insert(trace_record, key, hash, value_size) return False def _update_stats(self, access_time, is_hit): self.miss_ratio_stats.update_metrics(access_time, is_hit) self.per_hour_miss_ratio_stats.update_metrics(access_time, is_hit) def access(self, trace_record): assert self.used_size <= self.cache_size if ( self.enable_cache_row_key and trace_record.caller == 1 and trace_record.key_id != 0 and trace_record.get_id != 0 ): # This is a get request. if trace_record.get_id not in self.get_id_row_key_map: self.get_id_row_key_map[trace_record.get_id] = {} self.get_id_row_key_map[trace_record.get_id]["h"] = False if self.get_id_row_key_map[trace_record.get_id]["h"]: # We treat future accesses as hits since this get request # completes. self._update_stats(trace_record.access_time, is_hit=True) return if trace_record.key_id not in self.get_id_row_key_map[trace_record.get_id]: # First time seen this key. is_hit = self._access_kv( trace_record, key="g{}".format(trace_record.key_id), hash=trace_record.key_id, value_size=trace_record.kv_size, no_insert=False, ) inserted = False if trace_record.kv_size > 0: inserted = True self.get_id_row_key_map[trace_record.get_id][ trace_record.key_id ] = inserted self.get_id_row_key_map[trace_record.get_id]["h"] = is_hit if self.get_id_row_key_map[trace_record.get_id]["h"]: # We treat future accesses as hits since this get request # completes. self._update_stats(trace_record.access_time, is_hit=True) return # Access its blocks. is_hit = self._access_kv( trace_record, key="b{}".format(trace_record.block_id), hash=trace_record.block_id, value_size=trace_record.block_size, no_insert=trace_record.no_insert, ) self._update_stats(trace_record.access_time, is_hit) if ( trace_record.kv_size > 0 and not self.get_id_row_key_map[trace_record.get_id][ trace_record.key_id ] ): # Insert the row key-value pair. self._access_kv( trace_record, key="g{}".format(trace_record.key_id), hash=trace_record.key_id, value_size=trace_record.kv_size, no_insert=False, ) # Mark as inserted. self.get_id_row_key_map[trace_record.get_id][trace_record.key_id] = True return # Access the block. is_hit = self._access_kv( trace_record, key="b{}".format(trace_record.block_id), hash=trace_record.block_id, value_size=trace_record.block_size, no_insert=trace_record.no_insert, ) self._update_stats(trace_record.access_time, is_hit) class ThompsonSamplingCache(MLCache): """ An implementation of Thompson Sampling for the Bernoulli Bandit [1]. [1] Daniel J. Russo, Benjamin Van Roy, Abbas Kazerouni, Ian Osband, and Zheng Wen. 2018. A Tutorial on Thompson Sampling. Found. Trends Mach. Learn. 11, 1 (July 2018), 1-96. DOI: https://doi.org/10.1561/2200000070 """ def __init__(self, cache_size, enable_cache_row_key, policies, init_a=1, init_b=1): super(ThompsonSamplingCache, self).__init__( cache_size, enable_cache_row_key, policies ) self._as = {} self._bs = {} for _i in range(len(policies)): self._as = [init_a] * len(self.policies) self._bs = [init_b] * len(self.policies) def _select_policy(self, trace_record, key): samples = [ np.random.beta(self._as[x], self._bs[x]) for x in range(len(self.policies)) ] selected_policy = max(range(len(self.policies)), key=lambda x: samples[x]) reward = self.policies[selected_policy].generate_reward(key) assert reward <= 1 and reward >= 0 self._as[selected_policy] += reward self._bs[selected_policy] += 1 - reward return selected_policy def cache_name(self): if self.enable_cache_row_key: return "Hybrid ThompsonSampling (ts_hybrid)" return "ThompsonSampling (ts)" class LinUCBCache(MLCache): """ An implementation of LinUCB with disjoint linear models [2]. [2] Lihong Li, Wei Chu, John Langford, and Robert E. Schapire. 2010. A contextual-bandit approach to personalized news article recommendation. In Proceedings of the 19th international conference on World wide web (WWW '10). ACM, New York, NY, USA, 661-670. DOI=http://dx.doi.org/10.1145/1772690.1772758 """ def __init__(self, cache_size, enable_cache_row_key, policies): super(LinUCBCache, self).__init__(cache_size, enable_cache_row_key, policies) self.nfeatures = 4 # Block type, caller, level, cf. self.th = np.zeros((len(self.policies), self.nfeatures)) self.eps = 0.2 self.b = np.zeros_like(self.th) self.A = np.zeros((len(self.policies), self.nfeatures, self.nfeatures)) self.A_inv = np.zeros((len(self.policies), self.nfeatures, self.nfeatures)) for i in range(len(self.policies)): self.A[i] = np.identity(self.nfeatures) self.th_hat = np.zeros_like(self.th) self.p = np.zeros(len(self.policies)) self.alph = 0.2 def _select_policy(self, trace_record, key): x_i = np.zeros(self.nfeatures) # The current context vector x_i[0] = trace_record.block_type x_i[1] = trace_record.caller x_i[2] = trace_record.level x_i[3] = trace_record.cf_id p = np.zeros(len(self.policies)) for a in range(len(self.policies)): self.th_hat[a] = self.A_inv[a].dot(self.b[a]) ta = x_i.dot(self.A_inv[a]).dot(x_i) a_upper_ci = self.alph * np.sqrt(ta) a_mean = self.th_hat[a].dot(x_i) p[a] = a_mean + a_upper_ci p = p + (np.random.random(len(p)) * 0.000001) selected_policy = p.argmax() reward = self.policies[selected_policy].generate_reward(key) assert reward <= 1 and reward >= 0 self.A[selected_policy] += np.outer(x_i, x_i) self.b[selected_policy] += reward * x_i self.A_inv[selected_policy] = np.linalg.inv(self.A[selected_policy]) del x_i return selected_policy def cache_name(self): if self.enable_cache_row_key: return "Hybrid LinUCB (linucb_hybrid)" return "LinUCB (linucb)" def parse_cache_size(cs): cs = cs.replace("\n", "") if cs[-1] == "M": return int(cs[: len(cs) - 1]) * 1024 * 1024 if cs[-1] == "G": return int(cs[: len(cs) - 1]) * 1024 * 1024 * 1024 if cs[-1] == "T": return int(cs[: len(cs) - 1]) * 1024 * 1024 * 1024 * 1024 return int(cs) def create_cache(cache_type, cache_size, downsample_size): policies = [] policies.append(LRUPolicy()) policies.append(MRUPolicy()) policies.append(LFUPolicy()) cache_size = cache_size / downsample_size enable_cache_row_key = False if "hybrid" in cache_type: enable_cache_row_key = True cache_type = cache_type[:-7] if cache_type == "ts": return ThompsonSamplingCache(cache_size, enable_cache_row_key, policies) elif cache_type == "linucb": return LinUCBCache(cache_size, enable_cache_row_key, policies) else: print("Unknown cache type {}".format(cache_type)) assert False return None def run(trace_file_path, cache_type, cache, warmup_seconds): warmup_complete = False num = 0 trace_start_time = 0 trace_duration = 0 start_time = time.time() time_interval = 1 trace_miss_ratio_stats = MissRatioStats(kSecondsInMinute) with open(trace_file_path, "r") as trace_file: for line in trace_file: num += 1 if num % 1000000 == 0: # Force a python gc periodically to reduce memory usage. gc.collect() ts = line.split(",") timestamp = int(ts[0]) if trace_start_time == 0: trace_start_time = timestamp trace_duration = timestamp - trace_start_time if not warmup_complete and trace_duration > warmup_seconds * 1000000: cache.miss_ratio_stats.reset_counter() warmup_complete = True record = TraceRecord( access_time=int(ts[0]), block_id=int(ts[1]), block_type=int(ts[2]), block_size=int(ts[3]), cf_id=int(ts[4]), cf_name=ts[5], level=int(ts[6]), fd=int(ts[7]), caller=int(ts[8]), no_insert=int(ts[9]), get_id=int(ts[10]), key_id=int(ts[11]), kv_size=int(ts[12]), is_hit=int(ts[13]), ) trace_miss_ratio_stats.update_metrics( record.access_time, is_hit=record.is_hit ) cache.access(record) del record if num % 100 != 0: continue # Report progress every 10 seconds. now = time.time() if now - start_time > time_interval * 10: print( "Take {} seconds to process {} trace records with trace " "duration of {} seconds. Throughput: {} records/second. " "Trace miss ratio {}".format( now - start_time, num, trace_duration / 1000000, num / (now - start_time), trace_miss_ratio_stats.miss_ratio(), ) ) time_interval += 1 print( "{},0,0,{},{},{}".format( cache_type, cache.cache_size, cache.miss_ratio_stats.miss_ratio(), cache.miss_ratio_stats.num_accesses, ) ) now = time.time() print( "Take {} seconds to process {} trace records with trace duration of {} " "seconds. Throughput: {} records/second. Trace miss ratio {}".format( now - start_time, num, trace_duration / 1000000, num / (now - start_time), trace_miss_ratio_stats.miss_ratio(), ) ) return trace_start_time, trace_duration def report_stats( cache, cache_type, cache_size, result_dir, trace_start_time, trace_end_time ): cache_label = "{}-{}".format(cache_type, cache_size) with open("{}/data-ml-mrc-{}".format(result_dir, cache_label), "w+") as mrc_file: mrc_file.write( "{},0,0,{},{},{}\n".format( cache_type, cache_size, cache.miss_ratio_stats.miss_ratio(), cache.miss_ratio_stats.num_accesses, ) ) cache.policy_stats.write_policy_timeline( cache_type, cache_size, result_dir, trace_start_time, trace_end_time ) cache.policy_stats.write_policy_ratio_timeline( cache_type, cache_size, result_dir, trace_start_time, trace_end_time ) cache.miss_ratio_stats.write_miss_timeline( cache_type, cache_size, result_dir, trace_start_time, trace_end_time ) cache.miss_ratio_stats.write_miss_ratio_timeline( cache_type, cache_size, result_dir, trace_start_time, trace_end_time ) cache.per_hour_policy_stats.write_policy_timeline( cache_type, cache_size, result_dir, trace_start_time, trace_end_time ) cache.per_hour_policy_stats.write_policy_ratio_timeline( cache_type, cache_size, result_dir, trace_start_time, trace_end_time ) cache.per_hour_miss_ratio_stats.write_miss_timeline( cache_type, cache_size, result_dir, trace_start_time, trace_end_time ) cache.per_hour_miss_ratio_stats.write_miss_ratio_timeline( cache_type, cache_size, result_dir, trace_start_time, trace_end_time ) if __name__ == "__main__": if len(sys.argv) <= 6: print( "Must provide 6 arguments. " "1) cache_type (ts, ts_hybrid, linucb, linucb_hybrid). " "2) cache size (xM, xG, xT). " "3) The sampling frequency used to collect the trace. (The " "simulation scales down the cache size by the sampling frequency). " "4) Warmup seconds (The number of seconds used for warmup). " "5) Trace file path. " "6) Result directory (A directory that saves generated results)" ) exit(1) cache_type = sys.argv[1] cache_size = parse_cache_size(sys.argv[2]) downsample_size = int(sys.argv[3]) warmup_seconds = int(sys.argv[4]) trace_file_path = sys.argv[5] result_dir = sys.argv[6] cache = create_cache(cache_type, cache_size, downsample_size) trace_start_time, trace_duration = run( trace_file_path, cache_type, cache, warmup_seconds ) trace_end_time = trace_start_time + trace_duration report_stats( cache, cache_type, cache_size, result_dir, trace_start_time, trace_end_time )