diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a3c1806c5..53044df60 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -46,7 +46,7 @@ repos: types: [python] - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.3.0 + rev: v0.3.2 hooks: # Run the linter. - id: ruff diff --git a/albumentations/augmentations/__init__.py b/albumentations/augmentations/__init__.py index 7c0c7cdcb..32a0ad9c2 100644 --- a/albumentations/augmentations/__init__.py +++ b/albumentations/augmentations/__init__.py @@ -1,10 +1,7 @@ -# Common classes from .blur.functional import * from .blur.transforms import * from .crops.functional import * from .crops.transforms import * - -# New transformations goes to individual files listed below from .domain_adaptation import * from .domain_adaptation_functional import * from .dropout.channel_dropout import * diff --git a/albumentations/augmentations/blur/functional.py b/albumentations/augmentations/blur/functional.py index 7dc5fb1d4..9f43e0987 100644 --- a/albumentations/augmentations/blur/functional.py +++ b/albumentations/augmentations/blur/functional.py @@ -13,7 +13,7 @@ preserve_shape, ) -__all__ = ["blur", "median_blur", "gaussian_blur", "glass_blur"] +__all__ = ["blur", "median_blur", "gaussian_blur", "glass_blur", "defocus", "central_zoom", "zoom_blur"] TWO = 2 EIGHT = 8 diff --git a/albumentations/augmentations/blur/transforms.py b/albumentations/augmentations/blur/transforms.py index 5aba209da..cbc3b7f75 100644 --- a/albumentations/augmentations/blur/transforms.py +++ b/albumentations/augmentations/blur/transforms.py @@ -7,10 +7,11 @@ from albumentations import random_utils from albumentations.augmentations import functional as FMain -from albumentations.augmentations.blur import functional as F from albumentations.core.transforms_interface import ImageOnlyTransform, to_tuple from albumentations.core.types import ScaleFloatType, ScaleIntType +from . import functional as F + __all__ = ["Blur", "MotionBlur", "GaussianBlur", "GlassBlur", "AdvancedBlur", "MedianBlur", "Defocus", "ZoomBlur"] @@ -273,8 +274,8 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, n return {"dxy": dxy} - def get_transform_init_args_names(self) -> Tuple[str, str, str]: - return ("sigma", "max_delta", "iterations") + def get_transform_init_args_names(self) -> Tuple[str, str, str, str]: + return ("sigma", "max_delta", "iterations", "mode") @property def targets_as_params(self) -> List[str]: diff --git a/albumentations/augmentations/crops/transforms.py b/albumentations/augmentations/crops/transforms.py index 6d3c8ba44..4b5d0d617 100644 --- a/albumentations/augmentations/crops/transforms.py +++ b/albumentations/augmentations/crops/transforms.py @@ -1,6 +1,7 @@ import math import random from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from warnings import warn import cv2 import numpy as np @@ -481,7 +482,8 @@ class RandomCropNearBBox(DualTransform): to `cropping_bbox` dimension. If max_part_shift is a single float, the range will be (max_part_shift, max_part_shift). Default (0.3, 0.3). - cropping_box_key (str): Additional target key for cropping box. Default `cropping_bbox` + cropping_bbox_key (str): Additional target key for cropping box. Default `cropping_bbox`. + cropping_box_key (str): [Deprecated] Use `cropping_bbox_key` instead. p (float): probability of applying the transform. Default: 1. Targets: @@ -491,7 +493,7 @@ class RandomCropNearBBox(DualTransform): uint8, float32 Examples: - >>> aug = Compose([RandomCropNearBBox(max_part_shift=(0.1, 0.5), cropping_box_key='test_box')], + >>> aug = Compose([RandomCropNearBBox(max_part_shift=(0.1, 0.5), cropping_bbox_key='test_box')], >>> bbox_params=BboxParams("pascal_voc")) >>> result = aug(image=image, bboxes=bboxes, test_box=[0, 5, 10, 20]) @@ -502,13 +504,26 @@ class RandomCropNearBBox(DualTransform): def __init__( self, max_part_shift: ScaleFloatType = (0.3, 0.3), - cropping_box_key: str = "cropping_bbox", + cropping_bbox_key: str = "cropping_bbox", + cropping_box_key: Optional[str] = None, # Deprecated always_apply: bool = False, p: float = 1.0, ): super().__init__(always_apply, p) self.max_part_shift = to_tuple(max_part_shift, low=max_part_shift) - self.cropping_bbox_key = cropping_box_key + + # Check for deprecated parameter and issue warning + if cropping_box_key is not None: + warn( + "The parameter 'cropping_box_key' is deprecated and will be removed in future versions. " + "Use 'cropping_bbox_key' instead.", + DeprecationWarning, + stacklevel=2, + ) + # Ensure the new parameter is used even if the old one is passed + cropping_bbox_key = cropping_box_key + + self.cropping_bbox_key = cropping_bbox_key if min(self.max_part_shift) < 0 or max(self.max_part_shift) > 1: raise ValueError(f"Invalid max_part_shift. Got: {max_part_shift}") @@ -552,8 +567,8 @@ def apply_to_keypoint( def targets_as_params(self) -> List[str]: return [self.cropping_bbox_key] - def get_transform_init_args_names(self) -> Tuple[str]: - return ("max_part_shift",) + def get_transform_init_args_names(self) -> Tuple[str, str]: + return ("max_part_shift", "cropping_bbox_key") class BBoxSafeRandomCrop(DualTransform): diff --git a/albumentations/augmentations/dropout/functional.py b/albumentations/augmentations/dropout/functional.py index 5e05c57b7..59807e3b7 100644 --- a/albumentations/augmentations/dropout/functional.py +++ b/albumentations/augmentations/dropout/functional.py @@ -7,6 +7,8 @@ TWO = 2 +__all__ = ["cutout", "channel_dropout", "keypoint_in_hole"] + @preserve_shape def channel_dropout( diff --git a/albumentations/augmentations/geometric/transforms.py b/albumentations/augmentations/geometric/transforms.py index 111fb1897..f2095e180 100644 --- a/albumentations/augmentations/geometric/transforms.py +++ b/albumentations/augmentations/geometric/transforms.py @@ -1253,6 +1253,7 @@ def get_transform_init_args_names(self) -> Tuple[str, ...]: "min_width", "pad_height_divisor", "pad_width_divisor", + "position", "border_mode", "value", "mask_value", diff --git a/albumentations/augmentations/transforms.py b/albumentations/augmentations/transforms.py index d08c83dee..06c920911 100644 --- a/albumentations/augmentations/transforms.py +++ b/albumentations/augmentations/transforms.py @@ -341,8 +341,9 @@ def __init__( super().__init__(always_apply, p) if not 0 <= snow_point_lower <= snow_point_upper <= 1: - msg = "Invalid combination of snow_point_lower and snow_point_upper. Got: {}".format( - (snow_point_lower, snow_point_upper) + msg = ( + "Invalid combination of snow_point_lower and snow_point_upper. " + f"Got: {(snow_point_lower, snow_point_upper)}" ) raise ValueError(msg) if brightness_coeff < 0: @@ -741,8 +742,9 @@ def __init__( if not 0 <= angle_lower < angle_upper <= 1: raise ValueError(f"Invalid combination of angle_lower nad angle_upper. Got: {(angle_lower, angle_upper)}") if not 0 <= num_flare_circles_lower < num_flare_circles_upper: - msg = "Invalid combination of num_flare_circles_lower nad num_flare_circles_upper. Got: {}".format( - (num_flare_circles_lower, num_flare_circles_upper) + msg = ( + "Invalid combination of num_flare_circles_lower and num_flare_circles_upper. " + f"Got: {(num_flare_circles_lower, num_flare_circles_upper)}" ) raise ValueError(msg) @@ -889,9 +891,8 @@ def __init__( if not 0 <= shadow_lower_x <= shadow_upper_x <= 1 or not 0 <= shadow_lower_y <= shadow_upper_y <= 1: raise ValueError(f"Invalid shadow_roi. Got: {shadow_roi}") if not 0 <= num_shadows_lower <= num_shadows_upper: - msg = "Invalid combination of num_shadows_lower nad num_shadows_upper. Got: {}".format( - (num_shadows_lower, num_shadows_upper) - ) + msg = "Invalid combination of num_shadows_lower nad num_shadows_upper. " + f"Got: {(num_shadows_lower, num_shadows_upper)}" raise ValueError(msg) self.shadow_roi = shadow_roi @@ -1171,8 +1172,8 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A def targets_as_params(self) -> List[str]: return ["image", *list(self.mask_params)] - def get_transform_init_args_names(self) -> Tuple[str, str]: - return ("mode", "by_channels") + def get_transform_init_args_names(self) -> Tuple[str, ...]: + return ("mode", "by_channels", "mask", "mask_params") class RGBShift(ImageOnlyTransform): @@ -2325,9 +2326,8 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A if get_num_channels(template) not in [1, get_num_channels(img)]: msg = ( "Template must be a single channel or " - "has the same number of channels as input image ({}), got {}".format( - get_num_channels(img), get_num_channels(template) - ) + "has the same number of channels as input " + f"image ({get_num_channels(img)}), got {get_num_channels(template)}" ) raise ValueError(msg) diff --git a/albumentations/core/serialization.py b/albumentations/core/serialization.py index 7a1581181..89b82b206 100644 --- a/albumentations/core/serialization.py +++ b/albumentations/core/serialization.py @@ -2,6 +2,8 @@ import json import warnings from abc import ABC, ABCMeta, abstractmethod +from collections.abc import Mapping, Sequence +from enum import Enum from pathlib import Path from typing import Any, Dict, Optional, TextIO, Tuple, Type, Union @@ -98,9 +100,10 @@ def to_dict(self, on_not_implemented_error: str = "raise") -> Dict[str, Any]: transform_dict = {} warnings.warn( - "Got NotImplementedError while trying to serialize {obj}. Object arguments are not preserved. " - "Implement either '{cls_name}.get_transform_init_args_names' or '{cls_name}.get_transform_init_args' " - "method to make the transform serializable".format(obj=self, cls_name=self.__class__.__name__) + f"Got NotImplementedError while trying to serialize {self}. Object arguments are not preserved. " + f"Implement either '{self.__class__.__name__}.get_transform_init_args_names' " + f"or '{self.__class__.__name__}.get_transform_init_args' " + "method to make the transform serializable" ) return {"__version__": __version__, "transform": transform_dict} @@ -165,6 +168,17 @@ def check_data_format(data_format: str) -> None: raise ValueError(f"Unknown data_format {data_format}. Supported formats are: 'json' and 'yaml'") +def serialize_enum(obj: Any) -> Any: + """Recursively search for Enum objects and convert them to their value. + Also handle any Mapping or Sequence types. + """ + if isinstance(obj, Mapping): + return {k: serialize_enum(v) for k, v in obj.items()} + if isinstance(obj, Sequence) and not isinstance(obj, str): # exclude strings since they're also sequences + return [serialize_enum(v) for v in obj] + return obj.value if isinstance(obj, Enum) else obj + + def save( transform: "Serializable", filepath_or_buffer: Union[str, Path, TextIO], @@ -192,6 +206,7 @@ def save( """ check_data_format(data_format) transform_dict = transform.to_dict(on_not_implemented_error=on_not_implemented_error) + transform_dict = serialize_enum(transform_dict) # Determine whether to write to a file or a file-like object if isinstance(filepath_or_buffer, (str, Path)): # It's a filepath diff --git a/requirements-dev.txt b/requirements-dev.txt index d79ee9e32..8c464a74d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -4,7 +4,7 @@ pre_commit>=3.5.0 pytest>=8.0.2 pytest_cov>=4.1.0 requests>=2.31.0 -ruff>=0.3.0 +ruff>=0.3.2 tomli>=2.0.1 types-pkg-resources types-PyYAML diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 019642ed2..7db363b1e 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -6,6 +6,7 @@ import numpy as np import pytest from deepdiff import DeepDiff +import inspect import albumentations as A import albumentations.augmentations.geometric.functional as FGeometric @@ -43,6 +44,13 @@ "mask_fill_value": 1, "fill_value": 0, }, + A.PadIfNeeded: { + "min_height": 512, + "min_width": 512, + "border_mode": 0, + "value": [124, 116, 104], + "position": "top_left" + } }, except_augmentations={ A.RandomCropNearBBox, @@ -403,7 +411,14 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m "mask_fill_value": 1, "fill_value": 0, }, - ] + ], + [A.PadIfNeeded, { + "min_height": 512, + "min_width": 512, + "border_mode": 0, + "value": [124, 116, 104], + "position": "top_left" + }] ] AUGMENTATION_CLS_EXCEPT = { @@ -480,6 +495,13 @@ def test_augmentations_serialization_to_file_with_custom_parameters( A.Resize: {"height": 10, "width": 10}, A.RandomSizedBBoxSafeCrop: {"height": 10, "width": 10}, A.BBoxSafeRandomCrop: {"erosion_rate": 0.6}, + A.PadIfNeeded: { + "min_height": 512, + "min_width": 512, + "border_mode": 0, + "value": [124, 116, 104], + "position": "top_left" + } }, except_augmentations={ A.RandomCropNearBBox, @@ -538,6 +560,13 @@ def test_augmentations_for_bboxes_serialization( "fill_value": 0, "mask_fill_value": 1, }, + A.PadIfNeeded: { + "min_height": 512, + "min_width": 512, + "border_mode": 0, + "value": [124, 116, 104], + "position": "top_left" + } }, except_augmentations={ A.RandomCropNearBBox, @@ -861,8 +890,8 @@ def test_serialization_conversion_without_totensor(transform_file_name, data_for buffer.close() assert ( - DeepDiff(transform.to_dict(), transform_from_buffer.to_dict()) == {} - ), "The loaded transform is not equal to the original one" + DeepDiff(transform.to_dict(), transform_from_buffer.to_dict(), ignore_type_in_groups=[(tuple, list)]) == {} + ), f"The loaded transform is not equal to the original one {DeepDiff(transform.to_dict(), transform_from_buffer.to_dict(), ignore_type_in_groups=[(tuple, list)])}" set_seed(seed) image1 = transform(image=image)["image"] @@ -898,8 +927,8 @@ def test_serialization_conversion_with_totensor(transform_file_name, data_format buffer.close() # Ensure the buffer is closed after use assert ( - DeepDiff(transform.to_dict(), transform_from_buffer.to_dict()) == {} - ), "The loaded transform is not equal to the original one" + DeepDiff(transform.to_dict(), transform_from_buffer.to_dict(), ignore_type_in_groups=[(tuple, list)]) == {} + ), f"The loaded transform is not equal to the original one {DeepDiff(transform.to_dict(), transform_from_buffer.to_dict(), ignore_type_in_groups=[(tuple, list)])}" set_seed(seed) image1 = transform(image=image)["image"] @@ -959,3 +988,51 @@ def test_template_transform_serialization(image, template, seed, p): deserialized_aug_data = deserialized_aug(image=image) assert np.array_equal(aug_data["image"], deserialized_aug_data["image"]) + + +@pytest.mark.parametrize( ["augmentation_cls", "params"], get_transforms(custom_arguments={ + A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10}, + A.CenterCrop: {"height": 10, "width": 10}, + A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10}, + A.RandomCrop: {"height": 10, "width": 10}, + A.RandomResizedCrop: {"height": 10, "width": 10}, + A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10}, + A.CropAndPad: {"px": 10}, + A.Resize: {"height": 10, "width": 10}, + A.XYMasking: { + "num_masks_x": (1, 3), + "num_masks_y": 3, + "mask_x_length": (10, 20), + "mask_y_length": 10, + "fill_value": 0, + "mask_fill_value": 1, + }, + A.PadIfNeeded: { + "min_height": 512, + "min_width": 512, + "border_mode": 0, + "value": [124, 116, 104], + "position": "top_left" + }, + A.RandomSizedBBoxSafeCrop: {"height": 10, "width": 10} + }, except_augmentations={ + A.FDA, + A.HistogramMatching, + A.PixelDistributionAdaptation, + A.Lambda, + A.TemplateTransform, + A.MixUp, + A.ShiftScaleRotate, + },) ) +def test_augmentations_serialization(augmentation_cls, params): + instance = augmentation_cls(**params) + + # Retrieve the constructor's parameters, except 'self', "always_apply"\ + init_params = inspect.signature(augmentation_cls.__init__).parameters + expected_args = set(init_params.keys()) - {'self', "always_apply"} + + # Retrieve the arguments reported by the instance's get_transform_init_args_names + reported_args = set(instance.to_dict()["transform"].keys()) - {'__class_fullname__', "always_apply"} + + # Check if the reported arguments match the expected arguments + assert expected_args == reported_args, f"Mismatch in {augmentation_cls.__name__}: Expected {expected_args}, got {reported_args}" diff --git a/tools/make_transforms_docs.py b/tools/make_transforms_docs.py index 9c972377e..c010ee22a 100644 --- a/tools/make_transforms_docs.py +++ b/tools/make_transforms_docs.py @@ -167,16 +167,6 @@ def check_docs(filepath, image_only_transforms_links, dual_transforms_table, mix ) raise ValueError(msg) - if image_only_transforms_links not in text: - msg = "Image only transforms links are outdated." - raise ValueError(msg) - if dual_transforms_table not in text: - msg = "Dual transforms table are outdated." - raise ValueError(msg) - if mixing_transforms_table not in text: - msg = "Mixing transforms table are outdated." - raise ValueError(msg) - def main() -> None: