Skip to content

Commit f9ad2fa

Browse files
committedApr 29, 2021
running version of framework
1 parent cd3ba30 commit f9ad2fa

17 files changed

+969
-15
lines changed
 

‎algos.yaml

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
float:
2+
any:
3+
faiss-ivf:
4+
docker-tag: billion-scale-benchmark-faiss
5+
module: benchmark.algorithms.faiss_inmem
6+
constructor: FaissIVF
7+
base-args: ["@metric"]
8+
run-groups:
9+
base:
10+
args: [[1024,2048,4096,8192]]
11+
query-args: [[1, 5, 10, 50, 100, 200]]
12+
euclidean:
13+
faiss-ivf:
14+
docker-tag: billion-scale-benchmark-faiss
15+
module: benchmark.algorithms.faiss_inmem
16+
constructor: FaissIVF
17+
base-args: ["euclidean"]
18+
run-groups:
19+
base:
20+
args: [[1024]]
21+
query-args: [[1, 5, 10, 50, 100, 200]]

‎benchmark/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from __future__ import absolute_import

‎benchmark/algorithms/base.py

+19-5
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,25 @@
11
from __future__ import absolute_import
2+
import psutil
23

34
class BaseANN(object):
45
def done(self):
56
pass
67

7-
def fit(self, X, name):
8+
def fit(self, dataset):
89
"""
9-
Build the index for the data points given as X.
10-
Pass name as well to store index.
10+
Build the index for the data points given in dataset name.
11+
Assumes that after fitting index is loaded in memory.
1112
"""
1213
pass
1314

14-
def load_index(self, name):
15-
"""Load the index from name."""
15+
def load_index(self, dataset):
16+
"""
17+
Load the index for dataset. Returns False if index
18+
is not available, True otherwise.
19+
20+
Checking the index usually involves the dataset name
21+
and the index build paramters passed during construction.
22+
"""
1623
pass
1724

1825
def query(self, X, k):
@@ -40,3 +47,10 @@ def get_additional(self):
4047

4148
def __str__(self):
4249
return self.name
50+
51+
def get_memory_usage(self):
52+
"""Return the current memory usage of this algorithm instance
53+
(in kilobytes), or None if this information is not available."""
54+
# return in kB for backwards compatibility
55+
return psutil.Process().memory_info().rss / 1024
56+

‎benchmark/algorithms/definitions.py

