@@ -56,8 +56,6 @@ class Dataset:
56
56
name : str
57
57
path : Path
58
58
group : str
59
- data_format : str
60
- num_classes : int
61
59
num_repeat : int = 1
62
60
extra_overrides : dict | None = None
63
61
@@ -155,10 +153,6 @@ def run(
155
153
str (data_root ),
156
154
"--work_dir" ,
157
155
str (sub_work_dir ),
158
- "--model.num_classes" ,
159
- str (dataset .num_classes ),
160
- "--data.config.data_format" ,
161
- dataset .data_format ,
162
156
"--engine.device" ,
163
157
self .accelerator ,
164
158
]
@@ -172,7 +166,10 @@ def run(
172
166
start_time = time ()
173
167
self ._run_command (command )
174
168
extra_metrics = {"train/e2e_time" : time () - start_time }
175
- self ._rename_raw_data (work_dir = sub_work_dir / ".latest" / "train" , replaces = {"epoch" : "train/epoch" })
169
+ self ._rename_raw_data (
170
+ work_dir = sub_work_dir / ".latest" / "train" ,
171
+ replaces = {"train_" : "train/" , "{pre}" : "train/" },
172
+ )
176
173
self ._log_metrics (
177
174
work_dir = sub_work_dir / ".latest" / "train" ,
178
175
tags = tags ,
@@ -187,6 +184,10 @@ def run(
187
184
str (sub_work_dir ),
188
185
]
189
186
self ._run_command (command )
187
+ self ._rename_raw_data (
188
+ work_dir = sub_work_dir / ".latest" / "test" ,
189
+ replaces = {"test_" : "test/" , "{pre}" : "test/" },
190
+ )
190
191
self ._log_metrics (work_dir = sub_work_dir / ".latest" / "test" , tags = tags , criteria = criteria )
191
192
192
193
# Export & test
@@ -215,7 +216,10 @@ def run(
215
216
]
216
217
self ._run_command (command )
217
218
218
- self ._rename_raw_data (work_dir = sub_work_dir / ".latest" / "test" , replaces = {"test" : "export" })
219
+ self ._rename_raw_data (
220
+ work_dir = sub_work_dir / ".latest" / "test" ,
221
+ replaces = {"test" : "export" , "{pre}" : "export/" },
222
+ )
219
223
self ._log_metrics (work_dir = sub_work_dir / ".latest" / "test" , tags = tags , criteria = criteria )
220
224
221
225
# Optimize & test
@@ -250,7 +254,10 @@ def run(
250
254
]
251
255
self ._run_command (command )
252
256
253
- self ._rename_raw_data (work_dir = sub_work_dir / ".latest" / "test" , replaces = {"test" : "optimize" })
257
+ self ._rename_raw_data (
258
+ work_dir = sub_work_dir / ".latest" / "test" ,
259
+ replaces = {"test" : "optimize" , "{pre}" : "optimize/" },
260
+ )
254
261
self ._log_metrics (work_dir = sub_work_dir / ".latest" / "test" , tags = tags , criteria = criteria )
255
262
256
263
# Force memory clean up
@@ -310,11 +317,25 @@ def _log_metrics(
310
317
metrics .to_csv (work_dir / "benchmark.raw.csv" , index = False )
311
318
312
319
def _rename_raw_data (self , work_dir : Path , replaces : dict [str , str ]) -> None :
320
+ replaces = {** self .NAME_MAPPING , ** replaces }
321
+
322
+ def _rename_col (col_name : str ) -> str :
323
+ for src_str , dst_str in replaces .items ():
324
+ if src_str == "{pre}" :
325
+ if not col_name .startswith (dst_str ):
326
+ col_name = dst_str + col_name
327
+ elif src_str == "{post}" :
328
+ if not col_name .endswith (dst_str ):
329
+ col_name = col_name + dst_str
330
+ else :
331
+ col_name = col_name .replace (src_str , dst_str )
332
+ return col_name
333
+
313
334
csv_files = work_dir .glob ("**/metrics.csv" )
314
335
for csv_file in csv_files :
315
336
data = pd .read_csv (csv_file )
316
- for src_str , dst_str in replaces . items ():
317
- data . columns = data .columns . str . replace (src_str , dst_str )
337
+ data = data . rename ( columns = _rename_col ) # Column names
338
+ data = data .replace (replaces ) # Values
318
339
data .to_csv (csv_file , index = False )
319
340
320
341
@staticmethod
@@ -338,7 +359,7 @@ def load_result(result_path: Path) -> pd.DataFrame | None:
338
359
return pd .concat (results , ignore_index = True ).set_index (["task" , "model" , "data_group" , "data" ])
339
360
340
361
@staticmethod
341
- def average_result (data : pd .DataFrame , keys : list [str ]) -> pd .DataFrame :
362
+ def average_result (data : pd .DataFrame , keys : list [str ]) -> pd .DataFrame | None :
342
363
"""Average result w.r.t. given keys
343
364
344
365
Args:
@@ -348,6 +369,9 @@ def average_result(data: pd.DataFrame, keys: list[str]) -> pd.DataFrame:
348
369
Retruns:
349
370
pd.DataFrame: Averaged result table
350
371
"""
372
+ if data is None :
373
+ return None
374
+
351
375
# Flatten index
352
376
index_names = data .index .names
353
377
column_names = data .columns
@@ -391,3 +415,5 @@ def check(self, result: pd.DataFrame, criteria: list[Criterion]):
391
415
392
416
for criterion in criteria :
393
417
criterion (result_entry , target_entry )
418
+
419
+ NAME_MAPPING : dict [str , str ] = {} # noqa: RUF012
0 commit comments