diff --git a/tools/infer/utility.py b/tools/infer/utility.py index 57b585601fe..78406dc0a18 100644 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -225,11 +225,13 @@ def create_predictor(args, mode, logger): else: file_names = ["model", "inference"] + model_formats = ["pdmodel", "json"] for file_name in file_names: - model_file_path = "{}/{}.pdmodel".format(model_dir, file_name) params_file_path = "{}/{}.pdiparams".format(model_dir, file_name) - if os.path.exists(model_file_path) and os.path.exists(params_file_path): - break + for model_format in model_formats: + model_file_path = "{}/{}.{}".format(model_dir, file_name, model_format) + if os.path.exists(model_file_path) and os.path.exists(params_file_path): + break if not os.path.exists(model_file_path): raise ValueError( "not find model.pdmodel or inference.pdmodel in {}".format(model_dir)