Skip to content

Commit

Permalink
Add YOLOInference for downstream inference frameworks (#238)
Browse files Browse the repository at this point in the history
* Fixing docstring

* Fixing compatibility issue

* Minor fix of the docstring

* Separate out the concatenation part in PostProcess

* Deprecated the ONNX exporting in ncnn

* Fixing jit trace

* Add YOLOInference for downstream inference frameworks

* Add test_yolo_inference

* Use attempt_download

* Add test_attempt_download

* Abstract the _decode_pred_logits

* Abstract the LogitsDecoder
  • Loading branch information
zhiqwang authored Nov 27, 2021
1 parent 203b9bb commit 12389a3
Show file tree
Hide file tree
Showing 12 changed files with 251 additions and 210 deletions.
64 changes: 0 additions & 64 deletions deployment/ncnn/export_onnx.py

This file was deleted.

Empty file removed deployment/ncnn/tools/__init__.py
Empty file.
76 changes: 0 additions & 76 deletions deployment/ncnn/tools/yolort_deploy_friendly.py

This file was deleted.

9 changes: 2 additions & 7 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from yolort.models.backbone_utils import darknet_pan_backbone
from yolort.models.box_head import YOLOHead, PostProcess, SetCriterion
from yolort.models.transformer import darknet_tan_backbone
from yolort.v5 import get_yolov5_size
from yolort.v5 import get_yolov5_size, attempt_download


@contextlib.contextmanager
Expand Down Expand Up @@ -351,16 +351,11 @@ def test_load_from_yolov5(
hash_prefix: str,
):
img_path = "test/assets/bus.jpg"
checkpoint_path = f"{arch}_{upstream_version}_{hash_prefix}"

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

torch.hub.download_url_to_file(
model_url,
checkpoint_path,
hash_prefix=hash_prefix,
)
score_thresh = 0.25

model_yolov5 = YOLOv5.load_from_yolov5(
Expand Down
37 changes: 36 additions & 1 deletion test/test_relaying.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,45 @@
import pytest
import torch
from torch import Tensor
from torch.jit._trace import TopLevelTracedModule
from yolort.models import yolov5s
from yolort.relaying import get_trace_module
from yolort.relaying import get_trace_module, YOLOInference
from yolort.v5 import attempt_download


def test_get_trace_module():
model_func = yolov5s(pretrained=True)
script_module = get_trace_module(model_func, input_shape=(416, 320))
assert isinstance(script_module, TopLevelTracedModule)
assert script_module.code is not None


@pytest.mark.parametrize(
"arch, version, upstream_version, hash_prefix",
[
("yolov5s", "r4.0", "v4.0", "9ca9a642"),
("yolov5n", "r6.0", "v6.0", "649e089f"),
("yolov5s", "r6.0", "v6.0", "c3b140f3"),
("yolov5n6", "r6.0", "v6.0", "beecbbae"),
],
)
def test_yolo_inference(arch, version, upstream_version, hash_prefix):

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url)
score_thresh = 0.25

model = YOLOInference(
checkpoint_path,
score_thresh=score_thresh,
version=version,
)
model.eval()
samples = torch.rand(1, 3, 320, 320)
outs = model(samples)

assert isinstance(outs[0], dict)
assert isinstance(outs[0]["boxes"], Tensor)
assert isinstance(outs[0]["labels"], Tensor)
assert isinstance(outs[0]["scores"], Tensor)
16 changes: 3 additions & 13 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
load_yolov5_model,
scale_coords,
non_max_suppression,
attempt_download,
)


Expand All @@ -36,15 +37,10 @@ def test_load_from_ultralytics(
hash_prefix: str,
use_p6: bool,
):
checkpoint_path = f"{arch}_{upstream_version}_{hash_prefix}"
base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

torch.hub.download_url_to_file(
model_url,
checkpoint_path,
hash_prefix=hash_prefix,
)
model_info = load_from_ultralytics(checkpoint_path, version=version)
assert isinstance(model_info, dict)
assert model_info["num_classes"] == 80
Expand All @@ -65,16 +61,10 @@ def test_load_from_ultralytics_voc(
hash_prefix: str,
):
img_path = "test/assets/bus.jpg"
checkpoint_path = f"{arch}_{upstream_version}_{hash_prefix}"

base_url = "https://github.com/ultralytics/yolov5/releases/download/"
model_url = f"{base_url}/{upstream_version}/{arch}.pt"

torch.hub.download_url_to_file(
model_url,
checkpoint_path,
hash_prefix=hash_prefix,
)
checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix)

# Preprocess
img_raw = cv2.imread(img_path)
Expand Down
27 changes: 15 additions & 12 deletions test/test_v5.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,27 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path
import hashlib

import torch
from torch import Tensor
from yolort.v5 import load_yolov5_model
from yolort.v5 import load_yolov5_model, attempt_download


def test_attempt_download():

model_url = "https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt"
checkpoint_path = attempt_download(model_url, hash_prefix="9ca9a642")
with open(checkpoint_path, "rb") as f:
bytes = f.read() # read entire file as bytes
readable_hash = hashlib.sha256(bytes).hexdigest()
assert readable_hash[:8] == "9ca9a642"


def test_load_yolov5_model():
img_path = "test/assets/zidane.jpg"

yolov5s_r40_path = Path("yolov5s.pt")

if not yolov5s_r40_path.exists():
torch.hub.download_url_to_file(
"https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt",
yolov5s_r40_path,
hash_prefix="9ca9a642",
)
model_url = "https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt"
checkpoint_path = attempt_download(model_url, hash_prefix="9ca9a642")

model = load_yolov5_model(str(yolov5s_r40_path), autoshape=True, verbose=False)
model = load_yolov5_model(checkpoint_path, autoshape=True, verbose=False)
results = model(img_path)

assert isinstance(results.pred, list)
Expand Down
Loading

0 comments on commit 12389a3

Please sign in to comment.