Skip to content

Commit

Permalink
Fix pad if needed serialization (#1570)
Browse files Browse the repository at this point in the history
* Fix in serialization

* Fixes

* Update albumentations/core/serialization.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Update albumentations/augmentations/transforms.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Fixes

* Fix in Docs generation for Readme

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
  • Loading branch information
ternaus and sourcery-ai[bot] authored Mar 9, 2024
1 parent 5663816 commit f744c83
Show file tree
Hide file tree
Showing 12 changed files with 143 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ repos:
types: [python]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.0
rev: v0.3.2
hooks:
# Run the linter.
- id: ruff
Expand Down
3 changes: 0 additions & 3 deletions albumentations/augmentations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
# Common classes
from .blur.functional import *
from .blur.transforms import *
from .crops.functional import *
from .crops.transforms import *

# New transformations goes to individual files listed below
from .domain_adaptation import *
from .domain_adaptation_functional import *
from .dropout.channel_dropout import *
Expand Down
2 changes: 1 addition & 1 deletion albumentations/augmentations/blur/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
preserve_shape,
)

__all__ = ["blur", "median_blur", "gaussian_blur", "glass_blur"]
__all__ = ["blur", "median_blur", "gaussian_blur", "glass_blur", "defocus", "central_zoom", "zoom_blur"]

TWO = 2
EIGHT = 8
Expand Down
7 changes: 4 additions & 3 deletions albumentations/augmentations/blur/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

from albumentations import random_utils
from albumentations.augmentations import functional as FMain
from albumentations.augmentations.blur import functional as F
from albumentations.core.transforms_interface import ImageOnlyTransform, to_tuple
from albumentations.core.types import ScaleFloatType, ScaleIntType

from . import functional as F

__all__ = ["Blur", "MotionBlur", "GaussianBlur", "GlassBlur", "AdvancedBlur", "MedianBlur", "Defocus", "ZoomBlur"]


Expand Down Expand Up @@ -273,8 +274,8 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, n

return {"dxy": dxy}

def get_transform_init_args_names(self) -> Tuple[str, str, str]:
return ("sigma", "max_delta", "iterations")
def get_transform_init_args_names(self) -> Tuple[str, str, str, str]:
return ("sigma", "max_delta", "iterations", "mode")

@property
def targets_as_params(self) -> List[str]:
Expand Down
27 changes: 21 additions & 6 deletions albumentations/augmentations/crops/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import random
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from warnings import warn

