Skip to content

Commit 10ff9ed

Browse files
author
Che-Yu Wu
committed
Fix iree-tf/benchmark-model.py
1 parent 064f2aa commit 10ff9ed

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

iree-tf/benchmark/benchmark_model.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from typing import Optional
1616

1717
# Add library dir to the search path.
18-
sys.path.insert(0, str(pathlib.Path(__file__).parent.parent / "library"))
18+
sys.path.insert(0, str(pathlib.Path(__file__).parents[1] / "library"))
1919
from models import resnet50, bert_large, t5_large
2020

2121
# Add benchmark definitions to the search path.
@@ -46,6 +46,15 @@ def benchmark_lookup(unique_id: str):
4646
raise ValueError(f"Model definition not supported")
4747

4848

49+
def benchmark_lookup(benchmark_id: str):
50+
benchmark = tf_inference_benchmarks.ID_TO_BENCHMARK_MAP.get(benchmark_id)
51+
if benchmark is None:
52+
raise ValueError(f"Id {benchmark_id} does not exist in benchmark suite.")
53+
54+
model_name, model_class = model_lookup(benchmark.model.id)
55+
return (model_name, model_class, benchmark)
56+
57+
4958
def dump_result(file_path: str, result: dict) -> None:
5059
with open(file_path, "r") as f:
5160
dictObj = json.load(f)
@@ -66,7 +75,8 @@ def bytes_to_mb(bytes: Optional[int]) -> Optional[float]:
6675
def run_framework_benchmark(model_name: str, model_class: type[tf.Module],
6776
batch_size: int, warmup_iterations: int,
6877
benchmark_iterations: int, tf_device: str,
69-
hlo_dump_dir: str, dump_hlo: bool, shared_dict) -> None:
78+
hlo_dump_dir: str, dump_hlo: bool,
79+
shared_dict) -> None:
7080
try:
7181
with tf.device(tf_device):
7282
if dump_hlo:
@@ -216,17 +226,16 @@ def run_compiler_benchmark(hlo_benchmark_tool_path: str, hlo_dir: str,
216226

217227
args = argParser.parse_args()
218228

219-
model_name, model_class, model_definition = benchmark_lookup(
220-
args.benchmark_id)
229+
model_name, model_class, benchmark = benchmark_lookup(args.benchmark_id)
221230
print(
222-
f"\n\n--- {model_name} {args.benchmark_id} -------------------------------------"
231+
f"\n\n--- {benchmark.name} {args.benchmark_id} -------------------------------------"
223232
)
224233

225234
if os.path.exists(_HLO_DUMP_DIR):
226235
shutil.rmtree(_HLO_DUMP_DIR)
227236
os.mkdir(_HLO_DUMP_DIR)
228237

229-
batch_size = model_definition.input_batch_size
238+
batch_size = benchmark.input_batch_size
230239
benchmark_definition = {
231240
"benchmark_id": args.benchmark_id,
232241
"benchmark_name": model_definition.name,
@@ -248,9 +257,9 @@ def run_compiler_benchmark(hlo_benchmark_tool_path: str, hlo_dir: str,
248257
shared_dict = manager.dict()
249258

250259
if args.run_in_process:
251-
run_framework_benchmark(model_name, model_class, batch_size, args.warmup_iterations,
252-
args.iterations, tf_device, _HLO_DUMP_DIR, dump_hlo,
253-
shared_dict)
260+
run_framework_benchmark(model_name, model_class, batch_size,
261+
args.warmup_iterations, args.iterations,
262+
tf_device, _HLO_DUMP_DIR, dump_hlo, shared_dict)
254263
else:
255264
p = multiprocessing.Process(target=run_framework_benchmark,
256265
args=(model_name, model_class, batch_size,
@@ -269,8 +278,10 @@ def run_compiler_benchmark(hlo_benchmark_tool_path: str, hlo_dir: str,
269278
shared_dict = manager.dict()
270279

271280
if args.run_in_process:
272-
run_compiler_benchmark(args.hlo_benchmark_path, _HLO_DUMP_DIR, args.hlo_iterations,
273-
"cuda" if args.device == "gpu" else "cpu", shared_dict)
281+
run_compiler_benchmark(args.hlo_benchmark_path, _HLO_DUMP_DIR,
282+
args.hlo_iterations,
283+
"cuda" if args.device == "gpu" else "cpu",
284+
shared_dict)
274285
else:
275286
p = multiprocessing.Process(
276287
target=run_compiler_benchmark,

0 commit comments

Comments
 (0)