+173
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from __future__ import absolute_import
2+
from os import sep as pathsep
3+
import collections
4+
import importlib
5+
import os
6+
import sys
7+
import traceback
8+
import yaml
9+
from enum import Enum
10+
from itertools import product
11+
12+
13+
Definition = collections.namedtuple(
14+
'Definition',
15+
['algorithm', 'constructor', 'module', 'docker_tag',
16+
'arguments', 'query_argument_groups', 'disabled'])
17+
18+
19+
def instantiate_algorithm(definition):
20+
print('Trying to instantiate %s.%s(%s)' %
21+
(definition.module, definition.constructor, definition.arguments))
22+
module = importlib.import_module(definition.module)
23+
constructor = getattr(module, definition.constructor)
24+
return constructor(*definition.arguments)
25+
26+
27+
class InstantiationStatus(Enum):
28+
AVAILABLE = 0
29+
NO_CONSTRUCTOR = 1
30+
NO_MODULE = 2
31+
32+
33+
def algorithm_status(definition):
34+
try:
35+
module = importlib.import_module(definition.module)
36+
if hasattr(module, definition.constructor):
37+
return InstantiationStatus.AVAILABLE
38+
else:
39+
return InstantiationStatus.NO_CONSTRUCTOR
40+
except ImportError:
41+
return InstantiationStatus.NO_MODULE
42+
43+
44+
def _generate_combinations(args):
45+
if isinstance(args, list):
46+
args = [el if isinstance(el, list) else [el] for el in args]
47+
return [list(x) for x in product(*args)]
48+
elif isinstance(args, dict):
49+
flat = []
50+
for k, v in args.items():
51+
if isinstance(v, list):
52+
flat.append([(k, el) for el in v])
53+
else:
54+
flat.append([(k, v)])
55+
return [dict(x) for x in product(*flat)]
56+
else:
57+
raise TypeError("No args handling exists for %s" % type(args).__name__)
58+
59+
60+
def _substitute_variables(arg, vs):
61+
if isinstance(arg, dict):
62+
return dict([(k, _substitute_variables(v, vs))
63+
for k, v in arg.items()])
64+
elif isinstance(arg, list):
65+
return [_substitute_variables(a, vs) for a in arg]
66+
elif isinstance(arg, str) and arg in vs:
67+
return vs[arg]
68+
else:
69+
return arg
70+
71+
72+
def _get_definitions(definition_file):
73+
with open(definition_file, "r") as f:
74+
return yaml.load(f, yaml.SafeLoader)
75+
76+
77+
def list_algorithms(definition_file):
78+
definitions = _get_definitions(definition_file)
79+
80+
print('The following algorithms are supported...')
81+
for point in definitions:
82+
print('\t... for the point type "%s"...' % point)
83+
for metric in definitions[point]:
84+
print('\t\t... and the distance metric "%s":' % metric)
85+
for algorithm in definitions[point][metric]:
86+
print('\t\t\t%s' % algorithm)
87+
88+
89+
def get_unique_algorithms(definition_file):
90+
definitions = _get_definitions(definition_file)
91+
algos = set()
92+
for point in definitions:
93+
for metric in definitions[point]:
94+
for algorithm in definitions[point][metric]:
95+
algos.add(algorithm)
96+
return list(sorted(algos))
97+
98+
99+
def get_definitions(definition_file, dimension, point_type="float",
100+
distance_metric="euclidean", count=10):
101+
definitions = _get_definitions(definition_file)
102+
103+
algorithm_definitions = {}
104+
if "any" in definitions[point_type]:
105+
algorithm_definitions.update(definitions[point_type]["any"])
106+
algorithm_definitions.update(definitions[point_type][distance_metric])
107+
108+
definitions = []
109+
for (name, algo) in algorithm_definitions.items():
110+
for k in ['docker-tag', 'module', 'constructor']:
111+
if k not in algo:
112+
raise Exception(
113+
'algorithm %s does not define a "%s" property' % (name, k))
114+
115+
base_args = []
116+
if "base-args" in algo:
117+
base_args = algo["base-args"]
118+
119+
for run_group in algo["run-groups"].values():
120+
if "arg-groups" in run_group:
121+
groups = []
122+
for arg_group in run_group["arg-groups"]:
123+
if isinstance(arg_group, dict):
124+
# Dictionaries need to be expanded into lists in order
125+
# for the subsequent call to _generate_combinations to
126+
# do the right thing
127+
groups.append(_generate_combinations(arg_group))
128+
else:
129+
groups.append(arg_group)
130+
args = _generate_combinations(groups)
131+
elif "args" in run_group:
132+
args = _generate_combinations(run_group["args"])
133+
else:
134+
assert False, "? what? %s" % run_group
135+
136+
if "query-arg-groups" in run_group:
137+
groups = []
138+
for arg_group in run_group["query-arg-groups"]:
139+
if isinstance(arg_group, dict):
140+
groups.append(_generate_combinations(arg_group))
141+
else:
142+
groups.append(arg_group)
143+
query_args = _generate_combinations(groups)
144+
elif "query-args" in run_group:
145+
query_args = _generate_combinations(run_group["query-args"])
146+
else:
147+
query_args = []
148+
149+
for arg_group in args:
150+
aargs = []
151+
aargs.extend(base_args)
152+
if isinstance(arg_group, list):
153+
aargs.extend(arg_group)
154+
else:
155+
aargs.append(arg_group)
156+
157+
vs = {
158+
"@count": count,
159+
"@metric": distance_metric,
160+
"@dimension": dimension
161+
}
162+
aargs = [_substitute_variables(arg, vs) for arg in aargs]
163+
definitions.append(Definition(
164+
algorithm=name,
165+
docker_tag=algo['docker-tag'],
166+
module=algo['module'],
167+
constructor=algo['constructor'],
168+
arguments=aargs,
169+
query_argument_groups=query_args,
170+
disabled=algo.get('disabled', False)
171+
))
172+
173+
return definitions

‎benchmark/algorithms/faiss.py

Whitespace-only changes.

