Skip to content

Commit

Permalink
Merge pull request #8 from PINTO0309/post-process
Browse files Browse the repository at this point in the history
後処理をONNXに全部マージしたバージョンを作ってみました
  • Loading branch information
Kazuhito00 authored Jul 10, 2023
2 parents 48ddd9f + 86d7f3d commit 07f4d32
Show file tree
Hide file tree
Showing 16 changed files with 1,261 additions and 14 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,7 @@ dmypy.json

# mp4
*.mp4

# tensorrt
*.engine
*.profile
29 changes: 29 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "simple_demo",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"args": [
"--width", "640",
"--height", "480",
]
},
{
"name": "simple_demo_with_post",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": true,
"args": [
"--width", "640",
"--height", "480",
]
}
]
}
29 changes: 22 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Deep写輪眼:オブジェクト検出 YOLOX を用いた NARUTO の印認識<
* onnxruntime 1.10.0 or Later
* OpenCV 3.4.2 or Later
* Pillow 6.1.0 or Later (Ninjutsu_demo.pyを動かす場合のみ)
* Tensorflow 2.3.0 or Later (SSD、EfficientDetを動かす場合のみ)
* Tensorflow 2.3.0 or Later (SSD、EfficientDetを動かす場合、あるいは後処理をONNXへマージするときのみ)

# DataSet
### データセットについて
Expand Down Expand Up @@ -138,17 +138,28 @@ Issueで誤検出した条件を教えていただると助かります。<br>
<pre>
│ simple_demo.py
│ Ninjutsu_demo.py
├─model
│ └─yolox
│ │ yolox_nano.onnx
│ └─yolox_onnx.py
├─post_process_gen_tools
│ │ convert_script.sh
│ │ make_box_gather_nd.py
│ │ make_boxes_scores.py
│ │ make_cxcywh_y1x1y2x2.py
│ │ make_final_batch_nums_final_class_nums_final_box_nums.py
│ │ make_grids.py
│ │ make_input_output_shape_update.py
│ │ make_nms_outputs_merge.py
│ └─make_score_gather_nd.py
├─setting─┬─labels.csv
│ └─jutsu.csv
├─utils
└─_legacy
</pre>
#### simple_demo.py
Expand All @@ -164,6 +175,9 @@ Issueで誤検出した条件を教えていただると助かります。<br>
#### model
 訓練済みモデルを格納しています。

#### post_process_gen_tools
 ONNXにすべての後処理をマージするスクリプト群を格納しています。

#### setting
 ラベルデータ(labels.csv)と術名データ(jutsu.csv)を格納しています。
* labels.csv<br>
Expand All @@ -186,13 +200,14 @@ Issueで誤検出した条件を教えていただると助かります。<br>
デモの実行方法は以下です。
```bash
python simple_demo.py
python simple_demo_without_post.py
python Ninjutsu_demo.py
```

また、デモ実行時には、以下のオプションが指定可能です。
<details>
<summary>オプション指定</summary>

* --device<br>
カメラデバイス番号の指定<br>
デフォルト:
Expand Down Expand Up @@ -308,7 +323,7 @@ YOLOXのトレーニングには<span id="cite_ref-7">YOLOX-Colaboratory-Trainin
# Affiliations(所属)
-->

