diff --git a/README.md b/README.md index c01f42b24..2104722f9 100644 --- a/README.md +++ b/README.md @@ -236,6 +236,7 @@ Spatial-level transforms will simultaneously change both an input image as well | [SmallestMaxSize](https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/#albumentations.augmentations.geometric.resize.SmallestMaxSize) | ✓ | ✓ | ✓ | ✓ | | [Transpose](https://albumentations.ai/docs/api_reference/augmentations/geometric/transforms/#albumentations.augmentations.geometric.transforms.Transpose) | ✓ | ✓ | ✓ | ✓ | | [VerticalFlip](https://albumentations.ai/docs/api_reference/augmentations/geometric/transforms/#albumentations.augmentations.geometric.transforms.VerticalFlip) | ✓ | ✓ | ✓ | ✓ | +| [XYMasking](https://albumentations.ai/docs/api_reference/augmentations/dropout/xy_masking/#albumentations.augmentations.dropout.xy_masking.XYMasking) | ✓ | ✓ | | ✓ | ## A few more examples of augmentations diff --git a/albumentations/augmentations/__init__.py b/albumentations/augmentations/__init__.py index de00b12b8..0064d6ec6 100644 --- a/albumentations/augmentations/__init__.py +++ b/albumentations/augmentations/__init__.py @@ -11,6 +11,7 @@ from .dropout.functional import * from .dropout.grid_dropout import * from .dropout.mask_dropout import * +from .dropout.xy_masking import * from .functional import * from .geometric.functional import * from .geometric.resize import * diff --git a/albumentations/augmentations/dropout/__init__.py b/albumentations/augmentations/dropout/__init__.py index 1a2eb9463..33978fe2f 100644 --- a/albumentations/augmentations/dropout/__init__.py +++ b/albumentations/augmentations/dropout/__init__.py @@ -2,3 +2,4 @@ from .coarse_dropout import * from .grid_dropout import * from .mask_dropout import * +from .xy_masking import * diff --git a/albumentations/augmentations/dropout/coarse_dropout.py b/albumentations/augmentations/dropout/coarse_dropout.py index 6a5f5c48b..279ccf803 100644 --- a/albumentations/augmentations/dropout/coarse_dropout.py +++ b/albumentations/augmentations/dropout/coarse_dropout.py @@ -5,7 +5,7 @@ from ...core.transforms_interface import DualTransform from ...core.types import KeypointType, ScalarType -from .functional import cutout +from .functional import cutout, keypoint_in_hole __all__ = ["CoarseDropout"] @@ -79,7 +79,8 @@ def __init__( if not 0 < self.min_width <= self.max_width: raise ValueError(f"Invalid combination of min_width and max_width. Got: {[min_width, max_width]}") - def check_range(self, dimension: ScalarType) -> None: + @staticmethod + def check_range(dimension: ScalarType) -> None: if isinstance(dimension, float) and not 0 <= dimension < 1.0: raise ValueError(f"Invalid value {dimension}. If using floats, the value should be in the range [0.0, 1.0)") @@ -108,7 +109,7 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A height, width = img.shape[:2] holes = [] - for _n in range(random.randint(self.min_holes, self.max_holes)): + for _ in range(random.randint(self.min_holes, self.max_holes)): if all( [ isinstance(self.min_height, int), @@ -156,20 +157,14 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A def targets_as_params(self) -> List[str]: return ["image"] - def _keypoint_in_hole(self, keypoint: KeypointType, hole: Tuple[int, int, int, int]) -> bool: - x1, y1, x2, y2 = hole - x, y = keypoint[:2] - return x1 <= x < x2 and y1 <= y < y2 - def apply_to_keypoints( self, keypoints: Sequence[KeypointType], holes: Iterable[Tuple[int, int, int, int]] = (), **params: Any ) -> List[KeypointType]: - result = set(keypoints) - for hole in holes: - for kp in keypoints: - if self._keypoint_in_hole(kp, hole): - result.discard(kp) - return list(result) + filtered_keypoints = [] + for keypoint in keypoints: + if not any(keypoint_in_hole(keypoint, hole) for hole in holes): + filtered_keypoints.append(keypoint) + return filtered_keypoints def get_transform_init_args_names(self) -> Tuple[str, ...]: return ( diff --git a/albumentations/augmentations/dropout/functional.py b/albumentations/augmentations/dropout/functional.py index d3f17752e..063f84888 100644 --- a/albumentations/augmentations/dropout/functional.py +++ b/albumentations/augmentations/dropout/functional.py @@ -3,13 +3,12 @@ import numpy as np from albumentations.augmentations.utils import preserve_shape - -__all__ = ["channel_dropout"] +from albumentations.core.types import ColorType, KeypointType @preserve_shape def channel_dropout( - img: np.ndarray, channels_to_drop: Union[int, Tuple[int, ...], np.ndarray], fill_value: Union[int, float] = 0 + img: np.ndarray, channels_to_drop: Union[int, Tuple[int, ...], np.ndarray], fill_value: ColorType = 0 ) -> np.ndarray: if len(img.shape) == 2 or img.shape[2] == 1: raise NotImplementedError("Only one channel. ChannelDropout is not defined.") @@ -19,11 +18,18 @@ def channel_dropout( return img -def cutout( - img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]], fill_value: Union[int, float] = 0 -) -> np.ndarray: - # Make a copy of the input image since we don't want to modify it directly +def cutout(img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]], fill_value: ColorType = 0) -> np.ndarray: img = img.copy() + # Convert fill_value to a NumPy array for consistent broadcasting + if isinstance(fill_value, (tuple, list)): + fill_value = np.array(fill_value) + for x1, y1, x2, y2 in holes: img[y1:y2, x1:x2] = fill_value return img + + +def keypoint_in_hole(keypoint: KeypointType, hole: Tuple[int, int, int, int]) -> bool: + x, y = keypoint[:2] + x1, y1, x2, y2 = hole + return x1 <= x < x2 and y1 <= y < y2 diff --git a/albumentations/augmentations/dropout/xy_masking.py b/albumentations/augmentations/dropout/xy_masking.py new file mode 100644 index 000000000..f4f61e598 --- /dev/null +++ b/albumentations/augmentations/dropout/xy_masking.py @@ -0,0 +1,211 @@ +import random +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast + +import numpy as np + +from albumentations.core.types import ColorType, KeypointType, ScaleIntType + +from ...core.transforms_interface import DualTransform, to_tuple +from .functional import cutout, keypoint_in_hole + +__all__ = ["XYMasking"] + + +class XYMasking(DualTransform): + """ + Applies masking strips to an image, either horizontally (X axis) or vertically (Y axis), + simulating occlusions. This transform is useful for training models to recognize images + with varied visibility conditions. It's particularly effective for spectrogram images, + allowing spectral and frequency masking to improve model robustness. + + At least one of `max_x_length` or `max_y_length` must be specified, dictating the mask's + maximum size along each axis. + + Args: + num_masks_x (Union[int, Tuple[int, int]]): Number or range of horizontal regions to mask. Defaults to 0. + num_masks_y (Union[int, Tuple[int, int]]): Number or range of vertical regions to mask. Defaults to 0. + mask_x_length ([Union[int, Tuple[int, int]]): Specifies the length of the masks along + the X (horizontal) axis. If an integer is provided, it sets a fixed mask length. + If a tuple of two integers (min, max) is provided, + the mask length is randomly chosen within this range for each mask. + This allows for variable-length masks in the horizontal direction. + mask_y_length (Union[int, Tuple[int, int]]): Specifies the height of the masks along + the Y (vertical) axis. Similar to `mask_x_length`, an integer sets a fixed mask height, + while a tuple (min, max) allows for variable-height masks, chosen randomly + within the specified range for each mask. This flexibility facilitates creating masks of various + sizes in the vertical direction. + fill_value (Union[int, float, List[int], List[float]]): Value to fill image masks. Defaults to 0. + mask_fill_value (Optional[Union[int, float, List[int], List[float]]]): Value to fill masks in the mask. + If `None`, uses mask is not affected. Default: `None`. + p (float): Probability of applying the transform. Defaults to 0.5. + + Targets: + image, mask, keypoints + + Image types: + uint8, float32 + + Note: Either `max_x_length` or `max_y_length` or both must be defined. + """ + + def __init__( + self, + num_masks_x: ScaleIntType = 0, + num_masks_y: ScaleIntType = 0, + mask_x_length: ScaleIntType = 0, + mask_y_length: ScaleIntType = 0, + fill_value: ColorType = 0, + mask_fill_value: ColorType = 0, + always_apply: bool = False, + p: float = 0.5, + ): + super().__init__(always_apply, p) + + if ( + isinstance(mask_x_length, (int, float)) + and mask_x_length <= 0 + and isinstance(mask_y_length, (int, float)) + and mask_y_length <= 0 + ): + raise ValueError("At least one of `mask_x_length` or `mask_y_length` Should be a positive number.") + + if isinstance(num_masks_x, int) and num_masks_x <= 0 and isinstance(num_masks_y, int) and num_masks_y <= 0: + raise ValueError( + "At least one of `num_masks_x` or `num_masks_y` " + "should be a positive number or tuple of two positive numbers." + ) + + if isinstance(num_masks_x, (tuple, list)) and min(num_masks_x) <= 0: + raise ValueError("All values in `num_masks_x` should be non negative integers.") + + if isinstance(num_masks_y, (tuple, list)) and min(num_masks_y) <= 0: + raise ValueError("All values in `num_masks_y` should be non negative integers.") + + self.num_masks_x = num_masks_x + self.num_masks_y = num_masks_y + + self.mask_x_length = mask_x_length + self.mask_y_length = mask_y_length + self.fill_value = fill_value + self.mask_fill_value = mask_fill_value + + def apply( + self, + img: np.ndarray, + masks_x: List[Tuple[int, int, int, int]], + masks_y: List[Tuple[int, int, int, int]], + **params: Any, + ) -> np.ndarray: + return cutout(img, masks_x + masks_y, self.fill_value) + + def apply_to_mask( + self, + mask: np.ndarray, + masks_x: List[Tuple[int, int, int, int]], + masks_y: List[Tuple[int, int, int, int]], + **params: Any, + ) -> np.ndarray: + if self.mask_fill_value is None: + return mask + return cutout(mask, masks_x + masks_y, self.mask_fill_value) + + def validate_mask_length( + self, mask_length: Optional[ScaleIntType], dimension_size: int, dimension_name: str + ) -> None: + """ + Validate the mask length against the corresponding image dimension size. + + Args: + mask_length (Optional[Union[int, Tuple[int, int]]]): The length of the mask to be validated. + dimension_size (int): The size of the image dimension (width or height) + against which to validate the mask length. + dimension_name (str): The name of the dimension ('width' or 'height') for error messaging. + """ + if mask_length is not None: + if isinstance(mask_length, tuple): + if mask_length[0] < 0 or mask_length[1] > dimension_size: + raise ValueError( + f"{dimension_name} range {mask_length} is out of valid range [0, {dimension_size}]" + ) + elif mask_length < 0 or mask_length > dimension_size: + raise ValueError(f"{dimension_name} {mask_length} exceeds image {dimension_name} {dimension_size}") + + def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, List[Tuple[int, int, int, int]]]: + img = params["image"] + height, width = img.shape[:2] + + # Use the helper method to validate mask lengths against image dimensions + self.validate_mask_length(self.mask_x_length, width, "mask_x_length") + self.validate_mask_length(self.mask_y_length, height, "mask_y_length") + + masks_x = self.generate_masks(self.num_masks_x, width, height, self.mask_x_length, axis="x") + masks_y = self.generate_masks(self.num_masks_y, width, height, self.mask_y_length, axis="y") + + return {"masks_x": masks_x, "masks_y": masks_y} + + @staticmethod + def generate_mask_size(mask_length: Union[ScaleIntType]) -> int: + if isinstance(mask_length, int): + return mask_length # Use fixed size or adjust to dimension size + + return random.randint(min(mask_length), max(mask_length)) + + def generate_masks( + self, + num_masks: ScaleIntType, + width: int, + height: int, + max_length: Optional[ScaleIntType], + axis: str, + ) -> List[Tuple[int, int, int, int]]: + if max_length is None or max_length == 0 or isinstance(num_masks, (int, float)) and num_masks == 0: + return [] + + masks = [] + + if isinstance(num_masks, int): + num_masks_integer = num_masks + else: + num_masks_integer = random.randint(num_masks[0], num_masks[1]) + + for _ in range(num_masks_integer): + length = self.generate_mask_size(max_length) + + if axis == "x": + x1 = random.randint(0, width - length) + y1 = 0 + x2, y2 = x1 + length, height + else: # axis == 'y' + y1 = random.randint(0, height - length) + x1 = 0 + x2, y2 = width, y1 + length + + masks.append((x1, y1, x2, y2)) + return masks + + @property + def targets_as_params(self) -> List[str]: + return ["image"] + + def apply_to_keypoints( + self, + keypoints: Sequence[KeypointType], + masks_x: List[Tuple[int, int, int, int]], + masks_y: List[Tuple[int, int, int, int]], + **params: Any, + ) -> List[KeypointType]: + filtered_keypoints = [] + for keypoint in keypoints: + if not any(keypoint_in_hole(keypoint, hole) for hole in masks_x + masks_y): + filtered_keypoints.append(keypoint) + return filtered_keypoints + + def get_transform_init_args_names(self) -> Tuple[str, ...]: + return ( + "num_masks_x", + "num_masks_y", + "mask_x_length", + "mask_y_length", + "fill_value", + "mask_fill_value", + ) diff --git a/albumentations/core/transforms_interface.py b/albumentations/core/transforms_interface.py index 10d7894fe..906170eb1 100644 --- a/albumentations/core/transforms_interface.py +++ b/albumentations/core/transforms_interface.py @@ -10,6 +10,7 @@ from .types import ( BoxInternalType, BoxType, + ColorType, KeypointInternalType, KeypointType, ScalarType, @@ -74,8 +75,8 @@ def __init__(self, downscale: int = cv2.INTER_NEAREST, upscale: int = cv2.INTER_ class BasicTransform(Serializable): call_backup = None interpolation: Union[int, Interpolation] - fill_value: ScalarType - mask_fill_value: Optional[ScalarType] + fill_value: ColorType + mask_fill_value: Optional[ColorType] def __init__(self, always_apply: bool = False, p: float = 0.5): self.p = p diff --git a/albumentations/core/types.py b/albumentations/core/types.py index 815dd9464..78ff3702b 100644 --- a/albumentations/core/types.py +++ b/albumentations/core/types.py @@ -3,7 +3,7 @@ import numpy as np ScalarType = Union[int, float] -ColorType = Union[int, float, Tuple[int, int, int], Tuple[float, float, float]] +ColorType = Union[int, float, Sequence[int], Sequence[float]] SizeType = Sequence[int] BoxInternalType = Tuple[float, float, float, float] diff --git a/tests/test_augmentations.py b/tests/test_augmentations.py index a4f0d5dc3..674d047aa 100644 --- a/tests/test_augmentations.py +++ b/tests/test_augmentations.py @@ -96,6 +96,14 @@ def test_image_only_augmentations_with_float_values(augmentation_cls, params, fl A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10}, A.CropAndPad: {"px": 10}, A.Resize: {"height": 10, "width": 10}, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "fill_value": 0, + "mask_fill_value": 1, + }, }, except_augmentations={A.RandomCropNearBBox, A.RandomSizedBBoxSafeCrop, A.BBoxSafeRandomCrop}, ), @@ -119,6 +127,14 @@ def test_dual_augmentations(augmentation_cls, params, image, mask): A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10}, A.CropAndPad: {"px": 10}, A.Resize: {"height": 10, "width": 10}, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, }, except_augmentations={A.RandomCropNearBBox, A.RandomSizedBBoxSafeCrop, A.BBoxSafeRandomCrop}, ), @@ -158,6 +174,14 @@ def test_dual_augmentations_with_float_values(augmentation_cls, params, float_im A.TemplateTransform: { "templates": np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8), }, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, }, except_augmentations={A.RandomCropNearBBox, A.RandomSizedBBoxSafeCrop, A.BBoxSafeRandomCrop}, ), @@ -200,6 +224,14 @@ def test_augmentations_wont_change_input(augmentation_cls, params, image, mask): A.TemplateTransform: { "templates": np.random.uniform(low=0.0, high=1.0, size=(100, 100, 3)).astype(np.float32), }, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, }, except_augmentations={ A.CLAHE, @@ -239,6 +271,14 @@ def test_augmentations_wont_change_float_input(augmentation_cls, params, float_i A.TemplateTransform: { "templates": np.random.randint(low=0, high=256, size=(224, 224), dtype=np.uint8), }, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, }, except_augmentations={ A.ChannelDropout, @@ -315,6 +355,14 @@ def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, ima A.TemplateTransform: { "templates": np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8), }, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, }, except_augmentations={ A.RandomCropNearBBox, @@ -404,6 +452,14 @@ def test_mask_fill_value(augmentation_cls, params): A.TemplateTransform: { "templates": np.random.randint(0, 256, (100, 100, 6), dtype=np.uint8), }, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, }, except_augmentations={ A.CLAHE, @@ -467,6 +523,14 @@ def test_multichannel_image_augmentations(augmentation_cls, params): A.TemplateTransform: { "templates": np.random.uniform(0.0, 1.0, (100, 100, 6)).astype(np.float32), }, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, }, except_augmentations={ A.CLAHE, @@ -521,6 +585,14 @@ def test_float_multichannel_image_augmentations(augmentation_cls, params): A.TemplateTransform: { "templates": np.random.randint(0, 1, (100, 100), dtype=np.uint8), }, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, }, except_augmentations={ A.CLAHE, @@ -579,6 +651,14 @@ def test_multichannel_image_augmentations_diff_channels(augmentation_cls, params A.TemplateTransform: { "templates": np.random.uniform(0.0, 1.0, (100, 100, 1)).astype(np.float32), }, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, }, except_augmentations={ A.CLAHE, @@ -800,7 +880,7 @@ def test_perspective_valid_keypoints_after_transform(seed: int, scale: float, h: def test_pixel_domain_adaptation(kind): img_uint8 = np.random.randint(low=100, high=200, size=(100, 100, 3), dtype=np.uint8) ref_img_uint8 = np.random.randint(low=0, high=100, size=(100, 100, 3), dtype=np.uint8) - img_float, ref_img_float = [x.astype("float32") / 255.0 for x in (img_uint8, ref_img_uint8)] + img_float, ref_img_float = (x.astype("float32") / 255.0 for x in (img_uint8, ref_img_uint8)) for img, ref_img in ((img_uint8, ref_img_uint8), (img_float, ref_img_float)): adapter = A.PixelDistributionAdaptation( @@ -854,6 +934,14 @@ def test_pixel_domain_adaptation(kind): A.Resize: {"height": 10, "width": 10}, A.RandomSizedBBoxSafeCrop: {"height": 10, "width": 10}, A.BBoxSafeRandomCrop: {"erosion_rate": 0.5}, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, }, ), ) diff --git a/tests/test_functional_cutout.py b/tests/test_functional_cutout.py new file mode 100644 index 000000000..65fa4d94e --- /dev/null +++ b/tests/test_functional_cutout.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest + +from albumentations.augmentations.dropout.functional import cutout + + +@pytest.mark.parametrize( + "img, fill_value", + [ + # Single-channel image, fill_value is a number + (np.zeros((10, 10), dtype=np.uint8), 255), + # Multi-channel image with different channel counts, fill_value is a number (applied to all channels) + (np.zeros((10, 10, 3), dtype=np.uint8), 255), + # Multi-channel image, fill_value is a tuple with different values for different channels + (np.zeros((10, 10, 3), dtype=np.uint8), (128, 128, 128)), + # Multi-channel image, fill_value as list with different values + (np.zeros((10, 10, 2), dtype=np.uint8), [64, 192]), + # Multi-channel image, fill_value as np.ndarray with different values + (np.zeros((10, 10, 3), dtype=np.uint8), np.array([32, 64, 96], dtype=np.uint8)), + ], +) +def test_cutout_with_various_fill_values(img, fill_value): + holes = [(2, 2, 5, 5)] + result = cutout(img, holes, fill_value=fill_value) + + # Compute expected result + expected_result = img.copy() + for x1, y1, x2, y2 in holes: + if isinstance(fill_value, (int, float)): + fill_array = np.array(fill_value, dtype=img.dtype) + else: + fill_array = np.array(fill_value, dtype=img.dtype).reshape(-1) + if img.ndim == 2: # Single-channel image + expected_result[y1:y2, x1:x2] = fill_array + else: # Multi-channel image + fill_shape = (y2 - y1, x2 - x1, img.shape[2]) if img.ndim == 3 else (y2 - y1, x2 - x1) + expected_fill = np.full(fill_shape, fill_array, dtype=img.dtype) + expected_result[y1:y2, x1:x2] = expected_fill[: y2 - y1, : x2 - x1] + + # Check the filled values + assert np.all(result == expected_result), "The result does not match the expected output." diff --git a/tests/test_serialization.py b/tests/test_serialization.py index aca08f2a4..eaf9a51d2 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -36,6 +36,14 @@ A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10}, A.CropAndPad: {"px": 10}, A.Resize: {"height": 10, "width": 10}, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, }, except_augmentations={ A.RandomCropNearBBox, @@ -385,6 +393,17 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m ], [A.Defocus, {"radius": (5, 7), "alias_blur": (0.2, 0.6)}], [A.ZoomBlur, {"max_factor": (1.56, 1.7), "step_factor": (0.02, 0.04)}], + [ + A.XYMasking, + { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, + ], ] AUGMENTATION_CLS_EXCEPT = { @@ -476,6 +495,7 @@ def test_augmentations_serialization_to_file_with_custom_parameters( A.MaskDropout, A.OpticalDistortion, A.TemplateTransform, + A.XYMasking, }, ), ) @@ -508,6 +528,14 @@ def test_augmentations_for_bboxes_serialization( A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10}, A.CropAndPad: {"px": 10}, A.Resize: {"height": 10, "width": 10}, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "fill_value": 0, + "mask_fill_value": 1, + }, }, except_augmentations={ A.RandomCropNearBBox, diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 14c64848e..d898bb8c8 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -156,6 +156,14 @@ def test_elastic_transform_interpolation(monkeypatch, interpolation): A.CropAndPad: {"px": 10}, A.Resize: {"height": 10, "width": 10}, A.PixelDropout: {"dropout_prob": 0.5, "mask_drop_value": 10, "drop_value": 20}, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, }, except_augmentations={A.RandomCropNearBBox, A.RandomSizedBBoxSafeCrop, A.BBoxSafeRandomCrop, A.PixelDropout}, ), @@ -188,13 +196,12 @@ def test_binary_mask_interpolation(augmentation_cls, params): A.BBoxSafeRandomCrop, A.CropAndPad, A.PixelDropout, + A.XYMasking, }, ), ) def test_semantic_mask_interpolation(augmentation_cls, params): - """Checks whether transformations based on DualTransform does not introduce a mask interpolation artifacts. - Note: IAAAffine, IAAPiecewiseAffine, IAAPerspective does not properly operate if mask has values other than {0;1} - """ + """Checks whether transformations based on DualTransform does not introduce a mask interpolation artifacts.""" aug = augmentation_cls(p=1, **params) image = np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8) mask = np.random.randint(low=0, high=4, size=(100, 100), dtype=np.uint8) * 64 @@ -223,6 +230,14 @@ def __test_multiprocessing_support_proc(args): A.TemplateTransform: { "templates": np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8), }, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": (1, 3), + "mask_x_length": 10, + "mask_y_length": 10, + "mask_fill_value": 1, + "fill_value": 0, + }, }, except_augmentations={ A.RandomCropNearBBox, diff --git a/tests/utils.py b/tests/utils.py index 3d0118ae0..508024aa3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,7 +2,7 @@ import random import typing from io import StringIO -from typing import Optional, Set, Type +from typing import Optional, Set import numpy as np @@ -26,7 +26,7 @@ def convert_2d_to_target_format(arrays, target): if target == "image_4_channels": return convert_2d_to_3d(arrays, num_channels=4) - raise ValueError("Unknown target {}".format(target)) + raise ValueError(f"Unknown target {target}") class InMemoryFile(StringIO): @@ -73,7 +73,7 @@ def get_filtered_transforms( result = [] - for name, cls in inspect.getmembers(albumentations): + for _, cls in inspect.getmembers(albumentations): if not inspect.isclass(cls) or not issubclass(cls, (albumentations.BasicTransform, albumentations.BaseCompose)): continue @@ -83,14 +83,7 @@ def get_filtered_transforms( if not issubclass(cls, base_classes) or any(cls == i for i in base_classes) or cls in except_augmentations: continue - try: - if issubclass(cls, albumentations.BasicIAATransform): - continue - except AttributeError: - pass - result.append((cls, custom_arguments.get(cls, {}))) - return result diff --git a/tools/make_transforms_docs.py b/tools/make_transforms_docs.py index cb7475216..fbcd24e7c 100644 --- a/tools/make_transforms_docs.py +++ b/tools/make_transforms_docs.py @@ -10,7 +10,6 @@ IGNORED_CLASSES = { "BasicTransform", "DualTransform", - "ImageOnlyIAATransform", "ImageOnlyTransform", }