Skip to content

Commit b1bfb77

Browse files
committed
Add YOLOInference for downstream inference frameworks
1 parent e9567cc commit b1bfb77

File tree

3 files changed

+98
-0
lines changed

3 files changed

+98
-0
lines changed

yolort/models/yolo.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def load_from_yolov5(
190190
score_thresh: float = 0.25,
191191
nms_thresh: float = 0.45,
192192
version: str = "r6.0",
193+
post_process: Optional[nn.Module] = None,
193194
):
194195
"""
195196
Load model state from the checkpoint trained by YOLOv5.
@@ -220,6 +221,7 @@ def load_from_yolov5(
220221
anchor_grids=model_info["anchor_grids"],
221222
score_thresh=score_thresh,
222223
nms_thresh=nms_thresh,
224+
post_process=post_process,
223225
)
224226

225227
model.load_state_dict(model_info["state_dict"])

yolort/relaying/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
22
from .trace_wrapper import get_trace_module
3+
from .yolo_inference import YOLOInference
4+
5+
__all__ = ["get_trace_module", "YOLOInference"]

yolort/relaying/yolo_inference.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
2+
from typing import Dict, List, Tuple
3+
4+
import torch
5+
from torch import nn, Tensor
6+
from yolort.models import YOLO
7+
from yolort.models._utils import decode_single
8+
from yolort.models.box_head import _concat_pred_logits
9+
10+
__all__ = ["YOLOInference"]
11+
12+
13+
class YOLOInference(YOLO):
14+
"""
15+
A deployment friendly wrapper of YOLO.
16+
17+
Remove the ``torchvision::nms`` in this warpper, due to the fact that some third-party
18+
inference frameworks currently do not support this operator very well.
19+
"""
20+
21+
def __init__(
22+
self,
23+
checkpoint_path: str,
24+
score_thresh: float = 0.25,
25+
version: str = "r6.0",
26+
):
27+
post_process = PostProcess(score_thresh)
28+
29+
self.model = YOLO.load_from_yolov5(
30+
checkpoint_path,
31+
version=version,
32+
post_process=post_process,
33+
)
34+
35+
def forward(self, inputs: Tensor):
36+
"""
37+
Args:
38+
inputs (Tensor): batched images, of shape [batch_size x 3 x H x W]
39+
"""
40+
# Compute the detections
41+
outputs = self.model(inputs)
42+
43+
return outputs
44+
45+
46+
class PostProcess(nn.Module):
47+
"""
48+
This is a simplified version of PostProcess to remove the ``torchvision::nms`` module.
49+
50+
Args:
51+
score_thresh (float): Score threshold used for postprocessing the detections.
52+
"""
53+
54+
def __init__(self, score_thresh: float) -> None:
55+
super().__init__()
56+
self.score_thresh = score_thresh
57+
58+
def forward(
59+
self,
60+
head_outputs: List[Tensor],
61+
anchors_tuple: Tuple[Tensor, Tensor, Tensor],
62+
) -> List[Dict[str, Tensor]]:
63+
"""
64+
Just concat the predict logits, ignore the original ``torchvision::nms`` module
65+
from original ``yolort.models.box_head.PostProcess``.
66+
67+
Args:
68+
head_outputs (List[Tensor]): The predicted locations and class/object confidence,
69+
shape of the element is (N, A, H, W, K).
70+
anchors_tuple (Tuple[Tensor, Tensor, Tensor]):
71+
"""
72+
batch_size = len(head_outputs[0])
73+
74+
all_pred_logits = _concat_pred_logits(head_outputs)
75+
76+
detections: List[Dict[str, Tensor]] = []
77+
78+
for idx in range(batch_size): # image idx, image inference
79+
pred_logits = torch.sigmoid(all_pred_logits[idx])
80+
81+
# Compute conf
82+
# box_conf x class_conf, w/ shape: num_anchors x num_classes
83+
scores = pred_logits[:, 5:] * pred_logits[:, 4:5]
84+
85+
boxes = decode_single(pred_logits[:, :4], anchors_tuple)
86+
87+
# remove low scoring boxes
88+
inds, labels = torch.where(scores > self.score_thresh)
89+
boxes, scores = boxes[inds], scores[inds, labels]
90+
91+
detections.append({"scores": scores, "labels": labels, "boxes": boxes})
92+
93+
return detections

0 commit comments

Comments
 (0)