diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 93dbca5..f725750 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -13,38 +13,30 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.6, 3.7, 3.8] + python-version: [3.6, 3.7, 3.8, 3.9] steps: - uses: actions/checkout@v2 - name: Set up Python - uses: actions/setup-python@v1.1.1 + uses: actions/setup-python@v2 with: python-version: ${{ matrix.python-version }} - - name: Cache pip - uses: actions/cache@v1 - with: - path: ~/.cache/pip # This path is specific to Ubuntu - # Look to see if there is a cache hit for the corresponding requirements file - key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} - restore-keys: | - ${{ runner.os }}-pip- - ${{ runner.os }}- # You can test your matrix by printing the current Python version - name: Display Python version run: python -c "import sys; print(sys.version)" - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -r requirements.txt - pip install black flake8 mypy pytest hypothesis + pip install -r requirements_dev.txt - name: Run black run: black --check . - name: Run flake8 run: flake8 + - name: Run Pylint + run: pylint retinaface - name: Run Mypy run: mypy retinaface -# - name: tests -# run: | -# pip install .[tests] -# pytest + - name: tests + run: | + pip install .[tests] + pytest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a138ce1..7eb2083 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,56 +1,60 @@ -exclude: _pb2\.py$ repos: -- repo: https://github.com/pre-commit/mirrors-isort - rev: f0001b2 # Use the revision sha / tag you want to point at - hooks: - - id: isort - args: ["--profile", "black"] -- repo: https://github.com/psf/black - rev: 20.8b1 - hooks: - - id: black -- repo: https://github.com/asottile/yesqa - rev: v1.1.0 - hooks: - - id: yesqa - additional_dependencies: - - flake8-bugbear==20.1.4 - - flake8-builtins==1.5.2 - - flake8-comprehensions==3.2.2 - - flake8-tidy-imports==4.1.0 - - flake8==3.7.9 -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 - hooks: - - id: check-docstring-first - - id: check-json - - id: check-merge-conflict - - id: check-yaml - - id: debug-statements - - id: end-of-file-fixer - - id: trailing-whitespace - - id: flake8 - - id: requirements-txt-fixer -- repo: https://github.com/pre-commit/mirrors-pylint - rev: d230ffd - hooks: - - id: pylint + - repo: https://github.com/asottile/pyupgrade + rev: v2.19.4 + hooks: + - id: pyupgrade + args: [ "--py38-plus" ] + - repo: https://github.com/pre-commit/mirrors-isort + rev: 1ba6bfc # Use the revision sha / tag you want to point at + hooks: + - id: isort + args: ["--profile", "black"] + - repo: https://github.com/psf/black + rev: 21.6b0 + hooks: + - id: black + - repo: https://gitlab.com/pycqa/flake8 + rev: 3.9.2 + hooks: + - id: flake8 + language_version: python3 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.0.1 + hooks: + - id: check-docstring-first + - id: check-json + - id: check-merge-conflict + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: trailing-whitespace + - id: requirements-txt-fixer + - repo: https://github.com/pre-commit/mirrors-pylint + rev: 56b3cb4 + hooks: + - id: pylint + args: + - --max-line-length=120 + - --ignore-imports=yes + - -d duplicate-code + - repo: https://github.com/pre-commit/pygrep-hooks + rev: v1.9.0 + hooks: + - id: python-check-mock-methods + - id: python-use-type-annotations + - id: python-check-blanket-noqa + - id: python-use-type-annotations + - id: text-unicode-replacement-char + - repo: https://github.com/pre-commit/mirrors-mypy + rev: 9feadeb + hooks: + - id: mypy + exclude: ^tests/ args: - - --max-line-length=119 - - --ignore-imports=yes - - -d duplicate-code -- repo: https://github.com/asottile/pyupgrade - rev: v2.7.3 - hooks: - - id: pyupgrade - args: ['--py37-plus'] -- repo: https://github.com/pre-commit/pygrep-hooks - rev: v1.5.1 - hooks: - - id: python-check-mock-methods - - id: python-use-type-annotations -- repo: https://github.com/pre-commit/mirrors-mypy - rev: 9feadeb - hooks: - - id: mypy - args: [--ignore-missing-imports, --warn-no-return, --warn-redundant-casts, --disallow-incomplete-defs] + [ + --disallow-untyped-defs, + --check-untyped-defs, + --warn-redundant-casts, + --no-implicit-optional, + --strict-optional + ] diff --git a/.pylintrc b/.pylintrc index 259d22b..fa7df2a 100644 --- a/.pylintrc +++ b/.pylintrc @@ -148,7 +148,12 @@ disable=print-statement, too-few-public-methods, attribute-defined-outside-init, too-many-locals, - too-many-arguments + too-many-arguments, + too-many-instance-attributes, + unused-argument, + no-member, + arguments-differ, + super-init-not-called # Enable the message, report, category or checker with the given id(s). You can # either give multiple identifier separated by comma (,) or put this option diff --git a/README.md b/README.md index 0775e00..bce8f34 100644 --- a/README.md +++ b/README.md @@ -28,10 +28,10 @@ Todo: * Horizontal Flip is not implemented in Albumentations * Spatial transforms like rotations or transpose are not implemented yet. -Color transforms are defined in the config. +Color transforms defined in the config. ### Added mAP calculation for validation -In order to track thr progress, mAP metric is calculated on validation. +In order to track the progress, mAP metric is calculated on validation. ## Installation @@ -102,6 +102,11 @@ You can convert the default labels of the WiderFaces to the json of the propper ## Training +### Install dependencies +``` +pip install -r requirements.txt +pip install -r requirements_dev.txt +``` ### Define config Example configs could be found at [retinaface/configs](retinaface/configs) @@ -183,3 +188,10 @@ python -m torch.distributed.launch --nproc_per_node= retinaface/infere https://retinaface.herokuapp.com/ Code for the web app: https://github.com/ternaus/retinaface_demo + +### Converting to ONNX +The inference could be sped up on CPU by converting the model to ONNX. + +``` +Ex: python -m converters.to_onnx -m 1280 -o retinaface1280.onnx +``` diff --git a/converters/__init__.py b/converters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/converters/to_onnx.py b/converters/to_onnx.py new file mode 100644 index 0000000..15b71a0 --- /dev/null +++ b/converters/to_onnx.py @@ -0,0 +1,153 @@ +import argparse +from typing import Dict, List, Tuple, Union + +import albumentations as albu +import cv2 +import numpy as np +import onnx +import onnxruntime as ort +import torch +from torch import nn +from torch.nn import functional as F +from torch.utils import model_zoo +from torchvision.ops import nms + +from retinaface.box_utils import decode, decode_landm +from retinaface.network import RetinaFace +from retinaface.prior_box import priorbox +from retinaface.utils import tensor_from_rgb_image, vis_annotations + +state_dict = model_zoo.load_url( + "https://github.com/ternaus/retinaface/releases/download/0.01/retinaface_resnet50_2020-07-20-f168fae3c.zip", + progress=True, + map_location="cpu", +) + + +class M(nn.Module): + def __init__(self, max_size: int = 1280): + super().__init__() + self.model = RetinaFace( + name="Resnet50", + pretrained=False, + return_layers={"layer2": 1, "layer3": 2, "layer4": 3}, + in_channels=256, + out_channels=256, + ) + self.model.load_state_dict(state_dict) + + self.max_size = max_size + + self.scale_landmarks = torch.from_numpy(np.tile([self.max_size, self.max_size], 5)) + self.scale_bboxes = torch.from_numpy(np.tile([self.max_size, self.max_size], 2)) + + self.prior_box = priorbox( + min_sizes=[[16, 32], [64, 128], [256, 512]], + steps=[8, 16, 32], + clip=False, + image_size=(self.max_size, self.max_size), + ) + self.nms_threshold: float = 0.4 + self.variance = [0.1, 0.2] + self.confidence_threshold: float = 0.7 + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + loc, conf, land = self.model(x) + + conf = F.softmax(conf, dim=-1) + + boxes = decode(loc.data[0], self.prior_box, self.variance) + + boxes *= self.scale_bboxes + scores = conf[0][:, 1] + + landmarks = decode_landm(land.data[0], self.prior_box, self.variance) + landmarks *= self.scale_landmarks + + # ignore low scores + valid_index = torch.where(scores > self.confidence_threshold)[0] + boxes = boxes[valid_index] + landmarks = landmarks[valid_index] + scores = scores[valid_index] + + # do NMS + keep = nms(boxes, scores, self.nms_threshold) + boxes = boxes[keep, :] + + landmarks = landmarks[keep] + scores = scores[keep] + return boxes, scores, landmarks + + +def prepare_image(image: np.ndarray, max_size: int = 1280) -> np.ndarray: + image = albu.Compose([albu.LongestMaxSize(max_size=max_size), albu.Normalize(p=1)])(image=image)["image"] + + height, width = image.shape[:2] + + return cv2.copyMakeBorder(image, 0, max_size - height, 0, max_size - width, borderType=cv2.BORDER_CONSTANT) + + +def main() -> None: + parser = argparse.ArgumentParser() + arg = parser.add_argument + arg( + "-m", + "--max_size", + type=int, + help="Size of the input image. The onnx model will predict on (max_size, max_size)", + required=True, + ) + + arg("-o", "--output_file", type=str, help="Path to save onnx model.", required=True) + args = parser.parse_args() + + raw_image = cv2.imread("tests/data/13.jpg") + + image = prepare_image(raw_image, args.max_size) + + x = tensor_from_rgb_image(image).unsqueeze(0).float() + + model = M(max_size=args.max_size) + model.eval() + with torch.no_grad(): + out_torch = model(x) + + torch.onnx.export( + model, + x, + args.output_file, + verbose=True, + opset_version=12, + input_names=["input"], + export_params=True, + do_constant_folding=True, + ) + + onnx_model = onnx.load(args.output_file) + onnx.checker.check_model(onnx_model) + + ort_session = ort.InferenceSession(args.output_file) + + outputs = ort_session.run(None, {"input": np.expand_dims(np.transpose(image, (2, 0, 1)), 0)}) + + for i in range(3): + if not np.allclose(out_torch[i].numpy(), outputs[i]): + raise ValueError("torch and onnx models do not match!") + + annotations: List[Dict[str, List[Union[float, List[float]]]]] = [] + + for box_id, box in enumerate(outputs[0]): + annotations += [ + { + "bbox": box.tolist(), + "score": outputs[1][box_id], + "landmarks": outputs[2][box_id].reshape(-1, 2).tolist(), + } + ] + + im = albu.Compose([albu.LongestMaxSize(max_size=1280)])(image=raw_image)["image"] + cv2.imwrite("example.jpg", vis_annotations(im, annotations)) + + +if __name__ == "__main__": + main() diff --git a/requirements.txt b/requirements.txt index 19a0380..5141a85 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,3 @@ -albumentations -iglovikov_helper_functions -numpy -pillow -torch +albumentations==1.0.0 +torch==1.9.0 +torchvision==0.10.0 diff --git a/requirements_dev.txt b/requirements_dev.txt new file mode 100644 index 0000000..f9669bd --- /dev/null +++ b/requirements_dev.txt @@ -0,0 +1,50 @@ +black==21.6b0 +flake8==3.9.2 +flake8-bandit==2.1.2 +flake8-breakpoint==1.1.0 +flake8-bugbear==21.3.1 +flake8-builtins==1.5.3 +flake8-colors==0.1.9 +flake8-comprehensions==3.3.1 +flake8-debugger==4.0.0 +flake8-docstrings==1.5.0 +flake8-eradicate==1.0.0 +flake8-executable==2.1.1 +flake8-fixme==1.1.1 +flake8-graphql==0.2.5 +flake8-implicit-str-concat==0.2.0 +flake8-logging-format==0.6.0 +flake8-mock==0.3 +flake8-pathlib==0.1.3 +flake8-plugin-utils==1.3.1 +flake8-polyfill==1.0.2 +flake8-print==4.0.0 +flake8-printf-formatting==1.1.2 +flake8-pytest==1.3 +flake8-pytest-style==1.3.0 +flake8-raise==0.0.5 +flake8-requests==0.4.0 +flake8-requirements==1.3.3 +flake8-string-format==0.3.0 +flake8-tidy-imports==4.2.1 +flake8-todo==0.7 +flake8-tuple==0.4.1 +iglovikov-helper-functions==0.0.53 +isort==5.8.0 +mypy==0.910 +onnx==1.9.0 +pep8-naming==0.11.1 +pip==21.1.3 +pre-commit==2.13.0 +pylint==2.8.2 +pytest==6.2.4 +pytest-asyncio==0.15.1 +pytest-clarity==0.3.0a0 +pytest-cov==2.11.1 +pytest-mock==3.6.1 +pytest-parallel==0.1.0 +pytest-sugar==0.9.4 +pytest-xdist==2.2.1 +pytorch_lightning==1.3.8 +types-PyYAML==5.4.3 +types-requests==2.25.0 diff --git a/retinaface/box_utils.py b/retinaface/box_utils.py index 6962516..e7db45c 100644 --- a/retinaface/box_utils.py +++ b/retinaface/box_utils.py @@ -5,8 +5,9 @@ def point_form(boxes: torch.Tensor) -> torch.Tensor: - """Convert prior_boxes to (x_min, y_min, x_max, y_max) representation for comparison - to point form ground truth data. + """Convert prior_boxes to (x_min, y_min, x_max, y_max) representation. + + For comparison to point form ground truth data. Args: boxes: center-size default boxes from priorbox layers. @@ -18,6 +19,7 @@ def point_form(boxes: torch.Tensor) -> torch.Tensor: def center_size(boxes: torch.Tensor) -> torch.Tensor: """Convert prior_boxes to (cx, cy, w, h) representation for comparison to center-size form ground truth data. + Args: boxes: point_form boxes Return: @@ -27,7 +29,8 @@ def center_size(boxes: torch.Tensor) -> torch.Tensor: def intersect(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor: - """We resize both tensors to [A,B,2] without new malloc: + """We resize both tensors to [A,B,2] without new malloc. + [A, 2] -> [A, 1, 2] -> [A, B, 2] [B, 2] -> [1, B, 2] -> [A, B, 2] Then we compute the area of intersect between box_a and box_b. @@ -37,17 +40,19 @@ def intersect(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor: Return: intersection area, Shape: [A, B]. """ - A = box_a.size(0) - B = box_b.size(0) - max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2)) - min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2)) + a = box_a.size(0) + b = box_b.size(0) + max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(a, b, 2), box_b[:, 2:].unsqueeze(0).expand(a, b, 2)) + min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(a, b, 2), box_b[:, :2].unsqueeze(0).expand(a, b, 2)) inter = torch.clamp((max_xy - min_xy), min=0) return inter[:, :, 0] * inter[:, :, 1] def jaccard(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor: - """Compute the jaccard overlap of two sets of boxes. The jaccard overlap is simply the intersection over - union of two boxes. Here we operate on ground truth boxes and default boxes. + """Computes the jaccard overlap of two sets of boxes. + + The jaccard overlap is simply the intersection over union of two boxes. + Here we operate on ground truth boxes and default boxes. E.g.: A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B) Args: @@ -64,9 +69,7 @@ def jaccard(box_a: torch.Tensor, box_b: torch.Tensor) -> torch.Tensor: def matrix_iof(a: np.ndarray, b: np.ndarray) -> np.ndarray: - """ - return iof of a and b, numpy version for data augmentation - """ + """Returns iof of a and b, numpy version for data augmentation.""" lt = np.maximum(a[:, np.newaxis, :2], b[:, :2]) rb = np.minimum(a[:, np.newaxis, 2:], b[:, 2:]) @@ -87,8 +90,10 @@ def match( landmarks_t: torch.Tensor, batch_id: int, ) -> None: - """Match each prior box with the ground truth box of the highest jaccard overlap, encode the bounding - boxes, then return the matched indices corresponding to both confidence and location preds. + """Match each prior box with the ground truth box of the highest jaccard overlap. + + Eencode the bounding boxes, then return the matched indices corresponding to both + confidence and location preds. Args: threshold: The overlap threshold used when matching boxes. @@ -143,19 +148,19 @@ def match( landmarks_t[batch_id] = landmarks_gt -def encode(matched, priors, variances): - """Encode the variances from the priorbox layers into the ground truth boxes - we have matched (based on jaccard overlap) with the prior boxes. +def encode(matched: torch.Tensor, priors: torch.Tensor, variances: List[float]) -> torch.Tensor: + """Encodes the variances from the priorbox layers into the ground truth boxes we have matched. + + (based on jaccard overlap) with the prior boxes. Args: - matched: (tensor) Coords of ground truth for each prior in point-form + matched: Coords of ground truth for each prior in point-form Shape: [num_priors, 4]. - priors: (tensor) Prior boxes in center-offset form + priors: Prior boxes in center-offset form Shape: [num_priors,4]. - variances: (list[float]) Variances of priorboxes + variances: Variances of priorboxes Return: - encoded boxes (tensor), Shape: [num_priors, 4] + encoded boxes, Shape: [num_priors, 4] """ - # dist b/t match center and prior's center g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2] # encode variance @@ -170,7 +175,8 @@ def encode(matched, priors, variances): def encode_landm( matched: torch.Tensor, priors: torch.Tensor, variances: Union[List[float], Tuple[float, float]] ) -> torch.Tensor: - """Encode the variances from the priorbox layers into the ground truth boxes we have matched + """Encodes the variances from the priorbox layers into the ground truth boxes we have matched. + (based on jaccard overlap) with the prior boxes. Args: matched: Coords of ground truth for each prior in point-form @@ -181,7 +187,6 @@ def encode_landm( Return: encoded landmarks, Shape: [num_priors, 10] """ - # dist b/t match center and prior's center matched = torch.reshape(matched, (matched.size(0), 5, 2)) priors_cx = priors[:, 0].unsqueeze(1).expand(matched.size(0), 5).unsqueeze(2) @@ -200,7 +205,8 @@ def encode_landm( def decode( loc: torch.Tensor, priors: torch.Tensor, variances: Union[List[float], Tuple[float, float]] ) -> torch.Tensor: - """Decode locations from predictions using priors to undo the encoding we did for offset regression at train time. + """Decodes locations from predictions using priors to undo the encoding we did for offset regression at train time. + Args: loc: location predictions for loc layers, Shape: [num_priors, 4] @@ -225,7 +231,8 @@ def decode( def decode_landm( pre: torch.Tensor, priors: torch.Tensor, variances: Union[List[float], Tuple[float, float]] ) -> torch.Tensor: - """Decode landmarks from predictions using priors to undo the encoding we did for offset regression at train time. + """Decodes landmarks from predictions using priors to undo the encoding we did for offset regression at train time. + Args: pre: landmark predictions for loc layers, Shape: [num_priors, 10] @@ -248,8 +255,9 @@ def decode_landm( def log_sum_exp(x: torch.Tensor) -> torch.Tensor: - """Utility function for computing log_sum_exp while determining This will be used to determine unaveraged - confidence loss across all examples in a batch. + """Computes log_sum_exp. + + This will be used to determine unaveraged confidence loss across all examples in a batch. Args: x: conf_preds from conf layers """ diff --git a/retinaface/data_augment.py b/retinaface/data_augment.py index 0d2b4dc..9f90649 100644 --- a/retinaface/data_augment.py +++ b/retinaface/data_augment.py @@ -9,7 +9,8 @@ def random_crop( image: np.ndarray, boxes: np.ndarray, labels: np.ndarray, landm: np.ndarray, img_dim: int ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, bool]: - """ + """Crop random patch. + if random.uniform(0, 1) <= 0.2: scale = 1.0 else: @@ -116,7 +117,7 @@ def _pad_to_square(image: np.ndarray, pad_image_flag: bool) -> np.ndarray: class Preproc: - def __init__(self, img_dim): + def __init__(self, img_dim: int) -> None: self.img_dim = img_dim def __call__(self, image: np.ndarray, targets: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: diff --git a/retinaface/dataset.py b/retinaface/dataset.py index 8c60b03..5403c30 100644 --- a/retinaface/dataset.py +++ b/retinaface/dataset.py @@ -28,7 +28,7 @@ def __init__( self.transform = transform self.rotate90 = rotate90 - with open(label_path) as f: + with label_path.open() as f: labels = json.load(f) self.labels = [x for x in labels if (image_path / x["file_name"]).exists()] @@ -125,8 +125,7 @@ def random_rotate_90(image: np.ndarray, annotations: np.ndarray) -> Tuple[np.nda def detection_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]: - """Custom collate fn for dealing with batches of images that have a different - number of associated object annotations (bounding boxes). + """Custom collate fn for dealing with batches of images that have a different number of boxes. Arguments: batch: (tuple) A tuple of tensor images and lists of annotations diff --git a/retinaface/inference.py b/retinaface/inference.py index 716f57d..9204e68 100644 --- a/retinaface/inference.py +++ b/retinaface/inference.py @@ -1,7 +1,7 @@ import argparse import json from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union import albumentations as albu import cv2 @@ -13,12 +13,10 @@ import yaml from albumentations.core.serialization import from_dict from iglovikov_helper_functions.config_parsing.utils import object_from_dict -from iglovikov_helper_functions.dl.pytorch.utils import ( - state_dict_from_disk, - tensor_from_rgb_image, -) +from iglovikov_helper_functions.dl.pytorch.utils import state_dict_from_disk from iglovikov_helper_functions.utils.image_utils import pad_to_size, unpad_from_size from PIL import Image +from torch import nn from torch.nn import functional as F from torch.utils.data import Dataset from torch.utils.data.distributed import DistributedSampler @@ -26,10 +24,10 @@ from tqdm import tqdm from retinaface.box_utils import decode, decode_landm -from retinaface.utils import vis_annotations +from retinaface.utils import tensor_from_rgb_image, vis_annotations -def get_args(): +def get_args() -> Any: parser = argparse.ArgumentParser() arg = parser.add_argument arg("-i", "--input_path", type=Path, help="Path with images.", required=True) @@ -51,7 +49,9 @@ def get_args(): class InferenceDataset(Dataset): - def __init__(self, file_paths: List[Path], max_size: int, transform: albu.Compose) -> None: + def __init__( + self, file_paths: List[Path], max_size: int, transform: albu.Compose + ) -> None: # pylint: disable=W0231 self.file_paths = file_paths self.transform = transform self.max_size = max_size @@ -85,7 +85,7 @@ def __getitem__(self, idx: int) -> Optional[Dict[str, Any]]: } -def unnormalize(image): +def unnormalize(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]: mean = [0.485, 0.456, 0.406] std = [0.229, 0.224, 0.225] @@ -98,21 +98,21 @@ def unnormalize(image): def process_predictions( - prediction, - original_shapes, - input_shape, - pads, - confidence_threshold, - nms_threshold, - prior_box, - variance, - keep_top_k, -): + prediction: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + original_shapes: List[Tuple[int, int]], + input_shape: Tuple[int, int, int, int], + pads: Tuple[int, int, int, int], + confidence_threshold: float, + nms_threshold: float, + prior_box: torch.Tensor, + variance: Tuple[float, float], + keep_top_k: bool, +) -> List[List[Dict[str, Union[float, List[float]]]]]: loc, conf, land = prediction conf = F.softmax(conf, dim=-1) - result: List[List[Dict[str, Union[List, float]]]] = [] + result: List[List[Dict[str, Union[List[float], float]]]] = [] batch_size, _, image_height, image_width = input_shape @@ -173,7 +173,7 @@ def process_predictions( annotations += [ { "bbox": bbox.tolist(), - "score": scores[crop_id], + "score": float(scores[crop_id]), "landmarks": landmarks[crop_id].reshape(-1, 2).tolist(), } ] @@ -183,11 +183,11 @@ def process_predictions( return result -def main(): +def main() -> None: args = get_args() torch.distributed.init_process_group(backend="nccl") - with open(args.config_path) as f: + with args.config_path.open() as f: hparams = yaml.load(f, Loader=yaml.SafeLoader) hparams.update( @@ -233,7 +233,7 @@ def main(): dataset = InferenceDataset(file_paths, max_size=args.max_size, transform=from_dict(hparams["test_aug"])) - sampler = DistributedSampler(dataset, shuffle=False) + sampler: DistributedSampler = DistributedSampler(dataset, shuffle=False) dataloader = torch.utils.data.DataLoader( dataset, @@ -248,7 +248,7 @@ def main(): predict(dataloader, model, hparams, device) -def predict(dataloader, model, hparams, device): +def predict(dataloader: torch.utils.data.DataLoader, model: nn.Module, hparams: dict, device: torch.device) -> None: model.eval() if hparams["local_rank"] == 0: @@ -307,7 +307,7 @@ def predict(dataloader, model, hparams, device): (hparams["output_label_path"] / folder_name).mkdir(exist_ok=True, parents=True) result_path = hparams["output_label_path"] / folder_name / f"{file_id}.json" - with open(result_path, "w") as f: + with result_path.open("w") as f: json.dump(predictions, f, indent=2) if hparams["visualize"]: @@ -322,7 +322,7 @@ def predict(dataloader, model, hparams, device): unpadded["image"].astype(np.uint8), (original_image_width, original_image_height) ) - image = vis_annotations(image, annotations=annotations) + image = vis_annotations(image, annotations=annotations) # type: ignore (hparams["output_vis_path"] / folder_name).mkdir(exist_ok=True, parents=True) result_path = hparams["output_vis_path"] / folder_name / f"{file_id}.jpg" diff --git a/retinaface/multibox_loss.py b/retinaface/multibox_loss.py index d25e5d6..8a728b0 100644 --- a/retinaface/multibox_loss.py +++ b/retinaface/multibox_loss.py @@ -4,11 +4,12 @@ import torch.nn.functional as F from torch import nn -from retinaface.box_utils import match, log_sum_exp +from retinaface.box_utils import log_sum_exp, match class MultiBoxLoss(nn.Module): - """SSD Weighted Loss Function + """SSD Weighted Loss Function. + Compute Targets: 1) Produce Confidence Target Indices by matching ground truth boxes with (default) 'priorboxes' that have jaccard index > threshold parameter. @@ -56,7 +57,8 @@ def __init__( def forward( self, predictions: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], targets: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Multibox Loss + """Multibox Loss. + Args: predictions: A tuple containing locations predictions, confidence predictions, and prior boxes from SSD net. @@ -67,7 +69,6 @@ def forward( targets: Ground truth boxes and labels_gt for a batch, shape: [batch_size, num_objs, 5] (last box_index is the label). """ - locations_data, confidence_data, landmark_data = predictions priors = self.priors.to(targets[0].device) @@ -99,11 +100,10 @@ def forward( box_index, ) - # landmark Loss (Smooth L1) - # Shape: [batch, num_priors, 10] + # landmark Loss (Smooth L1) Shape: [batch, num_priors, 10] positive_1 = conf_t > torch.zeros_like(conf_t) num_positive_landmarks = positive_1.long().sum(1, keepdim=True) - N1 = max(num_positive_landmarks.data.sum().float(), 1) + n1 = max(num_positive_landmarks.data.sum().float(), 1) # type: ignore pos_idx1 = positive_1.unsqueeze(positive_1.dim()).expand_as(landmark_data) landmarks_p = landmark_data[pos_idx1].view(-1, 10) landmarks_t = landmarks_t[pos_idx1].view(-1, 10) @@ -112,8 +112,7 @@ def forward( positive = conf_t != torch.zeros_like(conf_t) conf_t[positive] = 1 - # Localization Loss (Smooth L1) - # Shape: [batch, num_priors, 4] + # Localization Loss (Smooth L1) Shape: [batch, num_priors, 4] pos_idx = positive.unsqueeze(positive.dim()).expand_as(locations_data) loc_p = locations_data[pos_idx].view(-1, 4) boxes_t = boxes_t[pos_idx].view(-1, 4) @@ -140,6 +139,6 @@ def forward( loss_c = F.cross_entropy(conf_p, targets_weighted, reduction="sum") # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N - N = max(num_pos.data.sum().float(), 1) + n = max(num_pos.data.sum().float(), 1) # type: ignore - return loss_l / N, loss_c / N, loss_landm / N1 + return loss_l / n, loss_c / n, loss_landm / n1 diff --git a/retinaface/network.py b/retinaface/network.py index b1808a1..c5201b6 100644 --- a/retinaface/network.py +++ b/retinaface/network.py @@ -11,7 +11,7 @@ class ClassHead(nn.Module): def __init__(self, in_channels: int = 512, num_anchors: int = 3) -> None: super().__init__() - self.conv1x1 = nn.Conv2d(in_channels, num_anchors * 2, kernel_size=(1, 1), stride=1, padding=0) + self.conv1x1 = nn.Conv2d(in_channels, num_anchors * 2, kernel_size=(1, 1), stride=(1, 1), padding=0) def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.conv1x1(x) @@ -22,7 +22,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class BboxHead(nn.Module): def __init__(self, in_channels: int = 512, num_anchors: int = 3): super().__init__() - self.conv1x1 = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=(1, 1), stride=1, padding=0) + self.conv1x1 = nn.Conv2d(in_channels, num_anchors * 4, kernel_size=(1, 1), stride=(1, 1), padding=0) def forward(self, x: torch.Tensor) -> torch.Tensor: out = self.conv1x1(x) diff --git a/retinaface/pre_trained_models.py b/retinaface/pre_trained_models.py index 7942637..b828307 100644 --- a/retinaface/pre_trained_models.py +++ b/retinaface/pre_trained_models.py @@ -8,7 +8,7 @@ models = { "resnet50_2020-07-20": model( - url="https://github.com/ternaus/retinaface/releases/download/0.01/retinaface_resnet50_2020-07-20-f168fae3c.zip", # noqa: E501 + url="https://github.com/ternaus/retinaface/releases/download/0.01/retinaface_resnet50_2020-07-20-f168fae3c.zip", # noqa: E501 pylint: disable=C0301 model=Model, ) } diff --git a/retinaface/predict_single.py b/retinaface/predict_single.py index c4986c0..8a029eb 100644 --- a/retinaface/predict_single.py +++ b/retinaface/predict_single.py @@ -1,19 +1,19 @@ -""" -There is a lot of post processing of the predictions. -""" +"""There is a lot of post processing of the predictions.""" +from collections import OrderedDict from typing import Dict, List, Union import albumentations as A import numpy as np import torch -from iglovikov_helper_functions.dl.pytorch.utils import tensor_from_rgb_image -from iglovikov_helper_functions.utils.image_utils import pad_to_size, unpad_from_size from torch.nn import functional as F from torchvision.ops import nms from retinaface.box_utils import decode, decode_landm from retinaface.network import RetinaFace from retinaface.prior_box import priorbox +from retinaface.utils import tensor_from_rgb_image + +ROUNDING_DIGITS = 2 class Model: @@ -28,49 +28,49 @@ def __init__(self, max_size: int = 960, device: str = "cpu") -> None: self.device = device self.transform = A.Compose([A.LongestMaxSize(max_size=max_size, p=1), A.Normalize(p=1)]) self.max_size = max_size - self.prior_box = priorbox( - min_sizes=[[16, 32], [64, 128], [256, 512]], - steps=[8, 16, 32], - clip=False, - image_size=(self.max_size, self.max_size), - ).to(device) self.variance = [0.1, 0.2] - def load_state_dict(self, state_dict: Dict[str, torch.Tensor]) -> None: + def load_state_dict(self, state_dict: OrderedDict) -> None: self.model.load_state_dict(state_dict) - def eval(self): + def eval(self) -> None: # noqa: A003 self.model.eval() def predict_jsons( - self, image: np.array, confidence_threshold: float = 0.7, nms_threshold: float = 0.4 + self, image: np.ndarray, confidence_threshold: float = 0.7, nms_threshold: float = 0.4 ) -> List[Dict[str, Union[List, float]]]: with torch.no_grad(): original_height, original_width = image.shape[:2] - scale_landmarks = torch.from_numpy(np.tile([self.max_size, self.max_size], 5)).to(self.device) - scale_bboxes = torch.from_numpy(np.tile([self.max_size, self.max_size], 2)).to(self.device) - transformed_image = self.transform(image=image)["image"] - paded = pad_to_size(target_size=(self.max_size, self.max_size), image=transformed_image) + transformed_height, transformed_width = transformed_image.shape[:2] + transformed_size = (transformed_width, transformed_height) - pads = paded["pads"] + scale_landmarks = torch.from_numpy(np.tile(transformed_size, 5)).to(self.device) + scale_bboxes = torch.from_numpy(np.tile(transformed_size, 2)).to(self.device) - torched_image = tensor_from_rgb_image(paded["image"]).to(self.device) + prior_box = priorbox( + min_sizes=[[16, 32], [64, 128], [256, 512]], + steps=[8, 16, 32], + clip=False, + image_size=transformed_image.shape[:2], + ).to(self.device) - loc, conf, land = self.model(torched_image.unsqueeze(0)) + torched_image = tensor_from_rgb_image(transformed_image).to(self.device) + + loc, conf, land = self.model(torched_image.unsqueeze(0)) # pylint: disable=E1102 conf = F.softmax(conf, dim=-1) annotations: List[Dict[str, Union[List, float]]] = [] - boxes = decode(loc.data[0], self.prior_box, self.variance) + boxes = decode(loc.data[0], prior_box, self.variance) boxes *= scale_bboxes scores = conf[0][:, 1] - landmarks = decode_landm(land.data[0], self.prior_box, self.variance) + landmarks = decode_landm(land.data[0], prior_box, self.variance) landmarks *= scale_landmarks # ignore low scores @@ -79,34 +79,25 @@ def predict_jsons( landmarks = landmarks[valid_index] scores = scores[valid_index] - # Sort from high to low - order = scores.argsort(descending=True) - boxes = boxes[order] - landmarks = landmarks[order] - scores = scores[order] - # do NMS keep = nms(boxes, scores, nms_threshold) - boxes = boxes[keep, :].int() + boxes = boxes[keep, :] if boxes.shape[0] == 0: return [{"bbox": [], "score": -1, "landmarks": []}] landmarks = landmarks[keep] - scores = scores[keep].cpu().numpy().astype(np.float64) - boxes = boxes.cpu().numpy() - landmarks = landmarks.cpu().numpy() - landmarks = landmarks.reshape([-1, 2]) - - unpadded = unpad_from_size(pads, bboxes=boxes, keypoints=landmarks) + scores = scores[keep].cpu().numpy().astype(float) - resize_coeff = max(original_height, original_width) / self.max_size + boxes_np = boxes.cpu().numpy() + landmarks_np = landmarks.cpu().numpy() + resize_coeff = original_height / transformed_height - boxes = (unpadded["bboxes"] * resize_coeff).astype(int) - landmarks = (unpadded["keypoints"].reshape(-1, 10) * resize_coeff).astype(int) + boxes *= resize_coeff + landmarks_np = landmarks_np.reshape(-1, 10) * resize_coeff - for box_id, bbox in enumerate(boxes): + for box_id, bbox in enumerate(boxes_np): x_min, y_min, x_max, y_max = bbox x_min = np.clip(x_min, 0, original_width - 1) @@ -123,9 +114,11 @@ def predict_jsons( annotations += [ { - "bbox": bbox.tolist(), - "score": scores[box_id], - "landmarks": landmarks[box_id].reshape(-1, 2).tolist(), + "bbox": np.round(bbox.astype(float), ROUNDING_DIGITS).tolist(), + "score": np.round(scores, ROUNDING_DIGITS)[box_id], + "landmarks": np.round(landmarks_np[box_id].astype(float), ROUNDING_DIGITS) + .reshape(-1, 2) + .tolist(), } ] diff --git a/retinaface/prior_box.py b/retinaface/prior_box.py index 9594f97..ae70b3e 100644 --- a/retinaface/prior_box.py +++ b/retinaface/prior_box.py @@ -1,13 +1,14 @@ from itertools import product from math import ceil +from typing import List, Tuple import torch -def priorbox(min_sizes, steps, clip, image_size): +def priorbox(min_sizes: List[List[int]], steps: List[int], clip: bool, image_size: Tuple[int, int]) -> torch.Tensor: feature_maps = [[ceil(image_size[0] / step), ceil(image_size[1] / step)] for step in steps] - anchors = [] + anchors: List[float] = [] for k, f in enumerate(feature_maps): t_min_sizes = min_sizes[k] for i, j in product(range(f[0]), range(f[1])): diff --git a/retinaface/train.py b/retinaface/train.py index a0d7dc7..30c9e8a 100644 --- a/retinaface/train.py +++ b/retinaface/train.py @@ -2,7 +2,7 @@ import os from collections import OrderedDict from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Callable, Dict, List, Tuple, Union import numpy as np import pytorch_lightning as pl @@ -13,7 +13,9 @@ from albumentations.core.serialization import from_dict from iglovikov_helper_functions.config_parsing.utils import object_from_dict from iglovikov_helper_functions.metrics.map import recall_precision +from pytorch_lightning.core.optimizer import LightningOptimizer from pytorch_lightning.loggers import WandbLogger +from torch.optim import Optimizer from torch.utils.data import DataLoader from torchvision.ops import nms @@ -27,21 +29,16 @@ TRAIN_LABEL_PATH = Path(os.environ["TRAIN_LABEL_PATH"]) VAL_LABEL_PATH = Path(os.environ["VAL_LABEL_PATH"]) -print("TRAIN_IMAGE_PATH = ", TRAIN_IMAGE_PATH) -print("VAL_IMAGE_PATH = ", VAL_IMAGE_PATH) -print("TRAIN_LABEL_PATH = ", TRAIN_LABEL_PATH) -print("VAL_LABEL_PATH = ", VAL_LABEL_PATH) - -def get_args(): +def get_args() -> Any: parser = argparse.ArgumentParser() arg = parser.add_argument arg("-c", "--config_path", type=Path, help="Path to the config.", required=True) return parser.parse_args() -class RetinaFace(pl.LightningModule): - def __init__(self, config): +class RetinaFace(pl.LightningModule): # pylint: disable=R0901 + def __init__(self, config: Adict[str, Any]) -> None: super().__init__() self.config = config @@ -52,13 +49,13 @@ def __init__(self, config): self.loss = object_from_dict(self.config.loss, priors=self.prior_box) - def setup(self, state=0): # pylint: disable=W0613 + def setup(self, state=0) -> None: # type: ignore self.preproc = Preproc(img_dim=self.config.image_size[0]) - def forward(self, batch): + def forward(self, batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # type: ignore return self.model(batch) - def train_dataloader(self): + def train_dataloader(self) -> DataLoader: result = DataLoader( FaceDetectionDataset( label_path=TRAIN_LABEL_PATH, @@ -74,10 +71,9 @@ def train_dataloader(self): drop_last=False, collate_fn=detection_collate, ) - print("Len train dataloader = ", len(result)) return result - def val_dataloader(self): + def val_dataloader(self) -> DataLoader: result = DataLoader( FaceDetectionDataset( label_path=VAL_LABEL_PATH, @@ -93,20 +89,21 @@ def val_dataloader(self): drop_last=True, collate_fn=detection_collate, ) - print("Len val dataloader = ", len(result)) return result - def configure_optimizers(self): + def configure_optimizers( + self, + ) -> Tuple[Callable[[bool], Union[Optimizer, List[Optimizer], List[LightningOptimizer]]], List[Any]]: optimizer = object_from_dict( self.config.optimizer, params=[x for x in self.model.parameters() if x.requires_grad] ) scheduler = object_from_dict(self.config.scheduler, optimizer=optimizer) - self.optimizers = [optimizer] - return self.optimizers, [scheduler] + self.optimizers = [optimizer] # type: ignore + return self.optimizers, [scheduler] # type: ignore - def training_step(self, batch, batch_idx): # pylint: disable=W0613 + def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int): # type: ignore images = batch["image"] targets = batch["annotation"] @@ -128,7 +125,7 @@ def training_step(self, batch, batch_idx): # pylint: disable=W0613 return total_loss - def validation_step(self, batch, batch_idx): # pylint: disable=W0613 + def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int): # type: ignore images = batch["image"] image_height = images.shape[2] @@ -220,18 +217,18 @@ def validation_epoch_end(self, outputs: List) -> None: _, _, average_precision = recall_precision(result_gt, result_predictions, 0.5) - self.log("epoch", self.trainer.current_epoch, on_step=False, on_epoch=True, logger=True) + self.log("epoch", self.trainer.current_epoch, on_step=False, on_epoch=True, logger=True) # type: ignore self.log("val_loss", average_precision, on_step=False, on_epoch=True, logger=True) def _get_current_lr(self) -> torch.Tensor: # type: ignore - lr = [x["lr"] for x in self.optimizers[0].param_groups][0] + lr = [x["lr"] for x in self.optimizers[0].param_groups][0] # type: ignore return torch.from_numpy(np.array([lr]))[0].to(self.device) -def main(): +def main() -> None: args = get_args() - with open(args.config_path) as f: + with args.config_path.open() as f: config = Adict(yaml.load(f, Loader=yaml.SafeLoader)) pl.trainer.seed_everything(config.seed) diff --git a/retinaface/utils.py b/retinaface/utils.py index 6f0f4a8..1f046d4 100644 --- a/retinaface/utils.py +++ b/retinaface/utils.py @@ -2,6 +2,7 @@ import cv2 import numpy as np +import torch def vis_annotations(image: np.ndarray, annotations: List[Dict[str, Any]]) -> np.ndarray: @@ -13,12 +14,17 @@ def vis_annotations(image: np.ndarray, annotations: List[Dict[str, Any]]) -> np. colors = [(255, 0, 0), (128, 255, 0), (255, 178, 102), (102, 128, 255), (0, 255, 255)] for landmark_id, (x, y) in enumerate(landmarks): - vis_image = cv2.circle(vis_image, (x, y), radius=3, color=colors[landmark_id], thickness=3) + vis_image = cv2.circle(vis_image, (int(x), int(y)), radius=3, color=colors[landmark_id], thickness=3) - x_min, y_min, x_max, y_max = annotation["bbox"] + x_min, y_min, x_max, y_max = (int(tx) for tx in annotation["bbox"]) x_min = np.clip(x_min, 0, x_max - 1) y_min = np.clip(y_min, 0, y_max - 1) vis_image = cv2.rectangle(vis_image, (x_min, y_min), (x_max, y_max), color=(0, 255, 0), thickness=2) return vis_image + + +def tensor_from_rgb_image(image: np.ndarray) -> torch.Tensor: + image = np.ascontiguousarray(np.transpose(image, (2, 0, 1))) + return torch.from_numpy(image) diff --git a/setup.cfg b/setup.cfg index c2f3c34..e34c821 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,15 @@ [flake8] -max-line-length = 119 -exclude =.git,__pycache__,docs/source/conf.py,build,dist -ignore = I101,I201,F401,F403,S001,D100,D101,D102,D103,D104,D105,D106,D107,D200,D205,D400,W504,E203,D202,D401,W503 +max-line-length = 120 +exclude =.git,__pycache__,docs/source/conf.py,build,dist,tests +ignore = I101,D100,D101,D102,D103,D104,D105,D107,D401,E203,I900,N802,N806,N812,W503,S311,S605,S607 [mypy] ignore_missing_imports = True +disallow_untyped_defs = True +check_untyped_defs = True +warn_redundant_casts = True +no_implicit_optional = True +strict_optional = True + +[mypy-tests.*] +ignore_errors = True diff --git a/setup.py b/setup.py index 90e6bde..ae8797c 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,9 @@ -import io import os import re import sys +from pathlib import Path from shutil import rmtree -from typing import Tuple, List +from typing import List, Tuple from setuptools import Command, find_packages, setup @@ -14,18 +14,18 @@ email = "iglovikov@gmail.com" author = "Vladimir Iglovikov" requires_python = ">=3.0.0" -current_dir = os.path.abspath(os.path.dirname(__file__)) +current_dir = Path(__file__).absolute().parent -def get_version(): - version_file = os.path.join(current_dir, "retinaface", "__init__.py") - with io.open(version_file, encoding="utf-8") as f: - return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', f.read(), re.M).group(1) +def get_version() -> str: + version_file = current_dir / "retinaface" / "__init__.py" + with version_file.open(encoding="utf-8") as f: + return re.search(r'^__version__ = [\'"]([^\'"]*)[\'"]', f.read(), re.M).group(1) # type: ignore # What packages are required for this module to be executed? try: - with open(os.path.join(current_dir, "requirements.txt"), encoding="utf-8") as f: + with (current_dir / "requirements.txt").open(encoding="utf-8") as f: required = f.read().split("\n") except FileNotFoundError: required = [] @@ -38,16 +38,16 @@ def get_version(): about = {"__version__": version} -def get_test_requirements(): +def get_test_requirements() -> List[str]: requirements = ["pytest"] if sys.version_info < (3, 3): requirements.append("mock") return requirements -def get_long_description(): - base_dir = os.path.abspath(os.path.dirname(__file__)) - with io.open(os.path.join(base_dir, "README.md"), encoding="utf-8") as f: +def get_long_description() -> str: + base_dir = Path(__file__).absolute().parent + with (base_dir / "README.md").open(encoding="utf-8") as f: return f.read() @@ -58,17 +58,17 @@ class UploadCommand(Command): user_options: List[Tuple] = [] @staticmethod - def status(s): + def status(s: str) -> None: """Print things in bold.""" - print(s) + print(s) # noqa: T001 - def initialize_options(self): + def initialize_options(self) -> None: pass - def finalize_options(self): + def finalize_options(self) -> None: pass - def run(self): + def run(self) -> None: try: self.status("Removing previous builds...") rmtree(os.path.join(current_dir, "dist")) @@ -82,7 +82,7 @@ def run(self): os.system("twine upload dist/*") self.status("Pushing git tags...") - os.system("git tag v{}".format(about["__version__"])) + os.system(f"git tag v{about['__version__']}") os.system("git push --tags") sys.exit() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..bf04e04 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,45 @@ +from pathlib import Path +from typing import Union + +import cv2 +import numpy as np + + +def load_rgb(file_path: Union[str, Path]) -> np.ndarray: + image = cv2.imread(str(file_path)) + + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + +images = { + "with_faces": { + "image": load_rgb("tests/data/13.jpg"), + "faces": [ + { + "bbox": [256.9, 93.64, 336.79, 201.76], + "score": 1.0, + "landmarks": [ + [286.17, 134.94], + [323.32, 135.28], + [309.15, 161.34], + [283.74, 168.48], + [320.72, 168.48], + ], + }, + { + "bbox": [436.62, 118.5, 510.04, 211.13], + "score": 1.0, + "landmarks": [[460.96, 155.7], [494.47, 154.35], [480.52, 175.92], [464.73, 188.05], [491.9, 187.53]], + }, + { + "bbox": [657.3, 156.87, 729.81, 245.78], + "score": 1.0, + "landmarks": [[665.64, 187.11], [696.5, 196.97], [670.65, 214.76], [666.92, 220.2], [689.45, 228.91]], + }, + ], + }, + "with_no_faces": { + "image": load_rgb("tests/data/no_face.jpg"), + "faces": [{"bbox": [], "score": -1, "landmarks": []}], + }, +} diff --git a/tests/data/13.jpg b/tests/data/13.jpg new file mode 100644 index 0000000..36945c7 Binary files /dev/null and b/tests/data/13.jpg differ diff --git a/tests/data/no_face.jpg b/tests/data/no_face.jpg new file mode 100644 index 0000000..9d0c74f Binary files /dev/null and b/tests/data/no_face.jpg differ diff --git a/tests/test_retinaface.py b/tests/test_retinaface.py new file mode 100644 index 0000000..34c114a --- /dev/null +++ b/tests/test_retinaface.py @@ -0,0 +1,22 @@ +import pytest + +from retinaface.pre_trained_models import get_model +from tests.conftest import images + +max_size = 1280 + + +@pytest.mark.parametrize( + ["image", "faces"], + [ + (images["with_faces"]["image"], images["with_faces"]["faces"]), + (images["with_no_faces"]["image"], images["with_no_faces"]["faces"]), + ], +) +def test_predict_jsons(image, faces): + model = get_model("resnet50_2020-07-20", max_size=max_size) + model.eval() + + result = model.predict_jsons(image) + + assert result == faces