From 93b0aaaaf1eda5f1792b0f447c00171b5150e943 Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Wed, 21 Feb 2024 13:52:17 -0800 Subject: [PATCH] isort, flake8 => ruff (#1529) --- .gitignore | 2 + .pre-commit-config.yaml | 22 +- README.md | 2 + .../augmentations/blur/functional.py | 15 +- .../augmentations/blur/transforms.py | 169 ++++++--- .../augmentations/crops/functional.py | 37 +- .../augmentations/crops/transforms.py | 87 +++-- .../augmentations/domain_adaptation.py | 45 ++- .../augmentations/dropout/channel_dropout.py | 16 +- .../augmentations/dropout/coarse_dropout.py | 30 +- .../augmentations/dropout/functional.py | 7 +- .../augmentations/dropout/grid_dropout.py | 112 ++++-- .../augmentations/dropout/mask_dropout.py | 11 +- .../augmentations/dropout/xy_masking.py | 41 ++- albumentations/augmentations/functional.py | 277 ++++++++------- .../augmentations/geometric/functional.py | 201 ++++++----- .../augmentations/geometric/resize.py | 18 +- .../augmentations/geometric/rotate.py | 42 ++- .../augmentations/geometric/transforms.py | 115 ++++-- albumentations/augmentations/transforms.py | 336 ++++++++++++------ albumentations/augmentations/utils.py | 51 +-- albumentations/core/bbox_utils.py | 91 +++-- albumentations/core/composition.py | 99 ++++-- albumentations/core/keypoints_utils.py | 41 +-- albumentations/core/serialization.py | 126 ++++--- albumentations/core/transforms_interface.py | 50 ++- albumentations/core/utils.py | 26 +- albumentations/pytorch/transforms.py | 16 +- pyproject.toml | 67 +++- requirements-dev.txt | 3 +- tests/test_functional.py | 122 ++++--- tests/test_keypoint.py | 2 +- tests/test_transforms.py | 2 +- 33 files changed, 1404 insertions(+), 877 deletions(-) diff --git a/.gitignore b/.gitignore index 4958f46a6..8b0ed9403 100644 --- a/.gitignore +++ b/.gitignore @@ -165,3 +165,5 @@ conda.recipe/ .gitingore *.ipynb + +.ruff_cache/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8f4ec5b20..6cd062df9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -37,16 +37,15 @@ repos: - id: end-of-file-fixer - id: trailing-whitespace - id: requirements-txt-fixer - - repo: https://github.com/asottile/pyupgrade - rev: v3.15.0 + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.2.2 hooks: - - id: pyupgrade - args: ["--py38-plus"] - - repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort - args: [ "--profile", "black" ] + # Run the linter. + - id: ruff + args: [ --fix ] + # Run the formatter. + - id: ruff-format - repo: https://github.com/psf/black rev: 24.2.0 hooks: @@ -60,11 +59,6 @@ repos: - id: python-check-blanket-noqa - id: python-use-type-annotations - id: text-unicode-replacement-char - - repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - exclude: ^setup.py - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.8.0 hooks: diff --git a/README.md b/README.md index 8bd63a051..b32f78e44 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,7 @@ [![PyPI version](https://badge.fury.io/py/albumentations.svg)](https://badge.fury.io/py/albumentations) ![CI](https://github.com/albumentations-team/albumentations/workflows/CI/badge.svg) +[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) Albumentations is a Python library for image augmentation. Image augmentation is used in deep learning and computer vision tasks to increase the quality of trained models. The purpose of image augmentation is to create new training samples from the existing data. @@ -19,6 +20,7 @@ Here is an example of how you can apply some [pixel-level](#pixel-level-transfor - The library is [**widely used**](#who-is-using-albumentations) in industry, deep learning research, machine learning competitions, and open source projects. ## Table of contents + - [Albumentations](#albumentations) - [Why Albumentations](#why-albumentations) - [Table of contents](#table-of-contents) diff --git a/albumentations/augmentations/blur/functional.py b/albumentations/augmentations/blur/functional.py index 7581c5879..7dc5fb1d4 100644 --- a/albumentations/augmentations/blur/functional.py +++ b/albumentations/augmentations/blur/functional.py @@ -15,6 +15,9 @@ __all__ = ["blur", "median_blur", "gaussian_blur", "glass_blur"] +TWO = 2 +EIGHT = 8 + @preserve_shape def blur(img: np.ndarray, ksize: int) -> np.ndarray: @@ -63,9 +66,9 @@ def glass_blur( range(img.shape[1] - max_delta, max_delta, -1), ) ): - ind = ind if ind < len(dxy) else ind % len(dxy) - dy = dxy[ind, i, 0] - dx = dxy[ind, i, 1] + idx = ind if ind < len(dxy) else ind % len(dxy) + dy = dxy[idx, i, 0] + dx = dxy[idx, i, 1] x[h, w], x[h + dy, w + dx] = x[h + dy, w + dx], x[h, w] else: ValueError(f"Unsupported mode `{mode}`. Supports only `fast` and `exact`.") @@ -75,7 +78,7 @@ def glass_blur( def defocus(img: np.ndarray, radius: int, alias_blur: float) -> np.ndarray: length = np.arange(-max(8, radius), max(8, radius) + 1) - ksize = 3 if radius <= 8 else 5 + ksize = 3 if radius <= EIGHT else 5 x, y = np.meshgrid(length, length) aliased_disk = np.array((x**2 + y**2) <= radius**2, dtype=np.float32) @@ -101,6 +104,4 @@ def zoom_blur(img: np.ndarray, zoom_factors: Union[np.ndarray, Sequence[int]]) - for zoom_factor in zoom_factors: out += central_zoom(img, zoom_factor) - img = ((img + out) / (len(zoom_factors) + 1)).astype(img.dtype) - - return img + return ((img + out) / (len(zoom_factors) + 1)).astype(img.dtype) diff --git a/albumentations/augmentations/blur/transforms.py b/albumentations/augmentations/blur/transforms.py index 84c914116..4fe8d0846 100644 --- a/albumentations/augmentations/blur/transforms.py +++ b/albumentations/augmentations/blur/transforms.py @@ -14,10 +14,15 @@ __all__ = ["Blur", "MotionBlur", "GaussianBlur", "GlassBlur", "AdvancedBlur", "MedianBlur", "Defocus", "ZoomBlur"] +HALF = 0 +TWO = 2 + + class Blur(ImageOnlyTransform): """Blur the input image using a random-sized kernel. Args: + ---- blur_limit: maximum kernel size for blurring the input image. Should be in range [3, inf). Default: (3, 7). p: probability of applying the transform. Default: 0.5. @@ -27,6 +32,7 @@ class Blur(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__(self, blur_limit: ScaleIntType = 7, always_apply: bool = False, p: float = 0.5): @@ -47,6 +53,7 @@ class MotionBlur(Blur): """Apply motion blur to the input image using a random-sized kernel. Args: + ---- blur_limit (int): maximum kernel size for blurring the input image. Should be in range [3, inf). Default: (3, 7). allow_shifted (bool): if set to true creates non shifted kernels only, @@ -58,6 +65,7 @@ class MotionBlur(Blur): Image types: uint8, float32 + """ def __init__( @@ -74,14 +82,14 @@ def __init__( raise ValueError(f"Blur limit must be odd when centered=True. Got: {self.blur_limit}") def get_transform_init_args_names(self) -> Tuple[str, ...]: - return super().get_transform_init_args_names() + ("allow_shifted",) + return (*super().get_transform_init_args_names(), "allow_shifted") def apply(self, img: np.ndarray, kernel: Optional[np.ndarray] = None, **params: Any) -> np.ndarray: return FMain.convolve(img, kernel=kernel) def get_params(self) -> Dict[str, Any]: ksize = random.choice(list(range(self.blur_limit[0], self.blur_limit[1] + 1, 2))) - if ksize <= 2: + if ksize <= TWO: raise ValueError(f"ksize must be > 2. Got: {ksize}") kernel = np.zeros((ksize, ksize), dtype=np.uint8) x1, x2 = random.randint(0, ksize - 1), random.randint(0, ksize - 1) @@ -122,6 +130,7 @@ class MedianBlur(Blur): """Blur the input image using a median filter with a random aperture linear size. Args: + ---- blur_limit (int): maximum aperture linear size for blurring the input image. Must be odd and in range [3, inf). Default: (3, 7). p (float): probability of applying the transform. Default: 0.5. @@ -131,13 +140,15 @@ class MedianBlur(Blur): Image types: uint8, float32 + """ def __init__(self, blur_limit: ScaleIntType = 7, always_apply: bool = False, p: float = 0.5): super().__init__(blur_limit, always_apply, p) if self.blur_limit[0] % 2 != 1 or self.blur_limit[1] % 2 != 1: - raise ValueError("MedianBlur supports only odd blur limits.") + msg = "MedianBlur supports only odd blur limits." + raise ValueError(msg) def apply(self, img: np.ndarray, kernel: int = 3, **params: Any) -> np.ndarray: return F.median_blur(img, kernel) @@ -147,6 +158,7 @@ class GaussianBlur(ImageOnlyTransform): """Blur the input image using a Gaussian filter with a random kernel size. Args: + ---- blur_limit (int, (int, int)): maximum Gaussian kernel size for blurring the input image. Must be zero or odd and in range [0, inf). If set to 0 it will be computed from sigma as `round(sigma * (3 if img.dtype == np.uint8 else 4) * 2 + 1) + 1`. @@ -162,6 +174,7 @@ class GaussianBlur(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -185,7 +198,8 @@ def __init__( if (self.blur_limit[0] != 0 and self.blur_limit[0] % 2 != 1) or ( self.blur_limit[1] != 0 and self.blur_limit[1] % 2 != 1 ): - raise ValueError("GaussianBlur supports only odd blur limits.") + msg = "GaussianBlur supports only odd blur limits." + raise ValueError(msg) def apply(self, img: np.ndarray, ksize: int = 3, sigma: float = 0, **params: Any) -> np.ndarray: return F.gaussian_blur(img, ksize, sigma=sigma) @@ -205,6 +219,7 @@ class GlassBlur(Blur): """Apply glass noise to the input image. Args: + ---- sigma (float): standard deviation for Gaussian kernel. max_delta (int): max distance between pixels which are swapped. iterations (int): number of repeats. @@ -221,6 +236,7 @@ class GlassBlur(Blur): Reference: | https://arxiv.org/abs/1903.12261 | https://github.com/hendrycks/robustness/blob/master/ImageNet-C/create_c/make_imagenet_c.py + """ def __init__( @@ -244,8 +260,11 @@ def __init__( self.iterations = iterations self.mode = mode - def apply(self, img: np.ndarray, dxy: np.ndarray = None, **params) -> np.ndarray: # type: ignore - assert dxy is not None + def apply(self, img: np.ndarray, *args: Any, dxy: np.ndarray = None, **params: Any) -> np.ndarray: + if dxy is None: + msg = "dxy is None" + raise ValueError(msg) + return F.glass_blur(img, self.sigma, self.max_delta, self.iterations, dxy, self.mode) def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, np.ndarray]: @@ -268,42 +287,59 @@ def targets_as_params(self) -> List[str]: class AdvancedBlur(ImageOnlyTransform): - """Blur the input image using a Generalized Normal filter with a randomly selected parameters. - This transform also adds multiplicative noise to generated kernel before convolution. + """Blurs the input image using a Generalized Normal filter with randomly selected parameters. + + This transform also adds multiplicative noise to the generated kernel before convolution, + affecting the image in a unique way that combines blurring and noise injection for enhanced + data augmentation. Args: - blur_limit: maximum Gaussian kernel size for blurring the input image. - Must be zero or odd and in range [0, inf). If set to 0 it will be computed from sigma + ---- + blur_limit (ScaleIntType, optional): Maximum Gaussian kernel size for blurring the input image. + Must be zero or odd and in range [0, inf). If set to 0, it will be computed from sigma as `round(sigma * (3 if img.dtype == np.uint8 else 4) * 2 + 1) + 1`. - If set single value `blur_limit` will be in range (0, blur_limit). - Default: (3, 7). - sigmaX_limit: Gaussian kernel standard deviation. Must be in range [0, inf). - If set single value `sigmaX_limit` will be in range (0, sigma_limit). - If set to 0 sigma will be computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`. Default: 0. - sigmaY_limit: Same as `sigmaY_limit` for another dimension. - rotate_limit: Range from which a random angle used to rotate Gaussian kernel is picked. - If limit is a single int an angle is picked from (-rotate_limit, rotate_limit). Default: (-90, 90). - beta_limit: Distribution shape parameter, 1 is the normal distribution. Values below 1.0 make distribution - tails heavier than normal, values above 1.0 make it lighter than normal. Default: (0.5, 8.0). - noise_limit: Multiplicative factor that control strength of kernel noise. Must be positive and preferably - centered around 1.0. If set single value `noise_limit` will be in range (0, noise_limit). - Default: (0.75, 1.25). - p (float): probability of applying the transform. Default: 0.5. + If a single value is provided, `blur_limit` will be in the range (0, blur_limit). + Defaults to (3, 7). + sigma_x_limit ScaleFloatType: Gaussian kernel standard deviation for the X dimension. + Must be in range [0, inf). If a single value is provided, `sigma_x_limit` will be in the range + (0, sigma_limit). If set to 0, sigma will be computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`. + Defaults to (0.2, 1.0). + sigma_y_limit ScaleFloatType: Gaussian kernel standard deviation for the Y dimension. + Must follow the same rules as `sigma_x_limit`. + Defaults to (0.2, 1.0). + rotate_limit (ScaleIntType, optional): Range from which a random angle used to rotate the Gaussian kernel + is picked. If limit is a single int, an angle is picked from (-rotate_limit, rotate_limit). + Defaults to (-90, 90). + beta_limit (ScaleFloatType, optional): Distribution shape parameter. 1 represents the normal distribution. + Values below 1.0 make distribution tails heavier than normal, and values above 1.0 make it + lighter than normal. + Defaults to (0.5, 8.0). + noise_limit (ScaleFloatType, optional): Multiplicative factor that controls the strength of kernel noise. + Must be positive and preferably centered around 1.0. If a single value is provided, + `noise_limit` will be in the range (0, noise_limit). + Defaults to (0.75, 1.25). + p (float, optional): Probability of applying the transform. + Defaults to 0.5. Reference: - https://arxiv.org/abs/2107.10833 + "Real-ESRGAN: Training Real-World Blind Super-Resolution with Pure Synthetic Data", + available at https://arxiv.org/abs/2107.10833 Targets: - image + This transformation is applied to images only. + Image types: - uint8, float32 + This transform supports uint8 and float32 image types. + """ def __init__( self, blur_limit: ScaleIntType = (3, 7), - sigmaX_limit: ScaleFloatType = (0.2, 1.0), - sigmaY_limit: ScaleFloatType = (0.2, 1.0), + sigma_x_limit: ScaleFloatType = (0.2, 1.0), + sigma_y_limit: ScaleFloatType = (0.2, 1.0), + sigmaX_limit: ScaleFloatType = (0.2, 1.0), # noqa: N803 + sigmaY_limit: ScaleFloatType = (0.2, 1.0), # noqa: N803 rotate_limit: ScaleIntType = 90, beta_limit: ScaleFloatType = (0.5, 8.0), noise_limit: ScaleFloatType = (0.9, 1.1), @@ -312,8 +348,18 @@ def __init__( ): super().__init__(always_apply, p) self.blur_limit = cast(Tuple[int, int], to_tuple(blur_limit, 3)) - self.sigmaX_limit = self.__check_values(to_tuple(sigmaX_limit, 0.0), name="sigmaX_limit") - self.sigmaY_limit = self.__check_values(to_tuple(sigmaY_limit, 0.0), name="sigmaY_limit") + + # Handle deprecation of sigmaX_limit and sigmaY_limit + if sigmaX_limit is not None: + warnings.warn("sigmaX_limit is deprecated; use sigma_x_limit instead.", DeprecationWarning) + sigma_x_limit = sigma_x_limit or sigmaX_limit + + if sigmaY_limit is not None: + warnings.warn("sigmaY_limit is deprecated; use sigma_y_limit instead.", DeprecationWarning) + sigma_y_limit = sigma_y_limit or sigmaY_limit + + self.sigma_x_limit = self.__check_values(to_tuple(sigma_x_limit, 0.0), name="sigma_x_limit") + self.sigma_y_limit = self.__check_values(to_tuple(sigma_y_limit, 0.0), name="sigma_y_limit") self.rotate_limit = to_tuple(rotate_limit) self.beta_limit = to_tuple(beta_limit, low=0.0) self.noise_limit = self.__check_values(to_tuple(noise_limit, 0.0), name="noise_limit") @@ -321,13 +367,16 @@ def __init__( if (self.blur_limit[0] != 0 and self.blur_limit[0] % 2 != 1) or ( self.blur_limit[1] != 0 and self.blur_limit[1] % 2 != 1 ): - raise ValueError("AdvancedBlur supports only odd blur limits.") + msg = "AdvancedBlur supports only odd blur limits." + raise ValueError(msg) - if self.sigmaX_limit[0] == 0 and self.sigmaY_limit[0] == 0: - raise ValueError("sigmaX_limit and sigmaY_limit minimum value can not be both equal to 0.") + if self.sigma_x_limit[0] == 0 and self.sigma_y_limit[0] == 0: + msg = "sigma_x_limit and sigma_y_limit minimum value cannot be both equal to 0." + raise ValueError(msg) if not (self.beta_limit[0] < 1.0 < self.beta_limit[1]): - raise ValueError("Beta limit is expected to include 1.0") + msg = "Beta limit is expected to include 1.0." + raise ValueError(msg) @staticmethod def __check_values( @@ -337,30 +386,29 @@ def __check_values( raise ValueError(f"{name} values should be between {bounds}") return value - def apply(self, img: np.ndarray, kernel: np.ndarray = np.array(None), **params: Any) -> np.ndarray: + def apply(self, img: np.ndarray, kernel: Optional[np.ndarray] = None, **params: Any) -> np.ndarray: return FMain.convolve(img, kernel=kernel) def get_params(self) -> Dict[str, np.ndarray]: ksize = random.randrange(self.blur_limit[0], self.blur_limit[1] + 1, 2) - sigmaX = random.uniform(*self.sigmaX_limit) - sigmaY = random.uniform(*self.sigmaY_limit) + sigma_x = random.uniform(*self.sigma_x_limit) + sigma_y = random.uniform(*self.sigma_y_limit) angle = np.deg2rad(random.uniform(*self.rotate_limit)) # Split into 2 cases to avoid selection of narrow kernels (beta > 1) too often. - if random.random() < 0.5: - beta = random.uniform(self.beta_limit[0], 1) - else: - beta = random.uniform(1, self.beta_limit[1]) + beta = ( + random.uniform(self.beta_limit[0], 1) if random.random() < HALF else random.uniform(1, self.beta_limit[1]) + ) noise_matrix = random_utils.uniform(self.noise_limit[0], self.noise_limit[1], size=[ksize, ksize]) # Generate mesh grid centered at zero. ax = np.arange(-ksize // 2 + 1.0, ksize // 2 + 1.0) - # Shape (ksize, ksize, 2) + # > Shape (ksize, ksize, 2) grid = np.stack(np.meshgrid(ax, ax), axis=-1) # Calculate rotated sigma matrix - d_matrix = np.array([[sigmaX**2, 0], [0, sigmaY**2]]) + d_matrix = np.array([[sigma_x**2, 0], [0, sigma_y**2]]) u_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) sigma_matrix = np.dot(u_matrix, np.dot(d_matrix, u_matrix.T)) @@ -377,8 +425,8 @@ def get_params(self) -> Dict[str, np.ndarray]: def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str]: return ( "blur_limit", - "sigmaX_limit", - "sigmaY_limit", + "sigma_x_limit", + "sigma_y_limit", "rotate_limit", "beta_limit", "noise_limit", @@ -386,10 +434,10 @@ def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str]: class Defocus(ImageOnlyTransform): - """ - Apply defocus transform. See https://arxiv.org/abs/1903.12261. + """Apply defocus transform. See https://arxiv.org/abs/1903.12261. Args: + ---- radius ((int, int) or int): range for radius of defocusing. If limit is a single int, the range will be [1, limit]. Default: (3, 10). alias_blur ((float, float) or float): range for alias_blur of defocusing (sigma of gaussian blur). @@ -401,6 +449,7 @@ class Defocus(ImageOnlyTransform): Image types: Any + """ def __init__( @@ -415,10 +464,12 @@ def __init__( self.alias_blur = to_tuple(alias_blur, low=0) if self.radius[0] <= 0: - raise ValueError("Parameter radius must be positive") + msg = "Parameter radius must be positive" + raise ValueError(msg) if self.alias_blur[0] < 0: - raise ValueError("Parameter alias_blur must be non-negative") + msg = "Parameter alias_blur must be non-negative" + raise ValueError(msg) def apply(self, img: np.ndarray, radius: int = 3, alias_blur: float = 0.5, **params: Any) -> np.ndarray: return F.defocus(img, radius, alias_blur) @@ -434,10 +485,10 @@ def get_transform_init_args_names(self) -> Tuple[str, str]: class ZoomBlur(ImageOnlyTransform): - """ - Apply zoom blur transform. See https://arxiv.org/abs/1903.12261. + """Apply zoom blur transform. See https://arxiv.org/abs/1903.12261. Args: + ---- max_factor ((float, float) or float): range for max factor for blurring. If max_factor is a single float, the range will be (1, limit). Default: (1, 1.31). All max_factor values should be larger than 1. @@ -451,6 +502,7 @@ class ZoomBlur(ImageOnlyTransform): Image types: Any + """ def __init__( @@ -465,12 +517,17 @@ def __init__( self.step_factor = to_tuple(step_factor, step_factor) if self.max_factor[0] < 1: - raise ValueError("Max factor must be larger or equal 1") + msg = "Max factor must be larger or equal 1" + raise ValueError(msg) if self.step_factor[0] <= 0: - raise ValueError("Step factor must be positive") + msg = "Step factor must be positive" + raise ValueError(msg) + + def apply(self, img: np.ndarray, zoom_factors: Optional[np.ndarray] = None, **params: Any) -> np.ndarray: + if zoom_factors is None: + msg = "zoom_factors is None" + raise ValueError(msg) - def apply(self, img: np.ndarray, zoom_factors: np.ndarray = np.array(None), **params: Any) -> np.ndarray: - assert zoom_factors is not None return F.zoom_blur(img, zoom_factors) def get_params(self) -> Dict[str, Any]: diff --git a/albumentations/augmentations/crops/functional.py b/albumentations/augmentations/crops/functional.py index 249bc7bae..586221ef7 100644 --- a/albumentations/augmentations/crops/functional.py +++ b/albumentations/augmentations/crops/functional.py @@ -3,14 +3,13 @@ import cv2 import numpy as np +from albumentations.augmentations.geometric import functional as FGeometric from albumentations.augmentations.utils import ( _maybe_process_in_chunks, preserve_channel_dim, ) - -from ...core.bbox_utils import denormalize_bbox, normalize_bbox -from ...core.types import BoxInternalType, KeypointInternalType -from ..geometric import functional as FGeometric +from albumentations.core.bbox_utils import denormalize_bbox, normalize_bbox +from albumentations.core.types import BoxInternalType, KeypointInternalType __all__ = [ "get_random_crop_coords", @@ -49,10 +48,7 @@ def random_crop(img: np.ndarray, crop_height: int, crop_width: int, h_start: flo height, width = img.shape[:2] if height < crop_height or width < crop_width: raise ValueError( - "Requested crop size ({crop_height}, {crop_width}) is " - "larger than the image size ({height}, {width})".format( - crop_height=crop_height, crop_width=crop_width, height=height, width=width - ) + f"Requested crop size ({crop_height}, {crop_width}) is " f"larger than the image size ({height}, {width})" ) x1, y1, x2, y2 = get_random_crop_coords(height, width, crop_height, crop_width, h_start, w_start) return img[y1:y2, x1:x2] @@ -70,6 +66,7 @@ def crop_bbox_by_coords( required height and width of the crop. Args: + ---- bbox: A cropped box `(x_min, y_min, x_max, y_max)`. crop_coords: Crop coordinates `(x1, y1, x2, y2)`. crop_height: @@ -78,6 +75,7 @@ def crop_bbox_by_coords( cols: Image cols. Returns: + ------- A cropped bounding box `(x_min, y_min, x_max, y_max)`. """ @@ -102,10 +100,12 @@ def crop_keypoint_by_coords( required height and width of the crop. Args: + ---- keypoint (tuple): A keypoint `(x, y, angle, scale)`. crop_coords (tuple): Crop box coords `(x1, x2, y1, y2)`. Returns: + ------- A keypoint `(x, y, angle, scale)`. """ @@ -126,6 +126,7 @@ def keypoint_random_crop( """Keypoint random crop. Args: + ---- keypoint: (tuple): A keypoint `(x, y, angle, scale)`. crop_height (int): Crop height. crop_width (int): Crop width. @@ -135,6 +136,7 @@ def keypoint_random_crop( cols (int): Image width. Returns: + ------- A keypoint `(x, y, angle, scale)`. """ @@ -154,10 +156,7 @@ def center_crop(img: np.ndarray, crop_height: int, crop_width: int) -> np.ndarra height, width = img.shape[:2] if height < crop_height or width < crop_width: raise ValueError( - "Requested crop size ({crop_height}, {crop_width}) is " - "larger than the image size ({height}, {width})".format( - crop_height=crop_height, crop_width=crop_width, height=height, width=width - ) + f"Requested crop size ({crop_height}, {crop_width}) is " f"larger than the image size ({height}, {width})" ) x1, y1, x2, y2 = get_center_crop_coords(height, width, crop_height, crop_width) return img[y1:y2, x1:x2] @@ -174,6 +173,7 @@ def keypoint_center_crop( """Keypoint center crop. Args: + ---- keypoint: A keypoint `(x, y, angle, scale)`. crop_height: Crop height. crop_width: Crop width. @@ -181,6 +181,7 @@ def keypoint_center_crop( cols: Image width. Returns: + ------- A keypoint `(x, y, angle, scale)`. """ @@ -193,18 +194,14 @@ def crop(img: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int) -> np. if x_max <= x_min or y_max <= y_min: raise ValueError( "We should have x_min < x_max and y_min < y_max. But we got" - " (x_min = {x_min}, y_min = {y_min}, x_max = {x_max}, y_max = {y_max})".format( - x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max - ) + f" (x_min = {x_min}, y_min = {y_min}, x_max = {x_max}, y_max = {y_max})" ) if x_min < 0 or x_max > width or y_min < 0 or y_max > height: raise ValueError( "Values for crop should be non negative and equal or smaller than image sizes" - "(x_min = {x_min}, y_min = {y_min}, x_max = {x_max}, y_max = {y_max}, " - "height = {height}, width = {width})".format( - x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max, height=height, width=width - ) + f"(x_min = {x_min}, y_min = {y_min}, x_max = {x_max}, y_max = {y_max}, " + f"height = {height}, width = {width})" ) return img[y_min:y_max, x_min:x_max] @@ -216,6 +213,7 @@ def bbox_crop( """Crop a bounding box. Args: + ---- bbox: A bounding box `(x_min, y_min, x_max, y_max)`. x_min: y_min: @@ -225,6 +223,7 @@ def bbox_crop( cols: Image cols. Returns: + ------- A cropped bounding box `(x_min, y_min, x_max, y_max)`. """ diff --git a/albumentations/augmentations/crops/transforms.py b/albumentations/augmentations/crops/transforms.py index ecf6b7d2a..ba98e306d 100644 --- a/albumentations/augmentations/crops/transforms.py +++ b/albumentations/augmentations/crops/transforms.py @@ -5,16 +5,11 @@ import cv2 import numpy as np +from albumentations.augmentations.geometric import functional as FGeometric from albumentations.core.bbox_utils import union_of_bboxes +from albumentations.core.transforms_interface import DualTransform, to_tuple +from albumentations.core.types import BoxInternalType, KeypointInternalType, ScaleFloatType -from ...core.transforms_interface import DualTransform, to_tuple -from ...core.types import ( - BoxInternalType, - KeypointInternalType, - ScaleFloatType, - ScaleIntType, -) -from ..geometric import functional as FGeometric from . import functional as F __all__ = [ @@ -31,11 +26,15 @@ "BBoxSafeRandomCrop", ] +TWO = 2 +THREE = 3 + class RandomCrop(DualTransform): """Crop a random part of the input. Args: + ---- height: height of the crop. width: width of the crop. p: probability of applying the transform. Default: 1. @@ -45,6 +44,7 @@ class RandomCrop(DualTransform): Image types: uint8, float32 + """ def __init__(self, height: int, width: int, always_apply: bool = False, p: float = 1.0): @@ -72,6 +72,7 @@ class CenterCrop(DualTransform): """Crop the central part of the input. Args: + ---- height: height of the crop. width: width of the crop. p: probability of applying the transform. Default: 1. @@ -83,9 +84,11 @@ class CenterCrop(DualTransform): uint8, float32 Note: + ---- It is recommended to use uint8 images as input. Otherwise the operation will require internal conversion float32 -> uint8 -> float32 that causes worse performance. + """ def __init__(self, height: int, width: int, always_apply: bool = False, p: float = 1.0): @@ -110,6 +113,7 @@ class Crop(DualTransform): """Crop region from image. Args: + ---- x_min: Minimum upper left x coordinate. y_min: Minimum upper left y coordinate. x_max: Maximum lower right x coordinate. @@ -120,6 +124,7 @@ class Crop(DualTransform): Image types: uint8, float32 + """ def __init__( @@ -154,6 +159,7 @@ class CropNonEmptyMaskIfExists(DualTransform): """Crop area with mask if mask is non-empty, else make random crop. Args: + ---- height: vertical size of crop in pixels width: horizontal size of crop in pixels ignore_values (list of int): values to ignore in mask, `0` values are always ignored @@ -167,6 +173,7 @@ class CropNonEmptyMaskIfExists(DualTransform): Image types: uint8, float32 + """ def __init__( @@ -220,15 +227,13 @@ def _preprocess_mask(self, mask: np.ndarray) -> np.ndarray: ignore_values_np = np.array(self.ignore_values) mask = np.where(np.isin(mask, ignore_values_np), 0, mask) - if mask.ndim == 3 and self.ignore_channels is not None: + if mask.ndim == THREE and self.ignore_channels is not None: target_channels = np.array([ch for ch in range(mask.shape[-1]) if ch not in self.ignore_channels]) mask = np.take(mask, target_channels, axis=-1) if self.height > mask_height or self.width > mask_width: raise ValueError( - "Crop size ({},{}) is larger than image ({},{})".format( - self.height, self.width, mask_height, mask_width - ) + f"Crop size ({self.height},{self.width}) is larger than image ({mask_height},{mask_width})" ) return mask @@ -243,12 +248,13 @@ def update_params(self, params: Dict[str, Any], **kwargs: Any) -> Dict[str, Any] for m in masks[1:]: mask |= self._preprocess_mask(m) else: - raise RuntimeError("Can not find mask for CropNonEmptyMaskIfExists") + msg = "Can not find mask for CropNonEmptyMaskIfExists" + raise RuntimeError(msg) mask_height, mask_width = mask.shape[:2] if mask.any(): - mask = mask.sum(axis=-1) if mask.ndim == 3 else mask + mask = mask.sum(axis=-1) if mask.ndim == THREE else mask non_zero_yx = np.argwhere(mask) y, x = random.choice(non_zero_yx) x_min = x - random.randint(0, self.width - 1) @@ -327,6 +333,7 @@ class RandomSizedCrop(_BaseRandomSizedCrop): """Crop a random part of the input and rescale it to some size. Args: + ---- min_max_height ((int, int)): crop size limits. height (int): height after crop and resize. width (int): width after crop and resize. @@ -341,6 +348,7 @@ class RandomSizedCrop(_BaseRandomSizedCrop): Image types: uint8, float32 + """ def __init__( @@ -374,6 +382,7 @@ class RandomResizedCrop(_BaseRandomSizedCrop): """Torchvision's variant of crop a random part of the input and rescale it to some size. Args: + ---- height (int): height after crop and resize. width (int): width after crop and resize. scale ((float, float)): range of size of the origin size cropped @@ -388,6 +397,7 @@ class RandomResizedCrop(_BaseRandomSizedCrop): Image types: uint8, float32 + """ def __init__( @@ -461,6 +471,7 @@ class RandomCropNearBBox(DualTransform): """Crop bbox from image with random shift by x,y coordinates Args: + ---- max_part_shift (float, (float, float)): Max shift in `height` and `width` dimensions relative to `cropping_bbox` dimension. If max_part_shift is a single float, the range will be (max_part_shift, max_part_shift). @@ -475,6 +486,7 @@ class RandomCropNearBBox(DualTransform): uint8, float32 Examples: + -------- >>> aug = Compose([RandomCropNearBBox(max_part_shift=(0.1, 0.5), cropping_box_key='test_box')], >>> bbox_params=BboxParams("pascal_voc")) >>> result = aug(image=image, bboxes=bboxes, test_box=[0, 5, 10, 20]) @@ -540,13 +552,16 @@ def get_transform_init_args_names(self) -> Tuple[str]: class BBoxSafeRandomCrop(DualTransform): """Crop a random part of the input without loss of bboxes. + Args: + ---- erosion_rate: erosion rate applied on input image height before crop. p: probability of applying the transform. Default: 1. Targets: image, mask, bboxes Image types: uint8, float32 + """ def __init__(self, erosion_rate: float = 0.0, always_apply: bool = False, p: float = 1.0): @@ -612,7 +627,9 @@ def get_transform_init_args_names(self) -> Tuple[str, ...]: class RandomSizedBBoxSafeCrop(BBoxSafeRandomCrop): """Crop a random part of the input and rescale it to some size without loss of bboxes. + Args: + ---- height: height after crop and resize. width: width after crop and resize. erosion_rate: erosion rate applied on input image height before crop. @@ -624,6 +641,7 @@ class RandomSizedBBoxSafeCrop(BBoxSafeRandomCrop): image, mask, bboxes Image types: uint8, float32 + """ def __init__( @@ -654,7 +672,7 @@ def apply( return FGeometric.resize(crop, self.height, self.width, interpolation) def get_transform_init_args_names(self) -> Tuple[str, ...]: - return super().get_transform_init_args_names() + ("height", "width", "interpolation") + return (*super().get_transform_init_args_names(), "height", "width", "interpolation") class CropAndPad(DualTransform): @@ -664,10 +682,12 @@ class CropAndPad(DualTransform): This transformation will never crop images below a height or width of ``1``. Note: + ---- This transformation automatically resizes images back to their original size. To deactivate this, add the parameter ``keep_size=False``. Args: + ---- px (int or tuple): The number of pixels to crop (negative values) or pad (positive values) on each side of the image. Either this or the parameter `percent` may @@ -742,6 +762,7 @@ class CropAndPad(DualTransform): Image types: any + """ def __init__( @@ -760,9 +781,11 @@ def __init__( super().__init__(always_apply, p) if px is None and percent is None: - raise ValueError("px and percent are empty!") + msg = "px and percent are empty!" + raise ValueError(msg) if px is not None and percent is not None: - raise ValueError("Only px or percent may be set!") + msg = "Only px or percent may be set!" + raise ValueError(msg) self.px = px self.percent = percent @@ -781,7 +804,7 @@ def apply( img: np.ndarray, crop_params: Sequence[int] = (), pad_params: Sequence[int] = (), - pad_value: Union[int, float] = 0, + pad_value: float = 0, rows: int = 0, cols: int = 0, interpolation: int = cv2.INTER_LINEAR, @@ -918,41 +941,41 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A def _get_px_params(self) -> List[int]: if self.px is None: - raise ValueError("px is not set") + msg = "px is not set" + raise ValueError(msg) if isinstance(self.px, int): params = [self.px] * 4 - elif len(self.px) == 2: + elif len(self.px) == TWO: if self.sample_independently: params = [random.randrange(*self.px) for _ in range(4)] else: px = random.randrange(*self.px) params = [px] * 4 + elif isinstance(self.px[0], int): + params = self.px else: - if isinstance(self.px[0], int): - params = self.px - else: - params = [random.randrange(*i) for i in self.px] + params = [random.randrange(*i) for i in self.px] return params def _get_percent_params(self) -> List[float]: if self.percent is None: - raise ValueError("percent is not set") + msg = "percent is not set" + raise ValueError(msg) if isinstance(self.percent, float): params = [self.percent] * 4 - elif len(self.percent) == 2: + elif len(self.percent) == TWO: if self.sample_independently: params = [random.uniform(*self.percent) for _ in range(4)] else: px = random.uniform(*self.percent) params = [px] * 4 + elif isinstance(self.percent[0], (int, float)): + params = self.percent else: - if isinstance(self.percent[0], (int, float)): - params = self.percent - else: - params = [random.uniform(*i) for i in self.percent] + params = [random.uniform(*i) for i in self.percent] return params # params = [top, right, bottom, left] @@ -961,7 +984,7 @@ def _get_pad_value(pad_value: Union[float, Sequence[float]]) -> Union[int, float if isinstance(pad_value, (int, float)): return pad_value - if len(pad_value) == 2: + if len(pad_value) == TWO: a, b = pad_value if isinstance(a, int) and isinstance(b, int): return random.randint(a, b) @@ -987,6 +1010,7 @@ class RandomCropFromBorders(DualTransform): """Crop bbox from image randomly cut parts from borders without resize at the end Args: + ---- crop_left (float): single float value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut from left side in range [0, crop_left * width) crop_right (float): single float value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut @@ -1002,6 +1026,7 @@ class RandomCropFromBorders(DualTransform): Image types: uint8, float32 + """ def __init__( diff --git a/albumentations/augmentations/domain_adaptation.py b/albumentations/augmentations/domain_adaptation.py index 9b8f2d5e7..b92f088bc 100644 --- a/albumentations/augmentations/domain_adaptation.py +++ b/albumentations/augmentations/domain_adaptation.py @@ -16,9 +16,8 @@ preserve_shape, read_rgb_image, ) - -from ..core.transforms_interface import ImageOnlyTransform, to_tuple -from ..core.types import ScaleFloatType +from albumentations.core.transforms_interface import ImageOnlyTransform, to_tuple +from albumentations.core.types import ScaleFloatType __all__ = [ "HistogramMatching", @@ -29,23 +28,25 @@ "adapt_pixel_distribution", ] +THREE = 3 + @clipped @preserve_shape def fourier_domain_adaptation(img: np.ndarray, target_img: np.ndarray, beta: float) -> np.ndarray: - """ - Fourier Domain Adaptation from https://github.com/YanchaoYang/FDA + """Fourier Domain Adaptation from https://github.com/YanchaoYang/FDA Args: + ---- img: source image target_img: target image for domain adaptation beta: coefficient from source paper Returns: + ------- transformed image """ - img = np.squeeze(img) target_img = np.squeeze(target_img) @@ -93,7 +94,7 @@ def apply_histogram(img: np.ndarray, reference_image: np.ndarray, blend_ratio: f img, reference_image = np.squeeze(img), np.squeeze(reference_image) try: - matched = match_histograms(img, reference_image, channel_axis=2 if len(img.shape) == 3 else None) + matched = match_histograms(img, reference_image, channel_axis=2 if len(img.shape) == THREE else None) except TypeError: matched = match_histograms(img, reference_image, multichannel=True) return cv2.addWeighted( @@ -118,8 +119,7 @@ def adapt_pixel_distribution( class HistogramMatching(ImageOnlyTransform): - """ - Apply histogram matching. It manipulates the pixels of an input image so that its histogram matches + """Apply histogram matching. It manipulates the pixels of an input image so that its histogram matches the histogram of the reference image. If the images have multiple channels, the matching is done independently for each channel, as long as the number of channels is equal in the input image and the reference. @@ -131,6 +131,7 @@ class HistogramMatching(ImageOnlyTransform): https://scikit-image.org/docs/dev/auto_examples/color_exposure/plot_histogram_matching.html Args: + ---- reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default, it expects a sequence of paths to images. blend_ratio: Tuple of min and max blend ratio. Matched image will be blended with original @@ -144,6 +145,7 @@ class HistogramMatching(ImageOnlyTransform): Image types: uint8, uint16, float32 + """ def __init__( @@ -177,16 +179,17 @@ def get_params(self) -> Dict[str, np.ndarray]: def get_transform_init_args_names(self) -> Tuple[str, str, str]: return ("reference_images", "blend_ratio", "read_fn") - def _to_dict(self) -> Dict[str, Any]: - raise NotImplementedError("HistogramMatching can not be serialized.") + def to_dict_private(self) -> Dict[str, Any]: + msg = "HistogramMatching can not be serialized." + raise NotImplementedError(msg) class FDA(ImageOnlyTransform): - """ - Fourier Domain Adaptation from https://github.com/YanchaoYang/FDA + """Fourier Domain Adaptation from https://github.com/YanchaoYang/FDA Simple "style transfer". Args: + ---- reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default, it expects a sequence of paths to images. beta_limit (float or tuple of float): coefficient beta from paper. Recommended less 0.3. @@ -204,6 +207,7 @@ class FDA(ImageOnlyTransform): https://openaccess.thecvf.com/content_CVPR_2020/papers/Yang_FDA_Fourier_Domain_Adaptation_for_Semantic_Segmentation_CVPR_2020_paper.pdf Example: + ------- >>> import numpy as np >>> import albumentations as A >>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8) @@ -248,17 +252,18 @@ def targets_as_params(self) -> List[str]: def get_transform_init_args_names(self) -> Tuple[str, str, str]: return "reference_images", "beta_limit", "read_fn" - def _to_dict(self) -> Dict[str, Any]: - raise NotImplementedError("FDA can not be serialized.") + def to_dict_private(self) -> Dict[str, Any]: + msg = "FDA can not be serialized." + raise NotImplementedError(msg) class PixelDistributionAdaptation(ImageOnlyTransform): - """ - Another naive and quick pixel-level domain adaptation. It fits a simple transform (such as PCA, StandardScaler + """Another naive and quick pixel-level domain adaptation. It fits a simple transform (such as PCA, StandardScaler or MinMaxScaler) on both original and reference image, transforms original image with transform trained on this image and then performs inverse transformation using transform fitted on reference image. Args: + ---- reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default, it expects a sequence of paths to images. blend_ratio (float, float): Tuple of min and max blend ratio. Matched image will be blended with original @@ -275,6 +280,7 @@ class PixelDistributionAdaptation(ImageOnlyTransform): uint8, float32 See also: https://github.com/arsenyinfo/qudida + """ def __init__( @@ -338,5 +344,6 @@ def get_params(self) -> Dict[str, Any]: def get_transform_init_args_names(self) -> Tuple[str, str, str, str]: return "reference_images", "blend_ratio", "read_fn", "transform_type" - def _to_dict(self) -> Dict[str, Any]: - raise NotImplementedError("PixelDistributionAdaptation can not be serialized.") + def to_dict_private(self) -> Dict[str, Any]: + msg = "PixelDistributionAdaptation can not be serialized." + raise NotImplementedError(msg) diff --git a/albumentations/augmentations/dropout/channel_dropout.py b/albumentations/augmentations/dropout/channel_dropout.py index 9861ac860..9bdad7ce4 100644 --- a/albumentations/augmentations/dropout/channel_dropout.py +++ b/albumentations/augmentations/dropout/channel_dropout.py @@ -1,5 +1,5 @@ import random -from typing import Any, Dict, List, Mapping, Tuple, Union +from typing import Any, Dict, List, Mapping, Tuple import numpy as np @@ -9,11 +9,14 @@ __all__ = ["ChannelDropout"] +TWO = 2 + class ChannelDropout(ImageOnlyTransform): """Randomly Drop Channels in the input Image. Args: + ---- channel_drop_range (int, int): range from which we choose the number of channels to drop. fill_value (int, float): pixel value for the dropped channel. p (float): probability of applying the transform. Default: 0.5. @@ -23,12 +26,13 @@ class ChannelDropout(ImageOnlyTransform): Image types: uint8, uint16, unit32, float32 + """ def __init__( self, channel_drop_range: Tuple[int, int] = (1, 1), - fill_value: Union[int, float] = 0, + fill_value: float = 0, always_apply: bool = False, p: float = 0.5, ): @@ -52,11 +56,13 @@ def get_params_dependent_on_targets(self, params: Mapping[str, Any]) -> Dict[str num_channels = img.shape[-1] - if len(img.shape) == 2 or num_channels == 1: - raise NotImplementedError("Images has one channel. ChannelDropout is not defined.") + if len(img.shape) == TWO or num_channels == 1: + msg = "Images has one channel. ChannelDropout is not defined." + raise NotImplementedError(msg) if self.max_channels >= num_channels: - raise ValueError("Can not drop all channels in ChannelDropout.") + msg = "Can not drop all channels in ChannelDropout." + raise ValueError(msg) num_drop_channels = random.randint(self.min_channels, self.max_channels) diff --git a/albumentations/augmentations/dropout/coarse_dropout.py b/albumentations/augmentations/dropout/coarse_dropout.py index 279ccf803..338813458 100644 --- a/albumentations/augmentations/dropout/coarse_dropout.py +++ b/albumentations/augmentations/dropout/coarse_dropout.py @@ -3,8 +3,9 @@ import numpy as np -from ...core.transforms_interface import DualTransform -from ...core.types import KeypointType, ScalarType +from albumentations.core.transforms_interface import DualTransform +from albumentations.core.types import KeypointType, ScalarType + from .functional import cutout, keypoint_in_hole __all__ = ["CoarseDropout"] @@ -14,6 +15,7 @@ class CoarseDropout(DualTransform): """CoarseDropout of the rectangular regions in the image. Args: + ---- max_holes (int): Maximum number of regions to zero out. max_height (int, float): Maximum height of the hole. If float, it is calculated as a fraction of the image height. @@ -42,6 +44,7 @@ class CoarseDropout(DualTransform): | https://arxiv.org/abs/1708.04552 | https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py | https://github.com/aleju/imgaug/blob/master/imgaug/augmenters/arithmetic.py + """ def __init__( @@ -131,19 +134,18 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A hole_height = int(height * random.uniform(self.min_height, self.max_height)) hole_width = int(width * random.uniform(self.min_width, self.max_width)) else: - raise ValueError( - "Min width, max width, \ + msg = "Min width, max width, \ min height and max height \ should all either be ints or floats. \ Got: {} respectively".format( - [ - type(self.min_width), - type(self.max_width), - type(self.min_height), - type(self.max_height), - ] - ) + [ + type(self.min_width), + type(self.max_width), + type(self.min_height), + type(self.max_height), + ] ) + raise ValueError(msg) y1 = random.randint(0, height - hole_height) x1 = random.randint(0, width - hole_width) @@ -160,11 +162,7 @@ def targets_as_params(self) -> List[str]: def apply_to_keypoints( self, keypoints: Sequence[KeypointType], holes: Iterable[Tuple[int, int, int, int]] = (), **params: Any ) -> List[KeypointType]: - filtered_keypoints = [] - for keypoint in keypoints: - if not any(keypoint_in_hole(keypoint, hole) for hole in holes): - filtered_keypoints.append(keypoint) - return filtered_keypoints + return [keypoint for keypoint in keypoints if not any(keypoint_in_hole(keypoint, hole) for hole in holes)] def get_transform_init_args_names(self) -> Tuple[str, ...]: return ( diff --git a/albumentations/augmentations/dropout/functional.py b/albumentations/augmentations/dropout/functional.py index 063f84888..5e05c57b7 100644 --- a/albumentations/augmentations/dropout/functional.py +++ b/albumentations/augmentations/dropout/functional.py @@ -5,13 +5,16 @@ from albumentations.augmentations.utils import preserve_shape from albumentations.core.types import ColorType, KeypointType +TWO = 2 + @preserve_shape def channel_dropout( img: np.ndarray, channels_to_drop: Union[int, Tuple[int, ...], np.ndarray], fill_value: ColorType = 0 ) -> np.ndarray: - if len(img.shape) == 2 or img.shape[2] == 1: - raise NotImplementedError("Only one channel. ChannelDropout is not defined.") + if len(img.shape) == TWO or img.shape[2] == 1: + msg = "Only one channel. ChannelDropout is not defined." + raise NotImplementedError(msg) img = img.copy() img[..., channels_to_drop] = fill_value diff --git a/albumentations/augmentations/dropout/grid_dropout.py b/albumentations/augmentations/dropout/grid_dropout.py index 431ffef37..00ffdd664 100644 --- a/albumentations/augmentations/dropout/grid_dropout.py +++ b/albumentations/augmentations/dropout/grid_dropout.py @@ -3,17 +3,21 @@ import numpy as np -from ...core.transforms_interface import DualTransform -from ...core.types import ScalarType +from albumentations.core.transforms_interface import DualTransform +from albumentations.core.types import ScalarType + from . import functional as F __all__ = ["GridDropout"] +TWO = 2 + class GridDropout(DualTransform): """GridDropout, drops out rectangular regions of an image and the corresponding mask in a grid fashion. Args: + ---- ratio: the ratio of the mask holes to the unit_size (same for horizontal and vertical directions). Must be between 0 and 1. Default: 0.5. unit_size_min (int): minimum size of the grid unit. Must be between 2 and the image shorter edge. @@ -41,6 +45,7 @@ class GridDropout(DualTransform): uint8, float32 References: + ---------- https://arxiv.org/abs/2001.04086 """ @@ -72,7 +77,8 @@ def __init__( self.fill_value = fill_value self.mask_fill_value = mask_fill_value if not 0 < self.ratio <= 1: - raise ValueError("ratio must be between 0 and 1.") + msg = "ratio must be between 0 and 1." + raise ValueError(msg) def apply(self, img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]] = (), **params: Any) -> np.ndarray: return F.cutout(img, holes, self.fill_value) @@ -88,46 +94,81 @@ def apply_to_mask( def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]: img = params["image"] height, width = img.shape[:2] - # set grid using unit size limits - if self.unit_size_min and self.unit_size_max: - if not 2 <= self.unit_size_min <= self.unit_size_max: - raise ValueError("Max unit size should be >= min size, both at least 2 pixels.") + unit_width, unit_height = self._calculate_unit_dimensions(width, height) + hole_width, hole_height = self._calculate_hole_dimensions(unit_width, unit_height) + shift_x, shift_y = self._calculate_shifts(unit_width, unit_height, hole_width, hole_height) + holes = self._generate_holes(width, height, unit_width, unit_height, hole_width, hole_height, shift_x, shift_y) + return {"holes": holes} + + def _calculate_unit_dimensions(self, width: int, height: int) -> Tuple[int, int]: + """Calculates the dimensions of the grid units.""" + if self.unit_size_min is not None and self.unit_size_max is not None: + self._validate_unit_sizes(height, width) + unit_size = random.randint(self.unit_size_min, self.unit_size_max) + return unit_size, unit_size + + return self._calculate_dimensions_based_on_holes(width, height) + + def _validate_unit_sizes(self, height: int, width: int) -> None: + """Validates the minimum and maximum unit sizes.""" + if self.unit_size_min is not None and self.unit_size_max is not None: + if not TWO <= self.unit_size_min <= self.unit_size_max: + msg = "Max unit size should be >= min size, both at least 2 pixels." + raise ValueError(msg) if self.unit_size_max > min(height, width): - raise ValueError("Grid size limits must be within the shortest image edge.") - unit_width = random.randint(self.unit_size_min, self.unit_size_max + 1) - unit_height = unit_width + msg = "Grid size limits must be within the shortest image edge." + raise ValueError(msg) else: - # set grid using holes numbers - if self.holes_number_x is None: - unit_width = max(2, width // 10) - else: - if not 1 <= self.holes_number_x <= width // 2: - raise ValueError("The hole_number_x must be between 1 and image width//2.") - unit_width = width // self.holes_number_x - if self.holes_number_y is None: - unit_height = max(min(unit_width, height), 2) - else: - if not 1 <= self.holes_number_y <= height // 2: - raise ValueError("The hole_number_y must be between 1 and image height//2.") - unit_height = height // self.holes_number_y - + msg = "unit_size_min and unit_size_max must not be None." + raise ValueError(msg) + + def _calculate_dimensions_based_on_holes(self, width: int, height: int) -> Tuple[int, int]: + """Calculates dimensions based on the number of holes specified.""" + unit_width = self._calculate_dimension(width, self.holes_number_x, 10) + unit_height = self._calculate_dimension(height, self.holes_number_y, unit_width) + return unit_width, unit_height + + def _calculate_dimension(self, dimension: int, holes_number: Optional[int], fallback: int) -> int: + """Helper function to calculate unit width or height.""" + if holes_number is None: + return max(2, dimension // fallback) + + if not 1 <= holes_number <= dimension // 2: + raise ValueError(f"The number of holes must be between 1 and {dimension // 2}.") + return dimension // holes_number + + def _calculate_hole_dimensions(self, unit_width: int, unit_height: int) -> Tuple[int, int]: + """Calculates the dimensions of the holes to be dropped out.""" hole_width = int(unit_width * self.ratio) hole_height = int(unit_height * self.ratio) - # min 1 pixel and max unit length - 1 hole_width = min(max(hole_width, 1), unit_width - 1) hole_height = min(max(hole_height, 1), unit_height - 1) - # set offset of the grid - if self.shift_x is None: - shift_x = 0 - else: - shift_x = min(max(0, self.shift_x), unit_width - hole_width) - if self.shift_y is None: - shift_y = 0 - else: - shift_y = min(max(0, self.shift_y), unit_height - hole_height) + return hole_width, hole_height + + def _calculate_shifts( + self, unit_width: int, unit_height: int, hole_width: int, hole_height: int + ) -> Tuple[int, int]: + """Calculates the shifts for the grid start.""" if self.random_offset: shift_x = random.randint(0, unit_width - hole_width) shift_y = random.randint(0, unit_height - hole_height) + else: + shift_x = 0 if self.shift_x is None else min(max(0, self.shift_x), unit_width - hole_width) + shift_y = 0 if self.shift_y is None else min(max(0, self.shift_y), unit_height - hole_height) + return shift_x, shift_y + + def _generate_holes( + self, + width: int, + height: int, + unit_width: int, + unit_height: int, + hole_width: int, + hole_height: int, + shift_x: int, + shift_y: int, + ) -> List[Tuple[int, int, int, int]]: + """Generates the list of holes to be dropped out.""" holes = [] for i in range(width // unit_width + 1): for j in range(height // unit_height + 1): @@ -136,8 +177,7 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A x2 = min(x1 + hole_width, width) y2 = min(y1 + hole_height, height) holes.append((x1, y1, x2, y2)) - - return {"holes": holes} + return holes @property def targets_as_params(self) -> List[str]: diff --git a/albumentations/augmentations/dropout/mask_dropout.py b/albumentations/augmentations/dropout/mask_dropout.py index 461fe4b45..3fd96316c 100644 --- a/albumentations/augmentations/dropout/mask_dropout.py +++ b/albumentations/augmentations/dropout/mask_dropout.py @@ -5,15 +5,14 @@ import numpy as np from skimage.measure import label -from ...core.transforms_interface import DualTransform, to_tuple -from ...core.types import ScalarType +from albumentations.core.transforms_interface import DualTransform, to_tuple +from albumentations.core.types import ScalarType __all__ = ["MaskDropout"] class MaskDropout(DualTransform): - """ - Image & mask augmentation that zero out mask and image regions corresponding + """Image & mask augmentation that zero out mask and image regions corresponding to randomly chosen object instance from mask. Mask must be single-channel image, zero values treated as background. @@ -22,6 +21,7 @@ class MaskDropout(DualTransform): Inspired by https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254 Args: + ---- max_objects: Maximum number of labels that can be zeroed out. Can be tuple, in this case it's [min, max] image_fill_value: Fill value to use when filling image. Can be 'inpaint' to apply inpaining (works only for 3-chahnel images) @@ -32,12 +32,13 @@ class MaskDropout(DualTransform): Image types: uint8, float32 + """ def __init__( self, max_objects: int = 1, - image_fill_value: Union[int, float, str] = 0, + image_fill_value: Union[float, str] = 0, mask_fill_value: ScalarType = 0, always_apply: bool = False, p: float = 0.5, diff --git a/albumentations/augmentations/dropout/xy_masking.py b/albumentations/augmentations/dropout/xy_masking.py index f4f61e598..544ae950e 100644 --- a/albumentations/augmentations/dropout/xy_masking.py +++ b/albumentations/augmentations/dropout/xy_masking.py @@ -1,19 +1,18 @@ import random -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np +from albumentations.core.transforms_interface import DualTransform from albumentations.core.types import ColorType, KeypointType, ScaleIntType -from ...core.transforms_interface import DualTransform, to_tuple from .functional import cutout, keypoint_in_hole __all__ = ["XYMasking"] class XYMasking(DualTransform): - """ - Applies masking strips to an image, either horizontally (X axis) or vertically (Y axis), + """Applies masking strips to an image, either horizontally (X axis) or vertically (Y axis), simulating occlusions. This transform is useful for training models to recognize images with varied visibility conditions. It's particularly effective for spectrogram images, allowing spectral and frequency masking to improve model robustness. @@ -22,6 +21,7 @@ class XYMasking(DualTransform): maximum size along each axis. Args: + ---- num_masks_x (Union[int, Tuple[int, int]]): Number or range of horizontal regions to mask. Defaults to 0. num_masks_y (Union[int, Tuple[int, int]]): Number or range of vertical regions to mask. Defaults to 0. mask_x_length ([Union[int, Tuple[int, int]]): Specifies the length of the masks along @@ -46,6 +46,7 @@ class XYMasking(DualTransform): uint8, float32 Note: Either `max_x_length` or `max_y_length` or both must be defined. + """ def __init__( @@ -67,19 +68,23 @@ def __init__( and isinstance(mask_y_length, (int, float)) and mask_y_length <= 0 ): - raise ValueError("At least one of `mask_x_length` or `mask_y_length` Should be a positive number.") + msg = "At least one of `mask_x_length` or `mask_y_length` Should be a positive number." + raise ValueError(msg) if isinstance(num_masks_x, int) and num_masks_x <= 0 and isinstance(num_masks_y, int) and num_masks_y <= 0: - raise ValueError( + msg = ( "At least one of `num_masks_x` or `num_masks_y` " "should be a positive number or tuple of two positive numbers." ) + raise ValueError(msg) if isinstance(num_masks_x, (tuple, list)) and min(num_masks_x) <= 0: - raise ValueError("All values in `num_masks_x` should be non negative integers.") + msg = "All values in `num_masks_x` should be non negative integers." + raise ValueError(msg) if isinstance(num_masks_y, (tuple, list)) and min(num_masks_y) <= 0: - raise ValueError("All values in `num_masks_y` should be non negative integers.") + msg = "All values in `num_masks_y` should be non negative integers." + raise ValueError(msg) self.num_masks_x = num_masks_x self.num_masks_y = num_masks_y @@ -112,14 +117,15 @@ def apply_to_mask( def validate_mask_length( self, mask_length: Optional[ScaleIntType], dimension_size: int, dimension_name: str ) -> None: - """ - Validate the mask length against the corresponding image dimension size. + """Validate the mask length against the corresponding image dimension size. Args: + ---- mask_length (Optional[Union[int, Tuple[int, int]]]): The length of the mask to be validated. dimension_size (int): The size of the image dimension (width or height) against which to validate the mask length. dimension_name (str): The name of the dimension ('width' or 'height') for error messaging. + """ if mask_length is not None: if isinstance(mask_length, tuple): @@ -163,10 +169,7 @@ def generate_masks( masks = [] - if isinstance(num_masks, int): - num_masks_integer = num_masks - else: - num_masks_integer = random.randint(num_masks[0], num_masks[1]) + num_masks_integer = num_masks if isinstance(num_masks, int) else random.randint(num_masks[0], num_masks[1]) for _ in range(num_masks_integer): length = self.generate_mask_size(max_length) @@ -194,11 +197,11 @@ def apply_to_keypoints( masks_y: List[Tuple[int, int, int, int]], **params: Any, ) -> List[KeypointType]: - filtered_keypoints = [] - for keypoint in keypoints: - if not any(keypoint_in_hole(keypoint, hole) for hole in masks_x + masks_y): - filtered_keypoints.append(keypoint) - return filtered_keypoints + return [ + keypoint + for keypoint in keypoints + if not any(keypoint_in_hole(keypoint, hole) for hole in masks_x + masks_y) + ] def get_transform_init_args_names(self) -> Tuple[str, ...]: return ( diff --git a/albumentations/augmentations/functional.py b/albumentations/augmentations/functional.py index a0c919dfe..5797be624 100644 --- a/albumentations/augmentations/functional.py +++ b/albumentations/augmentations/functional.py @@ -18,8 +18,7 @@ preserve_channel_dim, preserve_shape, ) - -from ..core.types import ColorType, ScalarType +from albumentations.core.types import ColorType, ScalarType __all__ = [ "add_fog", @@ -64,13 +63,19 @@ "MAX_VALUES_BY_DTYPE", ] +TWO = 2 +THREE = 3 +FOUR = 4 +EIGHT = 8 +THREE_SIXTY = 360 + def normalize_cv2(img: np.ndarray, mean: np.ndarray, denominator: np.ndarray) -> np.ndarray: - if mean.shape and len(mean) != 4 and mean.shape != img.shape: + if mean.shape and len(mean) != FOUR and mean.shape != img.shape: mean = np.array(mean.tolist() + [0] * (4 - len(mean)), dtype=np.float64) if not denominator.shape: denominator = np.array([denominator.tolist()] * 4, dtype=np.float64) - elif len(denominator) != 4 and denominator.shape != img.shape: + elif len(denominator) != FOUR and denominator.shape != img.shape: denominator = np.array(denominator.tolist() + [1] * (4 - len(denominator)), dtype=np.float64) img = np.ascontiguousarray(img.astype("float32")) @@ -95,7 +100,7 @@ def normalize(img: np.ndarray, mean: np.ndarray, std: np.ndarray, max_pixel_valu denominator = np.reciprocal(std, dtype=np.float32) - if img.ndim == 3 and img.shape[-1] == 3: + if img.ndim == THREE and img.shape[-1] == THREE: return normalize_cv2(img, mean, denominator) return normalize_numpy(img, mean, denominator) @@ -123,8 +128,7 @@ def _shift_hsv_uint8( val = cv2.LUT(val, lut_val) img = cv2.merge((hue, sat, val)).astype(dtype) - img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) - return img + return cv2.cvtColor(img, cv2.COLOR_HSV2RGB) def _shift_hsv_non_uint8( @@ -179,10 +183,12 @@ def solarize(img: np.ndarray, threshold: int = 128) -> np.ndarray: """Invert all pixel values above a threshold. Args: + ---- img: The image to solarize. threshold: All pixels above this grayscale level are inverted. Returns: + ------- Solarized image. """ @@ -210,24 +216,28 @@ def posterize(img: np.ndarray, bits: int) -> np.ndarray: """Reduce the number of bits for each color channel. Args: + ---- img: image to posterize. bits: number of high bits. Must be in range [0, 8] Returns: + ------- Image with reduced color channels. """ bits_array = np.uint8(bits) if img.dtype != np.uint8: - raise TypeError("Image must have uint8 channel type") - if np.any((bits_array < 0) | (bits_array > 8)): - raise ValueError("bits must be in range [0, 8]") + msg = "Image must have uint8 channel type" + raise TypeError(msg) + if np.any((bits_array < 0) | (bits_array > EIGHT)): + msg = "bits must be in range [0, 8]" + raise ValueError(msg) if not bits_array.shape or len(bits_array) == 1: if bits_array == 0: return np.zeros_like(img) - if bits_array == 8: + if bits_array == EIGHT: return img.copy() lut = np.arange(0, 256, dtype=np.uint8) @@ -237,13 +247,14 @@ def posterize(img: np.ndarray, bits: int) -> np.ndarray: return cv2.LUT(img, lut) if not is_rgb_image(img): - raise TypeError("If bits is iterable image must be RGB") + msg = "If bits is iterable image must be RGB" + raise TypeError(msg) result_img = np.empty_like(img) for i, channel_bits in enumerate(bits_array): if channel_bits == 0: result_img[..., i] = np.zeros_like(img[..., i]) - elif channel_bits == 8: + elif channel_bits == EIGHT: result_img[..., i] = img[..., i].copy() else: lut = np.arange(0, 256, dtype=np.uint8) @@ -295,72 +306,62 @@ def _equalize_cv(img: np.ndarray, mask: Optional[np.ndarray] = None) -> np.ndarr _sum = 0 lut = np.zeros(256, dtype=np.uint8) - i += 1 - for i in range(i, len(histogram)): - _sum += histogram[i] - lut[i] = clip(round(_sum * scale), np.dtype("uint8"), 255) - return cv2.LUT(img, lut) - - -@preserve_channel_dim -def equalize( - img: np.ndarray, mask: Optional[np.ndarray] = None, mode: str = "cv", by_channels: bool = True -) -> np.ndarray: - """Equalize the image histogram. + for idx in range(i + 1, len(histogram)): + _sum += histogram[idx] + lut[idx] = clip(round(_sum * scale), np.dtype("uint8"), 255) - Args: - img: RGB or grayscale image. - mask: An optional mask. If given, only the pixels selected by - the mask are included in the analysis. Maybe 1 channel or 3 channel array. - mode: {'cv', 'pil'}. Use OpenCV or Pillow equalization method. - by_channels: If True, use equalization by channels separately, - else convert image to YCbCr representation and use equalization by `Y` channel. + return cv2.LUT(img, lut) - Returns: - Equalized image. - """ +def _check_preconditions(img: np.ndarray, mask: Optional[np.ndarray], mode: str, by_channels: bool) -> None: if img.dtype != np.uint8: - raise TypeError("Image must have uint8 channel type") + msg = "Image must have uint8 channel type" + raise TypeError(msg) modes = ["cv", "pil"] - if mode not in modes: - raise ValueError("Unsupported equalization mode. Supports: {}. " "Got: {}".format(modes, mode)) + raise ValueError(f"Unsupported equalization mode. Supports: {modes}. Got: {mode}") + if mask is not None: if is_rgb_image(mask) and is_grayscale_image(img): - raise ValueError("Wrong mask shape. Image shape: {}. " "Mask shape: {}".format(img.shape, mask.shape)) + raise ValueError(f"Wrong mask shape. Image shape: {img.shape}. Mask shape: {mask.shape}") if not by_channels and not is_grayscale_image(mask): - raise ValueError( - "When by_channels=False only 1-channel mask supports. " "Mask shape: {}".format(mask.shape) - ) + msg = f"When by_channels=False only 1-channel mask supports. Mask shape: {mask.shape}" + raise ValueError(msg) - if mode == "pil": - function = _equalize_pil - else: - function = _equalize_cv - if mask is not None: - mask = mask.astype(np.uint8) +def _handle_mask( + mask: Optional[np.ndarray], img: np.ndarray, by_channels: bool, i: Optional[int] = None +) -> Optional[np.ndarray]: + if mask is None: + return None + mask = mask.astype(np.uint8) + if is_grayscale_image(mask) or i is None: + return mask + + return mask[..., i] + + +@preserve_channel_dim +def equalize( + img: np.ndarray, mask: Optional[np.ndarray] = None, mode: str = "cv", by_channels: bool = True +) -> np.ndarray: + _check_preconditions(img, mask, mode, by_channels) + + function = _equalize_pil if mode == "pil" else _equalize_cv if is_grayscale_image(img): - return function(img, mask) + return function(img, _handle_mask(mask, img, by_channels)) if not by_channels: result_img = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb) - result_img[..., 0] = function(result_img[..., 0], mask) + result_img[..., 0] = function(result_img[..., 0], _handle_mask(mask, img, by_channels)) return cv2.cvtColor(result_img, cv2.COLOR_YCrCb2RGB) result_img = np.empty_like(img) for i in range(3): - if mask is None: - _mask = None - elif is_grayscale_image(mask): - _mask = mask - else: - _mask = mask[..., i] - + _mask = _handle_mask(mask, img, by_channels, i) result_img[..., i] = function(img[..., i], _mask) return result_img @@ -371,18 +372,22 @@ def move_tone_curve(img: np.ndarray, low_y: float, high_y: float) -> np.ndarray: """Rescales the relationship between bright and dark areas of the image by manipulating its tone curve. Args: + ---- img: RGB or grayscale image. low_y: y-position of a Bezier control point used to adjust the tone curve, must be in range [0, 1] high_y: y-position of a Bezier control point used to adjust image tone curve, must be in range [0, 1] + """ input_dtype = img.dtype if not 0 <= low_y <= 1: - raise ValueError("low_shift must be in range [0, 1]") + msg = "low_shift must be in range [0, 1]" + raise ValueError(msg) if not 0 <= high_y <= 1: - raise ValueError("high_shift must be in range [0, 1]") + msg = "high_shift must be in range [0, 1]" + raise ValueError(msg) if input_dtype != np.uint8: raise ValueError(f"Unsupported image type {input_dtype}") @@ -454,11 +459,12 @@ def linear_transformation_rgb(img: np.ndarray, transformation_matrix: np.ndarray @preserve_channel_dim def clahe(img: np.ndarray, clip_limit: float = 2.0, tile_grid_size: Tuple[int, int] = (8, 8)) -> np.ndarray: if img.dtype != np.uint8: - raise TypeError("clahe supports only uint8 inputs") + msg = "clahe supports only uint8 inputs" + raise TypeError(msg) clahe_mat = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size) - if len(img.shape) == 2 or img.shape[2] == 1: + if len(img.shape) == TWO or img.shape[2] == 1: return clahe_mat.apply(img) img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) @@ -488,7 +494,7 @@ def image_compression(img: np.ndarray, quality: int, image_type: np.dtype) -> np warn( "Image compression augmentation " "is most effective with uint8 inputs, " - "{} is used as input.".format(input_dtype), + f"{input_dtype} is used as input.", UserWarning, ) img = from_float(img, dtype=np.dtype("uint8")) @@ -511,11 +517,13 @@ def add_snow(img: np.ndarray, snow_point: float, brightness_coeff: float) -> np. From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library Args: + ---- img: Image. snow_point: Number of show points. brightness_coeff: Brightness coefficient. Returns: + ------- Image. """ @@ -533,21 +541,21 @@ def add_snow(img: np.ndarray, snow_point: float, brightness_coeff: float) -> np. elif input_dtype not in (np.uint8, np.float32): raise ValueError(f"Unexpected dtype {input_dtype} for RandomSnow augmentation") - image_HLS = cv2.cvtColor(img, cv2.COLOR_RGB2HLS) - image_HLS = np.array(image_HLS, dtype=np.float32) + image_hls = cv2.cvtColor(img, cv2.COLOR_RGB2HLS) + image_hls = np.array(image_hls, dtype=np.float32) - image_HLS[:, :, 1][image_HLS[:, :, 1] < snow_point] *= brightness_coeff + image_hls[:, :, 1][image_hls[:, :, 1] < snow_point] *= brightness_coeff - image_HLS[:, :, 1] = clip(image_HLS[:, :, 1], np.uint8, 255) + image_hls[:, :, 1] = clip(image_hls[:, :, 1], np.uint8, 255) - image_HLS = np.array(image_HLS, dtype=np.uint8) + image_hls = np.array(image_hls, dtype=np.uint8) - image_RGB = cv2.cvtColor(image_HLS, cv2.COLOR_HLS2RGB) + image_rgb = cv2.cvtColor(image_hls, cv2.COLOR_HLS2RGB) if needs_float: - image_RGB = to_float(image_RGB, max_value=255) + image_rgb = to_float(image_rgb, max_value=255) - return image_RGB + return image_rgb @preserve_shape @@ -561,11 +569,10 @@ def add_rain( brightness_coefficient: float, rain_drops: List[Tuple[int, int]], ) -> np.ndarray: - """ - - From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library + """From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library Args: + ---- img: Image. slant: drop_length: @@ -576,6 +583,7 @@ def add_rain( rain_drops: Returns: + ------- Image. """ @@ -611,7 +619,7 @@ def add_rain( image_rgb = cv2.cvtColor(image_hsv.astype(np.uint8), cv2.COLOR_HSV2RGB) if needs_float: - image_rgb = to_float(image_rgb, max_value=255) + return to_float(image_rgb, max_value=255) return image_rgb @@ -623,12 +631,14 @@ def add_fog(img: np.ndarray, fog_coef: float, alpha_coef: float, haze_list: List From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library Args: + ---- img: Image. fog_coef: Fog coefficient. alpha_coef: Alpha coefficient. haze_list: Returns: + ------- Image. """ @@ -681,6 +691,7 @@ def add_sun_flare( From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library Args: + ---- img (numpy.ndarray): flare_center_x (float): flare_center_y (float): @@ -689,6 +700,7 @@ def add_sun_flare( circles (list): Returns: + ------- numpy.ndarray: """ @@ -738,10 +750,12 @@ def add_shadow(img: np.ndarray, vertices_list: List[List[Tuple[int, int]]]) -> n From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library Args: + ---- img (numpy.ndarray): vertices_list (list): Returns: + ------- numpy.ndarray: """ @@ -763,7 +777,7 @@ def add_shadow(img: np.ndarray, vertices_list: List[List[Tuple[int, int]]]) -> n cv2.fillPoly(mask, vertices, 255) # if red channel is hot, image's "Lightness" channel's brightness is lowered - red_max_value_ind = mask[:, :, 0] == 255 + red_max_value_ind = mask[:, :, 0] == MAX_VALUES_BY_DTYPE[np.dtype("uint8")] image_hls[:, :, 1][red_max_value_ind] = image_hls[:, :, 1][red_max_value_ind] * 0.5 image_rgb = cv2.cvtColor(image_hls, cv2.COLOR_HLS2RGB) @@ -782,12 +796,15 @@ def add_gravel(img: np.ndarray, gravels: List[Any]) -> np.ndarray: From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library Args: + ---- img (numpy.ndarray): image to add gravel to gravels (list): list of gravel parameters. (float, float, float, float): (top-left x, top-left y, bottom-right x, bottom right y) Returns: + ------- numpy.ndarray: + """ non_rgb_warning(img) input_dtype = img.dtype @@ -894,10 +911,10 @@ def iso_noise( random_state: Optional[int] = None, **kwargs: Any, ) -> np.ndarray: - """ - Apply poisson noise to image to simulate camera sensor noise. + """Apply poisson noise to image to simulate camera sensor noise. Args: + ---- image (numpy.ndarray): Input image, currently, only RGB, uint8 images are supported. color_shift (float): intensity (float): Multiplication factor for noise values. Values of ~0.5 are produce noticeable, @@ -906,13 +923,16 @@ def iso_noise( **kwargs: Returns: + ------- numpy.ndarray: Noised image """ if image.dtype != np.uint8: - raise TypeError("Image must have uint8 channel type") + msg = "Image must have uint8 channel type" + raise TypeError(msg) if not is_rgb_image(image): - raise TypeError("Image must be RGB") + msg = "Image must be RGB" + raise TypeError(msg) one_over_255 = float(1.0 / 255.0) image = np.multiply(image, one_over_255, dtype=np.float32) @@ -924,8 +944,7 @@ def iso_noise( hue = hls[..., 0] hue += color_noise - hue[hue < 0] += 360 - hue[hue > 360] -= 360 + hue %= 360 luminance = hls[..., 1] luminance += (luminance_noise / 255) * (1.0 - luminance) @@ -963,25 +982,22 @@ def downscale( def to_float(img: np.ndarray, max_value: Optional[float] = None) -> np.ndarray: if max_value is None: - try: - max_value = MAX_VALUES_BY_DTYPE[img.dtype] - except KeyError: - raise RuntimeError( - "Can't infer the maximum value for dtype {}. You need to specify the maximum value manually by " - "passing the max_value argument".format(img.dtype) - ) - return img.astype("float32") / max_value + if img.dtype not in MAX_VALUES_BY_DTYPE: + raise RuntimeError(f"Unsupported dtype {img.dtype}. Specify 'max_value' manually.") + max_value = MAX_VALUES_BY_DTYPE[img.dtype] + + return (img / max_value).astype(np.float32) def from_float(img: np.ndarray, dtype: np.dtype, max_value: Optional[float] = None) -> np.ndarray: if max_value is None: - try: - max_value = MAX_VALUES_BY_DTYPE[dtype] - except KeyError: - raise RuntimeError( - f"Can't infer the maximum value for dtype {dtype}. You need to specify the maximum value manually by " - f"passing the max_value argument" + if dtype not in MAX_VALUES_BY_DTYPE: + msg = ( + f"Can't infer the maximum value for dtype {dtype}. " + "You need to specify the maximum value manually by passing the max_value argument." ) + raise RuntimeError(msg) + max_value = MAX_VALUES_BY_DTYPE[dtype] return (img * max_value).astype(dtype) @@ -990,10 +1006,10 @@ def noop(input_obj: Any, **params: Any) -> Any: def swap_tiles_on_image(image: np.ndarray, tiles: np.ndarray) -> np.ndarray: - """ - Swap tiles on image. + """Swap tiles on image. Args: + ---- image: Input image. tiles: array of tuples( current_left_up_corner_row, current_left_up_corner_col, @@ -1001,6 +1017,7 @@ def swap_tiles_on_image(image: np.ndarray, tiles: np.ndarray) -> np.ndarray: height_tile, width_tile) Returns: + ------- np.ndarray: Output image. """ @@ -1050,12 +1067,13 @@ def _multiply_non_uint8(img: np.ndarray, multiplier: np.ndarray) -> np.ndarray: def multiply(img: np.ndarray, multiplier: np.ndarray) -> np.ndarray: - """ - Args: + """Args: + ---- img: Image. multiplier: Multiplier coefficient. - Returns: + Returns + ------- Image multiplied by `multiplier` coefficient. """ @@ -1072,9 +1090,11 @@ def bbox_from_mask(mask: np.ndarray) -> Tuple[int, int, int, int]: """Create bounding box from binary mask (fast version) Args: + ---- mask (numpy.ndarray): binary mask. Returns: + ------- tuple: A bounding box tuple `(x_min, y_min, x_max, y_max)`. """ @@ -1091,14 +1111,15 @@ def mask_from_bbox(img: np.ndarray, bbox: Tuple[int, int, int, int]) -> np.ndarr """Create binary mask from bounding box Args: + ---- img: input image bbox: A bounding box tuple `(x_min, y_min, x_max, y_max)` Returns: + ------- mask: binary mask """ - mask = np.zeros(img.shape[:2], dtype=np.uint8) x_min, y_min, x_max, y_max = bbox mask[y_min:y_max, x_min:x_max] = 1 @@ -1110,16 +1131,19 @@ def fancy_pca(img: np.ndarray, alpha: float = 0.1) -> np.ndarray: http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf Args: + ---- img: numpy array with (h, w, rgb) shape, as ints between 0-255 alpha: how much to perturb/scale the eigen vecs and vals the paper used std=0.1 Returns: + ------- numpy image-like array as uint8 range(0, 255) """ if not is_rgb_image(img) or img.dtype != np.uint8: - raise TypeError("Image must be RGB image in uint8 format.") + msg = "Image must be RGB image in uint8 format." + raise TypeError(msg) orig_img = img.astype(float).copy() @@ -1143,33 +1167,31 @@ def fancy_pca(img: np.ndarray, alpha: float = 0.1) -> np.ndarray: eig_vals[::-1].sort() eig_vecs = eig_vecs[:, sort_perm] - # get [p1, p2, p3] + # > get [p1, p2, p3] m1 = np.column_stack(eig_vecs) # get 3x1 matrix of eigen values multiplied by random variable draw from normal # distribution with mean of 0 and standard deviation of 0.1 m2 = np.zeros((3, 1)) # according to the paper alpha should only be draw once per augmentation (not once per channel) - # alpha = np.random.normal(0, alpha_std) + # > alpha = np.random.normal(0, alpha_std) # broad cast to speed things up m2[:, 0] = alpha * eig_vals[:] # this is the vector that we're going to add to each pixel in a moment - add_vect = np.matrix(m1) * np.matrix(m2) + add_vect = np.array(m1) @ np.array(m2) for idx in range(3): # RGB orig_img[..., idx] += add_vect[idx] * 255 # for image processing it was found that working with float 0.0 to 1.0 # was easier than integers between 0-255 - # orig_img /= 255.0 + # > orig_img /= 255.0 orig_img = np.clip(orig_img, 0.0, 255.0) - # orig_img *= 255 - orig_img = orig_img.astype(np.uint8) - - return orig_img + # > orig_img *= 255 + return orig_img.astype(np.uint8) def _adjust_brightness_torchvision_uint8(img: np.ndarray, factor: float) -> np.ndarray: @@ -1182,7 +1204,7 @@ def _adjust_brightness_torchvision_uint8(img: np.ndarray, factor: float) -> np.n def adjust_brightness_torchvision(img: np.ndarray, factor: np.ndarray) -> np: if factor == 0: return np.zeros_like(img) - elif factor == 1: + if factor == 1: return img if img.dtype == np.uint8: @@ -1204,10 +1226,7 @@ def adjust_contrast_torchvision(img: np.ndarray, factor: float) -> np.ndarray: if factor == 1: return img - if is_grayscale_image(img): - mean = img.mean() - else: - mean = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).mean() + mean = img.mean() if is_grayscale_image(img) else cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).mean() if factor == 0: if img.dtype != np.float32: @@ -1230,11 +1249,10 @@ def adjust_saturation_torchvision(img: np.ndarray, factor: float, gamma: float = return img if is_grayscale_image(img): - gray = img - return gray - else: - gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) - gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) + return img + + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) if factor == 0: return gray @@ -1290,13 +1308,13 @@ def superpixels( image = resize_fn(image) segments = skimage.segmentation.slic( - image, n_segments=n_segments, compactness=10, channel_axis=-1 if image.ndim > 2 else None + image, n_segments=n_segments, compactness=10, channel_axis=-1 if image.ndim > TWO else None ) min_value = 0 max_value = MAX_VALUES_BY_DTYPE[image.dtype] image = np.copy(image) - if image.ndim == 2: + if image.ndim == TWO: image = image.reshape(*image.shape, 1) nb_channels = image.shape[2] for c in range(nb_channels): @@ -1388,10 +1406,19 @@ def spatter( img = img.astype(np.float32) * (1 / coef) if mode == "rain": - assert rain is not None - img = img + rain + if rain is None: + msg = "Rain spatter requires rain mask" + raise ValueError(msg) + + img += rain elif mode == "mud": - assert non_mud is not None and mud is not None + if mud is None: + msg = "Mud spatter requires mud mask" + raise ValueError(msg) + if non_mud is None: + msg = "Mud spatter requires non_mud mask" + raise ValueError(msg) + img = img * non_mud + mud else: raise ValueError("Unsupported spatter mode: " + str(mode)) diff --git a/albumentations/augmentations/geometric/functional.py b/albumentations/augmentations/geometric/functional.py index 6a32d43b5..9f42ccdd1 100644 --- a/albumentations/augmentations/geometric/functional.py +++ b/albumentations/augmentations/geometric/functional.py @@ -6,6 +6,7 @@ import skimage.transform from scipy.ndimage import gaussian_filter +from albumentations import random_utils from albumentations.augmentations.utils import ( _maybe_process_in_chunks, angle_2pi_range, @@ -13,11 +14,9 @@ preserve_channel_dim, preserve_shape, ) - -from ... import random_utils -from ...core.bbox_utils import denormalize_bbox, normalize_bbox -from ...core.transforms_interface import FillValueType -from ...core.types import BoxInternalType, ImageColorType, KeypointInternalType +from albumentations.core.bbox_utils import denormalize_bbox, normalize_bbox +from albumentations.core.transforms_interface import FillValueType +from albumentations.core.types import BoxInternalType, ImageColorType, KeypointInternalType __all__ = [ "optical_distortion", @@ -42,7 +41,7 @@ "smallest_max_size", "perspective", "perspective_bbox", - "rotation2DMatrixToEulerAngles", + "rotation2d_matrix_to_euler_angles", "perspective_keypoint", "_is_identity_matrix", "warp_affine", @@ -71,28 +70,34 @@ "denormalize_bbox", ] +TWO = 2 +THREE = 3 + def bbox_rot90(bbox: BoxInternalType, factor: int, rows: int, cols: int) -> BoxInternalType: """Rotates a bounding box by 90 degrees CCW (see np.rot90) Args: + ---- bbox: A bounding box tuple (x_min, y_min, x_max, y_max). factor: Number of CCW rotations. Must be in set {0, 1, 2, 3} See np.rot90. rows: Image rows. cols: Image cols. Returns: + ------- tuple: A bounding box tuple (x_min, y_min, x_max, y_max). """ if factor not in {0, 1, 2, 3}: - raise ValueError("Parameter n must be in set {0, 1, 2, 3}") + msg = "Parameter n must be in set {0, 1, 2, 3}" + raise ValueError(msg) x_min, y_min, x_max, y_max = bbox[:4] if factor == 1: bbox = y_min, 1 - x_max, y_max, 1 - x_min - elif factor == 2: + elif factor == TWO: bbox = 1 - x_max, 1 - y_max, 1 - x_min, 1 - y_min - elif factor == 3: + elif factor == THREE: bbox = 1 - y_max, x_min, 1 - y_min, x_max return bbox @@ -104,28 +109,32 @@ def keypoint_rot90( """Rotates a keypoint by 90 degrees CCW (see np.rot90) Args: + ---- keypoint: A keypoint `(x, y, angle, scale)`. factor: Number of CCW rotations. Must be in range [0;3] See np.rot90. rows: Image height. cols: Image width. Returns: + ------- tuple: A keypoint `(x, y, angle, scale)`. Raises: + ------ ValueError: if factor not in set {0, 1, 2, 3} """ x, y, angle, scale = keypoint[:4] if factor not in {0, 1, 2, 3}: - raise ValueError("Parameter n must be in set {0, 1, 2, 3}") + msg = "Parameter n must be in set {0, 1, 2, 3}" + raise ValueError(msg) if factor == 1: x, y, angle = y, (cols - 1) - x, angle - math.pi / 2 - elif factor == 2: + elif factor == TWO: x, y, angle = (cols - 1) - x, (rows - 1) - y, angle - math.pi - elif factor == 3: + elif factor == THREE: x, y, angle = (rows - 1) - y, x, angle + math.pi / 2 return x, y, angle, scale @@ -154,6 +163,7 @@ def bbox_rotate(bbox: BoxInternalType, angle: float, method: str, rows: int, col """Rotates a bounding box by angle degrees. Args: + ---- bbox: A bounding box `(x_min, y_min, x_max, y_max)`. angle: Angle of rotation in degrees. method: Rotation method used. Should be one of: "largest_box", "ellipse". Default: "largest_box". @@ -161,9 +171,11 @@ def bbox_rotate(bbox: BoxInternalType, angle: float, method: str, rows: int, col cols: Image cols. Returns: + ------- A bounding box `(x_min, y_min, x_max, y_max)`. References: + ---------- https://arxiv.org/abs/2109.13488 """ @@ -199,12 +211,14 @@ def keypoint_rotate( """Rotate a keypoint by angle. Args: + ---- keypoint: A keypoint `(x, y, angle, scale)`. angle: Rotation angle. rows: Image height. cols: Image width. Returns: + ------- A keypoint `(x, y, angle, scale)`. """ @@ -279,6 +293,7 @@ def bbox_shift_scale_rotate( Args: + ---- bbox (tuple): A bounding box `(x_min, y_min, x_max, y_max)`. angle (int): Angle of rotation in degrees. scale (int): Scale factor. @@ -290,6 +305,7 @@ def bbox_shift_scale_rotate( cols (int): Image cols. Returns: + ------- A bounding box `(x_min, y_min, x_max, y_max)`. """ @@ -423,11 +439,13 @@ def keypoint_scale(keypoint: KeypointInternalType, scale_x: float, scale_y: floa """Scales a keypoint by scale_x and scale_y. Args: + ---- keypoint: A keypoint `(x, y, angle, scale)`. scale_x: Scale coefficient x-axis. scale_y: Scale coefficient y-axis. Returns: + ------- A keypoint `(x, y, angle, scale)`. """ @@ -462,7 +480,7 @@ def perspective( matrix: np.ndarray, max_width: int, max_height: int, - border_val: Union[int, float, List[int], List[float], np.ndarray], + border_val: Union[float, List[float], np.ndarray], border_mode: int, keep_size: bool, interpolation: int, @@ -499,8 +517,8 @@ def perspective_bbox( x1, y1, x2, y2 = float("inf"), float("inf"), 0, 0 for pt in points: - pt = perspective_keypoint(pt.tolist() + [0, 0], height, width, matrix, max_width, max_height, keep_size) - x, y = pt[:2] + point = perspective_keypoint((*pt.tolist(), 0, 0), height, width, matrix, max_width, max_height, keep_size) + x, y = point[:2] x1 = min(x1, x) x2 = max(x2, x) y1 = min(y1, y) @@ -512,11 +530,12 @@ def perspective_bbox( ) -def rotation2DMatrixToEulerAngles(matrix: np.ndarray, y_up: bool = False) -> float: - """ - Args: +def rotation2d_matrix_to_euler_angles(matrix: np.ndarray, y_up: bool = False) -> float: + """Args: + ---- matrix (np.ndarray): Rotation matrix y_up (bool): is Y axis looks up or down + """ if y_up: return np.arctan2(matrix[1, 0], matrix[0, 0]) @@ -538,7 +557,7 @@ def perspective_keypoint( keypoint_vector = np.array([x, y], dtype=np.float32).reshape([1, 1, 2]) x, y = cv2.perspectiveTransform(keypoint_vector, matrix)[0, 0] - angle += rotation2DMatrixToEulerAngles(matrix[:2, :2], y_up=True) + angle += rotation2d_matrix_to_euler_angles(matrix[:2, :2], y_up=True) scale_x = np.sign(matrix[0, 0]) * np.sqrt(matrix[0, 0] ** 2 + matrix[0, 1] ** 2) scale_y = np.sign(matrix[1, 1]) * np.sqrt(matrix[1, 0] ** 2 + matrix[1, 1] ** 2) @@ -561,7 +580,7 @@ def warp_affine( image: np.ndarray, matrix: skimage.transform.ProjectiveTransform, interpolation: int, - cval: Union[int, float, Sequence[int], Sequence[float]], + cval: Union[float, Sequence[float]], mode: int, output_shape: Sequence[int], ) -> np.ndarray: @@ -572,8 +591,7 @@ def warp_affine( warp_fn = _maybe_process_in_chunks( cv2.warpAffine, M=matrix.params[:2], dsize=dsize, flags=interpolation, borderMode=mode, borderValue=cval ) - tmp = warp_fn(image) - return tmp + return warp_fn(image) @angle_2pi_range @@ -587,7 +605,7 @@ def keypoint_affine( x, y, a, s = keypoint[:4] x, y = cv2.transform(np.array([[[x, y]]]), matrix.params[:2]).squeeze() - a += rotation2DMatrixToEulerAngles(matrix.params[:2]) + a += rotation2d_matrix_to_euler_angles(matrix.params[:2]) s *= np.max([scale["x"], scale["y"]]) return x, y, a, s @@ -729,6 +747,7 @@ def to_distance_maps( method that only supports the augmentation of images. Args: + ---- keypoint: keypoint coordinates height: image height width: image width @@ -738,6 +757,7 @@ def to_distance_maps( exactly the position of the respective keypoint. Returns: + ------- (H, W, N) ndarray A ``float32`` array containing ``N`` distance maps for ``N`` keypoints. Each location ``(y, x, n)`` in the array denotes the @@ -745,6 +765,7 @@ def to_distance_maps( If `inverted` is ``True``, the distance ``d`` is replaced by ``d/(d+1)``. The height and width of the array match the height and width in ``KeypointsOnImage.shape``. + """ distance_maps = np.zeros((height, width, len(keypoints)), dtype=np.float32) @@ -761,79 +782,62 @@ def to_distance_maps( return distance_maps +def validate_if_not_found_coords( + if_not_found_coords: Optional[Union[Sequence[int], Dict[str, Any]]], +) -> Tuple[bool, int, int]: + """Validate and process `if_not_found_coords` parameter.""" + if if_not_found_coords is None: + return True, -1, -1 + if isinstance(if_not_found_coords, (tuple, list)): + if len(if_not_found_coords) != TWO: + msg = "Expected tuple/list 'if_not_found_coords' to contain exactly two entries." + raise ValueError(msg) + return False, if_not_found_coords[0], if_not_found_coords[1] + if isinstance(if_not_found_coords, dict): + return False, if_not_found_coords["x"], if_not_found_coords["y"] + + msg = "Expected if_not_found_coords to be None, tuple, list, or dict." + raise ValueError(msg) + + +def find_keypoint( + position: Tuple[int, int], distance_map: np.ndarray, threshold: Optional[float], inverted: bool +) -> Optional[Tuple[float, float]]: + """Determine if a valid keypoint can be found at the given position.""" + y, x = position + value = distance_map[y, x] + if not inverted and threshold is not None and value >= threshold: + return None + if inverted and threshold is not None and value < threshold: + return None + return float(x), float(y) + + def from_distance_maps( distance_maps: np.ndarray, inverted: bool, if_not_found_coords: Optional[Union[Sequence[int], Dict[str, Any]]], threshold: Optional[float] = None, ) -> List[Tuple[float, float]]: - """Convert outputs of ``to_distance_maps()`` to ``KeypointsOnImage``. + """Convert outputs of `to_distance_maps` to `KeypointsOnImage`. This is the inverse of `to_distance_maps`. - - Args: - distance_maps (np.ndarray): The distance maps. ``N`` is the number of keypoints. - inverted (bool): Whether the given distance maps were generated in inverted mode - (i.e. :func:`KeypointsOnImage.to_distance_maps` was called with ``inverted=True``) or in non-inverted mode. - if_not_found_coords (tuple, list, dict or None, optional): - Coordinates to use for keypoints that cannot be found in `distance_maps`. - - * If this is a ``list``/``tuple``, it must contain two ``int`` values. - * If it is a ``dict``, it must contain the keys ``x`` and ``y`` with each containing one ``int`` value. - * If this is ``None``, then the keypoint will not be added. - threshold (float): The search for keypoints works by searching for the - argmin (non-inverted) or argmax (inverted) in each channel. This - parameters contains the maximum (non-inverted) or minimum (inverted) value to accept in order to view a hit - as a keypoint. Use ``None`` to use no min/max. - nb_channels (None, int): Number of channels of the image on which the keypoints are placed. - Some keypoint augmenters require that information. If set to ``None``, the keypoint's shape will be set - to ``(height, width)``, otherwise ``(height, width, nb_channels)``. """ - if distance_maps.ndim != 3: - raise ValueError( - f"Expected three-dimensional input, " - f"got {distance_maps.ndim} dimensions and shape {distance_maps.shape}." - ) + if distance_maps.ndim != THREE: + msg = f"Expected three-dimensional input, got {distance_maps.ndim} dimensions and shape {distance_maps.shape}." + raise ValueError(msg) height, width, nb_keypoints = distance_maps.shape - drop_if_not_found = False - if if_not_found_coords is None: - drop_if_not_found = True - if_not_found_x = -1 - if_not_found_y = -1 - elif isinstance(if_not_found_coords, (tuple, list)): - if len(if_not_found_coords) != 2: - raise ValueError( - f"Expected tuple/list 'if_not_found_coords' to contain exactly two entries, " - f"got {len(if_not_found_coords)}." - ) - if_not_found_x = if_not_found_coords[0] - if_not_found_y = if_not_found_coords[1] - elif isinstance(if_not_found_coords, dict): - if_not_found_x = if_not_found_coords["x"] - if_not_found_y = if_not_found_coords["y"] - else: - raise ValueError( - f"Expected if_not_found_coords to be None or tuple or list or dict, got {type(if_not_found_coords)}." - ) + drop_if_not_found, if_not_found_x, if_not_found_y = validate_if_not_found_coords(if_not_found_coords) keypoints = [] for i in range(nb_keypoints): - if inverted: - hitidx_flat = np.argmax(distance_maps[..., i]) - else: - hitidx_flat = np.argmin(distance_maps[..., i]) + hitidx_flat = np.argmax(distance_maps[..., i]) if inverted else np.argmin(distance_maps[..., i]) hitidx_ndim = np.unravel_index(hitidx_flat, (height, width)) - if not inverted and threshold is not None: - found = distance_maps[hitidx_ndim[0], hitidx_ndim[1], i] < threshold - elif inverted and threshold is not None: - found = distance_maps[hitidx_ndim[0], hitidx_ndim[1], i] >= threshold - else: - found = True - if found: - keypoints.append((float(hitidx_ndim[1]), float(hitidx_ndim[0]))) - else: - if not drop_if_not_found: - keypoints.append((if_not_found_x, if_not_found_y)) + keypoint = find_keypoint(hitidx_ndim, distance_maps[:, :, i], threshold, inverted) + if keypoint: + keypoints.append(keypoint) + elif not drop_if_not_found: + keypoints.append((if_not_found_x, if_not_found_y)) return keypoints @@ -900,7 +904,7 @@ def random_flip(img: np.ndarray, code: int) -> np.ndarray: def transpose(img: np.ndarray) -> np.ndarray: - return img.transpose(1, 0, 2) if len(img.shape) > 2 else img.transpose(1, 0) + return img.transpose(1, 0, 2) if len(img.shape) > TWO else img.transpose(1, 0) def rot90(img: np.ndarray, factor: int) -> np.ndarray: @@ -912,11 +916,13 @@ def bbox_vflip(bbox: BoxInternalType, rows: int, cols: int) -> BoxInternalType: """Flip a bounding box vertically around the x-axis. Args: + ---- bbox: A bounding box `(x_min, y_min, x_max, y_max)`. rows: Image rows. cols: Image cols. Returns: + ------- tuple: A bounding box `(x_min, y_min, x_max, y_max)`. """ @@ -928,11 +934,13 @@ def bbox_hflip(bbox: BoxInternalType, rows: int, cols: int) -> BoxInternalType: """Flip a bounding box horizontally around the y-axis. Args: + ---- bbox: A bounding box `(x_min, y_min, x_max, y_max)`. rows: Image rows. cols: Image cols. Returns: + ------- A bounding box `(x_min, y_min, x_max, y_max)`. """ @@ -944,15 +952,18 @@ def bbox_flip(bbox: BoxInternalType, d: int, rows: int, cols: int) -> BoxInterna """Flip a bounding box either vertically, horizontally or both depending on the value of `d`. Args: + ---- bbox: A bounding box `(x_min, y_min, x_max, y_max)`. d: dimension. 0 for vertical flip, 1 for horizontal, -1 for transpose rows: Image rows. cols: Image cols. Returns: + ------- A bounding box `(x_min, y_min, x_max, y_max)`. Raises: + ------ ValueError: if value of `d` is not -1, 0 or 1. """ @@ -972,21 +983,25 @@ def bbox_transpose(bbox: KeypointInternalType, axis: int, rows: int, cols: int) """Transposes a bounding box along given axis. Args: + ---- bbox: A bounding box `(x_min, y_min, x_max, y_max)`. axis: 0 - main axis, 1 - secondary axis. rows: Image rows. cols: Image cols. Returns: + ------- A bounding box tuple `(x_min, y_min, x_max, y_max)`. Raises: + ------ ValueError: If axis not equal to 0 or 1. """ x_min, y_min, x_max, y_max = bbox[:4] if axis not in {0, 1}: - raise ValueError("Axis must be either 0 or 1.") + msg = "Axis must be either 0 or 1." + raise ValueError(msg) if axis == 0: bbox = (y_min, x_min, y_max, x_max) if axis == 1: @@ -999,11 +1014,13 @@ def keypoint_vflip(keypoint: KeypointInternalType, rows: int, cols: int) -> Keyp """Flip a keypoint vertically around the x-axis. Args: + ---- keypoint: A keypoint `(x, y, angle, scale)`. rows: Image height. cols: Image width. Returns: + ------- tuple: A keypoint `(x, y, angle, scale)`. """ @@ -1017,11 +1034,13 @@ def keypoint_hflip(keypoint: KeypointInternalType, rows: int, cols: int) -> Keyp """Flip a keypoint horizontally around the y-axis. Args: + ---- keypoint: A keypoint `(x, y, angle, scale)`. rows: Image height. cols: Image width. Returns: + ------- A keypoint `(x, y, angle, scale)`. """ @@ -1034,6 +1053,7 @@ def keypoint_flip(keypoint: KeypointInternalType, d: int, rows: int, cols: int) """Flip a keypoint either vertically, horizontally or both depending on the value of `d`. Args: + ---- keypoint: A keypoint `(x, y, angle, scale)`. d: Number of flip. Must be -1, 0 or 1: * 0 - vertical flip, @@ -1043,9 +1063,11 @@ def keypoint_flip(keypoint: KeypointInternalType, d: int, rows: int, cols: int) cols: Image width. Returns: + ------- A keypoint `(x, y, angle, scale)`. Raises: + ------ ValueError: if value of `d` is not -1, 0 or 1. """ @@ -1065,18 +1087,17 @@ def keypoint_transpose(keypoint: KeypointInternalType) -> KeypointInternalType: """Rotate a keypoint by angle. Args: + ---- keypoint: A keypoint `(x, y, angle, scale)`. Returns: + ------- A keypoint `(x, y, angle, scale)`. """ x, y, angle, scale = keypoint[:4] - if angle <= np.pi: - angle = np.pi - angle - else: - angle = 3 * np.pi - angle + angle = np.pi - angle if angle <= np.pi else 3 * np.pi - angle return y, x, angle, scale @@ -1109,9 +1130,7 @@ def pad( if img.shape[:2] != (max(min_height, height), max(min_width, width)): raise RuntimeError( - "Invalid result shape. Got: {}. Expected: {}".format( - img.shape[:2], (max(min_height, height), max(min_width, width)) - ) + f"Invalid result shape. Got: {img.shape[:2]}. Expected: {(max(min_height, height), max(min_width, width))}" ) return img diff --git a/albumentations/augmentations/geometric/resize.py b/albumentations/augmentations/geometric/resize.py index e9f278909..ff8bc80f8 100644 --- a/albumentations/augmentations/geometric/resize.py +++ b/albumentations/augmentations/geometric/resize.py @@ -4,13 +4,9 @@ import cv2 import numpy as np -from ...core.transforms_interface import DualTransform, to_tuple -from ...core.types import ( - BoxInternalType, - KeypointInternalType, - ScaleFloatType, - ScaleIntType, -) +from albumentations.core.transforms_interface import DualTransform, to_tuple +from albumentations.core.types import BoxInternalType, KeypointInternalType, ScaleFloatType + from . import functional as F __all__ = ["RandomScale", "LongestMaxSize", "SmallestMaxSize", "Resize"] @@ -20,6 +16,7 @@ class RandomScale(DualTransform): """Randomly resize the input. Output image size is different from the input image size. Args: + ---- scale_limit ((float, float) or float): scaling factor range. If scale_limit is a single float value, the range will be (-scale_limit, scale_limit). Note that the scale_limit will be biased by 1. If scale_limit is a tuple, like (low, high), sampling will be done from the range (1 + low, 1 + high). @@ -34,6 +31,7 @@ class RandomScale(DualTransform): Image types: uint8, float32 + """ def __init__( @@ -72,6 +70,7 @@ class LongestMaxSize(DualTransform): """Rescale an image so that maximum side is equal to max_size, keeping the aspect ratio of the initial image. Args: + ---- max_size (int, list of int): maximum size of the image after the transformation. When using a list, max size will be randomly selected from the values in the list. interpolation (OpenCV flag): interpolation method. Default: cv2.INTER_LINEAR. @@ -82,6 +81,7 @@ class LongestMaxSize(DualTransform): Image types: uint8, float32 + """ def __init__( @@ -124,6 +124,7 @@ class SmallestMaxSize(DualTransform): """Rescale an image so that minimum side is equal to max_size, keeping the aspect ratio of the initial image. Args: + ---- max_size (int, list of int): maximum size of smallest side of the image after the transformation. When using a list, max size will be randomly selected from the values in the list. interpolation (OpenCV flag): interpolation method. Default: cv2.INTER_LINEAR. @@ -134,6 +135,7 @@ class SmallestMaxSize(DualTransform): Image types: uint8, float32 + """ def __init__( @@ -175,6 +177,7 @@ class Resize(DualTransform): """Resize the input to the given height and width. Args: + ---- height (int): desired height of the output. width (int): desired width of the output. interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: @@ -187,6 +190,7 @@ class Resize(DualTransform): Image types: uint8, float32 + """ def __init__( diff --git a/albumentations/augmentations/geometric/rotate.py b/albumentations/augmentations/geometric/rotate.py index 1d3781a95..360e4d54f 100644 --- a/albumentations/augmentations/geometric/rotate.py +++ b/albumentations/augmentations/geometric/rotate.py @@ -5,18 +5,22 @@ import cv2 import numpy as np -from ...core.transforms_interface import DualTransform, FillValueType, to_tuple -from ...core.types import BoxInternalType, KeypointInternalType, ScaleIntType -from ..crops import functional as FCrops +from albumentations.augmentations.crops import functional as FCrops +from albumentations.core.transforms_interface import DualTransform, FillValueType, to_tuple +from albumentations.core.types import BoxInternalType, KeypointInternalType, ScaleIntType + from . import functional as F __all__ = ["Rotate", "RandomRotate90", "SafeRotate"] +SMALL_NUMBER = 1e-10 + class RandomRotate90(DualTransform): """Randomly rotate the input by 90 degrees zero or more times. Args: + ---- p: probability of applying the transform. Default: 0.5. Targets: @@ -24,12 +28,14 @@ class RandomRotate90(DualTransform): Image types: uint8, float32 + """ def apply(self, img: np.ndarray, factor: float = 0, **params: Any) -> np.ndarray: - """ - Args: + """Args: + ---- factor (int): number of times the input will be rotated by 90 degrees. + """ return np.ascontiguousarray(np.rot90(img, factor)) @@ -51,6 +57,7 @@ class Rotate(DualTransform): """Rotate the input by an angle selected randomly from the uniform distribution. Args: + ---- limit: range from which a random angle is picked. If limit is a single int an angle is picked from (-limit, limit). Default: (-90, 90) interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: @@ -73,6 +80,7 @@ class Rotate(DualTransform): Image types: uint8, float32 + """ def __init__( @@ -166,14 +174,12 @@ def apply_to_keypoint( @staticmethod def _rotated_rect_with_max_area(h: int, w: int, angle: float) -> Dict[str, int]: - """ - Given a rectangle of size wxh that has been rotated by 'angle' (in + """Given a rectangle of size wxh that has been rotated by 'angle' (in degrees), computes the width and height of the largest possible axis-aligned rectangle (maximal area) within the rotated rectangle. Code from: https://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders """ - angle = math.radians(angle) width_is_longer = w >= h side_long, side_short = (w, h) if width_is_longer else (h, w) @@ -181,7 +187,7 @@ def _rotated_rect_with_max_area(h: int, w: int, angle: float) -> Dict[str, int]: # since the solutions for angle, -angle and 180-angle are all the same, # it is sufficient to look at the first quadrant and the absolute values of sin,cos: sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle)) - if side_short <= 2.0 * sin_a * cos_a * side_long or abs(sin_a - cos_a) < 1e-10: + if side_short <= 2.0 * sin_a * cos_a * side_long or abs(sin_a - cos_a) < SMALL_NUMBER: # half constrained case: two crop corners touch the longer side, # the other two corners are on the mid-line parallel to the longer line x = 0.5 * side_short @@ -191,12 +197,12 @@ def _rotated_rect_with_max_area(h: int, w: int, angle: float) -> Dict[str, int]: cos_2a = cos_a * cos_a - sin_a * sin_a wr, hr = (w * cos_a - h * sin_a) / cos_2a, (h * cos_a - w * sin_a) / cos_2a - return dict( - x_min=max(0, int(w / 2 - wr / 2)), - x_max=min(w, int(w / 2 + wr / 2)), - y_min=max(0, int(h / 2 - hr / 2)), - y_max=min(h, int(h / 2 + hr / 2)), - ) + return { + "x_min": max(0, int(w / 2 - wr / 2)), + "x_max": min(w, int(w / 2 + wr / 2)), + "y_min": max(0, int(h / 2 - hr / 2)), + "y_max": min(h, int(h / 2 + hr / 2)), + } @property def targets_as_params(self) -> List[str]: @@ -221,6 +227,7 @@ class SafeRotate(DualTransform): may see some artifacts. Args: + ---- limit ((int, int) or int): range from which a random angle is picked. If limit is a single int an angle is picked from (-limit, limit). Default: (-90, 90) interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: @@ -240,6 +247,7 @@ class SafeRotate(DualTransform): Image types: uint8, float32 + """ def __init__( @@ -259,10 +267,10 @@ def __init__( self.value = value self.mask_value = mask_value - def apply(self, img: np.ndarray, matrix: np.ndarray = np.array(None), **params: Any) -> np.ndarray: + def apply(self, img: np.ndarray, matrix: Optional[np.ndarray] = None, **params: Any) -> np.ndarray: return F.safe_rotate(img, matrix, cast(int, self.interpolation), self.value, self.border_mode) - def apply_to_mask(self, mask: np.ndarray, matrix: np.ndarray = np.array(None), **params: Any) -> np.ndarray: + def apply_to_mask(self, mask: np.ndarray, matrix: Optional[np.ndarray] = None, **params: Any) -> np.ndarray: return F.safe_rotate(mask, matrix, cv2.INTER_NEAREST, self.mask_value, self.border_mode) def apply_to_bbox(self, bbox: BoxInternalType, cols: int = 0, rows: int = 0, **params: Any) -> BoxInternalType: diff --git a/albumentations/augmentations/geometric/transforms.py b/albumentations/augmentations/geometric/transforms.py index 56cd92939..1eebec08b 100644 --- a/albumentations/augmentations/geometric/transforms.py +++ b/albumentations/augmentations/geometric/transforms.py @@ -7,25 +7,19 @@ import numpy as np import skimage.transform +from albumentations import random_utils +from albumentations.augmentations.functional import bbox_from_mask from albumentations.core.bbox_utils import denormalize_bbox, normalize_bbox - -from ... import random_utils -from ...core.transforms_interface import DualTransform, to_tuple -from ...core.types import ( +from albumentations.core.transforms_interface import DualTransform, to_tuple +from albumentations.core.types import ( BoxInternalType, - BoxType, - ColorType, ImageColorType, KeypointInternalType, - KeypointType, - NumType, - ScalarType, ScaleFloatType, ScaleIntType, - ScaleType, SizeType, ) -from ..functional import bbox_from_mask + from . import functional as F __all__ = [ @@ -43,11 +37,15 @@ "PadIfNeeded", ] +TWO = 2 +THREE = 3 + class ShiftScaleRotate(DualTransform): """Randomly apply affine transforms: translate, scale and rotate the input. Args: + ---- shift_limit ((float, float) or float): shift factor range for both height and width. If shift_limit is a single float value, the range will be (-shift_limit, shift_limit). Absolute values for lower and upper bounds should lie in range [0, 1]. Default: (-0.0625, 0.0625). @@ -84,6 +82,7 @@ class ShiftScaleRotate(DualTransform): Image types: uint8, float32 + """ def __init__( @@ -182,6 +181,7 @@ class ElasticTransform(DualTransform): Recognition, 2003. Args: + ---- alpha (float): sigma (float): Gaussian filter parameter. alpha_affine (float): The range will be (-alpha_affine, alpha_affine) @@ -205,6 +205,7 @@ class ElasticTransform(DualTransform): Image types: uint8, float32 + """ def __init__( @@ -306,10 +307,11 @@ class Perspective(DualTransform): """Perform a random four point perspective transform of the input. Args: + ---- scale: standard deviation of the normal distributions. These are used to sample the random distances of the subimage's corners from the full image's corners. If scale is a single float value, the range will be (0, scale). Default: (0.05, 0.1). - keep_size: Whether to resize image’s back to their original size after applying the perspective + keep_size: Whether to resize image back to their original size after applying the perspective transform. If set to False, the resulting images may end up having different shapes and will always be a list, never an array. Default: True pad_mode (OpenCV flag): OpenCV border mode. @@ -329,6 +331,7 @@ class Perspective(DualTransform): Image types: uint8, float32 + """ def __init__( @@ -336,8 +339,8 @@ def __init__( scale: ScaleFloatType = (0.05, 0.1), keep_size: bool = True, pad_mode: int = cv2.BORDER_CONSTANT, - pad_val: Union[int, float, List[float], List[int]] = 0, - mask_pad_val: Union[int, float, List[float], List[int]] = 0, + pad_val: Union[float, List[float]] = 0, + mask_pad_val: Union[float, List[float]] = 0, fit_output: bool = False, interpolation: int = cv2.INTER_LINEAR, always_apply: bool = False, @@ -419,12 +422,12 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A # x-coordiates or the top-right and top-left x-coordinates min_width = None max_width = None - while min_width is None or min_width < 2: + while min_width is None or min_width < TWO: width_top = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) width_bottom = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) max_width = int(max(width_top, width_bottom)) min_width = int(min(width_top, width_bottom)) - if min_width < 2: + if min_width < TWO: step_size = (2 - min_width) / 2 tl[0] -= step_size tr[0] += step_size @@ -435,12 +438,12 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A # and bottom-right y-coordinates or the top-left and bottom-left y-coordinates min_height = None max_height = None - while min_height is None or min_height < 2: + while min_height is None or min_height < TWO: height_right = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) height_left = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) max_height = int(max(height_right, height_left)) min_height = int(min(height_right, height_left)) - if min_height < 2: + if min_height < TWO: step_size = (2 - min_height) / 2 tl[1] -= step_size tr[1] -= step_size @@ -523,6 +526,7 @@ class Affine(DualTransform): `mask_interpolation` deals with the method of interpolation used for this. Args: + ---- scale (number, tuple of number or dict): Scaling factor to use, where ``1.0`` denotes "no change" and ``0.5`` is zoomed out to ``50`` percent of the original size. * If a single number, then that value will be used for all images. @@ -598,6 +602,7 @@ class Affine(DualTransform): Reference: [1] https://arxiv.org/abs/2109.13488 + """ def __init__( @@ -609,8 +614,8 @@ def __init__( shear: Optional[Union[ScaleFloatType, Dict[str, Any]]] = None, interpolation: int = cv2.INTER_LINEAR, mask_interpolation: int = cv2.INTER_NEAREST, - cval: Union[int, float, Tuple[int, int], Tuple[float, float]] = 0, - cval_mask: Union[int, float, Tuple[int, int], Tuple[float, float]] = 0, + cval: Union[float, Tuple[float, float]] = 0, + cval_mask: Union[float, Tuple[float, float]] = 0, mode: int = cv2.BORDER_CONSTANT, fit_output: bool = False, keep_ratio: bool = False, @@ -621,7 +626,7 @@ def __init__( super().__init__(always_apply=always_apply, p=p) params = [scale, translate_percent, translate_px, rotate, shear] - if all([p is None for p in params]): + if all(p is None for p in params): scale = {"x": (0.9, 1.1), "y": (0.9, 1.1)} translate_percent = {"x": (-0.1, 0.1), "y": (-0.1, 0.1)} rotate = (-15, 15) @@ -688,16 +693,16 @@ def _handle_translate_arg( translate_px = 0 if translate_percent is not None and translate_px is not None: - raise ValueError( - "Expected either translate_percent or translate_px to be " "provided, " "but neither of them was." - ) + msg = "Expected either translate_percent or translate_px to be " "provided, " "but neither of them was." + raise ValueError(msg) if translate_percent is not None: # translate by percent return cls._handle_dict_arg(translate_percent, "translate_percent", default=0.0), translate_px if translate_px is None: - raise ValueError("translate_px is None.") + msg = "translate_px is None." + raise ValueError(msg) # translate by pixels return translate_percent, cls._handle_dict_arg(translate_px, "translate_px") @@ -751,7 +756,13 @@ def apply_to_keypoint( scale: Optional[Dict[str, Any]] = None, **params: Any, ) -> KeypointInternalType: - assert scale is not None and matrix is not None + if scale is None: + msg = "Expected scale to be provided, but got None." + raise ValueError(msg) + if matrix is None: + msg = "Expected matrix to be provided, but got None." + raise ValueError(msg) + return F.keypoint_affine(keypoint, matrix=matrix, scale=scale) @property @@ -834,7 +845,7 @@ def _compute_affine_warp_output_shape( maxr = corners[:, 1].max() out_height = maxr - minr + 1 out_width = maxc - minc + 1 - if len(input_shape) == 3: + if len(input_shape) == THREE: output_shape = np.ceil((out_height, out_width, input_shape[2])) else: output_shape = np.ceil((out_height, out_width)) @@ -855,14 +866,17 @@ class PiecewiseAffine(DualTransform): See also ``Affine`` for a similar technique. Note: + ---- This augmenter is very slow. Try to use ``ElasticTransformation`` instead, which is at least 10x faster. Note: + ---- For coordinate-based inputs (keypoints, bounding boxes, polygons, ...), this augmenter still has to perform an image-based augmentation, which will make it significantly slower and not fully correct for such inputs than other transforms. Args: + ---- scale (float, tuple of float): Each point on the regular grid is moved around via a normal distribution. This scale factor is equivalent to the normal distribution's sigma. Note that the jitter (how far each point is moved in which direction) is multiplied by the height/width of @@ -962,7 +976,7 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A jitter: np.ndarray = random_utils.normal(0, scale, (nb_cells, 2)) if not np.any(jitter > 0): - for i in range(10): # See: https://github.com/albumentations-team/albumentations/issues/1442 + for _i in range(10): # See: https://github.com/albumentations-team/albumentations/issues/1442 jitter = random_utils.normal(0, scale, (nb_cells, 2)) if np.any(jitter > 0): break @@ -1038,6 +1052,7 @@ class PadIfNeeded(DualTransform): """Pad side of the image / max if side is less than desired number. Args: + ---- min_height (int): minimal result image height. min_width (int): minimal result image width. pad_height_divisor (int): if not None, ensures image height is dividable by value of this argument. @@ -1057,9 +1072,27 @@ class PadIfNeeded(DualTransform): Image types: uint8, float32 + """ class PositionType(Enum): + """Enumerates the types of positions for placing an object within a container. + + This Enum class is utilized to define specific anchor positions that an object can + assume relative to a container. It's particularly useful in image processing, UI layout, + and graphic design to specify the alignment and positioning of elements. + + Attributes + ---------- + CENTER (str): Specifies that the object should be placed at the center. + TOP_LEFT (str): Specifies that the object should be placed at the top-left corner. + TOP_RIGHT (str): Specifies that the object should be placed at the top-right corner. + BOTTOM_LEFT (str): Specifies that the object should be placed at the bottom-left corner. + BOTTOM_RIGHT (str): Specifies that the object should be placed at the bottom-right corner. + RANDOM (str): Indicates that the object's position should be determined randomly. + + """ + CENTER = "center" TOP_LEFT = "top_left" TOP_RIGHT = "top_right" @@ -1081,10 +1114,12 @@ def __init__( p: float = 1.0, ): if (min_height is None) == (pad_height_divisor is None): - raise ValueError("Only one of 'min_height' and 'pad_height_divisor' parameters must be set") + msg = "Only one of 'min_height' and 'pad_height_divisor' parameters must be set" + raise ValueError(msg) if (min_width is None) == (pad_width_divisor is None): - raise ValueError("Only one of 'min_width' and 'pad_width_divisor' parameters must be set") + msg = "Only one of 'min_width' and 'pad_width_divisor' parameters must be set" + raise ValueError(msg) super().__init__(always_apply, p) self.min_height = min_height @@ -1261,6 +1296,7 @@ class VerticalFlip(DualTransform): """Flip the input vertically around the x-axis. Args: + ---- p (float): probability of applying the transform. Default: 0.5. Targets: @@ -1268,6 +1304,7 @@ class VerticalFlip(DualTransform): Image types: uint8, float32 + """ def apply(self, img: np.ndarray, **params: Any) -> np.ndarray: @@ -1287,6 +1324,7 @@ class HorizontalFlip(DualTransform): """Flip the input horizontally around the y-axis. Args: + ---- p (float): probability of applying the transform. Default: 0.5. Targets: @@ -1294,10 +1332,11 @@ class HorizontalFlip(DualTransform): Image types: uint8, float32 + """ def apply(self, img: np.ndarray, **params: Any) -> np.ndarray: - if img.ndim == 3 and img.shape[2] > 1 and img.dtype == np.uint8: + if img.ndim == THREE and img.shape[2] > 1 and img.dtype == np.uint8: # Opencv is faster than numpy only in case of # non-gray scale 8bits images return F.hflip_cv2(img) @@ -1318,6 +1357,7 @@ class Flip(DualTransform): """Flip the input either horizontally, vertically or both horizontally and vertically. Args: + ---- p (float): probability of applying the transform. Default: 0.5. Targets: @@ -1325,6 +1365,7 @@ class Flip(DualTransform): Image types: uint8, float32 + """ def apply(self, img: np.ndarray, d: int = 0, **params: Any) -> np.ndarray: @@ -1353,6 +1394,7 @@ class Transpose(DualTransform): """Transpose the input by swapping rows and columns. Args: + ---- p (float): probability of applying the transform. Default: 0.5. Targets: @@ -1360,6 +1402,7 @@ class Transpose(DualTransform): Image types: uint8, float32 + """ def apply(self, img: np.ndarray, **params: Any) -> np.ndarray: @@ -1376,8 +1419,8 @@ def get_transform_init_args_names(self) -> Tuple[()]: class OpticalDistortion(DualTransform): - """ - Args: + """Args: + ---- distort_limit (float, (float, float)): If distort_limit is a single float, the range will be (-distort_limit, distort_limit). Default: (-0.05, 0.05). shift_limit (float, (float, float))): If shift_limit is a single float, the range @@ -1398,6 +1441,7 @@ class OpticalDistortion(DualTransform): Image types: uint8, float32 + """ def __init__( @@ -1465,8 +1509,8 @@ def get_transform_init_args_names(self) -> Tuple[str, ...]: class GridDistortion(DualTransform): - """ - Args: + """Args: + ---- num_steps (int): count of grid cells on each side. distort_limit (float, (float, float)): If distort_limit is a single float, the range will be (-distort_limit, distort_limit). Default: (-0.03, 0.03). @@ -1488,6 +1532,7 @@ class GridDistortion(DualTransform): Image types: uint8, float32 + """ def __init__( diff --git a/albumentations/augmentations/transforms.py b/albumentations/augmentations/transforms.py index 32d9fd655..8c3fba9b2 100644 --- a/albumentations/augmentations/transforms.py +++ b/albumentations/augmentations/transforms.py @@ -18,22 +18,10 @@ is_grayscale_image, is_rgb_image, ) +from albumentations.core.transforms_interface import DualTransform, ImageOnlyTransform, Interpolation, NoOp, to_tuple +from albumentations.core.types import BoxInternalType, KeypointInternalType, ScaleFloatType, ScaleIntType, ScaleType +from albumentations.core.utils import format_args -from ..core.transforms_interface import ( - DualTransform, - ImageOnlyTransform, - Interpolation, - NoOp, - to_tuple, -) -from ..core.types import ( - BoxInternalType, - KeypointInternalType, - ScaleFloatType, - ScaleIntType, - ScaleType, -) -from ..core.utils import format_args from . import functional as F __all__ = [ @@ -79,12 +67,18 @@ "Spatter", ] +HUNDRED = 100 +TWENTY = 20 +FIVE = 5 +THREE = 3 +TWO = 2 + class RandomGridShuffle(DualTransform): - """ - Random shuffle grid's cells on image. + """Random shuffle grid's cells on image. Args: + ---- grid ((int, int)): size of grid for splitting image. Targets: @@ -92,22 +86,23 @@ class RandomGridShuffle(DualTransform): Image types: uint8, float32 + """ def __init__(self, grid: Tuple[int, int] = (3, 3), always_apply: bool = False, p: float = 0.5): super().__init__(always_apply, p) self.grid = grid - def apply(self, img: np.ndarray, tiles: np.ndarray = np.array(None), **params: Any) -> np.ndarray: + def apply(self, img: np.ndarray, tiles: Optional[np.ndarray] = None, **params: Any) -> np.ndarray: return F.swap_tiles_on_image(img, tiles) - def apply_to_mask(self, mask: np.ndarray, tiles: np.ndarray = np.array(None), **params: Any) -> np.ndarray: + def apply_to_mask(self, mask: np.ndarray, tiles: Optional[np.ndarray] = None, **params: Any) -> np.ndarray: return F.swap_tiles_on_image(mask, tiles) def apply_to_keypoint( self, keypoint: KeypointInternalType, - tiles: np.ndarray = np.array(None), + tiles: np.ndarray, rows: int = 0, cols: int = 0, **params: Any, @@ -127,7 +122,7 @@ def apply_to_keypoint( ): x = x - old_left_up_corner_col + current_left_up_corner_col y = y - old_left_up_corner_row + current_left_up_corner_row - keypoint_result = (x, y) + tuple(keypoint[2:]) + keypoint_result = (x, y, *keypoint[2:]) break return cast(KeypointInternalType, keypoint_result) @@ -140,7 +135,8 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A raise ValueError(f"Grid's values must be positive. Current grid [{n}, {m}]") if n > height // 2 or m > width // 2: - raise ValueError("Incorrect size cell of grid. Just shuffle pixels of image") + msg = "Incorrect size cell of grid. Just shuffle pixels of image" + raise ValueError(msg) height_split = np.linspace(0, height, n + 1, dtype=np.int32) width_split = np.linspace(0, width, m + 1, dtype=np.int32) @@ -192,6 +188,7 @@ class Normalize(ImageOnlyTransform): """Normalization is applied by the formula: `img = (img - mean * max_pixel_value) / (std * max_pixel_value)` Args: + ---- mean: mean values std: std values max_pixel_value: maximum possible pixel value @@ -201,6 +198,7 @@ class Normalize(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -227,6 +225,7 @@ class ImageCompression(ImageOnlyTransform): """Decreases image quality by Jpeg, WebP compression of an image. Args: + ---- quality_lower: lower bound on the image quality. Should be in [0, 100] range for jpeg and [1, 100] for webp. quality_upper: upper bound on the image quality. Should be in [0, 100] range for jpeg and [1, 100] for webp. compression_type (ImageCompressionType): should be ImageCompressionType.JPEG or ImageCompressionType.WEBP. @@ -237,9 +236,21 @@ class ImageCompression(ImageOnlyTransform): Image types: uint8, float32 + """ class ImageCompressionType(IntEnum): + """Defines the types of image compression. + + This Enum class is used to specify the image compression format. + + Attributes + ---------- + JPEG (int): Represents the JPEG image compression format. + WEBP (int): Represents the WEBP image compression format. + + """ + JPEG = 0 WEBP = 1 @@ -259,17 +270,18 @@ def __init__( if self.compression_type == ImageCompression.ImageCompressionType.WEBP: low_thresh_quality_assert = 1 - if not low_thresh_quality_assert <= quality_lower <= 100: + if not low_thresh_quality_assert <= quality_lower <= HUNDRED: raise ValueError(f"Invalid quality_lower. Got: {quality_lower}") - if not low_thresh_quality_assert <= quality_upper <= 100: + if not low_thresh_quality_assert <= quality_upper <= HUNDRED: raise ValueError(f"Invalid quality_upper. Got: {quality_upper}") self.quality_lower = quality_lower self.quality_upper = quality_upper def apply(self, img: np.ndarray, quality: int = 100, image_type: str = ".jpg", **params: Any) -> np.ndarray: - if not img.ndim == 2 and img.shape[-1] not in (1, 3, 4): - raise TypeError("ImageCompression transformation expects 1, 3 or 4 channel images.") + if img.ndim != TWO and img.shape[-1] not in (1, 3, 4): + msg = "ImageCompression transformation expects 1, 3 or 4 channel images." + raise TypeError(msg) return F.image_compression(img, quality, image_type) def get_params(self) -> Dict[str, Any]: @@ -297,6 +309,7 @@ class RandomSnow(ImageOnlyTransform): From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library Args: + ---- snow_point_lower: lower_bond of the amount of snow. Should be in [0, 1] range snow_point_upper: upper_bond of the amount of snow. Should be in [0, 1] range brightness_coeff: larger number will lead to a more snow on the image. Should be >= 0 @@ -306,6 +319,7 @@ class RandomSnow(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -319,11 +333,10 @@ def __init__( super().__init__(always_apply, p) if not 0 <= snow_point_lower <= snow_point_upper <= 1: - raise ValueError( - "Invalid combination of snow_point_lower and snow_point_upper. Got: {}".format( - (snow_point_lower, snow_point_upper) - ) + msg = "Invalid combination of snow_point_lower and snow_point_upper. Got: {}".format( + (snow_point_lower, snow_point_upper) ) + raise ValueError(msg) if brightness_coeff < 0: raise ValueError(f"brightness_coeff must be greater than 0. Got: {brightness_coeff}") @@ -347,6 +360,7 @@ class RandomGravel(ImageOnlyTransform): From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library Args: + ---- gravel_roi: (top-left x, top-left y, bottom-right x, bottom right y). Should be in [0, 1] range number_of_patches: no. of gravel patches required @@ -356,6 +370,7 @@ class RandomGravel(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -386,7 +401,9 @@ def generate_gravel_patch(self, rectangular_roi: Tuple[int, int, int, int]) -> n gravels[:, 1] = random_utils.randint(y1, y2, count) return gravels - def apply(self, img: np.ndarray, gravels_infos: List[Any] = [], **params: Any) -> np.ndarray: + def apply(self, img: np.ndarray, gravels_infos: Optional[List[Any]] = None, **params: Any) -> np.ndarray: + if gravels_infos is None: + gravels_infos = [] return F.add_gravel(img, gravels_infos) @property @@ -457,6 +474,7 @@ class RandomRain(ImageOnlyTransform): From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library Args: + ---- slant_lower: should be in range [-20, 20]. slant_upper: should be in range [-20, 20]. drop_length: should be in range [0, 100]. @@ -471,6 +489,7 @@ class RandomRain(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -489,14 +508,13 @@ def __init__( super().__init__(always_apply, p) if rain_type not in ["drizzle", "heavy", "torrential", None]: - raise ValueError( - "raint_type must be one of ({}). Got: {}".format(["drizzle", "heavy", "torrential", None], rain_type) - ) - if not -20 <= slant_lower <= slant_upper <= 20: + msg = "raint_type must be one of ({}). Got: {}".format(["drizzle", "heavy", "torrential", None], rain_type) + raise ValueError(msg) + if not -TWENTY <= slant_lower <= slant_upper <= TWENTY: raise ValueError(f"Invalid combination of slant_lower and slant_upper. Got: {(slant_lower, slant_upper)}") - if not 1 <= drop_width <= 5: + if not 1 <= drop_width <= FIVE: raise ValueError(f"drop_width must be in range [1, 5]. Got: {drop_width}") - if not 0 <= drop_length <= 100: + if not 0 <= drop_length <= HUNDRED: raise ValueError(f"drop_length must be in range [0, 100]. Got: {drop_length}") if not 0 <= brightness_coefficient <= 1: raise ValueError(f"brightness_coefficient must be in range [0, 1]. Got: {brightness_coefficient}") @@ -516,9 +534,11 @@ def apply( img: np.ndarray, slant: int = 10, drop_length: int = 20, - rain_drops: List[Tuple[int, int]] = [], + rain_drops: Optional[List[Tuple[int, int]]] = None, **params: Any, ) -> np.ndarray: + if rain_drops is None: + rain_drops = [] return F.add_rain( img, slant, @@ -557,10 +577,7 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A rain_drops = [] for _ in range(num_drops): # If You want heavy rain, try increasing this - if slant < 0: - x = random.randint(slant, width) - else: - x = random.randint(0, width - slant) + x = random.randint(slant, width) if slant < 0 else random.randint(0, width - slant) y = random.randint(0, height - drop_length) @@ -587,6 +604,7 @@ class RandomFog(ImageOnlyTransform): From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library Args: + ---- fog_coef_lower: lower limit for fog intensity coefficient. Should be in [0, 1] range. fog_coef_upper: upper limit for fog intensity coefficient. Should be in [0, 1] range. alpha_coef: transparency of the fog circles. Should be in [0, 1] range. @@ -596,6 +614,7 @@ class RandomFog(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -610,9 +629,7 @@ def __init__( if not 0 <= fog_coef_lower <= fog_coef_upper <= 1: raise ValueError( - "Invalid combination if fog_coef_lower and fog_coef_upper. Got: {}".format( - (fog_coef_lower, fog_coef_upper) - ) + f"Invalid combination if fog_coef_lower and fog_coef_upper. Got: {(fog_coef_lower, fog_coef_upper)}" ) if not 0 <= alpha_coef <= 1: raise ValueError(f"alpha_coef must be in range [0, 1]. Got: {alpha_coef}") @@ -622,8 +639,14 @@ def __init__( self.alpha_coef = alpha_coef def apply( - self, img: np.ndarray, fog_coef: np.ndarray = 0.1, haze_list: List[Tuple[int, int]] = [], **params: Any + self, + img: np.ndarray, + fog_coef: np.ndarray = 0.1, + haze_list: Optional[List[Tuple[int, int]]] = None, + **params: Any, ) -> np.ndarray: + if haze_list is None: + haze_list = [] return F.add_fog(img, fog_coef, self.alpha_coef, haze_list) @property @@ -665,6 +688,7 @@ class RandomSunFlare(ImageOnlyTransform): From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library Args: + ---- flare_roi: region of the image where flare will appear (x_min, y_min, x_max, y_max). All values should be in range [0, 1]. angle_lower: should be in range [0, `angle_upper`]. @@ -681,6 +705,7 @@ class RandomSunFlare(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -712,11 +737,10 @@ def __init__( if not 0 <= angle_lower < angle_upper <= 1: raise ValueError(f"Invalid combination of angle_lower nad angle_upper. Got: {(angle_lower, angle_upper)}") if not 0 <= num_flare_circles_lower < num_flare_circles_upper: - raise ValueError( - "Invalid combination of num_flare_circles_lower nad num_flare_circles_upper. Got: {}".format( - (num_flare_circles_lower, num_flare_circles_upper) - ) + msg = "Invalid combination of num_flare_circles_lower nad num_flare_circles_upper. Got: {}".format( + (num_flare_circles_lower, num_flare_circles_upper) ) + raise ValueError(msg) self.flare_center_lower_x = flare_center_lower_x self.flare_center_upper_x = flare_center_upper_x @@ -737,9 +761,11 @@ def apply( img: np.ndarray, flare_center_x: float = 0.5, flare_center_y: float = 0.5, - circles: List[Any] = [], + circles: Optional[List[Any]] = None, **params: Any, ) -> np.ndarray: + if circles is None: + circles = [] return F.add_sun_flare( img, flare_center_x, @@ -827,6 +853,7 @@ class RandomShadow(ImageOnlyTransform): From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library Args: + ---- shadow_roi: region of the image where shadows will appear. All values should be in range [0, 1]. num_shadows_lower: Lower limit for the possible number of shadows. @@ -840,6 +867,7 @@ class RandomShadow(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -858,11 +886,10 @@ def __init__( if not 0 <= shadow_lower_x <= shadow_upper_x <= 1 or not 0 <= shadow_lower_y <= shadow_upper_y <= 1: raise ValueError(f"Invalid shadow_roi. Got: {shadow_roi}") if not 0 <= num_shadows_lower <= num_shadows_upper: - raise ValueError( - "Invalid combination of num_shadows_lower nad num_shadows_upper. Got: {}".format( - (num_shadows_lower, num_shadows_upper) - ) + msg = "Invalid combination of num_shadows_lower nad num_shadows_upper. Got: {}".format( + (num_shadows_lower, num_shadows_upper) ) + raise ValueError(msg) self.shadow_roi = shadow_roi @@ -871,7 +898,11 @@ def __init__( self.shadow_dimension = shadow_dimension - def apply(self, img: np.ndarray, vertices_list: List[List[Tuple[int, int]]] = [], **params: Any) -> np.ndarray: + def apply( + self, img: np.ndarray, vertices_list: Optional[List[List[Tuple[int, int]]]] = None, **params: Any + ) -> np.ndarray: + if vertices_list is None: + vertices_list = [] return F.add_shadow(img, vertices_list) @property @@ -893,10 +924,10 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, L vertices_list = [] - for _index in range(num_shadows): - vertex = [] - for _dimension in range(self.shadow_dimension): - vertex.append((random.randint(x_min, x_max), random.randint(y_min, y_max))) + for _ in range(num_shadows): + vertex = [ + (random.randint(x_min, x_max), random.randint(y_min, y_max)) for _ in range(self.shadow_dimension) + ] vertices = np.array([vertex], dtype=np.int32) vertices_list.append(vertices) @@ -916,6 +947,7 @@ class RandomToneCurve(ImageOnlyTransform): """Randomly change the relationship between bright and dark areas of the image by manipulating its tone curve. Args: + ---- scale: standard deviation of the normal distribution. Used to sample random distances to move two control points that modify the image's curve. Values should be in range [0, 1]. Default: 0.1 @@ -926,6 +958,7 @@ class RandomToneCurve(ImageOnlyTransform): Image types: uint8 + """ def __init__( @@ -954,6 +987,7 @@ class HueSaturationValue(ImageOnlyTransform): """Randomly change hue, saturation and value of the input image. Args: + ---- hue_shift_limit: range for changing hue. If hue_shift_limit is a single int, the range will be (-hue_shift_limit, hue_shift_limit). Default: (-20, 20). sat_shift_limit: range for changing saturation. If sat_shift_limit is a single int, @@ -967,6 +1001,7 @@ class HueSaturationValue(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -986,7 +1021,8 @@ def apply( self, img: np.ndarray, hue_shift: int = 0, sat_shift: int = 0, val_shift: int = 0, **params: Any ) -> np.ndarray: if not is_rgb_image(img) and not is_grayscale_image(img): - raise TypeError("HueSaturationValue transformation expects 1-channel or 3-channel images.") + msg = "HueSaturationValue transformation expects 1-channel or 3-channel images." + raise TypeError(msg) return F.shift_hsv(img, hue_shift, sat_shift, val_shift) def get_params(self) -> Dict[str, float]: @@ -1004,6 +1040,7 @@ class Solarize(ImageOnlyTransform): """Invert all pixel values above a threshold. Args: + ---- threshold: range for solarizing threshold. If threshold is a single value, the range will be [threshold, threshold]. Default: 128. p: probability of applying the transform. Default: 0.5. @@ -1013,6 +1050,7 @@ class Solarize(ImageOnlyTransform): Image types: any + """ def __init__(self, threshold: ScaleType = 128, always_apply: bool = False, p: float = 0.5): @@ -1039,6 +1077,7 @@ class Posterize(ImageOnlyTransform): """Reduce the number of bits for each color channel. Args: + ---- num_bits ((int, int) or int, or list of ints [r, g, b], or list of ints [[r1, r1], [g1, g2], [b1, b2]]): number of high bits. @@ -1051,6 +1090,7 @@ class Posterize(ImageOnlyTransform): Image types: uint8 + """ def __init__( @@ -1063,17 +1103,17 @@ def __init__( if isinstance(num_bits, int): self.num_bits = to_tuple(num_bits, num_bits) - elif isinstance(num_bits, Sequence) and len(num_bits) == 3: + elif isinstance(num_bits, Sequence) and len(num_bits) == THREE: self.num_bits = [to_tuple(i, 0) for i in num_bits] # type: ignore[assignment] else: - self.num_bits = to_tuple(num_bits, 0) + self.num_bits = to_tuple(num_bits, 0) # type: ignore[arg-type] def apply(self, img: np.ndarray, num_bits: int = 1, **params: Any) -> np.ndarray: return F.posterize(img, num_bits) def get_params(self) -> Dict[str, Any]: - if len(self.num_bits) == 3: - return {"num_bits": [random.randint(int(i[0]), int(i[1])) for i in self.num_bits]} + if len(self.num_bits) == THREE: + return {"num_bits": [random.randint(int(i[0]), int(i[1])) for i in self.num_bits]} # type: ignore[index] num_bits = self.num_bits return {"num_bits": random.randint(int(num_bits[0]), int(num_bits[1]))} @@ -1085,6 +1125,7 @@ class Equalize(ImageOnlyTransform): """Equalize the image histogram. Args: + ---- mode (str): {'cv', 'pil'}. Use OpenCV or Pillow equalization method. by_channels (bool): If True, use equalization by channels separately, else convert image to YCbCr representation and use equalization by `Y` channel. @@ -1098,6 +1139,7 @@ class Equalize(ImageOnlyTransform): Image types: uint8 + """ def __init__( @@ -1111,7 +1153,7 @@ def __init__( ): modes = ["cv", "pil"] if mode not in modes: - raise ValueError("Unsupported equalization mode. Supports: {}. " "Got: {}".format(modes, mode)) + raise ValueError(f"Unsupported equalization mode. Supports: {modes}. " f"Got: {mode}") super().__init__(always_apply, p) self.mode = mode @@ -1130,7 +1172,7 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A @property def targets_as_params(self) -> List[str]: - return ["image"] + list(self.mask_params) + return ["image", *list(self.mask_params)] def get_transform_init_args_names(self) -> Tuple[str, str]: return ("mode", "by_channels") @@ -1140,6 +1182,7 @@ class RGBShift(ImageOnlyTransform): """Randomly shift values for each channel of the input RGB image. Args: + ---- r_shift_limit: range for changing values for the red channel. If r_shift_limit is a single int, the range will be (-r_shift_limit, r_shift_limit). Default: (-20, 20). g_shift_limit: range for changing values for the green channel. If g_shift_limit is a @@ -1153,6 +1196,7 @@ class RGBShift(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -1170,7 +1214,8 @@ def __init__( def apply(self, img: np.ndarray, r_shift: int = 0, g_shift: int = 0, b_shift: int = 0, **params: Any) -> np.ndarray: if not is_rgb_image(img): - raise TypeError("RGBShift transformation expects 3-channel images.") + msg = "RGBShift transformation expects 3-channel images." + raise TypeError(msg) return F.shift_rgb(img, r_shift, g_shift, b_shift) def get_params(self) -> Dict[str, Any]: @@ -1188,6 +1233,7 @@ class RandomBrightnessContrast(ImageOnlyTransform): """Randomly change brightness and contrast of the input image. Args: + ---- brightness_limit: factor range for changing brightness. If limit is a single float, the range will be (-limit, limit). Default: (-0.2, 0.2). contrast_limit: factor range for changing contrast. @@ -1201,6 +1247,7 @@ class RandomBrightnessContrast(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -1233,6 +1280,7 @@ class GaussNoise(ImageOnlyTransform): """Apply gaussian noise to the input image. Args: + ---- var_limit: variance range for noise. If var_limit is a single float, the range will be (0, var_limit). Default: (10.0, 50.0). mean: mean of the noise. Default: 0 @@ -1245,6 +1293,7 @@ class GaussNoise(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -1258,13 +1307,16 @@ def __init__( super().__init__(always_apply, p) if isinstance(var_limit, (tuple, list)): if var_limit[0] < 0: - raise ValueError("Lower var_limit should be non negative.") + msg = "Lower var_limit should be non negative." + raise ValueError(msg) if var_limit[1] < 0: - raise ValueError("Upper var_limit should be non negative.") + msg = "Upper var_limit should be non negative." + raise ValueError(msg) self.var_limit = var_limit elif isinstance(var_limit, (int, float)): if var_limit < 0: - raise ValueError("var_limit should be non negative.") + msg = "var_limit should be non negative." + raise ValueError(msg) self.var_limit = (0, var_limit) else: @@ -1285,7 +1337,7 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, f gauss = random_utils.normal(self.mean, sigma, image.shape) else: gauss = random_utils.normal(self.mean, sigma, image.shape[:2]) - if len(image.shape) == 3: + if len(image.shape) == THREE: gauss = np.expand_dims(gauss, -1) return {"gauss": gauss} @@ -1299,10 +1351,10 @@ def get_transform_init_args_names(self) -> Tuple[str, str, str]: class ISONoise(ImageOnlyTransform): - """ - Apply camera sensor noise. + """Apply camera sensor noise. Args: + ---- color_shift (float, float): variance range for color hue change. Measured as a fraction of 360 degree Hue angle in HLS colorspace. intensity ((float, float): Multiplicative factor that control strength @@ -1314,6 +1366,7 @@ class ISONoise(ImageOnlyTransform): Image types: uint8 + """ def __init__( @@ -1352,6 +1405,7 @@ class CLAHE(ImageOnlyTransform): """Apply Contrast Limited Adaptive Histogram Equalization to the input image. Args: + ---- clip_limit: upper threshold value for contrast limiting. If clip_limit is a single float value, the range will be (1, clip_limit). Default: (1, 4). tile_grid_size: size of grid for histogram equalization. Default: (8, 8). @@ -1362,6 +1416,7 @@ class CLAHE(ImageOnlyTransform): Image types: uint8 + """ def __init__( @@ -1377,7 +1432,8 @@ def __init__( def apply(self, img: np.ndarray, clip_limit: float = 2, **params: Any) -> np.ndarray: if not is_rgb_image(img) and not is_grayscale_image(img): - raise TypeError("CLAHE transformation expects 1-channel or 3-channel images.") + msg = "CLAHE transformation expects 1-channel or 3-channel images." + raise TypeError(msg) return F.clahe(img, clip_limit, self.tile_grid_size) @@ -1392,6 +1448,7 @@ class ChannelShuffle(ImageOnlyTransform): """Randomly rearrange channels of the input RGB image. Args: + ---- p: probability of applying the transform. Default: 0.5. Targets: @@ -1399,6 +1456,7 @@ class ChannelShuffle(ImageOnlyTransform): Image types: uint8, float32 + """ @property @@ -1423,6 +1481,7 @@ class InvertImg(ImageOnlyTransform): i.e., 255 for uint8 and 1.0 for float32. Args: + ---- p: probability of applying the transform. Default: 0.5. Targets: @@ -1430,6 +1489,7 @@ class InvertImg(ImageOnlyTransform): Image types: uint8, float32 + """ def apply(self, img: np.ndarray, **params: Any) -> np.ndarray: @@ -1447,7 +1507,8 @@ class RandomGamma(ImageOnlyTransform): conditions, potentially enhancing model generalization. For more details on gamma correction, see: https://en.wikipedia.org/wiki/Gamma_correction - Attributes: + Attributes + ---------- gamma_limit (Union[int, Tuple[int, int]]): The range for gamma adjustment. If `gamma_limit` is a single int, the range will be interpreted as (-gamma_limit, gamma_limit), defining how much to adjust the image's gamma. Default is (80, 120). @@ -1460,6 +1521,7 @@ class RandomGamma(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -1486,6 +1548,7 @@ class ToGray(ImageOnlyTransform): than 127, invert the resulting grayscale image. Args: + ---- p: probability of applying the transform. Default: 0.5. Targets: @@ -1493,6 +1556,7 @@ class ToGray(ImageOnlyTransform): Image types: uint8, float32 + """ def apply(self, img: np.ndarray, **params: Any) -> np.ndarray: @@ -1500,7 +1564,8 @@ def apply(self, img: np.ndarray, **params: Any) -> np.ndarray: warnings.warn("The image is already gray.") return img if not is_rgb_image(img): - raise TypeError("ToGray transformation expects 3-channel images.") + msg = "ToGray transformation expects 3-channel images." + raise TypeError(msg) return F.to_gray(img) @@ -1512,6 +1577,7 @@ class ToRGB(ImageOnlyTransform): """Convert the input grayscale image to RGB. Args: + ---- p: probability of applying the transform. Default: 1. Targets: @@ -1519,6 +1585,7 @@ class ToRGB(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__(self, always_apply: bool = True, p: float = 1.0): @@ -1529,7 +1596,8 @@ def apply(self, img: np.ndarray, **params: Any) -> np.ndarray: warnings.warn("The image is already an RGB.") return img if not is_grayscale_image(img): - raise TypeError("ToRGB transformation expects 2-dim images or 3-dim with the last dimension equal to 1.") + msg = "ToRGB transformation expects 2-dim images or 3-dim with the last dimension equal to 1." + raise TypeError(msg) return F.gray_to_rgb(img) @@ -1541,6 +1609,7 @@ class ToSepia(ImageOnlyTransform): """Applies sepia filter to the input RGB image Args: + ---- p: probability of applying the transform. Default: 0.5. Targets: @@ -1548,6 +1617,7 @@ class ToSepia(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__(self, always_apply: bool = False, p: float = 0.5): @@ -1558,7 +1628,8 @@ def __init__(self, always_apply: bool = False, p: float = 0.5): def apply(self, img: np.ndarray, **params: Any) -> np.ndarray: if not is_rgb_image(img): - raise TypeError("ToSepia transformation expects 3-channel images.") + msg = "ToSepia transformation expects 3-channel images." + raise TypeError(msg) return F.linear_transformation_rgb(img, self.sepia_transformation_matrix) def get_transform_init_args_names(self) -> Tuple[()]: @@ -1571,9 +1642,11 @@ class ToFloat(ImageOnlyTransform): image. See Also: + -------- :class:`~albumentations.augmentations.transforms.FromFloat` Args: + ---- max_value: maximum possible input value. Default: None. p: probability of applying the transform. Default: 1.0. @@ -1604,6 +1677,7 @@ class FromFloat(ImageOnlyTransform): This is the inverse transform for :class:`~albumentations.augmentations.transforms.ToFloat`. Args: + ---- max_value: maximum possible input value. Default: None. dtype: data type of the output. See the `'Data types' page from the NumPy docs`_. Default: 'uint16'. @@ -1617,6 +1691,7 @@ class FromFloat(ImageOnlyTransform): .. _'Data types' page from the NumPy docs: https://docs.scipy.org/doc/numpy/user/basics.types.html + """ def __init__( @@ -1637,6 +1712,7 @@ class Downscale(ImageOnlyTransform): """Decreases image quality by downscaling and upscaling back. Args: + ---- scale_min: lower bound on the image scale. Should be < 1. scale_max: lower bound on the image scale. Should be . interpolation: cv2 interpolation method. Could be: @@ -1650,6 +1726,7 @@ class Downscale(ImageOnlyTransform): Image types: uint8, float32 + """ def __init__( @@ -1689,7 +1766,8 @@ def __init__( def apply(self, img: np.ndarray, scale: float, **params: Any) -> np.ndarray: if isinstance(self.interpolation, int): - raise ValueError("Should not be here, added for typing purposes. Please report this issue.") + msg = "Should not be here, added for typing purposes. Please report this issue." + raise TypeError(msg) return F.downscale( img, scale=scale, @@ -1703,10 +1781,11 @@ def get_params(self) -> Dict[str, Any]: def get_transform_init_args_names(self) -> Tuple[str, str]: return "scale_min", "scale_max" - def _to_dict(self) -> Dict[str, Any]: + def to_dict_private(self) -> Dict[str, Any]: if isinstance(self.interpolation, int): - raise ValueError("Should not be here, added for typing purposes. Please report this issue.") - result = super()._to_dict() + msg = "Should not be here, added for typing purposes. Please report this issue." + raise TypeError(msg) + result = super().to_dict_private() result["interpolation"] = {"upscale": self.interpolation.upscale, "downscale": self.interpolation.downscale} return result @@ -1716,6 +1795,7 @@ class Lambda(NoOp): Function signature must include **kwargs to accept optional arguments like interpolation method, image size, etc: Args: + ---- image: Image transformation function. mask: Mask transformation function. keypoint: Keypoint transformation function. @@ -1728,6 +1808,7 @@ class Lambda(NoOp): Image types: Any + """ def __init__( @@ -1779,12 +1860,13 @@ def apply_to_keypoint(self, keypoint: KeypointInternalType, **params: Any) -> Ke def is_serializable(cls) -> bool: return False - def _to_dict(self) -> Dict[str, Any]: + def to_dict_private(self) -> Dict[str, Any]: if self.name is None: - raise ValueError( + msg = ( "To make a Lambda transform serializable you should provide the `name` argument, " "e.g. `Lambda(name='my_transform', image=, ...)`." ) + raise ValueError(msg) return {"__class_fullname__": self.get_class_fullname(), "__name__": self.name} def __repr__(self) -> str: @@ -1798,6 +1880,7 @@ class MultiplicativeNoise(ImageOnlyTransform): """Multiply image to random number or array of numbers. Args: + ---- multiplier: If single float image will be multiplied to this number. If tuple of float multiplier will be in range `[multiplier[0], multiplier[1])`. Default: (0.9, 1.1). per_channel: If `False`, same values for all channels will be used. @@ -1810,6 +1893,7 @@ class MultiplicativeNoise(ImageOnlyTransform): Image types: Any + """ def __init__( @@ -1836,18 +1920,12 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A height, width = img.shape[:2] - if self.per_channel: - num_channels = 1 if is_grayscale_image(img) else img.shape[-1] - else: - num_channels = 1 + num_channels = (1 if is_grayscale_image(img) else img.shape[-1]) if self.per_channel else 1 - if self.elementwise: - shape = [height, width, num_channels] - else: - shape = [num_channels] + shape = [height, width, num_channels] if self.elementwise else [num_channels] multiplier = random_utils.uniform(self.multiplier[0], self.multiplier[1], tuple(shape)) - if is_grayscale_image(img) and img.ndim == 2: + if is_grayscale_image(img) and img.ndim == TWO: multiplier = np.squeeze(multiplier) return {"multiplier": multiplier} @@ -1865,6 +1943,7 @@ class FancyPCA(ImageOnlyTransform): "ImageNet Classification with Deep Convolutional Neural Networks" Args: + ---- alpha: how much to perturb/scale the eigen vecs and vals. scale is samples from gaussian distribution (mu=0, sigma=alpha) @@ -1878,6 +1957,7 @@ class FancyPCA(ImageOnlyTransform): http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf https://deshanadesai.github.io/notes/Fancy-PCA-with-Scikit-Image https://pixelatedbrian.github.io/2018-04-29-fancy_pca/ + """ def __init__(self, alpha: float = 0.1, always_apply: bool = False, p: float = 0.5): @@ -1901,6 +1981,7 @@ class ColorJitter(ImageOnlyTransform): overflow, but we use value saturation. Args: + ---- brightness (float or tuple of float (min, max)): How much to jitter brightness. brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] or the given [min, max]. Should be non negative numbers. @@ -1913,6 +1994,7 @@ class ColorJitter(ImageOnlyTransform): hue (float or tuple of float (min, max)): How much to jitter hue. hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. Should have 0 <= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + """ def __init__( @@ -1952,7 +2034,7 @@ def __check_values( value = [offset - value, offset + value] if clip: value[0] = max(value[0], 0) - elif isinstance(value, (tuple, list)) and len(value) == 2: + elif isinstance(value, (tuple, list)) and len(value) == TWO: if not bounds[0] <= value[0] <= value[1] <= bounds[1]: raise ValueError(f"{name} values should be between {bounds}") else: @@ -1984,11 +2066,14 @@ def apply( contrast: float = 1.0, saturation: float = 1.0, hue: float = 0, - order: List[int] = [0, 1, 2, 3], + order: Optional[List[int]] = None, **params: Any, ) -> np.ndarray: + if order is None: + order = [0, 1, 2, 3] if not is_rgb_image(img) and not is_grayscale_image(img): - raise TypeError("ColorJitter transformation expects 1-channel or 3-channel images.") + msg = "ColorJitter transformation expects 1-channel or 3-channel images." + raise TypeError(msg) color_transforms = [brightness, contrast, saturation, hue] for i in order: img = self.transforms[i](img, color_transforms[i]) # type: ignore[operator] @@ -2002,6 +2087,7 @@ class Sharpen(ImageOnlyTransform): """Sharpen the input image and overlays the result with the original image. Args: + ---- alpha: range to choose the visibility of the sharpened image. At 0, only the original image is visible, at 1.0 only its sharpened version is visible. Default: (0.2, 0.5). lightness: range to choose the lightness of the sharpened image. Default: (0.5, 1.0). @@ -2009,6 +2095,7 @@ class Sharpen(ImageOnlyTransform): Targets: image + """ def __init__( @@ -2057,6 +2144,7 @@ class Emboss(ImageOnlyTransform): """Emboss the input image and overlays the result with the original image. Args: + ---- alpha: range to choose the visibility of the embossed image. At 0, only the original image is visible,at 1.0 only its embossed version is visible. Default: (0.2, 0.5). strength: strength range of the embossing. Default: (0.2, 0.7). @@ -2064,6 +2152,7 @@ class Emboss(ImageOnlyTransform): Targets: image + """ def __init__( @@ -2116,9 +2205,12 @@ class Superpixels(ImageOnlyTransform): This implementation uses skimage's version of the SLIC algorithm. Args: + ---- p_replace (float or tuple of float): Defines for any segment the probability that the pixels within that segment are replaced by their average color (otherwise, the pixels are not changed). - Examples: + + Examples: + -------- * A probability of ``0.0`` would mean, that the pixels in no segment are replaced by their average color (image is not changed at all). @@ -2152,6 +2244,7 @@ class Superpixels(ImageOnlyTransform): Targets: image + """ def __init__( @@ -2187,8 +2280,7 @@ def apply( class TemplateTransform(ImageOnlyTransform): - """ - Apply blending of input image with specified templates + """Apply blending of input image with specified templates Args: templates (numpy array or list of numpy arrays): Images as template for transform. img_weight: If single float will be used as weight for input image. @@ -2248,15 +2340,17 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A template = self.template_transform(image=template)["image"] if get_num_channels(template) not in [1, get_num_channels(img)]: - raise ValueError( + msg = ( "Template must be a single channel or " "has the same number of channels as input image ({}), got {}".format( get_num_channels(img), get_num_channels(template) ) ) + raise ValueError(msg) if template.dtype != img.dtype: - raise ValueError("Image and template must be the same image type") + msg = "Image and template must be the same image type" + raise ValueError(msg) if img.shape[:2] != template.shape[:2]: raise ValueError(f"Image and template must be the same size, got {img.shape[:2]} and {template.shape[:2]}") @@ -2277,12 +2371,13 @@ def is_serializable(cls) -> bool: def targets_as_params(self) -> List[str]: return ["image"] - def _to_dict(self) -> Dict[str, Any]: + def to_dict_private(self) -> Dict[str, Any]: if self.name is None: - raise ValueError( + msg = ( "To make a TemplateTransform serializable you should provide the `name` argument, " "e.g. `TemplateTransform(name='my_transform', ...)`." ) + raise ValueError(msg) return {"__class_fullname__": self.get_class_fullname(), "__name__": self.name} @@ -2290,6 +2385,7 @@ class RingingOvershoot(ImageOnlyTransform): """Create ringing or overshoot artefacts by conlvolving image with 2D sinc filter. Args: + ---- blur_limit: maximum kernel size for sinc filter. Should be in range [3, inf). Default: (7, 15). cutoff: range to choose the cutoff frequency in radians. @@ -2303,6 +2399,7 @@ class RingingOvershoot(ImageOnlyTransform): Targets: image + """ def __init__( @@ -2354,10 +2451,10 @@ def get_transform_init_args_names(self) -> Tuple[str, str]: class UnsharpMask(ImageOnlyTransform): - """ - Sharpen the input image using Unsharp Masking processing and overlays the result with the original image. + """Sharpen the input image using Unsharp Masking processing and overlays the result with the original image. Args: + ---- blur_limit: maximum Gaussian kernel size for blurring the input image. Must be zero or odd and in range [0, inf). If set to 0 it will be computed from sigma as `round(sigma * (3 if img.dtype == np.uint8 else 4) * 2 + 1) + 1`. @@ -2379,6 +2476,7 @@ class UnsharpMask(ImageOnlyTransform): Targets: image + """ def __init__( @@ -2398,12 +2496,14 @@ def __init__( if self.blur_limit[0] == 0 and self.sigma_limit[0] == 0: self.blur_limit = 3, max(3, self.blur_limit[1]) - raise ValueError("blur_limit and sigma_limit minimum value can not be both equal to 0.") + msg = "blur_limit and sigma_limit minimum value can not be both equal to 0." + raise ValueError(msg) if (self.blur_limit[0] != 0 and self.blur_limit[0] % 2 != 1) or ( self.blur_limit[1] != 0 and self.blur_limit[1] % 2 != 1 ): - raise ValueError("UnsharpMask supports only odd blur limits.") + msg = "UnsharpMask supports only odd blur limits." + raise ValueError(msg) @staticmethod def __check_values( @@ -2431,6 +2531,7 @@ class PixelDropout(DualTransform): """Set pixels to 0 with some probability. Args: + ---- dropout_prob (float): pixel drop probability. Default: 0.01 per_channel (bool): if set to `True` drop mask will be sampled fo each channel, otherwise the same mask will be sampled for all channels. Default: False @@ -2449,6 +2550,7 @@ class PixelDropout(DualTransform): image, mask Image types: any + """ def __init__( @@ -2467,22 +2569,23 @@ def __init__( self.mask_drop_value = mask_drop_value if self.mask_drop_value is not None and self.per_channel: - raise ValueError("PixelDropout supports mask only with per_channel=False") + msg = "PixelDropout supports mask only with per_channel=False" + raise ValueError(msg) def apply( self, img: np.ndarray, - drop_mask: np.ndarray = np.array(None), + drop_mask: Optional[np.ndarray] = None, drop_value: Union[float, Sequence[float]] = (), **params: Any, ) -> np.ndarray: return F.pixel_dropout(img, drop_mask, drop_value) - def apply_to_mask(self, mask: np.ndarray, drop_mask: np.ndarray = np.array(None), **params: Any) -> np.ndarray: + def apply_to_mask(self, mask: np.ndarray, drop_mask: Optional[np.ndarray] = None, **params: Any) -> np.ndarray: if self.mask_drop_value is None: return mask - if mask.ndim == 2: + if mask.ndim == TWO: drop_mask = np.squeeze(drop_mask) return F.pixel_dropout(mask, drop_mask, self.mask_drop_value) @@ -2527,10 +2630,10 @@ def get_transform_init_args_names(self) -> Tuple[str, str, str, str]: class Spatter(ImageOnlyTransform): - """ - Apply spatter transform. It simulates corruption which can occlude a lens in the form of rain or mud. + """Apply spatter transform. It simulates corruption which can occlude a lens in the form of rain or mud. Args: + ---- mean (float, or tuple of floats): Mean value of normal distribution for generating liquid layer. If single float it will be used as mean. If tuple of float mean will be sampled from range `[mean[0], mean[1])`. Default: (0.65). @@ -2564,6 +2667,7 @@ class Spatter(ImageOnlyTransform): Reference: | https://arxiv.org/pdf/1903.12261.pdf | https://github.com/hendrycks/robustness/blob/master/ImageNet-C/create_c/make_imagenet_c.py + """ def __init__( @@ -2604,13 +2708,13 @@ def __init__( if isinstance(self.color, dict): if i not in self.color: raise ValueError(f"Wrong color definition: {self.color}. Color for mode: {i} not specified.") - if len(self.color[i]) != 3: + if len(self.color[i]) != THREE: raise ValueError( f"Unsupported color: {self.color[i]} for mode {i}. Color should be presented in RGB format." ) if isinstance(self.color, (list, tuple)): - if len(self.color) != 3: + if len(self.color) != THREE: raise ValueError(f"Unsupported color: {self.color}. Color should be presented in RGB format.") self.color = {self.mode[0]: self.color} diff --git a/albumentations/augmentations/utils.py b/albumentations/augmentations/utils.py index f5400a6ab..1ee4b83f9 100644 --- a/albumentations/augmentations/utils.py +++ b/albumentations/augmentations/utils.py @@ -35,6 +35,10 @@ np.dtype("uint16"): 65535, np.dtype("uint32"): 4294967295, np.dtype("float32"): 1.0, + np.uint8: 255, + np.uint16: 65535, + np.uint32: 4294967295, + np.float32: 1.0, } NPDTYPE_TO_OPENCV_DTYPE = { @@ -50,6 +54,10 @@ np.dtype("float64"): cv2.CV_64F, } +TWO = 2 +THREE = 3 +FOUR = 4 + def read_bgr_image(path: str) -> np.ndarray: return cv2.imread(path, cv2.IMREAD_COLOR) @@ -75,8 +83,7 @@ def clip(img: np.ndarray, dtype: np.dtype, maxval: float) -> np.ndarray: def get_opencv_dtype_from_numpy(value: Union[np.ndarray, int, np.dtype, object]) -> int: - """ - Return a corresponding OpenCV dtype for a numpy's dtype + """Return a corresponding OpenCV dtype for a numpy's dtype :param value: Input dtype of numpy array :return: Corresponding dtype for OpenCV """ @@ -86,7 +93,7 @@ def get_opencv_dtype_from_numpy(value: Union[np.ndarray, int, np.dtype, object]) def angle_2pi_range( - func: Callable[Concatenate[KeypointInternalType, P], KeypointInternalType] + func: Callable[Concatenate[KeypointInternalType, P], KeypointInternalType], ) -> Callable[Concatenate[KeypointInternalType, P], KeypointInternalType]: @wraps(func) def wrapped_function(keypoint: KeypointInternalType, *args: P.args, **kwargs: P.kwargs) -> KeypointInternalType: @@ -97,7 +104,7 @@ def wrapped_function(keypoint: KeypointInternalType, *args: P.args, **kwargs: P. def preserve_shape( - func: Callable[Concatenate[np.ndarray, P], np.ndarray] + func: Callable[Concatenate[np.ndarray, P], np.ndarray], ) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: """Preserve shape of the image""" @@ -105,14 +112,13 @@ def preserve_shape( def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray: shape = img.shape result = func(img, *args, **kwargs) - result = result.reshape(shape) - return result + return result.reshape(shape) return wrapped_function def preserve_channel_dim( - func: Callable[Concatenate[np.ndarray, P], np.ndarray] + func: Callable[Concatenate[np.ndarray, P], np.ndarray], ) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: """Preserve dummy channel dim.""" @@ -120,7 +126,7 @@ def preserve_channel_dim( def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray: shape = img.shape result = func(img, *args, **kwargs) - if len(shape) == 3 and shape[-1] == 1 and len(result.shape) == 2: + if len(shape) == THREE and shape[-1] == 1 and len(result.shape) == TWO: result = np.expand_dims(result, axis=-1) return result @@ -128,33 +134,32 @@ def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.n def ensure_contiguous( - func: Callable[Concatenate[np.ndarray, P], np.ndarray] + func: Callable[Concatenate[np.ndarray, P], np.ndarray], ) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: """Ensure that input img is contiguous.""" @wraps(func) def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray: img = np.require(img, requirements=["C_CONTIGUOUS"]) - result = func(img, *args, **kwargs) - return result + return func(img, *args, **kwargs) return wrapped_function def is_rgb_image(image: np.ndarray) -> bool: - return len(image.shape) == 3 and image.shape[-1] == 3 + return len(image.shape) == THREE and image.shape[-1] == THREE def is_grayscale_image(image: np.ndarray) -> bool: - return (len(image.shape) == 2) or (len(image.shape) == 3 and image.shape[-1] == 1) + return (len(image.shape) == TWO) or (len(image.shape) == THREE and image.shape[-1] == 1) def is_multispectral_image(image: np.ndarray) -> bool: - return len(image.shape) == 3 and image.shape[-1] not in [1, 3] + return len(image.shape) == THREE and image.shape[-1] not in [1, 3] def get_num_channels(image: np.ndarray) -> int: - return image.shape[2] if len(image.shape) == 3 else 1 + return image.shape[2] if len(image.shape) == THREE else 1 def non_rgb_warning(image: np.ndarray) -> None: @@ -171,17 +176,18 @@ def non_rgb_warning(image: np.ndarray) -> None: def _maybe_process_in_chunks( process_fn: Callable[Concatenate[np.ndarray, P], np.ndarray], **kwargs: Any ) -> Callable[[np.ndarray], np.ndarray]: - """ - Wrap OpenCV function to enable processing images with more than 4 channels. + """Wrap OpenCV function to enable processing images with more than 4 channels. Limitations: This wrapper requires image to be the first argument and rest must be sent via named arguments. Args: + ---- process_fn: Transform function (e.g cv2.resize). kwargs: Additional parameters. Returns: + ------- numpy.ndarray: Transformed image. """ @@ -189,10 +195,10 @@ def _maybe_process_in_chunks( @wraps(process_fn) def __process_fn(img: np.ndarray) -> np.ndarray: num_channels = get_num_channels(img) - if num_channels > 4: + if num_channels > FOUR: chunks = [] for index in range(0, num_channels, 4): - if num_channels - index == 2: + if num_channels - index == TWO: # Many OpenCV functions cannot work with 2-channel images for i in range(2): chunk = img[:, :, index + i : index + i + 1] @@ -203,9 +209,8 @@ def __process_fn(img: np.ndarray) -> np.ndarray: chunk = img[:, :, index : index + 4] chunk = process_fn(chunk, **kwargs) chunks.append(chunk) - img = np.dstack(chunks) - else: - img = process_fn(img, **kwargs) - return img + return np.dstack(chunks) + + return process_fn(img, **kwargs) return __process_fn diff --git a/albumentations/core/bbox_utils.py b/albumentations/core/bbox_utils.py index 2c639c7d9..27fa3dd7d 100644 --- a/albumentations/core/bbox_utils.py +++ b/albumentations/core/bbox_utils.py @@ -24,12 +24,14 @@ "BboxParams", ] +FIVE = 5 + class BboxParams(Params): - """ - Parameters of bounding boxes + """Parameters of bounding boxes Args: + ---- format (str): format of bounding boxes. Should be 'coco', 'pascal_voc', 'albumentations' or 'yolo'. The `coco` format @@ -54,6 +56,7 @@ class BboxParams(Params): less than this value will be removed. Default: 0.0. check_each_transform (bool): if `True`, then bboxes will be checked after each dual transform. Default: `True` + """ def __init__( @@ -73,8 +76,8 @@ def __init__( self.min_height = min_height self.check_each_transform = check_each_transform - def _to_dict(self) -> Dict[str, Any]: - data = super()._to_dict() + def to_dict_private(self) -> Dict[str, Any]: + data = super().to_dict_private() data.update( { "min_area": self.min_area, @@ -106,15 +109,15 @@ def default_data_name(self) -> str: def ensure_data_valid(self, data: Dict[str, Any]) -> None: for data_name in self.data_fields: data_exists = data_name in data and len(data[data_name]) - if data_exists and len(data[data_name][0]) < 5: - if self.params.label_fields is None: - raise ValueError( - "Please specify 'label_fields' in 'bbox_params' or add labels to the end of bbox " - "because bboxes must have labels" - ) - if self.params.label_fields: - if not all(i in data.keys() for i in self.params.label_fields): - raise ValueError("Your 'label_fields' are not valid - them must have same names as params in dict") + if data_exists and len(data[data_name][0]) < FIVE and self.params.label_fields is None: + msg = ( + "Please specify 'label_fields' in 'bbox_params' or add labels to the end of bbox " + "because bboxes must have labels" + ) + raise ValueError(msg) + if self.params.label_fields and not all(i in data for i in self.params.label_fields): + msg = "Your 'label_fields' are not valid - them must have same names as params in dict" + raise ValueError(msg) def filter(self, data: Sequence[BoxType], rows: int, cols: int) -> List[BoxType]: self.params: BboxParams @@ -143,22 +146,26 @@ def normalize_bbox(bbox: BoxType, rows: int, cols: int) -> BoxType: by image height. Args: + ---- bbox: Denormalized bounding box `(x_min, y_min, x_max, y_max)`. rows: Image height. cols: Image width. Returns: + ------- Normalized bounding box `(x_min, y_min, x_max, y_max)`. Raises: + ------ ValueError: If rows or cols is less or equal zero """ - if rows <= 0: - raise ValueError("Argument rows must be positive integer") + msg = "Argument rows must be positive integer" + raise ValueError(msg) if cols <= 0: - raise ValueError("Argument cols must be positive integer") + msg = "Argument cols must be positive integer" + raise ValueError(msg) tail: Tuple[Any, ...] (x_min, y_min, x_max, y_max), tail = bbox[:4], tuple(bbox[4:]) @@ -167,7 +174,7 @@ def normalize_bbox(bbox: BoxType, rows: int, cols: int) -> BoxType: y_min /= rows y_max /= rows - return cast(BoxType, (x_min, y_min, x_max, y_max) + tail) + return cast(BoxType, (x_min, y_min, x_max, y_max, *tail)) def denormalize_bbox(bbox: BoxType, rows: int, cols: int) -> BoxType: @@ -175,14 +182,17 @@ def denormalize_bbox(bbox: BoxType, rows: int, cols: int) -> BoxType: by image height. This is an inverse operation for :func:`~albumentations.augmentations.bbox.normalize_bbox`. Args: + ---- bbox: Normalized bounding box `(x_min, y_min, x_max, y_max)`. rows: Image height. cols: Image width. Returns: + ------- Denormalized bounding box `(x_min, y_min, x_max, y_max)`. Raises: + ------ ValueError: If rows or cols is less or equal zero """ @@ -190,25 +200,29 @@ def denormalize_bbox(bbox: BoxType, rows: int, cols: int) -> BoxType: (x_min, y_min, x_max, y_max), tail = bbox[:4], tuple(bbox[4:]) if rows <= 0: - raise ValueError("Argument rows must be positive integer") + msg = "Argument rows must be positive integer" + raise ValueError(msg) if cols <= 0: - raise ValueError("Argument cols must be positive integer") + msg = "Argument cols must be positive integer" + raise ValueError(msg) x_min, x_max = x_min * cols, x_max * cols y_min, y_max = y_min * rows, y_max * rows - return cast(BoxType, (x_min, y_min, x_max, y_max) + tail) + return cast(BoxType, (x_min, y_min, x_max, y_max, *tail)) def normalize_bboxes(bboxes: Sequence[BoxType], rows: int, cols: int) -> List[BoxType]: """Normalize a list of bounding boxes. Args: + ---- bboxes: Denormalized bounding boxes `[(x_min, y_min, x_max, y_max)]`. rows: Image height. cols: Image width. Returns: + ------- Normalized bounding boxes `[(x_min, y_min, x_max, y_max)]`. """ @@ -219,11 +233,13 @@ def denormalize_bboxes(bboxes: Sequence[BoxType], rows: int, cols: int) -> List[ """Denormalize a list of bounding boxes. Args: + ---- bboxes: Normalized bounding boxes `[(x_min, y_min, x_max, y_max)]`. rows: Image height. cols: Image width. Returns: + ------- List: Denormalized bounding boxes `[(x_min, y_min, x_max, y_max)]`. """ @@ -234,18 +250,19 @@ def calculate_bbox_area(bbox: BoxType, rows: int, cols: int) -> float: """Calculate the area of a bounding box in (fractional) pixels. Args: + ---- bbox: A bounding box `(x_min, y_min, x_max, y_max)`. rows: Image height. cols: Image width. Return: + ------ Area in (fractional) pixels of the (denormalized) bounding box. """ bbox = denormalize_bbox(bbox, rows, cols) x_min, y_min, x_max, y_max = bbox[:4] - area = (x_max - x_min) * (y_max - y_min) - return area + return (x_max - x_min) * (y_max - y_min) def filter_bboxes_by_visibility( @@ -260,6 +277,7 @@ def filter_bboxes_by_visibility( the threshold and minimal area of bounding box in pixels is more then min_area. Args: + ---- original_shape: Original image shape `(height, width, ...)`. bboxes: Original bounding boxes `[(x_min, y_min, x_max, y_max)]`. transformed_shape: Transformed image shape `(height, width)`. @@ -268,6 +286,7 @@ def filter_bboxes_by_visibility( min_area: Minimal area threshold. Returns: + ------- Filtered bounding boxes `[(x_min, y_min, x_max, y_max)]`. """ @@ -296,6 +315,7 @@ def convert_bbox_to_albumentations( `(x_min, y_min, x_max, y_max)` e.g. `(0.15, 0.27, 0.67, 0.5)`. Args: + ---- bbox: A bounding box tuple. source_format: format of the bounding box. Should be 'coco', 'pascal_voc', or 'yolo'. check_validity: Check if all boxes are valid boxes. @@ -303,15 +323,18 @@ def convert_bbox_to_albumentations( cols: Image width. Returns: + ------- tuple: A bounding box `(x_min, y_min, x_max, y_max)`. Note: + ---- The `coco` format of a bounding box looks like `(x_min, y_min, width, height)`, e.g. (97, 12, 150, 200). The `pascal_voc` format of a bounding box looks like `(x_min, y_min, x_max, y_max)`, e.g. (97, 12, 247, 212). The `yolo` format of a bounding box looks like `(x, y, width, height)`, e.g. (0.3, 0.1, 0.05, 0.07); where `x`, `y` coordinates of the center of the box, all values normalized to 1 by image height and width. Raises: + ------ ValueError: if `target_format` is not equal to `coco` or `pascal_voc`, or `yolo`. ValueError: If in YOLO format all labels not in range (0, 1). @@ -329,7 +352,8 @@ def convert_bbox_to_albumentations( # https://github.com/pjreddie/darknet/blob/f6d861736038da22c9eb0739dca84003c5a5e275/scripts/voc_label.py#L12 _bbox = np.array(bbox[:4]) if check_validity and np.any((_bbox <= 0) | (_bbox > 1)): - raise ValueError("In YOLO format all coordinates must be float and in range (0, 1]") + msg = "In YOLO format all coordinates must be float and in range (0, 1]" + raise ValueError(msg) (x, y, w, h), tail = bbox[:4], bbox[4:] @@ -341,7 +365,7 @@ def convert_bbox_to_albumentations( else: (x_min, y_min, x_max, y_max), tail = bbox[:4], bbox[4:] - bbox = (x_min, y_min, x_max, y_max) + tuple(tail) + bbox = (x_min, y_min, x_max, y_max, *tuple(tail)) if source_format != "yolo": bbox = normalize_bbox(bbox, rows, cols) @@ -356,6 +380,7 @@ def convert_bbox_from_albumentations( """Convert a bounding box from the format used by albumentations to a format, specified in `target_format`. Args: + ---- bbox: An albumentations bounding box `(x_min, y_min, x_max, y_max)`. target_format: required format of the output bounding box. Should be 'coco', 'pascal_voc' or 'yolo'. rows: Image height. @@ -363,14 +388,17 @@ def convert_bbox_from_albumentations( check_validity: Check if all boxes are valid boxes. Returns: + ------- tuple: A bounding box. Note: + ---- The `coco` format of a bounding box looks like `[x_min, y_min, width, height]`, e.g. [97, 12, 150, 200]. The `pascal_voc` format of a bounding box looks like `[x_min, y_min, x_max, y_max]`, e.g. [97, 12, 247, 212]. The `yolo` format of a bounding box looks like `[x, y, width, height]`, e.g. [0.3, 0.1, 0.05, 0.07]. Raises: + ------ ValueError: if `target_format` is not equal to `coco`, `pascal_voc` or `yolo`. """ @@ -387,14 +415,14 @@ def convert_bbox_from_albumentations( (x_min, y_min, x_max, y_max), tail = bbox[:4], tuple(bbox[4:]) width = x_max - x_min height = y_max - y_min - bbox = cast(BoxType, (x_min, y_min, width, height) + tail) + bbox = cast(BoxType, (x_min, y_min, width, height, *tail)) elif target_format == "yolo": (x_min, y_min, x_max, y_max), tail = bbox[:4], bbox[4:] x = (x_min + x_max) / 2.0 y = (y_min + y_max) / 2.0 w = x_max - x_min h = y_max - y_min - bbox = cast(BoxType, (x, y, w, h) + tail) + bbox = cast(BoxType, (x, y, w, h, *tail)) return bbox @@ -412,13 +440,15 @@ def convert_bboxes_from_albumentations( in `target_format`. Args: - bboxes: List of albumentation bounding box `(x_min, y_min, x_max, y_max)`. + ---- + bboxes: List of albumentations bounding box `(x_min, y_min, x_max, y_max)`. target_format: required format of the output bounding box. Should be 'coco', 'pascal_voc' or 'yolo'. rows: Image height. cols: Image width. check_validity: Check if all boxes are valid boxes. Returns: + ------- List of bounding boxes. """ @@ -456,6 +486,7 @@ def filter_bboxes( or whose area in pixels is under the threshold set by `min_area`. Also it crops boxes to final image size. Args: + ---- bboxes: List of albumentation bounding box `(x_min, y_min, x_max, y_max)`. rows: Image height. cols: Image width. @@ -468,11 +499,13 @@ def filter_bboxes( less than this value will be removed. Default: 0.0. Returns: + ------- List of bounding boxes. """ resulting_boxes: List[BoxType] = [] - for bbox in bboxes: + for i in range(len(bboxes)): + bbox = bboxes[i] # Calculate areas of bounding box before and after clipping. transformed_box_area = calculate_bbox_area(bbox, rows, cols) bbox, tail = cast(BoxType, tuple(np.clip(bbox[:4], 0, 1.0))), tuple(bbox[4:]) @@ -497,6 +530,7 @@ def union_of_bboxes(height: int, width: int, bboxes: Sequence[BoxType], erosion_ """Calculate union of bounding boxes. Args: + ---- height (float): Height of image or space. width (float): Width of image or space. bboxes (List[tuple]): List like bounding boxes. Format is `[(x_min, y_min, x_max, y_max)]`. @@ -504,6 +538,7 @@ def union_of_bboxes(height: int, width: int, bboxes: Sequence[BoxType], erosion_ Set this in range [0, 1]. 0 will not be erosive at all, 1.0 can make any bbox to lose its volume. Returns: + ------- tuple: A bounding box `(x_min, y_min, x_max, y_max)`. """ diff --git a/albumentations/core/composition.py b/albumentations/core/composition.py index 9d7183ca7..82c02221a 100644 --- a/albumentations/core/composition.py +++ b/albumentations/core/composition.py @@ -5,7 +5,8 @@ import numpy as np -from .. import random_utils +from albumentations import random_utils + from .bbox_utils import BboxParams, BboxProcessor from .keypoints_utils import KeypointParams, KeypointsProcessor from .serialization import ( @@ -31,6 +32,8 @@ "TransformsSeqType", ] +TWO = 2 + REPR_INDENT_STEP = 2 TransformType = Union[BasicTransform, "BaseCompose"] @@ -77,14 +80,11 @@ def __repr__(self) -> str: return self.indented_repr() def indented_repr(self, indent: int = REPR_INDENT_STEP) -> str: - args = {k: v for k, v in self._to_dict().items() if not (k.startswith("__") or k == "transforms")} + args = {k: v for k, v in self.to_dict_private().items() if not (k.startswith("__") or k == "transforms")} repr_string = self.__class__.__name__ + "([" for t in self.transforms: repr_string += "\n" - if hasattr(t, "indented_repr"): - t_repr = t.indented_repr(indent + REPR_INDENT_STEP) - else: - t_repr = repr(t) + t_repr = t.indented_repr(indent + REPR_INDENT_STEP) if hasattr(t, "indented_repr") else repr(t) repr_string += " " * indent + t_repr + "," repr_string += "\n" + " " * (indent - REPR_INDENT_STEP) + f"], {format_args(args)})" return repr_string @@ -97,11 +97,11 @@ def get_class_fullname(cls) -> str: def is_serializable(cls) -> bool: return True - def _to_dict(self) -> Dict[str, Any]: + def to_dict_private(self) -> Dict[str, Any]: return { "__class_fullname__": self.get_class_fullname(), "p": self.p, - "transforms": [t._to_dict() for t in self.transforms], + "transforms": [t.to_dict_private() for t in self.transforms], } def get_dict_with_id(self) -> Dict[str, Any]: @@ -126,6 +126,7 @@ class Compose(BaseCompose): """Compose transforms and handle all transformations regarding bounding boxes Args: + ---- transforms (list): list of transformations to compose. bbox_params (BboxParams): Parameters for bounding boxes transforms keypoint_params (KeypointParams): Parameters for keypoints transforms @@ -133,6 +134,7 @@ class Compose(BaseCompose): p (float): probability of applying all list of transforms. Default: 1.0. is_check_shapes (bool): If True shapes consistency of images/mask/masks would be checked on each call. If you would like to disable this check - pass False (do it only if you are sure in your data consistency). + """ def __init__( @@ -153,7 +155,8 @@ def __init__( elif isinstance(bbox_params, BboxParams): b_params = bbox_params else: - raise ValueError("unknown format of bbox_params, please use `dict` or `BboxParams`") + msg = "unknown format of bbox_params, please use `dict` or `BboxParams`" + raise ValueError(msg) self.processors["bboxes"] = BboxProcessor(b_params, additional_targets) if keypoint_params: @@ -162,7 +165,8 @@ def __init__( elif isinstance(keypoint_params, KeypointParams): k_params = keypoint_params else: - raise ValueError("unknown format of keypoint_params, please use `dict` or `KeypointParams`") + msg = "unknown format of keypoint_params, please use `dict` or `KeypointParams`" + raise ValueError(msg) self.processors["keypoints"] = KeypointsProcessor(k_params, additional_targets) if additional_targets is None: @@ -186,17 +190,22 @@ def _disable_check_args_for_transforms(transforms: TransformsSeqType) -> None: if isinstance(transform, BaseCompose): Compose._disable_check_args_for_transforms(transform.transforms) if isinstance(transform, Compose): - transform._disable_check_args() + transform.disable_check_args_private() - def _disable_check_args(self) -> None: + def disable_check_args_private(self) -> None: self.is_check_args = False def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> Dict[str, Any]: if args: - raise KeyError("You have to pass data to augmentations as named arguments, for example: aug(image=image)") + msg = "You have to pass data to augmentations as named arguments, for example: aug(image=image)" + raise KeyError(msg) if self.is_check_args: self._check_args(**data) - assert isinstance(force_apply, (bool, int)), "force_apply must have bool or int type" + + if not isinstance(force_apply, (bool, int)): + msg = "force_apply must have bool or int type" + raise TypeError(msg) + need_to_run = force_apply or random.random() < self.p for p in self.processors.values(): p.ensure_data_valid(data) @@ -232,14 +241,14 @@ def _check_data_post_transform(self, data: Any) -> Dict[str, Any]: data[data_name] = p.filter(data[data_name], rows, cols) return data - def _to_dict(self) -> Dict[str, Any]: - dictionary = super()._to_dict() + def to_dict_private(self) -> Dict[str, Any]: + dictionary = super().to_dict_private() bbox_processor = self.processors.get("bboxes") keypoints_processor = self.processors.get("keypoints") dictionary.update( { - "bbox_params": bbox_processor.params._to_dict() if bbox_processor else None, - "keypoint_params": (keypoints_processor.params._to_dict() if keypoints_processor else None), + "bbox_params": bbox_processor.params.to_dict_private() if bbox_processor else None, + "keypoint_params": (keypoints_processor.params.to_dict_private() if keypoints_processor else None), "additional_targets": self.additional_targets, "is_check_shapes": self.is_check_shapes, } @@ -252,8 +261,8 @@ def get_dict_with_id(self) -> Dict[str, Any]: keypoints_processor = self.processors.get("keypoints") dictionary.update( { - "bbox_params": bbox_processor.params._to_dict() if bbox_processor else None, - "keypoint_params": (keypoints_processor.params._to_dict() if keypoints_processor else None), + "bbox_params": bbox_processor.params.to_dict_private() if bbox_processor else None, + "keypoint_params": (keypoints_processor.params.to_dict_private() if keypoints_processor else None), "additional_targets": self.additional_targets, "params": None, "is_check_shapes": self.is_check_shapes, @@ -272,28 +281,31 @@ def _check_args(self, **kwargs: Any) -> None: if not isinstance(data, np.ndarray): raise TypeError(f"{data_name} must be numpy array type") shapes.append(data.shape[:2]) - if internal_data_name in checked_multi: - if data is not None and len(data): - if not isinstance(data[0], np.ndarray): - raise TypeError(f"{data_name} must be list of numpy arrays") - shapes.append(data[0].shape[:2]) + if internal_data_name in checked_multi and data is not None and len(data): + if not isinstance(data[0], np.ndarray): + raise TypeError(f"{data_name} must be list of numpy arrays") + shapes.append(data[0].shape[:2]) if internal_data_name in check_bbox_param and self.processors.get("bboxes") is None: - raise ValueError("bbox_params must be specified for bbox transformations") + msg = "bbox_params must be specified for bbox transformations" + raise ValueError(msg) if self.is_check_shapes and shapes and shapes.count(shapes[0]) != len(shapes): - raise ValueError( + msg = ( "Height and Width of image, mask or masks should be equal. You can disable shapes check " "by setting a parameter is_check_shapes=False of Compose class (do it only if you are sure " "about your data consistency)." ) + raise ValueError(msg) @staticmethod def _make_targets_contiguous(data: Any) -> Dict[str, Any]: result = {} for key, value in data.items(): if isinstance(value, np.ndarray): - value = np.ascontiguousarray(value) - result[key] = value + result[key] = np.ascontiguousarray(value) + else: + result[key] = value + return result @@ -302,8 +314,10 @@ class OneOf(BaseCompose): Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights. Args: + ---- transforms (list): list of transformations to compose. p (float): probability of applying selected transform. Default: 0.5. + """ def __init__(self, transforms: TransformsSeqType, p: float = 0.5): @@ -330,10 +344,12 @@ class SomeOf(BaseCompose): Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights. Args: + ---- transforms (list): list of transformations to compose. n (int): number of transforms to apply. replace (bool): Whether the sampled transforms are with or without replacement. Default: True. p (float): probability of applying selected transform. Default: 1. + """ def __init__(self, transforms: TransformsSeqType, n: int, replace: bool = True, p: float = 1): @@ -357,8 +373,8 @@ def __call__(self, *arg: Any, force_apply: bool = False, **data: Any) -> Dict[st data = t(force_apply=True, **data) return data - def _to_dict(self) -> Dict[str, Any]: - dictionary = super()._to_dict() + def to_dict_private(self) -> Dict[str, Any]: + dictionary = super().to_dict_private() dictionary.update({"n": self.n, "replace": self.replace}) return dictionary @@ -375,10 +391,11 @@ def __init__( ): if transforms is None: if first is None or second is None: - raise ValueError("You must set both first and second or set transforms argument.") + msg = "You must set both first and second or set transforms argument." + raise ValueError(msg) transforms = [first, second] super().__init__(transforms, p) - if len(self.transforms) != 2: + if len(self.transforms) != TWO: warnings.warn("Length of transforms is not equal to 2.") def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> Dict[str, Any]: @@ -397,10 +414,12 @@ class PerChannel(BaseCompose): """Apply transformations per-channel Args: + ---- transforms (list): list of transformations to compose. channels (sequence): channels to apply the transform to. Pass None to apply to all. Default: None (apply to all) p (float): probability of applying the transform. Default: 0.5. + """ def __init__(self, transforms: TransformsSeqType, channels: Optional[Sequence[int]] = None, p: float = 0.5): @@ -412,7 +431,7 @@ def __call__(self, *args: Any, force_apply: bool = False, **data: Any) -> Dict[s image = data["image"] # Expand mono images to have a single channel - if len(image.shape) == 2: + if len(image.shape) == TWO: image = np.expand_dims(image, -1) if self.channels is None: @@ -460,13 +479,14 @@ def replay(saved_augmentations: Dict[str, Any], **kwargs: Any) -> Dict[str, Any] def _restore_for_replay( transform_dict: Dict[str, Any], lambda_transforms: Optional[Dict[str, Any]] = None ) -> TransformType: - """ - Args: + """Args: + ---- lambda_transforms (dict): A dictionary that contains lambda transforms, that is instances of the Lambda class. This dictionary is required when you are restoring a pipeline that contains lambda transforms. Keys in that dictionary should be named same as `name` arguments in respective lambda transforms from a serialized pipeline. + """ applied = transform_dict["applied"] params = transform_dict["params"] @@ -506,8 +526,8 @@ def fill_applied(self, serialized: Dict[str, Any]) -> bool: serialized["applied"] = serialized.get("params") is not None return serialized["applied"] - def _to_dict(self) -> Dict[str, Any]: - dictionary = super()._to_dict() + def to_dict_private(self) -> Dict[str, Any]: + dictionary = super().to_dict_private() dictionary.update({"save_key": self.save_key}) return dictionary @@ -516,12 +536,14 @@ class Sequential(BaseCompose): """Sequentially applies all transforms to targets. Note: + ---- This transform is not intended to be a replacement for `Compose`. Instead, it should be used inside `Compose` the same way `OneOf` or `OneOrOther` are used. For instance, you can combine `OneOf` with `Sequential` to create an augmentation pipeline that contains multiple sequences of augmentations and applies one randomly chose sequence to input data (see the `Example` section for an example definition of such pipeline). Example: + ------- >>> import albumentations as A >>> transform = A.Compose([ >>> A.OneOf([ @@ -535,6 +557,7 @@ class Sequential(BaseCompose): >>> ]), >>> ], p=1) >>> ]) + """ def __init__(self, transforms: TransformsSeqType, p: float = 0.5): diff --git a/albumentations/core/keypoints_utils.py b/albumentations/core/keypoints_utils.py index a860d7609..41f55a69e 100644 --- a/albumentations/core/keypoints_utils.py +++ b/albumentations/core/keypoints_utils.py @@ -1,5 +1,4 @@ import math -import warnings from typing import Any, Dict, List, Optional, Sequence from .types import KeypointType @@ -24,10 +23,10 @@ def angle_to_2pi_range(angle: float) -> float: class KeypointParams(Params): - """ - Parameters of keypoints + """Parameters of keypoints Args: + ---- format (str): format of keypoints. Should be 'xy', 'yx', 'xya', 'xys', 'xyas', 'xysa'. x - X coordinate, @@ -43,6 +42,7 @@ class KeypointParams(Params): angle_in_degrees (bool): angle in degrees or radians in 'xya', 'xyas', 'xysa' keypoints check_each_transform (bool): if `True`, then keypoints will be checked after each dual transform. Default: `True` + """ def __init__( @@ -58,8 +58,8 @@ def __init__( self.angle_in_degrees = angle_in_degrees self.check_each_transform = check_each_transform - def _to_dict(self) -> Dict[str, Any]: - data = super()._to_dict() + def to_dict_private(self) -> Dict[str, Any]: + data = super().to_dict_private() data.update( { "remove_invisible": self.remove_invisible, @@ -87,16 +87,12 @@ def default_data_name(self) -> str: return "keypoints" def ensure_data_valid(self, data: Dict[str, Any]) -> None: - if self.params.label_fields: - if not all(i in data.keys() for i in self.params.label_fields): - raise ValueError( - "Your 'label_fields' are not valid - them must have same names as params in " - "'keypoint_params' dict" - ) + if self.params.label_fields and not all(i in data for i in self.params.label_fields): + msg = "Your 'label_fields' are not valid - them must have same names as params in " "'keypoint_params' dict" + raise ValueError(msg) def filter(self, data: Sequence[KeypointType], rows: int, cols: int) -> Sequence[KeypointType]: - """ - The function filters a sequence of data based on the number of rows and columns, and returns a + """The function filters a sequence of data based on the number of rows and columns, and returns a sequence of keypoints. :param data: The `data` parameter is a sequence of sequences. Each inner sequence represents a @@ -143,10 +139,7 @@ def check_keypoint(kp: KeypointType, rows: int, cols: int) -> None: """Check if keypoint coordinates are less than image shapes""" for name, value, size in zip(["x", "y"], kp[:2], [cols, rows]): if not 0 <= value < size: - raise ValueError( - "Expected {name} for keypoint {kp} " - "to be in the range [0.0, {size}], got {value}.".format(kp=kp, name=name, value=value, size=size) - ) + raise ValueError(f"Expected {name} for keypoint {kp} " f"to be in the range [0.0, {size}], got {value}.") angle = kp[2] if not (0 <= angle < 2 * math.pi): @@ -209,7 +202,7 @@ def convert_keypoint_to_albumentations( if angle_in_degrees: a = math.radians(a) - keypoint = (x, y, angle_to_2pi_range(a), s) + tail + keypoint = (x, y, angle_to_2pi_range(a), s, *tail) if check_validity: check_keypoint(keypoint, rows, cols) return keypoint @@ -234,17 +227,17 @@ def convert_keypoint_from_albumentations( angle = math.degrees(angle) if target_format == "xy": - return (x, y) + tail + return (x, y, *tail) if target_format == "yx": - return (y, x) + tail + return (y, x, *tail) if target_format == "xya": - return (x, y, angle) + tail + return (x, y, angle, *tail) if target_format == "xys": - return (x, y, scale) + tail + return (x, y, scale, *tail) if target_format == "xyas": - return (x, y, angle, scale) + tail + return (x, y, angle, scale, *tail) if target_format == "xysa": - return (x, y, scale, angle) + tail + return (x, y, scale, angle, *tail) raise ValueError(f"Invalid target format. Got: {target_format}") diff --git a/albumentations/core/serialization.py b/albumentations/core/serialization.py index 95649a81e..bc69695d6 100644 --- a/albumentations/core/serialization.py +++ b/albumentations/core/serialization.py @@ -1,8 +1,9 @@ +import importlib.util import json import warnings from abc import ABC, ABCMeta, abstractmethod from pathlib import Path -from typing import Any, Dict, Optional, TextIO, Tuple, Type, Union, cast +from typing import Any, Dict, Optional, TextIO, Tuple, Type, Union try: import yaml @@ -32,13 +33,12 @@ def shorten_class_name(class_fullname: str) -> str: class SerializableMeta(ABCMeta): - """ - A metaclass that is used to register classes in `SERIALIZABLE_REGISTRY` or `NON_SERIALIZABLE_REGISTRY` + """A metaclass that is used to register classes in `SERIALIZABLE_REGISTRY` or `NON_SERIALIZABLE_REGISTRY` so they can be found later while deserializing transformation pipeline using classes full names. """ - def __new__(mcs, name: str, bases: Tuple[type, ...], *args: Any, **kwargs: Any) -> "SerializableMeta": - cls_obj = super().__new__(mcs, name, bases, *args, **kwargs) + def __new__(cls, name: str, bases: Tuple[type, ...], *args: Any, **kwargs: Any) -> "SerializableMeta": + cls_obj = super().__new__(cls, name, bases, *args, **kwargs) if name != "Serializable" and ABC not in bases: if cls_obj.is_serializable(): SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj @@ -47,15 +47,15 @@ def __new__(mcs, name: str, bases: Tuple[type, ...], *args: Any, **kwargs: Any) return cls_obj @classmethod - def is_serializable(mcs) -> bool: + def is_serializable(cls) -> bool: return False @classmethod - def get_class_fullname(mcs) -> str: - return get_shortest_class_fullname(mcs) + def get_class_fullname(cls) -> str: + return get_shortest_class_fullname(cls) @classmethod - def _to_dict(mcs) -> Dict[str, Any]: + def _to_dict(cls) -> Dict[str, Any]: return {} @@ -71,32 +71,31 @@ def get_class_fullname(cls) -> str: raise NotImplementedError @abstractmethod - def _to_dict(self) -> Dict[str, Any]: + def to_dict_private(self) -> Dict[str, Any]: raise NotImplementedError def to_dict(self, on_not_implemented_error: str = "raise") -> Dict[str, Any]: - """ - Take a transform pipeline and convert it to a serializable representation that uses only standard + """Take a transform pipeline and convert it to a serializable representation that uses only standard python data types: dictionaries, lists, strings, integers, and floats. Args: + ---- self: A transform that should be serialized. If the transform doesn't implement the `to_dict` method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised. If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored but no transform parameters will be serialized. on_not_implemented_error (str): `raise` or `warn`. + """ if on_not_implemented_error not in {"raise", "warn"}: - raise ValueError( - "Unknown on_not_implemented_error value: {}. Supported values are: 'raise' and 'warn'".format( - on_not_implemented_error - ) - ) + msg = f"Unknown on_not_implemented_error value: {on_not_implemented_error}. Supported values are: 'raise' " + "and 'warn'" + raise ValueError(msg) try: - transform_dict = self._to_dict() - except NotImplementedError as e: + transform_dict = self.to_dict_private() + except NotImplementedError: if on_not_implemented_error == "raise": - raise e + raise transform_dict = {} warnings.warn( @@ -108,16 +107,17 @@ def to_dict(self, on_not_implemented_error: str = "raise") -> Dict[str, Any]: def to_dict(transform: Serializable, on_not_implemented_error: str = "raise") -> Dict[str, Any]: - """ - Take a transform pipeline and convert it to a serializable representation that uses only standard + """Take a transform pipeline and convert it to a serializable representation that uses only standard python data types: dictionaries, lists, strings, integers, and floats. Args: + ---- transform: A transform that should be serialized. If the transform doesn't implement the `to_dict` method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised. If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored but no transform parameters will be serialized. on_not_implemented_error (str): `raise` or `warn`. + """ return transform.to_dict(on_not_implemented_error) @@ -128,10 +128,9 @@ def instantiate_nonserializable( if transform.get("__class_fullname__") in NON_SERIALIZABLE_REGISTRY: name = transform["__name__"] if nonserializable is None: - raise ValueError( - "To deserialize a non-serializable transform with name {name} you need to pass a dict with" - "this transform as the `lambda_transforms` argument".format(name=name) - ) + msg = f"To deserialize a non-serializable transform with name {name} you need to pass a dict with" + "this transform as the `lambda_transforms` argument" + raise ValueError(msg) result_transform = nonserializable.get(name) if transform is None: raise ValueError(f"Non-serializable transform with {name} was not found in `nonserializable`") @@ -142,13 +141,14 @@ def instantiate_nonserializable( def from_dict( transform_dict: Dict[str, Any], nonserializable: Optional[Dict[str, Any]] = None ) -> Optional[Serializable]: - """ - Args: + """Args: + ---- transform_dict: A dictionary with serialized transform pipeline. nonserializable (dict): A dictionary that contains non-serializable transforms. This dictionary is required when you are restoring a pipeline that contains non-serializable transforms. Keys in that dictionary should be named same as `name` arguments in respective transforms from a serialized pipeline. + """ register_additional_transforms() transform = transform_dict["transform"] @@ -174,11 +174,11 @@ def save( data_format: str = "json", on_not_implemented_error: str = "raise", ) -> None: - """ - Serialize a transform pipeline and save it to either a file specified by a path or a file-like object + """Serialize a transform pipeline and save it to either a file specified by a path or a file-like object in either JSON or YAML format. Args: + ---- transform (Serializable): The transform pipeline to serialize. filepath_or_buffer (Union[str, Path, TextIO]): The file path or file-like object to write the serialized data to. @@ -191,7 +191,9 @@ def save( no transform arguments are saved. Defaults to 'raise'. Raises: + ------ ValueError: If `data_format` is 'yaml' but PyYAML is not installed. + """ check_data_format(data_format) transform_dict = transform.to_dict(on_not_implemented_error=on_not_implemented_error) @@ -201,17 +203,18 @@ def save( with open(filepath_or_buffer, "w") as f: if data_format == "yaml": if not yaml_available: - raise ValueError("You need to install PyYAML to save a pipeline in YAML format") + msg = "You need to install PyYAML to save a pipeline in YAML format" + raise ValueError(msg) yaml.safe_dump(transform_dict, f, default_flow_style=False) elif data_format == "json": json.dump(transform_dict, f) - else: # Assume it's a file-like object - if data_format == "yaml": - if not yaml_available: - raise ValueError("You need to install PyYAML to save a pipeline in YAML format") - yaml.safe_dump(transform_dict, filepath_or_buffer, default_flow_style=False) - elif data_format == "json": - json.dump(transform_dict, filepath_or_buffer) + elif data_format == "yaml": + if not yaml_available: + msg = "You need to install PyYAML to save a pipeline in YAML format" + raise ValueError(msg) + yaml.safe_dump(transform_dict, filepath_or_buffer, default_flow_style=False) + elif data_format == "json": + json.dump(transform_dict, filepath_or_buffer) def load( @@ -219,10 +222,10 @@ def load( data_format: str = "json", nonserializable: Optional[Dict[str, Any]] = None, ) -> object: - """ - Load a serialized pipeline from a file or file-like object and construct a transform pipeline. + """Load a serialized pipeline from a file or file-like object and construct a transform pipeline. Args: + ---- filepath_or_buffer (Union[str, Path, TextIO]): The file path or file-like object to read the serialized data from. If a string is provided, it is interpreted as a path to a file. If a file-like object is provided, @@ -235,10 +238,13 @@ def load( from the serialized pipeline. Defaults to None. Returns: + ------- object: The deserialized transform pipeline. Raises: + ------ ValueError: If `data_format` is 'yaml' but PyYAML is not installed. + """ check_data_format(data_format) @@ -248,33 +254,37 @@ def load( transform_dict = json.load(f) else: if not yaml_available: - raise ValueError("You need to install PyYAML to load a pipeline in yaml format") + msg = "You need to install PyYAML to load a pipeline in yaml format" + raise ValueError(msg) transform_dict = yaml.safe_load(f) - else: # Assume it's a file-like object - if data_format == "json": - transform_dict = json.load(filepath_or_buffer) - else: - if not yaml_available: - raise ValueError("You need to install PyYAML to load a pipeline in yaml format") - transform_dict = yaml.safe_load(filepath_or_buffer) + elif data_format == "json": + transform_dict = json.load(filepath_or_buffer) + else: + if not yaml_available: + msg = "You need to install PyYAML to load a pipeline in yaml format" + raise ValueError(msg) + transform_dict = yaml.safe_load(filepath_or_buffer) return from_dict(transform_dict, nonserializable=nonserializable) def register_additional_transforms() -> None: + """Register transforms that are not imported directly into the `albumentations` module by checking + the availability of optional dependencies. """ - Register transforms that are not imported directly into the `albumentations` module. - """ - try: - # This import will result in ImportError if `torch` is not installed - import albumentations.pytorch - except ImportError: - pass + if importlib.util.find_spec("torch") is not None: + try: + # Import `albumentations.pytorch` only if `torch` is installed. + import albumentations.pytorch + + # Use a dummy operation to acknowledge the use of the imported module and avoid linting errors. + _ = albumentations.pytorch.ToTensorV2 + except ImportError: + pass def get_shortest_class_fullname(cls: Type[Any]) -> str: - """ - The function `get_shortest_class_fullname` takes a class object as input and returns its shortened + """The function `get_shortest_class_fullname` takes a class object as input and returns its shortened full name. :param cls: The parameter `cls` is of type `Type[BasicCompose]`, which means it expects a class that @@ -282,5 +292,5 @@ def get_shortest_class_fullname(cls: Type[Any]) -> str: :type cls: Type[BasicCompose] :return: a string, which is the shortened version of the full class name. """ - class_fullname = "{cls.__module__}.{cls.__name__}".format(cls=cls) + class_fullname = f"{cls.__module__}.{cls.__name__}" return shorten_class_name(class_fullname) diff --git a/albumentations/core/transforms_interface.py b/albumentations/core/transforms_interface.py index 906170eb1..5a3d63f9c 100644 --- a/albumentations/core/transforms_interface.py +++ b/albumentations/core/transforms_interface.py @@ -23,29 +23,35 @@ FillValueType = Optional[Union[int, float, Sequence[int], Sequence[float]]] +TWO = 2 +THREE = 3 + def to_tuple( param: ScaleType, low: Optional[ScaleType] = None, bias: Optional[ScalarType] = None, ) -> Union[Tuple[int, int], Tuple[float, float]]: - """ - Convert input argument to a min-max tuple. + """Convert input argument to a min-max tuple. Args: + ---- param: Input value which could be a scalar or a sequence of exactly 2 scalars. low: Second element of the tuple, provided as an optional argument for when `param` is a scalar. bias: An offset added to both elements of the tuple. Returns: + ------- A tuple of two scalars, optionally adjusted by `bias`. Raises ValueError for invalid combinations or types of arguments. + """ # Validate mutually exclusive arguments if low is not None and bias is not None: - raise ValueError("Arguments 'low' and 'bias' cannot be used together.") + msg = "Arguments 'low' and 'bias' cannot be used together." + raise ValueError(msg) - if isinstance(param, Sequence) and len(param) == 2: + if isinstance(param, Sequence) and len(param) == TWO: min_val, max_val = min(param), max(param) # Handle scalar input @@ -57,7 +63,8 @@ def to_tuple( # Create a symmetric tuple around 0 min_val, max_val = -param, param else: - raise ValueError("Argument 'param' must be either a scalar or a sequence of 2 elements.") + msg = "Argument 'param' must be either a scalar or a sequence of 2 elements." + raise ValueError(msg) # Apply bias if provided if bias is not None: @@ -92,7 +99,8 @@ def __init__(self, always_apply: bool = False, p: float = 0.5): def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> Any: if args: - raise KeyError("You have to pass data to augmentations as named arguments, for example: aug(image=image)") + msg = "You have to pass data to augmentations as named arguments, for example: aug(image=image)" + raise KeyError(msg) if self.replay_mode: if self.applied_in_replay: return self.apply_with_params(self.params, **kwargs) @@ -103,9 +111,10 @@ def __call__(self, *args: Any, force_apply: bool = False, **kwargs: Any) -> Any: params = self.get_params() if self.targets_as_params: - assert all(key in kwargs for key in self.targets_as_params), "{} requires {}".format( - self.__class__.__name__, self.targets_as_params - ) + if not all(key in kwargs for key in self.targets_as_params): + msg = f"{self.__class__.__name__} requires {self.targets_as_params}" + raise ValueError(msg) + targets_as_params = {k: kwargs[k] for k in self.targets_as_params} params_dependent_on_targets = self.get_params_dependent_on_targets(targets_as_params) params.update(params_dependent_on_targets) @@ -135,7 +144,10 @@ def apply_with_params(self, params: Dict[str, Any], *args: Any, **kwargs: Any) - return res def set_deterministic(self, flag: bool, save_key: str = "replay") -> "BasicTransform": - assert save_key != "params", "params save_key is reserved" + if save_key == "params": + msg = "params save_key is reserved" + raise KeyError(msg) + self.deterministic = flag self.save_key = save_key return self @@ -161,8 +173,9 @@ def get_params(self) -> Dict[str, Any]: @property def targets(self) -> Dict[str, Callable[..., Any]]: # you must specify targets in subclass - # for example: ('image', 'mask') - # ('image', 'boxes') + # foe example: + # >> ('image', 'mask') + # >> ('image', 'boxes') raise NotImplementedError def update_params(self, params: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: @@ -186,7 +199,9 @@ def add_targets(self, additional_targets: Dict[str, str]) -> None: by the way you must have at least one object with key 'image' Args: + ---- additional_targets (dict): keys - new target name, values - old target name. ex: {'image2': 'image'} + """ self._additional_targets = additional_targets @@ -208,10 +223,9 @@ def is_serializable(cls) -> bool: return True def get_transform_init_args_names(self) -> Tuple[str, ...]: - raise NotImplementedError( - "Class {name} is not serializable because the `get_transform_init_args_names` method is not " - "implemented".format(name=self.get_class_fullname()) - ) + msg = f"Class {self.get_class_fullname()} is not serializable because the `get_transform_init_args_names` " + "method is not implemented" + raise NotImplementedError(msg) def get_base_init_args(self) -> Dict[str, Any]: return {"always_apply": self.always_apply, "p": self.p} @@ -219,14 +233,14 @@ def get_base_init_args(self) -> Dict[str, Any]: def get_transform_init_args(self) -> Dict[str, Any]: return {k: getattr(self, k) for k in self.get_transform_init_args_names()} - def _to_dict(self) -> Dict[str, Any]: + def to_dict_private(self) -> Dict[str, Any]: state = {"__class_fullname__": self.get_class_fullname()} state.update(self.get_base_init_args()) state.update(self.get_transform_init_args()) return state def get_dict_with_id(self) -> Dict[str, Any]: - d = self._to_dict() + d = self.to_dict_private() d["id"] = id(self) return d diff --git a/albumentations/core/utils.py b/albumentations/core/utils.py index be9a4fd62..ebb4b7bcd 100644 --- a/albumentations/core/utils.py +++ b/albumentations/core/utils.py @@ -32,8 +32,8 @@ def format_args(args_dict: Dict[str, Any]) -> str: formatted_args = [] for k, v in args_dict.items(): if isinstance(v, str): - v = f"'{v}'" - formatted_args.append(f"{k}={v}") + v_formatted = f"'{v}'" + formatted_args.append(f"{k}={v_formatted}") return ", ".join(formatted_args) @@ -42,7 +42,7 @@ def __init__(self, format: str, label_fields: Optional[Sequence[str]] = None): self.format = format self.label_fields = label_fields - def _to_dict(self) -> Dict[str, Any]: + def to_dict_private(self) -> Dict[str, Any]: return {"format": self.format, "label_fields": self.label_fields} @@ -73,8 +73,7 @@ def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]: data[data_name] = self.filter(data[data_name], rows, cols) data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="from") - data = self.remove_label_fields_from_data(data) - return data + return self.remove_label_fields_from_data(data) def preprocess(self, data: Dict[str, Any]) -> None: data = self.add_label_fields_to_data(data) @@ -92,10 +91,10 @@ def check_and_convert( if direction == "to": return self.convert_to_albumentations(data, rows, cols) - elif direction == "from": + if direction == "from": return self.convert_from_albumentations(data, rows, cols) - else: - raise ValueError(f"Invalid direction. Must be `to` or `from`. Got `{direction}`") + + raise ValueError(f"Invalid direction. Must be `to` or `from`. Got `{direction}`") @abstractmethod def filter(self, data: Sequence[BoxOrKeypointType], rows: int, cols: int) -> Sequence[BoxOrKeypointType]: @@ -120,10 +119,12 @@ def add_label_fields_to_data(self, data: Dict[str, Any]) -> Dict[str, Any]: return data for data_name in self.data_fields: for field in self.params.label_fields: - assert len(data[data_name]) == len(data[field]) + if not len(data[data_name]) == len(data[field]): + raise ValueError + data_with_added_field = [] for d, field_value in zip(data[data_name], data[field]): - data_with_added_field.append(list(d) + [field_value]) + data_with_added_field.append([*list(d), field_value]) data[data_name] = data_with_added_field return data @@ -133,10 +134,7 @@ def remove_label_fields_from_data(self, data: Dict[str, Any]) -> Dict[str, Any]: for data_name in self.data_fields: label_fields_len = len(self.params.label_fields) for idx, field in enumerate(self.params.label_fields): - field_values = [] - for bbox in data[data_name]: - field_values.append(bbox[-label_fields_len + idx]) - data[field] = field_values + data[field] = [bbox[-label_fields_len + idx] for bbox in data[data_name]] if label_fields_len: data[data_name] = [d[:-label_fields_len] for d in data[data_name]] return data diff --git a/albumentations/pytorch/transforms.py b/albumentations/pytorch/transforms.py index a3bf291c6..6acacc516 100644 --- a/albumentations/pytorch/transforms.py +++ b/albumentations/pytorch/transforms.py @@ -3,20 +3,25 @@ import numpy as np import torch -from ..core.transforms_interface import BasicTransform +from albumentations.core.transforms_interface import BasicTransform __all__ = ["ToTensorV2"] +TWO = 2 +THREE = 3 + class ToTensorV2(BasicTransform): """Converts images/masks to PyTorch Tensors, inheriting from BasicTransform. Supports images in numpy `HWC` format and converts them to PyTorch `CHW` format. If the image is in `HW` format, it will be converted to PyTorch `HW`. - Attributes: + Attributes + ---------- transpose_mask (bool): If True, transposes 3D input mask dimensions from `[height, width, num_channels]` to `[num_channels, height, width]`. always_apply (bool): Indicates if this transformation should be always applied. Default: True. p (float): Probability of applying the transform. Default: 1.0. + """ def __init__(self, transpose_mask: bool = False, always_apply: bool = True, p: float = 1.0): @@ -29,15 +34,16 @@ def targets(self) -> Dict[str, Any]: def apply(self, img: np.ndarray, **params: Any) -> torch.Tensor: if len(img.shape) not in [2, 3]: - raise ValueError("Albumentations only supports images in HW or HWC format") + msg = "Albumentations only supports images in HW or HWC format" + raise ValueError(msg) - if len(img.shape) == 2: + if len(img.shape) == TWO: img = np.expand_dims(img, 2) return torch.from_numpy(img.transpose(2, 0, 1)) def apply_to_mask(self, mask: np.ndarray, **params: Any) -> torch.Tensor: - if self.transpose_mask and mask.ndim == 3: + if self.transpose_mask and mask.ndim == THREE: mask = mask.transpose(2, 0, 1) return torch.from_numpy(mask) diff --git a/pyproject.toml b/pyproject.toml index 383ed8197..8b5d5b1e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,8 +17,6 @@ no_implicit_reexport = true disallow_untyped_defs = true - - [tool.black] line-length = 120 target-version = ["py38"] @@ -39,3 +37,68 @@ exclude = ''' setup.py )/ ''' + +[tool.ruff] +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".ipynb_checkpoints", + ".mypy_cache", + ".nox", + ".pants.d", + ".pyenv", + ".pytest_cache", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + ".vscode", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "site-packages", + "venv", + "tests" +] + +# Same as Black. +line-length = 120 +indent-width = 4 + +# Assume Python 3.8 +target-version = "py38" + +[tool.ruff.lint] +explicit-preview-rules = true + +select = ["F", "E", "W", "C90", "I", "N", "D", "UP", "YTT", "ANN", "ASYNC", "TRIO", "S", "BLE", "FBT", "B", "A", "COM", "CPY", "C4", "DTZ", "T10", "DJ", "EM", "EXE", "ISC", "ICN", "G", "INP", "PIE", "T20", "PYI", "PT", "Q", "RSE", "RET", "SLF", "SLOT", "SIM", "TID", "TCH", "INT", "ARG", "PTH", "TD", "FIX", "ERA", "PD", "PGH", "PL", "TRY", "FLY", "NPY", "PERF", "FURB", "LOG", "RUF"] +ignore = ["D211", "D213", "ANN101", "D107", "ANN401", "D102", "D103", "ANN204", "ARG002", "D104", "S311", "F403", "PLR0913", "FBT001", "FBT002", "ISC001", "COM812", "ANN102", "D100", "D205", "D101", "EM102", "TRY003", "D401", "A002", "D105", "D415", "D400", "D202", "D203", "PTH123", "D203", "B028", "ARG001", "ARG005", "B028", "N812", "FBT003", "D417", "B027"] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.format] +# Like Black, use double quotes for strings. +quote-style = "double" + +# Like Black, indent with spaces, rather than tabs. +indent-style = "space" + +# Like Black, respect magic trailing commas. +skip-magic-trailing-comma = false + +# Like Black, automatically detect the appropriate line ending. +line-ending = "auto" diff --git a/requirements-dev.txt b/requirements-dev.txt index ce2f0b1a0..6445bd9b5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,9 +1,8 @@ black==24.2.0 deepdiff==6.7.1 -flake8==7.0.0 -isort==5.13.2 mypy==1.8.0 pre_commit>=3.5.0 +ruff==0.2.2 types-pkg-resources types-PyYAML types-setuptools diff --git a/tests/test_functional.py b/tests/test_functional.py index 75aae9275..3d9caa011 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -1,17 +1,12 @@ -from __future__ import absolute_import - import cv2 import numpy as np import pytest -from numpy.testing import assert_array_almost_equal_nulp +from numpy.testing import assert_array_almost_equal_nulp, assert_almost_equal import albumentations as A import albumentations.augmentations.functional as F import albumentations.augmentations.geometric.functional as FGeometric -from albumentations.augmentations.utils import ( - get_opencv_dtype_from_numpy, - is_multispectral_image, -) +from albumentations.augmentations.utils import get_opencv_dtype_from_numpy, is_multispectral_image, MAX_VALUES_BY_DTYPE from albumentations.core.bbox_utils import filter_bboxes from tests.utils import convert_2d_to_target_format @@ -439,49 +434,88 @@ def test_gamma_float_equal_uint8(): assert (np.abs(img - img_f) <= 1).all() -@pytest.mark.parametrize(["dtype", "divider"], [(np.uint8, 255), (np.uint16, 65535), (np.uint32, 4294967295)]) -def test_to_float_without_max_value_specified(dtype, divider): +@pytest.mark.parametrize( + ["dtype", "expected_divider", "max_value"], + [ + (np.uint8, 255, None), + (np.uint16, 65535, None), + (np.uint32, 4294967295, None), + (np.float32, 1.0, None), + (np.int16, None, 32767), # Unsupported dtype with max_value provided + ], +) +def test_to_float(dtype, expected_divider, max_value): img = np.ones((100, 100, 3), dtype=dtype) - expected = img.astype("float32") / divider - assert_array_almost_equal_nulp(F.to_float(img), expected) - + if expected_divider is not None: + expected = (img.astype(np.float32) / expected_divider).astype(np.float32) + else: + # For unsupported dtype with max_value, use max_value for conversion + expected = (img.astype(np.float32) / max_value).astype(np.float32) -@pytest.mark.parametrize("max_value", [255.0, 65535.0, 4294967295.0]) -def test_to_float_with_max_value_specified(max_value): - img = np.ones((100, 100, 3), dtype=np.uint16) - expected = img.astype("float32") / max_value - assert_array_almost_equal_nulp(F.to_float(img, max_value=max_value), expected) + actual = F.to_float(img, max_value=max_value) + assert_almost_equal(actual, expected, decimal=6) + assert actual.dtype == np.float32, "Resulting dtype is not float32." -def test_to_float_unknown_dtype(): - img = np.ones((100, 100, 3), dtype=np.int16) +@pytest.mark.parametrize("dtype", [np.float64, np.int64]) +def test_to_float_raises_for_unsupported_dtype_without_max_value(dtype): + img = np.ones((100, 100, 3), dtype=dtype) with pytest.raises(RuntimeError) as exc_info: F.to_float(img) - assert str(exc_info.value) == ( - "Can't infer the maximum value for dtype int16. You need to specify the maximum value manually by passing " - "the max_value argument" - ) + assert "Unsupported dtype" in str(exc_info.value) -@pytest.mark.parametrize("max_value", [255.0, 65535.0, 4294967295.0]) -def test_to_float_unknown_dtype_with_max_value(max_value): - img = np.ones((100, 100, 3), dtype=np.int16) - expected = img.astype("float32") / max_value - assert_array_almost_equal_nulp(F.to_float(img, max_value=max_value), expected) +@pytest.mark.parametrize("dtype", [np.float64, np.int64]) +def test_to_float_with_max_value_for_unsupported_dtypes(dtype): + img = np.ones((100, 100, 3), dtype=dtype) + max_value = 1.0 if dtype == np.float64 else np.iinfo(dtype).max + expected = (img.astype(np.float32) / max_value).astype(np.float32) + actual = F.to_float(img, max_value=max_value) + assert_almost_equal(actual, expected, decimal=6) + assert actual.dtype == np.float32, "Resulting dtype is not float32." -@pytest.mark.parametrize(["dtype", "multiplier"], [(np.uint8, 255), (np.uint16, 65535), (np.uint32, 4294967295)]) -def test_from_float_without_max_value_specified(dtype, multiplier): - img = np.ones((100, 100, 3), dtype=np.float32) - expected = (img * multiplier).astype(dtype) - assert_array_almost_equal_nulp(F.from_float(img, np.dtype(dtype)), expected) +@pytest.mark.parametrize( + "dtype, multiplier, max_value", + [ + (np.uint8, 255, None), + (np.uint16, 65535, None), + (np.uint32, 4294967295, None), + (np.uint32, 4294967295, 4294967295.0), # Custom max_value equal to the default to test the parameter is used + ], +) +def test_from_float(dtype, multiplier, max_value): + img = np.random.rand(100, 100, 3).astype(np.float32) # Use random data for more robust testing + expected_multiplier = multiplier if max_value is None else max_value + expected = (img * expected_multiplier).astype(dtype) + actual = F.from_float(img, dtype=np.dtype(dtype), max_value=max_value) + assert_array_almost_equal_nulp(actual, expected) -@pytest.mark.parametrize("max_value", [255.0, 65535.0, 4294967295.0]) -def test_from_float_with_max_value_specified(max_value): - img = np.ones((100, 100, 3), dtype=np.float32) - expected = (img * max_value).astype(np.uint32) - assert_array_almost_equal_nulp(F.from_float(img, dtype=np.uint32, max_value=max_value), expected) +@pytest.mark.parametrize("dtype", [np.int64, np.float64]) +def test_from_float_unsupported_dtype_without_max_value(dtype): + img = np.random.rand(100, 100, 3).astype(np.float32) + with pytest.raises(RuntimeError) as exc_info: + F.from_float(img, dtype=dtype) + expected_part_of_message = "Can't infer the maximum value for dtype" + assert expected_part_of_message in str(exc_info.value), "Expected error message not found." + + +@pytest.mark.parametrize( + "dtype, expected_dtype", + [ + (np.uint8, np.uint8), + (np.uint16, np.uint16), + (np.uint32, np.uint32), + ], +) +def test_from_float_dtype_consistency(dtype, expected_dtype): + # The code snippet is generating a random 100x100x3 array of values between 0 and the maximum + # value allowed for the specified data type `dtype`. The `MAX_VALUES_BY_DTYPE` dictionary is used + # to determine the maximum value for the given data type. + img = np.random.rand(100, 100, 3) * MAX_VALUES_BY_DTYPE[dtype] + actual = F.from_float(img.astype(np.float32), dtype=dtype) + assert actual.dtype == expected_dtype, f"Expected dtype {expected_dtype} but got {actual.dtype}" @pytest.mark.parametrize("target", ["image", "mask"]) @@ -532,10 +566,14 @@ def test_from_float_unknown_dtype(): img = np.ones((100, 100, 3), dtype=np.float32) with pytest.raises(RuntimeError) as exc_info: F.from_float(img, np.dtype(np.int16)) - assert str(exc_info.value) == ( + expected_message = ( "Can't infer the maximum value for dtype int16. You need to specify the maximum value manually by passing " "the max_value argument" ) + actual_message = str(exc_info.value) + assert ( + expected_message in actual_message or actual_message in expected_message + ), f"Expected part of the error message to be: '{expected_message}', got: '{actual_message}'" @pytest.mark.parametrize("target", ["image", "mask"]) @@ -871,14 +909,12 @@ def test_equalize_checks(): mask = np.random.randint(0, 1, [256, 256, 3], dtype=bool) with pytest.raises(ValueError) as exc_info: F.equalize(img, mask=mask) - assert str(exc_info.value) == "Wrong mask shape. Image shape: {}. Mask shape: {}".format(img.shape, mask.shape) + assert str(exc_info.value) == f"Wrong mask shape. Image shape: {img.shape}. Mask shape: {mask.shape}" img = np.random.randint(0, 255, [256, 256, 3], dtype=np.uint8) with pytest.raises(ValueError) as exc_info: F.equalize(img, mask=mask, by_channels=False) - assert str(exc_info.value) == "When by_channels=False only 1-channel mask supports. " "Mask shape: {}".format( - mask.shape - ) + assert str(exc_info.value) == f"When by_channels=False only 1-channel mask supports. Mask shape: {mask.shape}" img = np.random.random([256, 256, 3]) with pytest.raises(TypeError) as exc_info: diff --git a/tests/test_keypoint.py b/tests/test_keypoint.py index f0dae917c..83497a0ee 100644 --- a/tests/test_keypoint.py +++ b/tests/test_keypoint.py @@ -15,7 +15,7 @@ @pytest.mark.parametrize( - ["kp", "source_format", "expected"], + ("kp", "source_format", "expected"), [ ((20, 30), "xy", (20, 30, 0, 0)), (np.array([20, 30]), "xy", (20, 30, 0, 0)), diff --git a/tests/test_transforms.py b/tests/test_transforms.py index d898bb8c8..f5330315a 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -1035,7 +1035,7 @@ def test_advanced_blur_float_uint8_diff_less_than_two(val_uint8): [ [{"blur_limit": (2, 5)}], [{"blur_limit": (3, 6)}], - [{"sigmaX_limit": (0.0, 1.0), "sigmaY_limit": (0.0, 1.0)}], + [{"sigma_x_limit": (0.0, 1.0), "sigma_y_limit": (0.0, 1.0)}], [{"beta_limit": (0.1, 0.9)}], [{"beta_limit": (1.1, 8.0)}], ],