import cv2
import numpy as np
Expand Down Expand Up @@ -481,7 +482,8 @@ class RandomCropNearBBox(DualTransform):
to `cropping_bbox` dimension.
If max_part_shift is a single float, the range will be (max_part_shift, max_part_shift).
Default (0.3, 0.3).
cropping_box_key (str): Additional target key for cropping box. Default `cropping_bbox`
cropping_bbox_key (str): Additional target key for cropping box. Default `cropping_bbox`.
cropping_box_key (str): [Deprecated] Use `cropping_bbox_key` instead.
p (float): probability of applying the transform. Default: 1.
Targets:
Expand All @@ -491,7 +493,7 @@ class RandomCropNearBBox(DualTransform):
uint8, float32
Examples:
>>> aug = Compose([RandomCropNearBBox(max_part_shift=(0.1, 0.5), cropping_box_key='test_box')],
>>> aug = Compose([RandomCropNearBBox(max_part_shift=(0.1, 0.5), cropping_bbox_key='test_box')],
>>> bbox_params=BboxParams("pascal_voc"))
>>> result = aug(image=image, bboxes=bboxes, test_box=[0, 5, 10, 20])
Expand All @@ -502,13 +504,26 @@ class RandomCropNearBBox(DualTransform):
def __init__(
self,
max_part_shift: ScaleFloatType = (0.3, 0.3),
cropping_box_key: str = "cropping_bbox",
cropping_bbox_key: str = "cropping_bbox",
cropping_box_key: Optional[str] = None, # Deprecated
always_apply: bool = False,
p: float = 1.0,
):
super().__init__(always_apply, p)
self.max_part_shift = to_tuple(max_part_shift, low=max_part_shift)
self.cropping_bbox_key = cropping_box_key

# Check for deprecated parameter and issue warning
if cropping_box_key is not None:
warn(
"The parameter 'cropping_box_key' is deprecated and will be removed in future versions. "
"Use 'cropping_bbox_key' instead.",
DeprecationWarning,
stacklevel=2,
)
# Ensure the new parameter is used even if the old one is passed
cropping_bbox_key = cropping_box_key

self.cropping_bbox_key = cropping_bbox_key

if min(self.max_part_shift) < 0 or max(self.max_part_shift) > 1:
raise ValueError(f"Invalid max_part_shift. Got: {max_part_shift}")
Expand Down Expand Up @@ -552,8 +567,8 @@ def apply_to_keypoint(
def targets_as_params(self) -> List[str]:
return [self.cropping_bbox_key]

def get_transform_init_args_names(self) -> Tuple[str]:
return ("max_part_shift",)
def get_transform_init_args_names(self) -> Tuple[str, str]:
return ("max_part_shift", "cropping_bbox_key")


class BBoxSafeRandomCrop(DualTransform):
Expand Down
2 changes: 2 additions & 0 deletions albumentations/augmentations/dropout/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

TWO = 2

__all__ = ["cutout", "channel_dropout", "keypoint_in_hole"]


@preserve_shape
def channel_dropout(
Expand Down
1 change: 1 addition & 0 deletions albumentations/augmentations/geometric/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,7 @@ def get_transform_init_args_names(self) -> Tuple[str, ...]:
"min_width",
"pad_height_divisor",
"pad_width_divisor",
"position",
"border_mode",
"value",
"mask_value",
Expand Down
24 changes: 12 additions & 12 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,9 @@ def __init__(
super().__init__(always_apply, p)

if not 0 <= snow_point_lower <= snow_point_upper <= 1:
msg = "Invalid combination of snow_point_lower and snow_point_upper. Got: {}".format(
(snow_point_lower, snow_point_upper)
msg = (
"Invalid combination of snow_point_lower and snow_point_upper. "
f"Got: {(snow_point_lower, snow_point_upper)}"
)
raise ValueError(msg)
if brightness_coeff < 0:
Expand Down Expand Up @@ -741,8 +742,9 @@ def __init__(
if not 0 <= angle_lower < angle_upper <= 1:
raise ValueError(f"Invalid combination of angle_lower nad angle_upper. Got: {(angle_lower, angle_upper)}")
if not 0 <= num_flare_circles_lower < num_flare_circles_upper:
msg = "Invalid combination of num_flare_circles_lower nad num_flare_circles_upper. Got: {}".format(
(num_flare_circles_lower, num_flare_circles_upper)
msg = (
"Invalid combination of num_flare_circles_lower and num_flare_circles_upper. "
f"Got: {(num_flare_circles_lower, num_flare_circles_upper)}"
)
raise ValueError(msg)

Expand Down Expand Up @@ -889,9 +891,8 @@ def __init__(
if not 0 <= shadow_lower_x <= shadow_upper_x <= 1 or not 0 <= shadow_lower_y <= shadow_upper_y <= 1:
raise ValueError(f"Invalid shadow_roi. Got: {shadow_roi}")
if not 0 <= num_shadows_lower <= num_shadows_upper:
msg = "Invalid combination of num_shadows_lower nad num_shadows_upper. Got: {}".format(
(num_shadows_lower, num_shadows_upper)
)
msg = "Invalid combination of num_shadows_lower nad num_shadows_upper. "
f"Got: {(num_shadows_lower, num_shadows_upper)}"
raise ValueError(msg)

self.shadow_roi = shadow_roi
Expand Down Expand Up @@ -1171,8 +1172,8 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A
def targets_as_params(self) -> List[str]:
return ["image", *list(self.mask_params)]

def get_transform_init_args_names(self) -> Tuple[str, str]:
return ("mode", "by_channels")
def get_transform_init_args_names(self) -> Tuple[str, ...]:
return ("mode", "by_channels", "mask", "mask_params")


class RGBShift(ImageOnlyTransform):
Expand Down Expand Up @@ -2325,9 +2326,8 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A
if get_num_channels(template) not in [1, get_num_channels(img)]:
msg = (
"Template must be a single channel or "
"has the same number of channels as input image ({}), got {}".format(
get_num_channels(img), get_num_channels(template)
)
"has the same number of channels as input "
f"image ({get_num_channels(img)}), got {get_num_channels(template)}"
)
raise ValueError(msg)

Expand Down
21 changes: 18 additions & 3 deletions albumentations/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import json
import warnings
from abc import ABC, ABCMeta, abstractmethod
from collections.abc import Mapping, Sequence
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional, TextIO, Tuple, Type, Union

Expand Down Expand Up @@ -98,9 +100,10 @@ def to_dict(self, on_not_implemented_error: str = "raise") -> Dict[str, Any]:

transform_dict = {}
warnings.warn(
"Got NotImplementedError while trying to serialize {obj}. Object arguments are not preserved. "
"Implement either '{cls_name}.get_transform_init_args_names' or '{cls_name}.get_transform_init_args' "
"method to make the transform serializable".format(obj=self, cls_name=self.__class__.__name__)
f"Got NotImplementedError while trying to serialize {self}. Object arguments are not preserved. "
f"Implement either '{self.__class__.__name__}.get_transform_init_args_names' "
f"or '{self.__class__.__name__}.get_transform_init_args' "
"method to make the transform serializable"
)
return {"__version__": __version__, "transform": transform_dict}

Expand Down Expand Up @@ -165,6 +168,17 @@ def check_data_format(data_format: str) -> None:
raise ValueError(f"Unknown data_format {data_format}. Supported formats are: 'json' and 'yaml'")


def serialize_enum(obj: Any) -> Any:
"""Recursively search for Enum objects and convert them to their value.
Also handle any Mapping or Sequence types.
"""
if isinstance(obj, Mapping):
return {k: serialize_enum(v) for k, v in obj.items()}
if isinstance(obj, Sequence) and not isinstance(obj, str): # exclude strings since they're also sequences
return [serialize_enum(v) for v in obj]
return obj.value if isinstance(obj, Enum) else obj


def save(
transform: "Serializable",
filepath_or_buffer: Union[str, Path, TextIO],
Expand Down Expand Up @@ -192,6 +206,7 @@ def save(
"""
check_data_format(data_format)
transform_dict = transform.to_dict(on_not_implemented_error=on_not_implemented_error)
transform_dict = serialize_enum(transform_dict)

# Determine whether to write to a file or a file-like object
if isinstance(filepath_or_buffer, (str, Path)): # It's a filepath
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pre_commit>=3.5.0
pytest>=8.0.2
pytest_cov>=4.1.0
requests>=2.31.0
ruff>=0.3.0
ruff>=0.3.2
tomli>=2.0.1
types-pkg-resources
types-PyYAML
Expand Down
87 changes: 82 additions & 5 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest
from deepdiff import DeepDiff
import inspect

import albumentations as A
import albumentations.augmentations.geometric.functional as FGeometric
Expand Down Expand Up @@ -43,6 +44,13 @@
"mask_fill_value": 1,
"fill_value": 0,
},
A.PadIfNeeded: {
"min_height": 512,
"min_width": 512,
"border_mode": 0,
"value": [124, 116, 104],
"position": "top_left"
}
},
except_augmentations={
A.RandomCropNearBBox,
Expand Down Expand Up @@ -403,7 +411,14 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m
"mask_fill_value": 1,
"fill_value": 0,
},
]
],
[A.PadIfNeeded, {
"min_height": 512,
"min_width": 512,
"border_mode": 0,
"value": [124, 116, 104],
"position": "top_left"
}]
]

AUGMENTATION_CLS_EXCEPT = {
Expand Down Expand Up @@ -480,6 +495,13 @@ def test_augmentations_serialization_to_file_with_custom_parameters(
A.Resize: {"height": 10, "width": 10},
A.RandomSizedBBoxSafeCrop: {"height": 10, "width": 10},
A.BBoxSafeRandomCrop: {"erosion_rate": 0.6},
A.PadIfNeeded: {
"min_height": 512,
"min_width": 512,
"border_mode": 0,
"value": [124, 116, 104],
"position": "top_left"
}
},
except_augmentations={
A.RandomCropNearBBox,
Expand Down Expand Up @@ -538,6 +560,13 @@ def test_augmentations_for_bboxes_serialization(
"fill_value": 0,
"mask_fill_value": 1,
},
A.PadIfNeeded: {
"min_height": 512,
"min_width": 512,
"border_mode": 0,
"value": [124, 116, 104],
"position": "top_left"
}
},
except_augmentations={
A.RandomCropNearBBox,
Expand Down Expand Up @@ -861,8 +890,8 @@ def test_serialization_conversion_without_totensor(transform_file_name, data_for
buffer.close()

assert (
DeepDiff(transform.to_dict(), transform_from_buffer.to_dict()) == {}
), "The loaded transform is not equal to the original one"
DeepDiff(transform.to_dict(), transform_from_buffer.to_dict(), ignore_type_in_groups=[(tuple, list)]) == {}
), f"The loaded transform is not equal to the original one {DeepDiff(transform.to_dict(), transform_from_buffer.to_dict(), ignore_type_in_groups=[(tuple, list)])}"

set_seed(seed)
image1 = transform(image=image)["image"]
Expand Down Expand Up @@ -898,8 +927,8 @@ def test_serialization_conversion_with_totensor(transform_file_name, data_format
buffer.close() # Ensure the buffer is closed after use

assert (
DeepDiff(transform.to_dict(), transform_from_buffer.to_dict()) == {}
), "The loaded transform is not equal to the original one"
DeepDiff(transform.to_dict(), transform_from_buffer.to_dict(), ignore_type_in_groups=[(tuple, list)]) == {}
), f"The loaded transform is not equal to the original one {DeepDiff(transform.to_dict(), transform_from_buffer.to_dict(), ignore_type_in_groups=[(tuple, list)])}"

set_seed(seed)
image1 = transform(image=image)["image"]
Expand Down Expand Up @@ -959,3 +988,51 @@ def test_template_transform_serialization(image, template, seed, p):
deserialized_aug_data = deserialized_aug(image=image)

assert np.array_equal(aug_data["image"], deserialized_aug_data["image"])


@pytest.mark.parametrize( ["augmentation_cls", "params"], get_transforms(custom_arguments={
A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10},
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
A.Resize: {"height": 10, "width": 10},
A.XYMasking: {
"num_masks_x": (1, 3),
"num_masks_y": 3,
"mask_x_length": (10, 20),
"mask_y_length": 10,
"fill_value": 0,
"mask_fill_value": 1,
},
A.PadIfNeeded: {
"min_height": 512,
"min_width": 512,
"border_mode": 0,
"value": [124, 116, 104],
"position": "top_left"
},
A.RandomSizedBBoxSafeCrop: {"height": 10, "width": 10}
}, except_augmentations={
A.FDA,
A.HistogramMatching,
A.PixelDistributionAdaptation,
A.Lambda,
A.TemplateTransform,
A.MixUp,
A.ShiftScaleRotate,
},) )
def test_augmentations_serialization(augmentation_cls, params):
instance = augmentation_cls(**params)

# Retrieve the constructor's parameters, except 'self', "always_apply"\
init_params = inspect.signature(augmentation_cls.__init__).parameters
expected_args = set(init_params.keys()) - {'self', "always_apply"}

# Retrieve the arguments reported by the instance's get_transform_init_args_names
reported_args = set(instance.to_dict()["transform"].keys()) - {'__class_fullname__', "always_apply"}

# Check if the reported arguments match the expected arguments
assert expected_args == reported_args, f"Mismatch in {augmentation_cls.__name__}: Expected {expected_args}, got {reported_args}"
Loading

0 comments on commit f744c83

Please sign in to comment.