diff --git a/yolort/models/yolo.py b/yolort/models/yolo.py index 4fe6e357..34ba6ef8 100644 --- a/yolort/models/yolo.py +++ b/yolort/models/yolo.py @@ -190,6 +190,7 @@ def load_from_yolov5( score_thresh: float = 0.25, nms_thresh: float = 0.45, version: str = "r6.0", + post_process: Optional[nn.Module] = None, ): """ Load model state from the checkpoint trained by YOLOv5. @@ -220,6 +221,7 @@ def load_from_yolov5( anchor_grids=model_info["anchor_grids"], score_thresh=score_thresh, nms_thresh=nms_thresh, + post_process=post_process, ) model.load_state_dict(model_info["state_dict"]) diff --git a/yolort/relaying/__init__.py b/yolort/relaying/__init__.py index 82822225..e051b86b 100644 --- a/yolort/relaying/__init__.py +++ b/yolort/relaying/__init__.py @@ -1,2 +1,5 @@ # Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. from .trace_wrapper import get_trace_module +from .yolo_inference import YOLOInference + +__all__ = ["get_trace_module", "YOLOInference"] diff --git a/yolort/relaying/yolo_inference.py b/yolort/relaying/yolo_inference.py new file mode 100644 index 00000000..18af4bfa --- /dev/null +++ b/yolort/relaying/yolo_inference.py @@ -0,0 +1,93 @@ +# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. +from typing import Dict, List, Tuple + +import torch +from torch import nn, Tensor +from yolort.models import YOLO +from yolort.models._utils import decode_single +from yolort.models.box_head import _concat_pred_logits + +__all__ = ["YOLOInference"] + + +class YOLOInference(YOLO): + """ + A deployment friendly wrapper of YOLO. + + Remove the ``torchvision::nms`` in this warpper, due to the fact that some third-party + inference frameworks currently do not support this operator very well. + """ + + def __init__( + self, + checkpoint_path: str, + score_thresh: float = 0.25, + version: str = "r6.0", + ): + post_process = PostProcess(score_thresh) + + self.model = YOLO.load_from_yolov5( + checkpoint_path, + version=version, + post_process=post_process, + ) + + def forward(self, inputs: Tensor): + """ + Args: + inputs (Tensor): batched images, of shape [batch_size x 3 x H x W] + """ + # Compute the detections + outputs = self.model(inputs) + + return outputs + + +class PostProcess(nn.Module): + """ + This is a simplified version of PostProcess to remove the ``torchvision::nms`` module. + + Args: + score_thresh (float): Score threshold used for postprocessing the detections. + """ + + def __init__(self, score_thresh: float) -> None: + super().__init__() + self.score_thresh = score_thresh + + def forward( + self, + head_outputs: List[Tensor], + anchors_tuple: Tuple[Tensor, Tensor, Tensor], + ) -> List[Dict[str, Tensor]]: + """ + Just concat the predict logits, ignore the original ``torchvision::nms`` module + from original ``yolort.models.box_head.PostProcess``. + + Args: + head_outputs (List[Tensor]): The predicted locations and class/object confidence, + shape of the element is (N, A, H, W, K). + anchors_tuple (Tuple[Tensor, Tensor, Tensor]): + """ + batch_size = len(head_outputs[0]) + + all_pred_logits = _concat_pred_logits(head_outputs) + + detections: List[Dict[str, Tensor]] = [] + + for idx in range(batch_size): # image idx, image inference + pred_logits = torch.sigmoid(all_pred_logits[idx]) + + # Compute conf + # box_conf x class_conf, w/ shape: num_anchors x num_classes + scores = pred_logits[:, 5:] * pred_logits[:, 4:5] + + boxes = decode_single(pred_logits[:, :4], anchors_tuple) + + # remove low scoring boxes + inds, labels = torch.where(scores > self.score_thresh) + boxes, scores = boxes[inds], scores[inds, labels] + + detections.append({"scores": scores, "labels": labels, "boxes": boxes}) + + return detections