15
15
from typing import Optional
16
16
17
17
# Add library dir to the search path.
18
- sys .path .insert (0 , str (pathlib .Path (__file__ ).parent . parent / "library" ))
18
+ sys .path .insert (0 , str (pathlib .Path (__file__ ).parents [ 1 ] / "library" ))
19
19
from models import resnet50 , bert_large , t5_large
20
20
21
21
# Add benchmark definitions to the search path.
@@ -46,6 +46,15 @@ def benchmark_lookup(unique_id: str):
46
46
raise ValueError (f"Model definition not supported" )
47
47
48
48
49
+ def benchmark_lookup (benchmark_id : str ):
50
+ benchmark = tf_inference_benchmarks .ID_TO_BENCHMARK_MAP .get (benchmark_id )
51
+ if benchmark is None :
52
+ raise ValueError (f"Id { benchmark_id } does not exist in benchmark suite." )
53
+
54
+ model_name , model_class = model_lookup (benchmark .model .id )
55
+ return (model_name , model_class , benchmark )
56
+
57
+
49
58
def dump_result (file_path : str , result : dict ) -> None :
50
59
with open (file_path , "r" ) as f :
51
60
dictObj = json .load (f )
@@ -66,7 +75,8 @@ def bytes_to_mb(bytes: Optional[int]) -> Optional[float]:
66
75
def run_framework_benchmark (model_name : str , model_class : type [tf .Module ],
67
76
batch_size : int , warmup_iterations : int ,
68
77
benchmark_iterations : int , tf_device : str ,
69
- hlo_dump_dir : str , dump_hlo : bool , shared_dict ) -> None :
78
+ hlo_dump_dir : str , dump_hlo : bool ,
79
+ shared_dict ) -> None :
70
80
try :
71
81
with tf .device (tf_device ):
72
82
if dump_hlo :
@@ -216,17 +226,16 @@ def run_compiler_benchmark(hlo_benchmark_tool_path: str, hlo_dir: str,
216
226
217
227
args = argParser .parse_args ()
218
228
219
- model_name , model_class , model_definition = benchmark_lookup (
220
- args .benchmark_id )
229
+ model_name , model_class , benchmark = benchmark_lookup (args .benchmark_id )
221
230
print (
222
- f"\n \n --- { model_name } { args .benchmark_id } -------------------------------------"
231
+ f"\n \n --- { benchmark . name } { args .benchmark_id } -------------------------------------"
223
232
)
224
233
225
234
if os .path .exists (_HLO_DUMP_DIR ):
226
235
shutil .rmtree (_HLO_DUMP_DIR )
227
236
os .mkdir (_HLO_DUMP_DIR )
228
237
229
- batch_size = model_definition .input_batch_size
238
+ batch_size = benchmark .input_batch_size
230
239
benchmark_definition = {
231
240
"benchmark_id" : args .benchmark_id ,
232
241
"benchmark_name" : model_definition .name ,
@@ -248,9 +257,9 @@ def run_compiler_benchmark(hlo_benchmark_tool_path: str, hlo_dir: str,
248
257
shared_dict = manager .dict ()
249
258
250
259
if args .run_in_process :
251
- run_framework_benchmark (model_name , model_class , batch_size , args . warmup_iterations ,
252
- args .iterations , tf_device , _HLO_DUMP_DIR , dump_hlo ,
253
- shared_dict )
260
+ run_framework_benchmark (model_name , model_class , batch_size ,
261
+ args .warmup_iterations , args . iterations ,
262
+ tf_device , _HLO_DUMP_DIR , dump_hlo , shared_dict )
254
263
else :
255
264
p = multiprocessing .Process (target = run_framework_benchmark ,
256
265
args = (model_name , model_class , batch_size ,
@@ -269,8 +278,10 @@ def run_compiler_benchmark(hlo_benchmark_tool_path: str, hlo_dir: str,
269
278
shared_dict = manager .dict ()
270
279
271
280
if args .run_in_process :
272
- run_compiler_benchmark (args .hlo_benchmark_path , _HLO_DUMP_DIR , args .hlo_iterations ,
273
- "cuda" if args .device == "gpu" else "cpu" , shared_dict )
281
+ run_compiler_benchmark (args .hlo_benchmark_path , _HLO_DUMP_DIR ,
282
+ args .hlo_iterations ,
283
+ "cuda" if args .device == "gpu" else "cpu" ,
284
+ shared_dict )
274
285
else :
275
286
p = multiprocessing .Process (
276
287
target = run_compiler_benchmark ,
0 commit comments