Skip to content

Commit

Permalink
[third-party] Fix the issue of inference errors with KE models in ONN…
Browse files Browse the repository at this point in the history
…X format (PaddlePaddle#14138)

* fix inference KIE model using onnx model

* fix code style

* fix onnx inputs compatiblility with det and rec

* fix code style
  • Loading branch information
Alex37882388 authored Nov 1, 2024
1 parent d3d7e85 commit 58e876d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
18 changes: 13 additions & 5 deletions ppstructure/kie/predict_kie_token_ser.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

class SerPredictor(object):
def __init__(self, args):
self.args = args
self.ocr_engine = PaddleOCR(
use_angle_cls=args.use_angle_cls,
det_model_dir=args.det_model_dir,
Expand Down Expand Up @@ -113,15 +114,22 @@ def __call__(self, img):
data[idx] = np.expand_dims(data[idx], axis=0)
else:
data[idx] = [data[idx]]
if self.args.use_onnx:
input_tensor = {
name: data[idx] for idx, name in enumerate(self.input_tensor)
}
self.output_tensors = self.predictor.run(None, input_tensor)
else:
for idx in range(len(self.input_tensor)):
self.input_tensor[idx].copy_from_cpu(data[idx])

for idx in range(len(self.input_tensor)):
self.input_tensor[idx].copy_from_cpu(data[idx])

self.predictor.run()
self.predictor.run()

outputs = []
for output_tensor in self.output_tensors:
output = output_tensor.copy_to_cpu()
output = (
output_tensor if self.args.use_onnx else output_tensor.copy_to_cpu()
)
outputs.append(output)
preds = outputs[0]

Expand Down
8 changes: 7 additions & 1 deletion tools/infer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,13 @@ def create_predictor(args, mode, logger):
providers=["CPUExecutionProvider"],
sess_options=sess_options,
)
return sess, sess.get_inputs()[0], None, None
inputs = sess.get_inputs()
return (
sess,
inputs[0] if len(inputs) == 1 else [vo.name for vo in inputs],
None,
None,
)

else:
file_names = ["model", "inference"]
Expand Down

0 comments on commit 58e876d

Please sign in to comment.