Skip to content

Commit 2331417

Browse files
authored
Fix --list-algorithms using path names instead of algorithm names (fi… (#569)
* fix --list-algorithms using path names instead of algorithm names (fixes #555)
1 parent e38c914 commit 2331417

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

ann_benchmarks/definitions.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -145,16 +145,15 @@ def load_configs(point_type: str, base_dir: str = "ann_benchmarks/algorithms") -
145145
print(f"Error loading YAML from {config_file}: {e}")
146146
return configs
147147

148-
def _get_definitions(base_dir: str = "ann_benchmarks/algorithms") -> Dict[str, Dict[str, Any]]:
149-
"""Load algorithm configurations for a given point_type."""
148+
def _get_definitions(base_dir: str = "ann_benchmarks/algorithms") -> List[Dict[str, Any]]:
149+
"""Load algorithm configurations."""
150150
config_files = get_config_files(base_dir=base_dir)
151-
configs = {}
151+
configs = []
152152
for config_file in config_files:
153153
with open(config_file, 'r') as stream:
154154
try:
155155
config_data = yaml.safe_load(stream)
156-
algorithm_name = os.path.basename(os.path.dirname(config_file))
157-
configs[algorithm_name] = config_data
156+
configs.append(config_data)
158157
except yaml.YAMLError as e:
159158
print(f"Error loading YAML from {config_file}: {e}")
160159
return configs
@@ -211,16 +210,27 @@ def list_algorithms(base_dir: str = "ann_benchmarks/algorithms") -> None:
211210
base_dir (str, optional): The base directory where the algorithms are stored.
212211
Defaults to "ann_benchmarks/algorithms".
213212
"""
214-
definitions = _get_definitions(base_dir)
215-
216-
print("The following algorithms are supported...", definitions)
217-
for algorithm in definitions:
213+
all_configs = _get_definitions(base_dir)
214+
data = {}
215+
for algo_configs in all_configs:
216+
for point_type, config_for_point_type in algo_configs.items():
217+
for metric, ccc in config_for_point_type.items():
218+
algo_name = ccc[0]["name"]
219+
if algo_name not in data:
220+
data[algo_name] = {}
221+
if point_type not in data[algo_name]:
222+
data[algo_name][point_type] = []
223+
data[algo_name][point_type].append(metric)
224+
225+
print("The following algorithms are supported:", ", ".join(data))
226+
print("Details of supported metrics and data types: ")
227+
for algorithm in data:
218228
print('\t... for the algorithm "%s"...' % algorithm)
219229

220-
for point_type in definitions[algorithm]:
230+
for point_type in data[algorithm]:
221231
print('\t\t... and the point type "%s", metrics: ' % point_type)
222232

223-
for metric in definitions[algorithm][point_type]:
233+
for metric in data[algorithm][point_type]:
224234
print("\t\t\t%s" % metric)
225235

226236

0 commit comments

Comments
 (0)