‎benchmark/algorithms/faiss_inmem.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
from __future__ import absolute_import
2+
#import sys
3+
#sys.path.append("install/lib-faiss") # noqa
4+
import numpy
5+
import sklearn.preprocessing
6+
import ctypes
7+
import faiss
8+
import os
9+
from benchmark.algorithms.base import BaseANN
10+
from benchmark.datasets import DATASETS
11+
12+
13+
class Faiss(BaseANN):
14+
def query(self, X, n):
15+
if self._metric == 'angular':
16+
X /= numpy.linalg.norm(X)
17+
self.res = self.index.search(X.astype(numpy.float32), n)
18+
19+
def get_results(self):
20+
D, I = self.res
21+
return I
22+
# res = []
23+
# for i in range(len(D)):
24+
# r = []
25+
# for l, d in zip(L[i], D[i]):
26+
# if l != -1:
27+
# r.append(l)
28+
# res.append(r)
29+
# return res
30+
31+
32+
class FaissIVF(Faiss):
33+
def __init__(self, metric, n_list):
34+
self._n_list = n_list
35+
self._metric = metric
36+
37+
def index_name(self, name):
38+
return f"data/ivf_{name}_{self._n_list}_{self._metric}"
39+
40+
def fit(self, dataset):
41+
X = DATASETS[dataset].get_dataset() # assumes it fits into memory
42+
43+
if self._metric == 'angular':
44+
X = sklearn.preprocessing.normalize(X, axis=1, norm='l2')
45+
46+
if X.dtype != numpy.float32:
47+
X = X.astype(numpy.float32)
48+
49+
self.quantizer = faiss.IndexFlatL2(X.shape[1])
50+
index = faiss.IndexIVFFlat(
51+
self.quantizer, X.shape[1], self._n_list, faiss.METRIC_L2)
52+
index.train(X)
53+
index.add(X)
54+
faiss.write_index(index, self.index_name(dataset))
55+
self.index = index
56+
57+
def load_index(self, dataset):
58+
if not os.path.exists(self.index_name(dataset)):
59+
return False
60+
61+
self.index = faiss.read_index(self.index_name(dataset))
62+
return True
63+
64+
def set_query_arguments(self, n_probe):
65+
faiss.cvar.indexIVF_stats.reset()
66+
self._n_probe = n_probe
67+
self.index.nprobe = self._n_probe
68+
69+
def get_additional(self):
70+
return {"dist_comps": faiss.cvar.indexIVF_stats.ndis + # noqa
71+
faiss.cvar.indexIVF_stats.nq * self._n_list}
72+
73+
def __str__(self):
74+
return 'FaissIVF(n_list=%d, n_probe=%d)' % (self._n_list,
75+
self._n_probe)

