You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
rocksdb/tools/block_cache_analyzer/block_cache_pysim.py

865 lines
31 KiB

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import gc
import random
import sys
import time
from os import path
import numpy as np
kSampleSize = 16 # The sample size used when performing eviction.
kMicrosInSecond = 1000000
kSecondsInMinute = 60
kSecondsInHour = 3600
class TraceRecord:
"""
A trace record represents a block access.
It holds the same struct as BlockCacheTraceRecord in
trace_replay/block_cache_tracer.h
"""
def __init__(
self,
access_time,
block_id,
block_type,
block_size,
cf_id,
cf_name,
level,
fd,
caller,
no_insert,
get_id,
key_id,
kv_size,
is_hit,
):
self.access_time = access_time
self.block_id = block_id
self.block_type = block_type
self.block_size = block_size
self.cf_id = cf_id
self.cf_name = cf_name
self.level = level
self.fd = fd
self.caller = caller
if no_insert == 1:
self.no_insert = True
else:
self.no_insert = False
self.get_id = get_id
self.key_id = key_id
self.kv_size = kv_size
if is_hit == 1:
self.is_hit = True
else:
self.is_hit = False
class CacheEntry:
"""A cache entry stored in the cache."""
def __init__(self, value_size, cf_id, level, block_type, access_number):
self.value_size = value_size
self.last_access_number = access_number
self.num_hits = 0
self.cf_id = 0
self.level = level
self.block_type = block_type
def __repr__(self):
"""Debug string."""
return "s={},last={},hits={},cf={},l={},bt={}".format(
self.value_size,
self.last_access_number,
self.num_hits,
self.cf_id,
self.level,
self.block_type,
)
class HashEntry:
"""A hash entry stored in a hash table."""
def __init__(self, key, hash, value):
self.key = key
self.hash = hash
self.value = value
def __repr__(self):
return "k={},h={},v=[{}]".format(self.key, self.hash, self.value)
class HashTable:
"""
A custom implementation of hash table to support fast random sampling.
It is closed hashing and uses chaining to resolve hash conflicts.
It grows/shrinks the hash table upon insertion/deletion to support
fast lookups and random samplings.
"""
def __init__(self):
self.table = [None] * 32
self.elements = 0
def random_sample(self, sample_size):
"""Randomly sample 'sample_size' hash entries from the table."""
samples = []
index = random.randint(0, len(self.table))
pos = (index + 1) % len(self.table)
searches = 0
# Starting from index, adding hash entries to the sample list until
# sample_size is met or we ran out of entries.
while pos != index and len(samples) < sample_size:
if self.table[pos] is not None:
for i in range(len(self.table[pos])):
if self.table[pos][i] is None:
continue
samples.append(self.table[pos][i])
if len(samples) > sample_size:
break
pos += 1
pos = pos % len(self.table)
searches += 1
return samples
def insert(self, key, hash, value):
"""
Insert a hash entry in the table. Replace the old entry if it already
exists.
"""
self.grow()
inserted = False
index = hash % len(self.table)
if self.table[index] is None:
self.table[index] = []
for i in range(len(self.table[index])):
if self.table[index][i] is not None:
if (
self.table[index][i].hash == hash
and self.table[index][i].key == key
):
# The entry already exists in the table.
self.table[index][i] = HashEntry(key, hash, value)
return
continue
self.table[index][i] = HashEntry(key, hash, value)
inserted = True
break
if not inserted:
self.table[index].append(HashEntry(key, hash, value))
self.elements += 1
def resize(self, new_size):
if new_size == len(self.table):
return
if new_size == 0:
return
if self.elements < 100:
return
new_table = [None] * new_size
# Copy 'self.table' to new_table.
for i in range(len(self.table)):
entries = self.table[i]
if entries is None:
continue
for j in range(len(entries)):
if entries[j] is None:
continue
index = entries[j].hash % new_size
if new_table[index] is None:
new_table[index] = []
new_table[index].append(entries[j])
self.table = new_table
del new_table
# Manually call python gc here to free the memory as 'self.table'
# might be very large.
gc.collect()
def grow(self):
if self.elements < len(self.table):
return
new_size = int(len(self.table) * 1.2)
self.resize(new_size)
def delete(self, key, hash):
index = hash % len(self.table)
entries = self.table[index]
deleted = False
if entries is None:
return
for i in range(len(entries)):
if (
entries[i] is not None
and entries[i].hash == hash
and entries[i].key == key
):
entries[i] = None
self.elements -= 1
deleted = True
break
if deleted:
self.shrink()
def shrink(self):
if self.elements * 2 >= len(self.table):
return
new_size = int(len(self.table) * 0.7)
self.resize(new_size)
def lookup(self, key, hash):
index = hash % len(self.table)
entries = self.table[index]
if entries is None:
return None
for entry in entries:
if entry is not None and entry.hash == hash and entry.key == key:
return entry.value
return None
class MissRatioStats:
def __init__(self, time_unit):
self.num_misses = 0
self.num_accesses = 0
self.time_unit = time_unit
self.time_misses = {}
self.time_accesses = {}
def update_metrics(self, access_time, is_hit):
access_time /= kMicrosInSecond * self.time_unit
self.num_accesses += 1
if access_time not in self.time_accesses:
self.time_accesses[access_time] = 0
self.time_accesses[access_time] += 1
if not is_hit:
self.num_misses += 1
if access_time not in self.time_misses:
self.time_misses[access_time] = 0
self.time_misses[access_time] += 1
def reset_counter(self):
self.num_misses = 0
self.num_accesses = 0
def miss_ratio(self):
return float(self.num_misses) * 100.0 / float(self.num_accesses)
def write_miss_timeline(self, cache_type, cache_size, result_dir, start, end):
start /= kMicrosInSecond * self.time_unit
end /= kMicrosInSecond * self.time_unit
header_file_path = "{}/header-ml-miss-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
if not path.exists(header_file_path):
with open(header_file_path, "w+") as header_file:
header = "time"
for trace_time in range(start, end):
header += ",{}".format(trace_time)
header_file.write(header + "\n")
file_path = "{}/data-ml-miss-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
with open(file_path, "w+") as file:
row = "{}".format(cache_type)
for trace_time in range(start, end):
row += ",{}".format(self.time_misses.get(trace_time, 0))
file.write(row + "\n")
def write_miss_ratio_timeline(self, cache_type, cache_size, result_dir, start, end):
start /= kMicrosInSecond * self.time_unit
end /= kMicrosInSecond * self.time_unit
header_file_path = "{}/header-ml-miss-ratio-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
if not path.exists(header_file_path):
with open(header_file_path, "w+") as header_file:
header = "time"
for trace_time in range(start, end):
header += ",{}".format(trace_time)
header_file.write(header + "\n")
file_path = "{}/data-ml-miss-ratio-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
with open(file_path, "w+") as file:
row = "{}".format(cache_type)
for trace_time in range(start, end):
naccesses = self.time_accesses.get(trace_time, 0)
miss_ratio = 0
if naccesses > 0:
miss_ratio = float(
self.time_misses.get(trace_time, 0) * 100.0
) / float(naccesses)
row += ",{0:.2f}".format(miss_ratio)
file.write(row + "\n")
class PolicyStats:
def __init__(self, time_unit, policies):
self.time_selected_polices = {}
self.time_accesses = {}
self.policy_names = {}
self.time_unit = time_unit
for i in range(len(policies)):
self.policy_names[i] = policies[i].policy_name()
def update_metrics(self, access_time, selected_policy):
access_time /= kMicrosInSecond * self.time_unit
if access_time not in self.time_accesses:
self.time_accesses[access_time] = 0
self.time_accesses[access_time] += 1
if access_time not in self.time_selected_polices:
self.time_selected_polices[access_time] = {}
policy_name = self.policy_names[selected_policy]
if policy_name not in self.time_selected_polices[access_time]:
self.time_selected_polices[access_time][policy_name] = 0
self.time_selected_polices[access_time][policy_name] += 1
def write_policy_timeline(self, cache_type, cache_size, result_dir, start, end):
start /= kMicrosInSecond * self.time_unit
end /= kMicrosInSecond * self.time_unit
header_file_path = "{}/header-ml-policy-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
if not path.exists(header_file_path):
with open(header_file_path, "w+") as header_file:
header = "time"
for trace_time in range(start, end):
header += ",{}".format(trace_time)
header_file.write(header + "\n")
file_path = "{}/data-ml-policy-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
with open(file_path, "w+") as file:
for policy in self.policy_names:
policy_name = self.policy_names[policy]
row = "{}-{}".format(cache_type, policy_name)
for trace_time in range(start, end):
row += ",{}".format(
self.time_selected_polices.get(trace_time, {}).get(
policy_name, 0
)
)
file.write(row + "\n")
def write_policy_ratio_timeline(
self, cache_type, cache_size, file_path, start, end
):
start /= kMicrosInSecond * self.time_unit
end /= kMicrosInSecond * self.time_unit
header_file_path = "{}/header-ml-policy-ratio-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
if not path.exists(header_file_path):
with open(header_file_path, "w+") as header_file:
header = "time"
for trace_time in range(start, end):
header += ",{}".format(trace_time)
header_file.write(header + "\n")
file_path = "{}/data-ml-policy-ratio-timeline-{}-{}-{}".format(
result_dir, self.time_unit, cache_type, cache_size
)
with open(file_path, "w+") as file:
for policy in self.policy_names:
policy_name = self.policy_names[policy]
row = "{}-{}".format(cache_type, policy_name)
for trace_time in range(start, end):
naccesses = self.time_accesses.get(trace_time, 0)
ratio = 0
if naccesses > 0:
ratio = float(
self.time_selected_polices.get(trace_time, {}).get(
policy_name, 0
)
* 100.0
) / float(naccesses)
row += ",{0:.2f}".format(ratio)
file.write(row + "\n")
class Policy(object):
"""
A policy maintains a set of evicted keys. It returns a reward of one to
itself if it has not evicted a missing key. Otherwise, it gives itself 0
reward.
"""
def __init__(self):
self.evicted_keys = {}
def evict(self, key, max_size):
self.evicted_keys[key] = 0
def delete(self, key):
self.evicted_keys.pop(key, None)
def prioritize_samples(self, samples):
raise NotImplementedError
def policy_name(self):
raise NotImplementedError
def generate_reward(self, key):
if key in self.evicted_keys:
return 0
return 1
class LRUPolicy(Policy):
def prioritize_samples(self, samples):
return sorted(
samples,
cmp=lambda e1, e2: e1.value.last_access_number
- e2.value.last_access_number,
)
def policy_name(self):
return "lru"
class MRUPolicy(Policy):
def prioritize_samples(self, samples):
return sorted(
samples,
cmp=lambda e1, e2: e2.value.last_access_number
- e1.value.last_access_number,
)
def policy_name(self):
return "mru"
class LFUPolicy(Policy):
def prioritize_samples(self, samples):
return sorted(samples, cmp=lambda e1, e2: e1.value.num_hits - e2.value.num_hits)
def policy_name(self):
return "lfu"
class MLCache(object):
def __init__(self, cache_size, enable_cache_row_key, policies):
self.cache_size = cache_size
self.used_size = 0
self.miss_ratio_stats = MissRatioStats(kSecondsInMinute)
self.policy_stats = PolicyStats(kSecondsInMinute, policies)
self.per_hour_miss_ratio_stats = MissRatioStats(kSecondsInHour)
self.per_hour_policy_stats = PolicyStats(kSecondsInHour, policies)
self.table = HashTable()
self.enable_cache_row_key = enable_cache_row_key
self.get_id_row_key_map = {}
self.policies = policies
def _lookup(self, key, hash):
value = self.table.lookup(key, hash)
if value is not None:
value.last_access_number = self.miss_ratio_stats.num_accesses
value.num_hits += 1
return True
return False
def _select_policy(self, trace_record, key):
raise NotImplementedError
def cache_name(self):
raise NotImplementedError
def _evict(self, policy_index, value_size):
# Randomly sample n entries.
samples = self.table.random_sample(kSampleSize)
samples = self.policies[policy_index].prioritize_samples(samples)
for hash_entry in samples:
self.used_size -= hash_entry.value.value_size
self.table.delete(hash_entry.key, hash_entry.hash)
self.policies[policy_index].evict(
key=hash_entry.key, max_size=self.table.elements
)
if self.used_size + value_size <= self.cache_size:
break
def _insert(self, trace_record, key, hash, value_size):
if value_size > self.cache_size:
return
policy_index = self._select_policy(trace_record, key)
self.policies[policy_index].delete(key)
self.policy_stats.update_metrics(trace_record.access_time, policy_index)
self.per_hour_policy_stats.update_metrics(
trace_record.access_time, policy_index
)
while self.used_size + value_size > self.cache_size:
self._evict(policy_index, value_size)
self.table.insert(
key,
hash,
CacheEntry(
value_size,
trace_record.cf_id,
trace_record.level,
trace_record.block_type,
self.miss_ratio_stats.num_accesses,
),
)
self.used_size += value_size
def _access_kv(self, trace_record, key, hash, value_size, no_insert):
if self._lookup(key, hash):
return True
if not no_insert and value_size > 0:
self._insert(trace_record, key, hash, value_size)
return False
def _update_stats(self, access_time, is_hit):
self.miss_ratio_stats.update_metrics(access_time, is_hit)
self.per_hour_miss_ratio_stats.update_metrics(access_time, is_hit)
def access(self, trace_record):
assert self.used_size <= self.cache_size
if (
self.enable_cache_row_key
and trace_record.caller == 1
and trace_record.key_id != 0
and trace_record.get_id != 0
):
# This is a get request.
if trace_record.get_id not in self.get_id_row_key_map:
self.get_id_row_key_map[trace_record.get_id] = {}
self.get_id_row_key_map[trace_record.get_id]["h"] = False
if self.get_id_row_key_map[trace_record.get_id]["h"]:
# We treat future accesses as hits since this get request
# completes.
self._update_stats(trace_record.access_time, is_hit=True)
return
if trace_record.key_id not in self.get_id_row_key_map[trace_record.get_id]:
# First time seen this key.
is_hit = self._access_kv(
trace_record,
key="g{}".format(trace_record.key_id),
hash=trace_record.key_id,
value_size=trace_record.kv_size,
no_insert=False,
)
inserted = False
if trace_record.kv_size > 0:
inserted = True
self.get_id_row_key_map[trace_record.get_id][
trace_record.key_id
] = inserted
self.get_id_row_key_map[trace_record.get_id]["h"] = is_hit
if self.get_id_row_key_map[trace_record.get_id]["h"]:
# We treat future accesses as hits since this get request
# completes.
self._update_stats(trace_record.access_time, is_hit=True)
return
# Access its blocks.
is_hit = self._access_kv(
trace_record,
key="b{}".format(trace_record.block_id),
hash=trace_record.block_id,
value_size=trace_record.block_size,
no_insert=trace_record.no_insert,
)
self._update_stats(trace_record.access_time, is_hit)
if (
trace_record.kv_size > 0
and not self.get_id_row_key_map[trace_record.get_id][
trace_record.key_id
]
):
# Insert the row key-value pair.
self._access_kv(
trace_record,
key="g{}".format(trace_record.key_id),
hash=trace_record.key_id,
value_size=trace_record.kv_size,
no_insert=False,
)
# Mark as inserted.
self.get_id_row_key_map[trace_record.get_id][trace_record.key_id] = True
return
# Access the block.
is_hit = self._access_kv(
trace_record,
key="b{}".format(trace_record.block_id),
hash=trace_record.block_id,
value_size=trace_record.block_size,
no_insert=trace_record.no_insert,
)
self._update_stats(trace_record.access_time, is_hit)
class ThompsonSamplingCache(MLCache):
"""
An implementation of Thompson Sampling for the Bernoulli Bandit [1].
[1] Daniel J. Russo, Benjamin Van Roy, Abbas Kazerouni, Ian Osband,
and Zheng Wen. 2018. A Tutorial on Thompson Sampling. Found.
Trends Mach. Learn. 11, 1 (July 2018), 1-96.
DOI: https://doi.org/10.1561/2200000070
"""
def __init__(self, cache_size, enable_cache_row_key, policies, init_a=1, init_b=1):
super(ThompsonSamplingCache, self).__init__(
cache_size, enable_cache_row_key, policies
)
self._as = {}
self._bs = {}
for _i in range(len(policies)):
self._as = [init_a] * len(self.policies)
self._bs = [init_b] * len(self.policies)
def _select_policy(self, trace_record, key):
samples = [
np.random.beta(self._as[x], self._bs[x]) for x in range(len(self.policies))
]
selected_policy = max(range(len(self.policies)), key=lambda x: samples[x])
reward = self.policies[selected_policy].generate_reward(key)
assert reward <= 1 and reward >= 0
self._as[selected_policy] += reward
self._bs[selected_policy] += 1 - reward
return selected_policy
def cache_name(self):
if self.enable_cache_row_key:
return "Hybrid ThompsonSampling (ts_hybrid)"
return "ThompsonSampling (ts)"
class LinUCBCache(MLCache):
"""
An implementation of LinUCB with disjoint linear models [2].
[2] Lihong Li, Wei Chu, John Langford, and Robert E. Schapire. 2010.
A contextual-bandit approach to personalized news article recommendation.
In Proceedings of the 19th international conference on World wide web
(WWW '10). ACM, New York, NY, USA, 661-670.
DOI=http://dx.doi.org/10.1145/1772690.1772758
"""
def __init__(self, cache_size, enable_cache_row_key, policies):
super(LinUCBCache, self).__init__(cache_size, enable_cache_row_key, policies)
self.nfeatures = 4 # Block type, caller, level, cf.
self.th = np.zeros((len(self.policies), self.nfeatures))
self.eps = 0.2
self.b = np.zeros_like(self.th)
self.A = np.zeros((len(self.policies), self.nfeatures, self.nfeatures))
self.A_inv = np.zeros((len(self.policies), self.nfeatures, self.nfeatures))
for i in range(len(self.policies)):
self.A[i] = np.identity(self.nfeatures)
self.th_hat = np.zeros_like(self.th)
self.p = np.zeros(len(self.policies))
self.alph = 0.2
def _select_policy(self, trace_record, key):
x_i = np.zeros(self.nfeatures) # The current context vector
x_i[0] = trace_record.block_type
x_i[1] = trace_record.caller
x_i[2] = trace_record.level
x_i[3] = trace_record.cf_id
p = np.zeros(len(self.policies))
for a in range(len(self.policies)):
self.th_hat[a] = self.A_inv[a].dot(self.b[a])
ta = x_i.dot(self.A_inv[a]).dot(x_i)
a_upper_ci = self.alph * np.sqrt(ta)
a_mean = self.th_hat[a].dot(x_i)
p[a] = a_mean + a_upper_ci
p = p + (np.random.random(len(p)) * 0.000001)
selected_policy = p.argmax()
reward = self.policies[selected_policy].generate_reward(key)
assert reward <= 1 and reward >= 0
self.A[selected_policy] += np.outer(x_i, x_i)
self.b[selected_policy] += reward * x_i
self.A_inv[selected_policy] = np.linalg.inv(self.A[selected_policy])
del x_i
return selected_policy
def cache_name(self):
if self.enable_cache_row_key:
return "Hybrid LinUCB (linucb_hybrid)"
return "LinUCB (linucb)"
def parse_cache_size(cs):
cs = cs.replace("\n", "")
if cs[-1] == "M":
return int(cs[: len(cs) - 1]) * 1024 * 1024
if cs[-1] == "G":
return int(cs[: len(cs) - 1]) * 1024 * 1024 * 1024
if cs[-1] == "T":
return int(cs[: len(cs) - 1]) * 1024 * 1024 * 1024 * 1024
return int(cs)
def create_cache(cache_type, cache_size, downsample_size):
policies = []
policies.append(LRUPolicy())
policies.append(MRUPolicy())
policies.append(LFUPolicy())
cache_size = cache_size / downsample_size
enable_cache_row_key = False
if "hybrid" in cache_type:
enable_cache_row_key = True
cache_type = cache_type[:-7]
if cache_type == "ts":
return ThompsonSamplingCache(cache_size, enable_cache_row_key, policies)
elif cache_type == "linucb":
return LinUCBCache(cache_size, enable_cache_row_key, policies)
else:
print("Unknown cache type {}".format(cache_type))
assert False
return None
def run(trace_file_path, cache_type, cache, warmup_seconds):
warmup_complete = False
num = 0
trace_start_time = 0
trace_duration = 0
start_time = time.time()
time_interval = 1
trace_miss_ratio_stats = MissRatioStats(kSecondsInMinute)
with open(trace_file_path, "r") as trace_file:
for line in trace_file:
num += 1
if num % 1000000 == 0:
# Force a python gc periodically to reduce memory usage.
gc.collect()
ts = line.split(",")
timestamp = int(ts[0])
if trace_start_time == 0:
trace_start_time = timestamp
trace_duration = timestamp - trace_start_time
if not warmup_complete and trace_duration > warmup_seconds * 1000000:
cache.miss_ratio_stats.reset_counter()
warmup_complete = True
record = TraceRecord(
access_time=int(ts[0]),
block_id=int(ts[1]),
block_type=int(ts[2]),
block_size=int(ts[3]),
cf_id=int(ts[4]),
cf_name=ts[5],
level=int(ts[6]),
fd=int(ts[7]),
caller=int(ts[8]),
no_insert=int(ts[9]),
get_id=int(ts[10]),
key_id=int(ts[11]),
kv_size=int(ts[12]),
is_hit=int(ts[13]),
)
trace_miss_ratio_stats.update_metrics(
record.access_time, is_hit=record.is_hit
)
cache.access(record)
del record
if num % 100 != 0:
continue
# Report progress every 10 seconds.
now = time.time()
if now - start_time > time_interval * 10:
print(
"Take {} seconds to process {} trace records with trace "
"duration of {} seconds. Throughput: {} records/second. "
"Trace miss ratio {}".format(
now - start_time,
num,
trace_duration / 1000000,
num / (now - start_time),
trace_miss_ratio_stats.miss_ratio(),
)
)
time_interval += 1
print(
"{},0,0,{},{},{}".format(
cache_type,
cache.cache_size,
cache.miss_ratio_stats.miss_ratio(),
cache.miss_ratio_stats.num_accesses,
)
)
now = time.time()
print(
"Take {} seconds to process {} trace records with trace duration of {} "
"seconds. Throughput: {} records/second. Trace miss ratio {}".format(
now - start_time,
num,
trace_duration / 1000000,
num / (now - start_time),
trace_miss_ratio_stats.miss_ratio(),
)
)
return trace_start_time, trace_duration
def report_stats(
cache, cache_type, cache_size, result_dir, trace_start_time, trace_end_time
):
cache_label = "{}-{}".format(cache_type, cache_size)
with open("{}/data-ml-mrc-{}".format(result_dir, cache_label), "w+") as mrc_file:
mrc_file.write(
"{},0,0,{},{},{}\n".format(
cache_type,
cache_size,
cache.miss_ratio_stats.miss_ratio(),
cache.miss_ratio_stats.num_accesses,
)
)
cache.policy_stats.write_policy_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.policy_stats.write_policy_ratio_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.miss_ratio_stats.write_miss_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.miss_ratio_stats.write_miss_ratio_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.per_hour_policy_stats.write_policy_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.per_hour_policy_stats.write_policy_ratio_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.per_hour_miss_ratio_stats.write_miss_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
cache.per_hour_miss_ratio_stats.write_miss_ratio_timeline(
cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)
if __name__ == "__main__":
if len(sys.argv) <= 6:
print(
"Must provide 6 arguments. "
"1) cache_type (ts, ts_hybrid, linucb, linucb_hybrid). "
"2) cache size (xM, xG, xT). "
"3) The sampling frequency used to collect the trace. (The "
"simulation scales down the cache size by the sampling frequency). "
"4) Warmup seconds (The number of seconds used for warmup). "
"5) Trace file path. "
"6) Result directory (A directory that saves generated results)"
)
exit(1)
cache_type = sys.argv[1]
cache_size = parse_cache_size(sys.argv[2])
downsample_size = int(sys.argv[3])
warmup_seconds = int(sys.argv[4])
trace_file_path = sys.argv[5]
result_dir = sys.argv[6]
cache = create_cache(cache_type, cache_size, downsample_size)
trace_start_time, trace_duration = run(
trace_file_path, cache_type, cache, warmup_seconds
)
trace_end_time = trace_start_time + trace_duration
report_stats(
cache, cache_type, cache_size, result_dir, trace_start_time, trace_end_time
)