Skip to content

Commit

Permalink
Merge pull request #33 from ternaus/try_onnx
Browse files Browse the repository at this point in the history
Add converter to onnx
  • Loading branch information
ternaus authored Jul 8, 2021
2 parents 29611c7 + ff9a050 commit 0bfa402
Show file tree
Hide file tree
Showing 26 changed files with 538 additions and 245 deletions.
26 changes: 9 additions & 17 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,38 +13,30 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [3.6, 3.7, 3.8]
python-version: [3.6, 3.7, 3.8, 3.9]
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v1.1.1
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Cache pip
uses: actions/cache@v1
with:
path: ~/.cache/pip # This path is specific to Ubuntu
# Look to see if there is a cache hit for the corresponding requirements file
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-
${{ runner.os }}-
# You can test your matrix by printing the current Python version
- name: Display Python version
run: python -c "import sys; print(sys.version)"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install black flake8 mypy pytest hypothesis
pip install -r requirements_dev.txt
- name: Run black
run:
black --check .
- name: Run flake8
run: flake8
- name: Run Pylint
run: pylint retinaface
- name: Run Mypy
run: mypy retinaface
# - name: tests
# run: |
# pip install .[tests]
# pytest
- name: tests
run: |
pip install .[tests]
pytest
112 changes: 58 additions & 54 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,56 +1,60 @@
exclude: _pb2\.py$
repos:
- repo: https://github.com/pre-commit/mirrors-isort
rev: f0001b2 # Use the revision sha / tag you want to point at
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/psf/black
rev: 20.8b1
hooks:
- id: black
- repo: https://github.com/asottile/yesqa
rev: v1.1.0
hooks:
- id: yesqa
additional_dependencies:
- flake8-bugbear==20.1.4
- flake8-builtins==1.5.2
- flake8-comprehensions==3.2.2
- flake8-tidy-imports==4.1.0
- flake8==3.7.9
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v2.3.0
hooks:
- id: check-docstring-first
- id: check-json
- id: check-merge-conflict
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: trailing-whitespace
- id: flake8
- id: requirements-txt-fixer
- repo: https://github.com/pre-commit/mirrors-pylint
rev: d230ffd
hooks:
- id: pylint
- repo: https://github.com/asottile/pyupgrade
rev: v2.19.4
hooks:
- id: pyupgrade
args: [ "--py38-plus" ]
- repo: https://github.com/pre-commit/mirrors-isort
rev: 1ba6bfc # Use the revision sha / tag you want to point at
hooks:
- id: isort
args: ["--profile", "black"]
- repo: https://github.com/psf/black
rev: 21.6b0
hooks:
- id: black
- repo: https://gitlab.com/pycqa/flake8
rev: 3.9.2
hooks:
- id: flake8
language_version: python3
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.0.1
hooks:
- id: check-docstring-first
- id: check-json
- id: check-merge-conflict
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: trailing-whitespace
- id: requirements-txt-fixer
- repo: https://github.com/pre-commit/mirrors-pylint
rev: 56b3cb4
hooks:
- id: pylint
args:
- --max-line-length=120
- --ignore-imports=yes
- -d duplicate-code
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.9.0
hooks:
- id: python-check-mock-methods
- id: python-use-type-annotations
- id: python-check-blanket-noqa
- id: python-use-type-annotations
- id: text-unicode-replacement-char
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 9feadeb
hooks:
- id: mypy
exclude: ^tests/
args:
- --max-line-length=119
- --ignore-imports=yes
- -d duplicate-code
- repo: https://github.com/asottile/pyupgrade
rev: v2.7.3
hooks:
- id: pyupgrade
args: ['--py37-plus']
- repo: https://github.com/pre-commit/pygrep-hooks
rev: v1.5.1
hooks:
- id: python-check-mock-methods
- id: python-use-type-annotations
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 9feadeb
hooks:
- id: mypy
args: [--ignore-missing-imports, --warn-no-return, --warn-redundant-casts, --disallow-incomplete-defs]
[
--disallow-untyped-defs,
--check-untyped-defs,
--warn-redundant-casts,
--no-implicit-optional,
--strict-optional
]
7 changes: 6 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,12 @@ disable=print-statement,
too-few-public-methods,
attribute-defined-outside-init,
too-many-locals,
too-many-arguments
too-many-arguments,
too-many-instance-attributes,
unused-argument,
no-member,
arguments-differ,
super-init-not-called

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
16 changes: 14 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ Todo:
* Horizontal Flip is not implemented in Albumentations
* Spatial transforms like rotations or transpose are not implemented yet.

Color transforms are defined in the config.
Color transforms defined in the config.

### Added mAP calculation for validation
In order to track thr progress, mAP metric is calculated on validation.
In order to track the progress, mAP metric is calculated on validation.

## Installation

Expand Down Expand Up @@ -102,6 +102,11 @@ You can convert the default labels of the WiderFaces to the json of the propper


## Training
### Install dependencies
```
pip install -r requirements.txt
pip install -r requirements_dev.txt
```