‎benchmark/datasets.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import h5py
21
import numpy
32
import os
43
import random
@@ -67,7 +66,7 @@ def get_groundtruth(self, k=None):
6766
for each query."""
6867
pass
6968

70-
def query_type(self):
69+
def search_type(self):
7170
"""
7271
"knn" or "range"
7372
"""
@@ -156,9 +155,9 @@ def get_groundtruth(self, k=None):
156155
return gt
157156

158157
def distance(self):
159-
return "Euclidean"
158+
return "euclidean"
160159

161-
def query_type(self):
160+
def search_type(self):
162161
return "knn"
163162

164163
def __str__(self):
@@ -204,11 +203,11 @@ def get_groundtruth(self, k=None):
204203
gt = gt[:, :k]
205204
return gt
206205

207-
def query_type(self):
206+
def search_type(self):
208207
return "knn"
209208

210209
def distance(self):
211-
return "Euclidean"
210+
return "euclidean"
212211

213212
def __str__(self):
214213
return f"Deep1B"
@@ -226,7 +225,7 @@ def __init__(self, nb_M=1000):
226225
self.base_url = "https://storage.yandexcloud.net/yandex-research/ann-datasets/T2I/"
227226

228227
def distance(self):
229-
return "IP"
228+
return "ip"
230229

231230
def __str__(self):
232231
return f"TextToImage"

‎benchmark/distances.py

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from scipy.spatial.distance import pdist as scipy_pdist
2+
import itertools
3+
import numpy as np
4+
5+
def pdist(a, b, metric):
6+
return scipy_pdist([a, b], metric=metric)[0]
7+
8+
metrics = {
9+
'euclidean': {
10+
'distance': lambda a, b: pdist(a, b, "euclidean"),
11+
},
12+
'angular': {
13+
'distance': lambda a, b: pdist(a, b, "cosine"),
14+
}
15+
}
16+

‎benchmark/main.py

+185
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
from __future__ import absolute_import
2+
import argparse
3+
import logging
4+
import logging.config
5+
6+
import docker
7+
import multiprocessing.pool
8+
import os
9+
import psutil
10+
import random
11+
import shutil
12+
import sys
13+
import traceback
14+
15+
from benchmark.datasets import DATASETS
16+
from benchmark.algorithms.definitions import (get_definitions,
17+
list_algorithms,
18+
algorithm_status,
19+
InstantiationStatus)
20+
from benchmark.results import get_result_filename
21+
from benchmark.runner import run, run_docker
22+
23+
24+
def positive_int(s):
25+
i = None
26+
try:
27+
i = int(s)
28+
except ValueError:
29+
pass
30+
if not i or i < 1:
31+
raise argparse.ArgumentTypeError("%r is not a positive integer" % s)
32+
return i
33+
34+
35+
def run_worker(args, queue):
36+
while not queue.empty():
37+
definition = queue.get()
38+
memory_margin = 500e6 # reserve some extra memory for misc stuff
39+
mem_limit = int((psutil.virtual_memory().available - memory_margin))
40+
#mem_limit = 128e9 # 128gb for competition
41+
cpu_limit = "0-%d" % (multiprocessing.cpu_count() - 1)
42+
43+
run_docker(definition, args.dataset, args.count,
44+
args.runs, args.timeout, cpu_limit, mem_limit)
45+
46+
47+
def main():
48+
parser = argparse.ArgumentParser(
49+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
50+
parser.add_argument(
51+
'--dataset',
52+
metavar='NAME',
53+
help='the dataset to load training points from',
54+
default='sift-1M',
55+
choices=DATASETS.keys())
56+
parser.add_argument(
57+
"-k", "--count",
58+
default=10,
59+
type=positive_int,
60+
help="the number of near neighbours to search for")
61+
parser.add_argument(
62+
'--definitions',
63+
metavar='FILE',
64+
help='load algorithm definitions from FILE',
65+
default='algos.yaml')
66+
parser.add_argument(
67+
'--algorithm',
68+
metavar='NAME',
69+
help='run only the named algorithm',
70+
default=None)
71+
parser.add_argument(
72+
'--docker-tag',
73+
metavar='NAME',
74+
help='run only algorithms in a particular docker image',
75+
default=None)
76+
parser.add_argument(
77+
'--list-algorithms',
78+
help='print the names of all known algorithms and exit',
79+
action='store_true')
80+
parser.add_argument(
81+
'--force',
82+
help='re-run algorithms even if their results already exist',
83+
action='store_true')
84+
parser.add_argument(
85+
'--runs',
86+
metavar='COUNT',
87+
type=positive_int,
88+
help='run each algorithm instance %(metavar)s times and use only'
89+
' the best result',
90+
default=5)
91+
parser.add_argument(
92+
'--timeout',
93+
type=int,
94+
help='Timeout (in seconds) for each individual algorithm run, or -1'
95+
'if no timeout should be set',
96+
default=12 * 3600)
97+
parser.add_argument(
98+
'--max-n-algorithms',
99+
type=int,
100+
help='Max number of algorithms to run (just used for testing)',
101+
default=-1)
102+
103+
args = parser.parse_args()
104+
if args.timeout == -1:
105+
args.timeout = None
106+
107+
if args.list_algorithms:
108+
list_algorithms(args.definitions)
109+
sys.exit(0)
110+
111+
logging.config.fileConfig("logging.conf")
112+
logger = logging.getLogger("annb")
113+
114+
dataset = DATASETS[args.dataset]
115+
dimension = dataset.d
116+
point_type = 'float'
117+
distance = dataset.distance()
118+
definitions = get_definitions(
119+
args.definitions, dimension, point_type, distance, args.count)
120+
121+
# Filter out, from the loaded definitions, all those query argument groups
122+
# that correspond to experiments that have already been run. (This might
123+
# mean removing a definition altogether, so we can't just use a list
124+
# comprehension.)
125+
filtered_definitions = []
126+
for definition in definitions:
127+
query_argument_groups = definition.query_argument_groups
128+
if not query_argument_groups:
129+
query_argument_groups = [[]]
130+
not_yet_run = []
131+
for query_arguments in query_argument_groups:
132+
fn = get_result_filename(args.dataset,
133+
args.count, definition,
134+
query_arguments)
135+
if args.force or not os.path.exists(fn):
136+
not_yet_run.append(query_arguments)
137+
if not_yet_run:
138+
if definition.query_argument_groups:
139+
definition = definition._replace(
140+
query_argument_groups=not_yet_run)
141+
filtered_definitions.append(definition)
142+
definitions = filtered_definitions
143+
144+
random.shuffle(definitions)
145+
146+
if args.algorithm:
147+
logger.info(f'running only {args.algorithm}')
148+
definitions = [d for d in definitions if d.algorithm == args.algorithm]
149+
150+
# See which Docker images we have available
151+
docker_client = docker.from_env()
152+
docker_tags = set()
153+
for image in docker_client.images.list():
154+
for tag in image.tags:
155+
tag = tag.split(':')[0]
156+
docker_tags.add(tag)
157+
158+
if args.docker_tag:
159+
logger.info(f'running only {args.docker_tag}')
160+
definitions = [
161+
d for d in definitions if d.docker_tag == args.docker_tag]
162+
163+
if set(d.docker_tag for d in definitions).difference(docker_tags):
164+
logger.info(f'not all docker images available, only: {set(docker_tags)}')
165+
logger.info(f'missing docker images: '
166+
f'{str(set(d.docker_tag for d in definitions).difference(docker_tags))}')
167+
definitions = [
168+
d for d in definitions if d.docker_tag in docker_tags]
169+
170+
if args.max_n_algorithms >= 0:
171+
definitions = definitions[:args.max_n_algorithms]
172+
173+
if len(definitions) == 0:
174+
raise Exception('Nothing to run')
175+
else:
176+
logger.info(f'Order: {definitions}')
177+
178+
queue = multiprocessing.Queue()
179+
for definition in definitions:
180+
queue.put(definition)
181+
#run_worker(args, queue)
182+
workers = [multiprocessing.Process(target=run_worker, args=(args, queue))
183+
for i in range(1)]
184+
[worker.start() for worker in workers]
185+
[worker.join() for worker in workers]

‎benchmark/results.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from __future__ import absolute_import
2+
3+
import h5py
4+
import json
5+
import os
6+
import re
7+
import traceback
8+
9+
10+
def get_result_filename(dataset=None, count=None, definition=None,
11+
query_arguments=None):
12+
d = ['results']
13+
if dataset:
14+
d.append(dataset)
15+
if count:
16+
d.append(str(count))
17+
if definition:
18+
d.append(definition.algorithm)
19+
data = definition.arguments + query_arguments
20+
d.append(re.sub(r'\W+', '_', json.dumps(data, sort_keys=True))
21+
.strip('_'))
22+
return os.path.join(*d)
23+
24+
25+
def store_results(dataset, count, definition, query_arguments, attrs, results):
26+
fn = get_result_filename(
27+
dataset, count, definition, query_arguments) + '.hdf5'
28+
head, tail = os.path.split(fn)
29+
if not os.path.isdir(head):
30+
os.makedirs(head)
31+
f = h5py.File(fn, 'w')
32+
for k, v in attrs.items():
33+
f.attrs[k] = v
34+
neighbors = f.create_dataset('neighbors', (len(results), count), 'i')
35+
for i, idxs in enumerate(results):
36+
neighbors[i] = idxs
37+
f.close()
38+
39+
40+
def load_all_results(dataset=None, count=None):
41+
for root, _, files in os.walk(get_result_filename(dataset, count)):
42+
for fn in files:
43+
if os.path.splitext(fn)[-1] != '.hdf5':
44+
continue
45+
try:
46+
f = h5py.File(os.path.join(root, fn), 'r+')
47+
properties = dict(f.attrs)
48+
yield properties, f
49+
f.close()
50+
except:
51+
print('Was unable to read', fn)
52+
traceback.print_exc()
53+
54+
55+
def get_unique_algorithms():
56+
return set(properties['algo'] for properties, _ in load_all_results())

‎benchmark/runner.py

+216
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
import argparse
2+
import json
3+
import logging
4+
import os
5+
import threading
6+
import time
7+
import traceback
8+
9+
import colors
10+
import docker
11+
import numpy
12+
import psutil
13+
14+
from benchmark.algorithms.definitions import (Definition,
15+
instantiate_algorithm)
16+
from benchmark.datasets import DATASETS
17+
from benchmark.results import store_results
18+
19+
20+
def run_individual_query(algo, X, distance, count, run_count, search_type):
21+
best_search_time = float('inf')
22+
for i in range(run_count):
23+
print('Run %d/%d...' % (i + 1, run_count))
24+
25+
start = time.time()
26+
if search_type == "knn":
27+
algo.query(X, count)
28+
else:
29+
algo.range_query(X, count)
30+
total = (time.time() - start)
31+
32+
results = algo.get_results()
33+
34+
search_time = total
35+
best_search_time = min(best_search_time, search_time)
36+
37+
attrs = {
38+
"best_search_time": best_search_time,
39+
"name": str(algo),
40+
"run_count": run_count,
41+
"distance": distance,
42+
"type": search_type,
43+
"count": int(count)
44+
}
45+
additional = algo.get_additional()
46+
for k in additional:
47+
attrs[k] = additional[k]
48+
return (attrs, results)
49+
50+
51+
def run(definition, dataset, count, run_count):
52+
algo = instantiate_algorithm(definition)
53+
assert not definition.query_argument_groups \
54+
or hasattr(algo, "set_query_arguments"), """\
55+
error: query argument groups have been specified for %s.%s(%s), but the \
56+
algorithm instantiated from it does not implement the set_query_arguments \
57+
function""" % (definition.module, definition.constructor, definition.arguments)
58+
59+
ds = DATASETS[dataset]
60+
#X_train = numpy.array(D['train'])
61+
X = ds.get_queries()
62+
distance = ds.distance()
63+
search_type = ds.search_type()
64+
print(f"Running {definition.algorithm} on {dataset}")
65+
print(fr"Got {len(X)} queries")
66+
67+
try:
68+
# Try loading the index from the file
69+
memory_usage_before = algo.get_memory_usage()
70+
if not algo.load_index(dataset):
71+
# Build the index if it is not available
72+
t0 = time.time()
73+
algo.fit(dataset)
74+
build_time = time.time() - t0
75+
print('Built index in', build_time)
76+
77+
index_size = algo.get_memory_usage() - memory_usage_before
78+
print('Index size: ', index_size)
79+
80+
query_argument_groups = definition.query_argument_groups
81+
# Make sure that algorithms with no query argument groups still get run
82+
# once by providing them with a single, empty, harmless group
83+
if not query_argument_groups:
84+
query_argument_groups = [[]]
85+
86+
for pos, query_arguments in enumerate(query_argument_groups, 1):
87+
print("Running query argument group %d of %d..." %
88+
(pos, len(query_argument_groups)))
89+
if query_arguments:
90+
algo.set_query_arguments(*query_arguments)
91+
descriptor, results = run_individual_query(
92+
algo, X, distance, count, run_count, search_type)
93+
# A bit unclear how to set this correctly if we usually load from file
94+
#descriptor["build_time"] = build_time
95+
descriptor["index_size"] = index_size
96+
descriptor["algo"] = definition.algorithm
97+
descriptor["dataset"] = dataset
98+
store_results(dataset, count, definition,
99+
query_arguments, descriptor, results)
100+
finally:
101+
algo.done()
102+
103+
104+
def run_from_cmdline():
105+
parser = argparse.ArgumentParser('''
106+
107+
NOTICE: You probably want to run.py rather than this script.
108+
109+
''')
110+
parser.add_argument(
111+
'--dataset',
112+
choices=DATASETS.keys(),
113+
help=f'Dataset to benchmark on.',
114+
required=True)
115+
parser.add_argument(
116+
'--algorithm',
117+
help='Name of algorithm for saving the results.',
118+
required=True)
119+
parser.add_argument(
120+
'--module',
121+
help='Python module containing algorithm. E.g. "ann_benchmarks.algorithms.annoy"',
122+
required=True)
123+
parser.add_argument(
124+
'--constructor',
125+
help='Constructer to load from module. E.g. "Annoy"',
126+
required=True)
127+
parser.add_argument(
128+
'--count',
129+
help='k: Number of nearest neighbours for the algorithm to return.',
130+
required=True,
131+
type=int)
132+
parser.add_argument(
133+
'--runs',
134+
help='Number of times to run the algorihm. Will use the fastest run-time over the bunch.',
135+
required=True,
136+
type=int)
137+
parser.add_argument(
138+
'build',
139+
help='JSON of arguments to pass to the constructor. E.g. ["angular", 100]'
140+
)
141+
parser.add_argument(
142+
'queries',
143+
help='JSON of arguments to pass to the queries. E.g. [100]',
144+
nargs='*',
145+
default=[])
146+
args = parser.parse_args()
147+
algo_args = json.loads(args.build)
148+
print(algo_args)
149+
query_args = [json.loads(q) for q in args.queries]
150+
151+
definition = Definition(
152+
algorithm=args.algorithm,
153+
docker_tag=None, # not needed
154+
module=args.module,
155+
constructor=args.constructor,
156+
arguments=algo_args,
157+
query_argument_groups=query_args,
158+
disabled=False
159+
)
160+
run(definition, args.dataset, args.count, args.runs)
161+
162+
163+
def run_docker(definition, dataset, count, runs, timeout, cpu_limit,
164+
mem_limit=None):
165+
cmd = ['--dataset', dataset,
166+
'--algorithm', definition.algorithm,
167+
'--module', definition.module,
168+
'--constructor', definition.constructor,
169+
'--runs', str(runs),
170+
'--count', str(count)]
171+
cmd.append(json.dumps(definition.arguments))
172+
cmd += [json.dumps(qag) for qag in definition.query_argument_groups]
173+
174+
client = docker.from_env()
175+
if mem_limit is None:
176+
mem_limit = psutil.virtual_memory().available
177+
178+
container = client.containers.run(
179+
definition.docker_tag,
180+
cmd,
181+
volumes={
182+
os.path.abspath('benchmark'):
183+
{'bind': '/home/app/benchmark', 'mode': 'ro'},
184+
os.path.abspath('data'):
185+
{'bind': '/home/app/data', 'mode': 'rw'},
186+
os.path.abspath('results'):
187+
{'bind': '/home/app/results', 'mode': 'rw'},
188+
},
189+
cpuset_cpus=cpu_limit,
190+
mem_limit=mem_limit,
191+
detach=True)
192+
logger = logging.getLogger(f"annb.{container.short_id}")
193+
194+
logger.info('Created container %s: CPU limit %s, mem limit %s, timeout %d, command %s' % \
195+
(container.short_id, cpu_limit, mem_limit, timeout, cmd))
196+
197+
def stream_logs():
198+
for line in container.logs(stream=True):
199+
logger.info(colors.color(line.decode().rstrip(), fg='blue'))
200+
201+
t = threading.Thread(target=stream_logs, daemon=True)
202+
t.start()
203+
204+
try:
205+
exit_code = container.wait(timeout=timeout)
206+
207+
# Exit if exit code
208+
if exit_code not in [0, None]:
209+
logger.error(colors.color(container.logs().decode(), fg='red'))
210+
logger.error('Child process for container %s raised exception %d' % (container.short_id, exit_code))
211+
except:
212+
logger.error('Container.wait for container %s failed with exception' % container.short_id)
213+
logger.error('Invoked with %s' % cmd)
214+
traceback.print_exc()
215+
finally:
216+
container.remove(force=True)

‎install/Dockerfile

+3-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ RUN apt-get install -y python3-numpy python3-scipy python3-pip build-essential g
55
RUN pip3 install -U pip
66

77
WORKDIR /home/app
8-
#COPY requirements.txt run_algorithm.py ./
9-
COPY requirements.txt ./
8+
COPY requirements.txt run_algorithm.py ./
109
RUN pip3 install -rrequirements.txt
10+
11+
ENTRYPOINT ["python3", "run_algorithm.py"]

‎install/Dockerfile.faiss

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM ann-benchmarks
1+
FROM billion-scale-benchmark
22

33
RUN apt-get update && apt-get install -y libopenblas-base libopenblas-dev libpython3-dev swig python3-dev libssl-dev wget
44
RUN wget https://github.com/Kitware/CMake/releases/download/v3.18.3/cmake-3.18.3-Linux-x86_64.sh && mkdir cmake && sh cmake-3.18.3-Linux-x86_64.sh --skip-license --prefix=cmake && rm cmake-3.18.3-Linux-x86_64.sh

‎logging.conf

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
[loggers]
2+
keys=root,annb
3+
4+
[handlers]
5+
keys=consoleHandler,fileHandler
6+
7+
[formatters]
8+
keys=simpleFormatter
9+
10+
[formatter_simpleFormatter]
11+
format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
12+
datefmt=
13+
14+
[handler_consoleHandler]
15+
class=StreamHandler
16+
level=INFO
17+
formatter=simpleFormatter
18+
args=(sys.stdout,)
19+
20+
[handler_fileHandler]
21+
class=FileHandler
22+
level=INFO
23+
formatter=simpleFormatter
24+
args=('annb.log','w')
25+
26+
[logger_root]
27+
level=WARN
28+
handlers=consoleHandler
29+
30+
[logger_annb]
31+
level=INFO
32+
handlers=consoleHandler,fileHandler
33+
qualname=annb
34+
propagate=0

‎plot.py

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import os
2+
import matplotlib as mpl
3+
mpl.use('Agg') # noqa
4+
import matplotlib.pyplot as plt
5+
import numpy as np
6+
import argparse
7+
8+
from benchmark.datasets import DATASETS
9+
from benchmark.algorithms.definitions import get_definitions
10+
from benchmark.plotting.metrics import all_metrics as metrics
11+
from benchmark.plotting.utils import (get_plot_label, compute_metrics,
12+
create_linestyles, create_pointset)
13+
from benchmark.results import (store_results, load_all_results,
14+
get_unique_algorithms)
15+
16+
17+
def create_plot(all_data, raw, x_scale, y_scale, xn, yn, fn_out, linestyles):
18+
xm, ym = (metrics[xn], metrics[yn])
19+
# Now generate each plot
20+
handles = []
21+
labels = []
22+
plt.figure(figsize=(12, 9))
23+
24+
# Sorting by mean y-value helps aligning plots with labels
25+
def mean_y(algo):
26+
xs, ys, ls, axs, ays, als = create_pointset(all_data[algo], xn, yn)
27+
return -np.log(np.array(ys)).mean()
28+
# Find range for logit x-scale
29+
min_x, max_x = 1, 0
30+
for algo in sorted(all_data.keys(), key=mean_y):
31+
xs, ys, ls, axs, ays, als = create_pointset(all_data[algo], xn, yn)
32+
min_x = min([min_x]+[x for x in xs if x > 0])
33+
max_x = max([max_x]+[x for x in xs if x < 1])
34+
color, faded, linestyle, marker = linestyles[algo]
35+
handle, = plt.plot(xs, ys, '-', label=algo, color=color,
36+
ms=7, mew=3, lw=3, linestyle=linestyle,
37+
marker=marker)
38+
handles.append(handle)
39+
if raw:
40+
handle2, = plt.plot(axs, ays, '-', label=algo, color=faded,
41+
ms=5, mew=2, lw=2, linestyle=linestyle,
42+
marker=marker)
43+
labels.append(algo)
44+
45+
ax = plt.gca()
46+
ax.set_ylabel(ym['description'])
47+
ax.set_xlabel(xm['description'])
48+
# Custom scales of the type --x-scale a3
49+
if x_scale[0] == 'a':
50+
alpha = int(x_scale[1:])
51+
fun = lambda x: 1-(1-x)**(1/alpha)
52+
inv_fun = lambda x: 1-(1-x)**alpha
53+
ax.set_xscale('function', functions=(fun, inv_fun))
54+
if alpha <= 3:
55+
ticks = [inv_fun(x) for x in np.arange(0,1.2,.2)]
56+
plt.xticks(ticks)
57+
if alpha > 3:
58+
from matplotlib import ticker
59+
ax.xaxis.set_major_formatter(ticker.LogitFormatter())
60+
#plt.xticks(ticker.LogitLocator().tick_values(min_x, max_x))
61+
plt.xticks([0, 1/2, 1-1e-1, 1-1e-2, 1-1e-3, 1-1e-4, 1])
62+
# Other x-scales
63+
else:
64+
ax.set_xscale(x_scale)
65+
ax.set_yscale(y_scale)
66+
ax.set_title(get_plot_label(xm, ym))
67+
box = plt.gca().get_position()
68+
# plt.gca().set_position([box.x0, box.y0, box.width * 0.8, box.height])
69+
ax.legend(handles, labels, loc='center left',
70+
bbox_to_anchor=(1, 0.5), prop={'size': 9})
71+
plt.grid(b=True, which='major', color='0.65', linestyle='-')
72+
plt.setp(ax.get_xminorticklabels(), visible=True)
73+
74+
# Logit scale has to be a subset of (0,1)
75+
if 'lim' in xm and x_scale != 'logit':
76+
x0, x1 = xm['lim']
77+
plt.xlim(max(x0,0), min(x1,1))
78+
elif x_scale == 'logit':
79+
plt.xlim(min_x, max_x)
80+
if 'lim' in ym:
81+
plt.ylim(ym['lim'])
82+
83+
# Workaround for bug https://github.com/matplotlib/matplotlib/issues/6789
84+
ax.spines['bottom']._adjust_location()
85+
86+
plt.savefig(fn_out, bbox_inches='tight')
87+
plt.close()
88+
89+
90+
if __name__ == "__main__":
91+
parser = argparse.ArgumentParser()
92+
parser.add_argument(
93+
'--dataset',
94+
metavar="DATASET",
95+
default='sift-1M')
96+
parser.add_argument(
97+
'--count',
98+
default=10)
99+
parser.add_argument(
100+
'--definitions',
101+
metavar='FILE',
102+
help='load algorithm definitions from FILE',
103+
default='algos.yaml')
104+
parser.add_argument(
105+
'--limit',
106+
default=-1)
107+
parser.add_argument(
108+
'-o', '--output')
109+
parser.add_argument(
110+
'-x', '--x-axis',
111+
help='Which metric to use on the X-axis',
112+
choices=metrics.keys(),
113+
default="k-nn")
114+
parser.add_argument(
115+
'-y', '--y-axis',
116+
help='Which metric to use on the Y-axis',
117+
choices=metrics.keys(),
118+
default="qps")
119+
parser.add_argument(
120+
'-X', '--x-scale',
121+
help='Scale to use when drawing the X-axis. Typically linear, logit or a2',
122+
default='linear')
123+
parser.add_argument(
124+
'-Y', '--y-scale',
125+
help='Scale to use when drawing the Y-axis',
126+
choices=["linear", "log", "symlog", "logit"],
127+
default='linear')
128+
parser.add_argument(
129+
'--raw',
130+
help='Show raw results (not just Pareto frontier) in faded colours',
131+
action='store_true')
132+
parser.add_argument(
133+
'--recompute',
134+
help='Clears the cache and recomputes the metrics',
135+
action='store_true')
136+
args = parser.parse_args()
137+
138+
if not args.output:
139+
args.output = 'results/%s.png' % (args.dataset)
140+
print('writing output to %s' % args.output)
141+
142+
dataset = DATASETS[args.dataset]
143+
count = int(args.count)
144+
unique_algorithms = get_unique_algorithms()
145+
results = load_all_results(args.dataset, count)
146+
linestyles = create_linestyles(sorted(unique_algorithms))
147+
runs = compute_metrics(dataset.get_groundtruth(k=args.count),
148+
results, args.x_axis, args.y_axis, args.recompute)
149+
if not runs:
150+
raise Exception('Nothing to plot')
151+
152+
create_plot(runs, args.raw, args.x_scale,
153+
args.y_scale, args.x_axis, args.y_axis, args.output,
154+
linestyles)

‎run.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from benchmark.main import main
2+
from multiprocessing import freeze_support
3+
4+
if __name__ == "__main__":
5+
freeze_support()
6+
main()

‎run_algorithm.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from benchmark.runner import run_from_cmdline
2+
3+
run_from_cmdline()

0 commit comments

Comments
 (0)
Please sign in to comment.