# License
# License
NARUTO-HandSignDetection is under [MIT license](https://en.wikipedia.org/wiki/MIT_License).

# License(Font)
Expand Down
26 changes: 19 additions & 7 deletions README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ The speed has been greatly improved compared to the Deep写輪眼(using Efficien
* onnxruntime 1.10.0 or Later
* OpenCV 3.4.2 or Later
* Pillow 6.1.0 or Later (Only when running Ninjutsu_demo.py)
* Tensorflow 2.3.0 or Later (Only when running SSD or EfficientDet)
* Tensorflow 2.3.0 or Later (Only when running SSD or EfficientDet, or only when merging post-processing into ONNX)

# DataSet
### About the dataset
Expand Down Expand Up @@ -137,17 +137,28 @@ The trained model is published under the 'model' directory. * Move the old versi
<pre>
│ simple_demo.py
│ Ninjutsu_demo.py
├─model
│ └─yolox
│ │ yolox_nano.onnx
│ └─yolox_onnx.py
├─post_process_gen_tools
│ │ convert_script.sh
│ │ make_box_gather_nd.py
│ │ make_boxes_scores.py
│ │ make_cxcywh_y1x1y2x2.py
│ │ make_final_batch_nums_final_class_nums_final_box_nums.py
│ │ make_grids.py
│ │ make_input_output_shape_update.py
│ │ make_nms_outputs_merge.py
│ └─make_score_gather_nd.py
├─setting─┬─labels.csv
│ └─jutsu.csv
├─utils
└─_legacy
</pre>
#### simple_demo.py
Expand Down Expand Up @@ -185,13 +196,14 @@ The name of the Ninjutsu name and the required hand-sign are listed.<br>
Here's how to run the demo.
```bash
python simple_demo.py
python simple_demo_without_post.py
python Ninjutsu_demo.py
```

In addition, the following options can be specified when running the demo.
<details>
<summary>Option specification</summary>

* --device<br>
Camera device number<br>
Default:
Expand Down Expand Up @@ -308,7 +320,7 @@ Kazuhito Takahashi(https://twitter.com/KzhtTkhs)
# Affiliations(所属)
-->

# License
# License
NARUTO-HandSignDetection is under [MIT license](https://en.wikipedia.org/wiki/MIT_License).

# License(Font)
Expand Down
Binary file added model/yolox/yolox_nano_with_post.onnx
Binary file not shown.
113 changes: 113 additions & 0 deletions model/yolox/yolox_onnx_without_post.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import copy

import cv2
import numpy as np
import onnxruntime


class YoloxONNX(object):
def __init__(
self,
model_path='yolox_nano.onnx',
input_shape=(416, 416),
class_score_th=0.3,
with_p6=False,
providers=[
(
'TensorrtExecutionProvider', {
'trt_engine_cache_enable': True,
'trt_engine_cache_path': '.',
'trt_fp16_enable': True,
}
),
'CUDAExecutionProvider',
'CPUExecutionProvider',
],
):
# 入力サイズ
self.input_shape = input_shape

# 閾値
self.class_score_th = class_score_th
self.with_p6 = with_p6

# モデル読み込み
self.onnx_session = onnxruntime.InferenceSession(
model_path,
providers=providers,
)

self.input_name = self.onnx_session.get_inputs()[0].name
self.output_name = self.onnx_session.get_outputs()[0].name

def inference(self, image):
temp_image = copy.deepcopy(image)
image_height, image_width = image.shape[0], image.shape[1]

# 前処理
image, ratio = self._preprocess(temp_image, self.input_shape)

# 推論実施
results = self.onnx_session.run(
None,
{self.input_name: image[None, :, :, :]},
)

# 後処理
bboxes, scores, class_ids = self._postprocess(
results[0],
ratio,
image_width,
image_height,
)

return bboxes, scores, class_ids

def _preprocess(self, image, input_size, swap=(2, 0, 1)):
if len(image.shape) == 3:
padded_image = np.ones(
(input_size[0], input_size[1], 3), dtype=np.uint8) * 114
else:
padded_image = np.ones(input_size, dtype=np.uint8) * 114

ratio = min(input_size[0] / image.shape[0],
input_size[1] / image.shape[1])
resized_image = cv2.resize(
image,
(int(image.shape[1] * ratio), int(image.shape[0] * ratio)),
interpolation=cv2.INTER_LINEAR,
)
resized_image = resized_image.astype(np.uint8)

padded_image[:resized_image.shape[0], :resized_image.shape[1]] = resized_image
padded_image = padded_image.transpose(swap)
padded_image = np.ascontiguousarray(padded_image, dtype=np.float32)

return padded_image, ratio

def _postprocess(
self,
dets: np.ndarray,
ratio,
max_width: int,
max_height: int,
):
bbox = np.array([])
score = np.array([])
class_id = np.array([])
if dets is not None and dets.shape[0] >= 1:
class_ids, scores, bboxes = dets[..., 1:2], dets[..., 2:3], dets[..., 3:]
keep_idx = np.argmax(scores, axis=0)
class_id = class_ids[keep_idx, ...]
score = scores[keep_idx, ...]
bbox = bboxes[keep_idx, ...][0]
bbox /= ratio
bbox[0] = max(0, bbox[0])
bbox[1] = max(0, bbox[1])
bbox[2] = min(bbox[2], max_width)
bbox[3] = min(bbox[3], max_height)
bbox = bbox[np.newaxis, :]

return bbox.astype(np.float32), score.astype(np.float32), class_id.astype(np.int32)
Loading

0 comments on commit 07f4d32

Please sign in to comment.