Skip to content

Commit

Permalink
Add YOLOInference for downstream inference frameworks
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Nov 27, 2021
1 parent e9567cc commit b1bfb77
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 0 deletions.
2 changes: 2 additions & 0 deletions yolort/models/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def load_from_yolov5(
score_thresh: float = 0.25,
nms_thresh: float = 0.45,
version: str = "r6.0",
post_process: Optional[nn.Module] = None,
):
"""
Load model state from the checkpoint trained by YOLOv5.
Expand Down Expand Up @@ -220,6 +221,7 @@ def load_from_yolov5(
anchor_grids=model_info["anchor_grids"],
score_thresh=score_thresh,
nms_thresh=nms_thresh,
post_process=post_process,
)

model.load_state_dict(model_info["state_dict"])
Expand Down
3 changes: 3 additions & 0 deletions yolort/relaying/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from .trace_wrapper import get_trace_module
from .yolo_inference import YOLOInference

__all__ = ["get_trace_module", "YOLOInference"]
93 changes: 93 additions & 0 deletions yolort/relaying/yolo_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from typing import Dict, List, Tuple

import torch
from torch import nn, Tensor
from yolort.models import YOLO
from yolort.models._utils import decode_single
from yolort.models.box_head import _concat_pred_logits

__all__ = ["YOLOInference"]


class YOLOInference(YOLO):
"""
A deployment friendly wrapper of YOLO.
Remove the ``torchvision::nms`` in this warpper, due to the fact that some third-party
inference frameworks currently do not support this operator very well.
"""

def __init__(
self,
checkpoint_path: str,
score_thresh: float = 0.25,
version: str = "r6.0",
):
post_process = PostProcess(score_thresh)

self.model = YOLO.load_from_yolov5(
checkpoint_path,
version=version,
post_process=post_process,
)

def forward(self, inputs: Tensor):
"""
Args:
inputs (Tensor): batched images, of shape [batch_size x 3 x H x W]
"""
# Compute the detections
outputs = self.model(inputs)

return outputs


class PostProcess(nn.Module):
"""
This is a simplified version of PostProcess to remove the ``torchvision::nms`` module.
Args:
score_thresh (float): Score threshold used for postprocessing the detections.
"""

def __init__(self, score_thresh: float) -> None:
super().__init__()
self.score_thresh = score_thresh

def forward(
self,
head_outputs: List[Tensor],
anchors_tuple: Tuple[Tensor, Tensor, Tensor],
) -> List[Dict[str, Tensor]]:
"""
Just concat the predict logits, ignore the original ``torchvision::nms`` module
from original ``yolort.models.box_head.PostProcess``.
Args:
head_outputs (List[Tensor]): The predicted locations and class/object confidence,
shape of the element is (N, A, H, W, K).
anchors_tuple (Tuple[Tensor, Tensor, Tensor]):
"""
batch_size = len(head_outputs[0])

all_pred_logits = _concat_pred_logits(head_outputs)

detections: List[Dict[str, Tensor]] = []

for idx in range(batch_size): # image idx, image inference
pred_logits = torch.sigmoid(all_pred_logits[idx])

# Compute conf
# box_conf x class_conf, w/ shape: num_anchors x num_classes
scores = pred_logits[:, 5:] * pred_logits[:, 4:5]

boxes = decode_single(pred_logits[:, :4], anchors_tuple)

# remove low scoring boxes
inds, labels = torch.where(scores > self.score_thresh)
boxes, scores = boxes[inds], scores[inds, labels]

detections.append({"scores": scores, "labels": labels, "boxes": boxes})

return detections

0 comments on commit b1bfb77

Please sign in to comment.