### Define config
Example configs could be found at [retinaface/configs](retinaface/configs)
Expand Down Expand Up @@ -183,3 +188,10 @@ python -m torch.distributed.launch --nproc_per_node=<num_gpus> retinaface/infere
https://retinaface.herokuapp.com/

Code for the web app: https://github.com/ternaus/retinaface_demo

### Converting to ONNX
The inference could be sped up on CPU by converting the model to ONNX.

```
Ex: python -m converters.to_onnx -m 1280 -o retinaface1280.onnx
```
Empty file added converters/__init__.py
Empty file.
153 changes: 153 additions & 0 deletions converters/to_onnx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import argparse
from typing import Dict, List, Tuple, Union

import albumentations as albu
import cv2
import numpy as np
import onnx
import onnxruntime as ort
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import model_zoo
from torchvision.ops import nms

from retinaface.box_utils import decode, decode_landm
from retinaface.network import RetinaFace
from retinaface.prior_box import priorbox
from retinaface.utils import tensor_from_rgb_image, vis_annotations

state_dict = model_zoo.load_url(
"https://github.com/ternaus/retinaface/releases/download/0.01/retinaface_resnet50_2020-07-20-f168fae3c.zip",
progress=True,
map_location="cpu",
)


class M(nn.Module):
def __init__(self, max_size: int = 1280):
super().__init__()
self.model = RetinaFace(
name="Resnet50",
pretrained=False,
return_layers={"layer2": 1, "layer3": 2, "layer4": 3},
in_channels=256,
out_channels=256,
)
self.model.load_state_dict(state_dict)

self.max_size = max_size

self.scale_landmarks = torch.from_numpy(np.tile([self.max_size, self.max_size], 5))
self.scale_bboxes = torch.from_numpy(np.tile([self.max_size, self.max_size], 2))

self.prior_box = priorbox(
min_sizes=[[16, 32], [64, 128], [256, 512]],
steps=[8, 16, 32],
clip=False,
image_size=(self.max_size, self.max_size),
)
self.nms_threshold: float = 0.4
self.variance = [0.1, 0.2]
self.confidence_threshold: float = 0.7

def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
loc, conf, land = self.model(x)

conf = F.softmax(conf, dim=-1)

boxes = decode(loc.data[0], self.prior_box, self.variance)

boxes *= self.scale_bboxes
scores = conf[0][:, 1]

landmarks = decode_landm(land.data[0], self.prior_box, self.variance)
landmarks *= self.scale_landmarks

# ignore low scores
valid_index = torch.where(scores > self.confidence_threshold)[0]
boxes = boxes[valid_index]
landmarks = landmarks[valid_index]
scores = scores[valid_index]

# do NMS
keep = nms(boxes, scores, self.nms_threshold)
boxes = boxes[keep, :]

landmarks = landmarks[keep]
scores = scores[keep]
return boxes, scores, landmarks


def prepare_image(image: np.ndarray, max_size: int = 1280) -> np.ndarray:
image = albu.Compose([albu.LongestMaxSize(max_size=max_size), albu.Normalize(p=1)])(image=image)["image"]

height, width = image.shape[:2]

return cv2.copyMakeBorder(image, 0, max_size - height, 0, max_size - width, borderType=cv2.BORDER_CONSTANT)


def main() -> None:
parser = argparse.ArgumentParser()
arg = parser.add_argument
arg(
"-m",
"--max_size",
type=int,
help="Size of the input image. The onnx model will predict on (max_size, max_size)",
required=True,
)

arg("-o", "--output_file", type=str, help="Path to save onnx model.", required=True)
args = parser.parse_args()

raw_image = cv2.imread("tests/data/13.jpg")

image = prepare_image(raw_image, args.max_size)

x = tensor_from_rgb_image(image).unsqueeze(0).float()

model = M(max_size=args.max_size)
model.eval()
with torch.no_grad():
out_torch = model(x)

torch.onnx.export(
model,
x,
args.output_file,
verbose=True,
opset_version=12,
input_names=["input"],
export_params=True,
do_constant_folding=True,
)

onnx_model = onnx.load(args.output_file)
onnx.checker.check_model(onnx_model)

ort_session = ort.InferenceSession(args.output_file)

outputs = ort_session.run(None, {"input": np.expand_dims(np.transpose(image, (2, 0, 1)), 0)})

for i in range(3):
if not np.allclose(out_torch[i].numpy(), outputs[i]):
raise ValueError("torch and onnx models do not match!")

annotations: List[Dict[str, List[Union[float, List[float]]]]] = []

for box_id, box in enumerate(outputs[0]):
annotations += [
{
"bbox": box.tolist(),
"score": outputs[1][box_id],
"landmarks": outputs[2][box_id].reshape(-1, 2).tolist(),
}
]

im = albu.Compose([albu.LongestMaxSize(max_size=1280)])(image=raw_image)["image"]
cv2.imwrite("example.jpg", vis_annotations(im, annotations))


if __name__ == "__main__":
main()
8 changes: 3 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
albumentations
iglovikov_helper_functions
numpy
pillow
torch
albumentations==1.0.0
torch==1.9.0
torchvision==0.10.0
Loading

0 comments on commit 0bfa402

Please sign in to comment.