Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pad if needed serialization #1570

Merged
merged 6 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ repos:
types: [python]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.3.0
rev: v0.3.2
hooks:
# Run the linter.
- id: ruff
Expand Down
3 changes: 0 additions & 3 deletions albumentations/augmentations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
# Common classes
from .blur.functional import *
from .blur.transforms import *
from .crops.functional import *
from .crops.transforms import *

# New transformations goes to individual files listed below
from .domain_adaptation import *
from .domain_adaptation_functional import *
from .dropout.channel_dropout import *
Expand Down
2 changes: 1 addition & 1 deletion albumentations/augmentations/blur/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
preserve_shape,
)

__all__ = ["blur", "median_blur", "gaussian_blur", "glass_blur"]
__all__ = ["blur", "median_blur", "gaussian_blur", "glass_blur", "defocus", "central_zoom", "zoom_blur"]

TWO = 2
EIGHT = 8
Expand Down
7 changes: 4 additions & 3 deletions albumentations/augmentations/blur/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@

from albumentations import random_utils
from albumentations.augmentations import functional as FMain
from albumentations.augmentations.blur import functional as F
from albumentations.core.transforms_interface import ImageOnlyTransform, to_tuple
from albumentations.core.types import ScaleFloatType, ScaleIntType

from . import functional as F

__all__ = ["Blur", "MotionBlur", "GaussianBlur", "GlassBlur", "AdvancedBlur", "MedianBlur", "Defocus", "ZoomBlur"]


Expand Down Expand Up @@ -273,8 +274,8 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, n

return {"dxy": dxy}

def get_transform_init_args_names(self) -> Tuple[str, str, str]:
return ("sigma", "max_delta", "iterations")
def get_transform_init_args_names(self) -> Tuple[str, str, str, str]:
return ("sigma", "max_delta", "iterations", "mode")

@property
def targets_as_params(self) -> List[str]:
Expand Down
10 changes: 5 additions & 5 deletions albumentations/augmentations/crops/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,7 @@ class RandomCropNearBBox(DualTransform):
to `cropping_bbox` dimension.
If max_part_shift is a single float, the range will be (max_part_shift, max_part_shift).
Default (0.3, 0.3).
cropping_box_key (str): Additional target key for cropping box. Default `cropping_bbox`
cropping_bbox_key (str): Additional target key for cropping box. Default `cropping_bbox`
p (float): probability of applying the transform. Default: 1.

Targets:
Expand All @@ -502,13 +502,13 @@ class RandomCropNearBBox(DualTransform):
def __init__(
self,
max_part_shift: ScaleFloatType = (0.3, 0.3),
cropping_box_key: str = "cropping_bbox",
cropping_bbox_key: str = "cropping_bbox",
always_apply: bool = False,
p: float = 1.0,
):
super().__init__(always_apply, p)
self.max_part_shift = to_tuple(max_part_shift, low=max_part_shift)
self.cropping_bbox_key = cropping_box_key
self.cropping_bbox_key = cropping_bbox_key

if min(self.max_part_shift) < 0 or max(self.max_part_shift) > 1:
raise ValueError(f"Invalid max_part_shift. Got: {max_part_shift}")
Expand Down Expand Up @@ -552,8 +552,8 @@ def apply_to_keypoint(
def targets_as_params(self) -> List[str]:
return [self.cropping_bbox_key]

def get_transform_init_args_names(self) -> Tuple[str]:
return ("max_part_shift",)
def get_transform_init_args_names(self) -> Tuple[str, str]:
return ("max_part_shift", "cropping_bbox_key")


class BBoxSafeRandomCrop(DualTransform):
Expand Down
2 changes: 2 additions & 0 deletions albumentations/augmentations/dropout/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@

TWO = 2

__all__ = ["cutout", "channel_dropout", "keypoint_in_hole"]


@preserve_shape
def channel_dropout(
Expand Down
1 change: 1 addition & 0 deletions albumentations/augmentations/geometric/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,7 @@ def get_transform_init_args_names(self) -> Tuple[str, ...]:
"min_width",
"pad_height_divisor",
"pad_width_divisor",
"position",
"border_mode",
"value",
"mask_value",
Expand Down
24 changes: 12 additions & 12 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,9 @@ def __init__(
super().__init__(always_apply, p)

if not 0 <= snow_point_lower <= snow_point_upper <= 1:
msg = "Invalid combination of snow_point_lower and snow_point_upper. Got: {}".format(
(snow_point_lower, snow_point_upper)
msg = (
"Invalid combination of snow_point_lower and snow_point_upper. "
f"Got: {(snow_point_lower, snow_point_upper)}"
)
raise ValueError(msg)
if brightness_coeff < 0:
Expand Down Expand Up @@ -741,8 +742,9 @@ def __init__(
if not 0 <= angle_lower < angle_upper <= 1:
raise ValueError(f"Invalid combination of angle_lower nad angle_upper. Got: {(angle_lower, angle_upper)}")
if not 0 <= num_flare_circles_lower < num_flare_circles_upper:
msg = "Invalid combination of num_flare_circles_lower nad num_flare_circles_upper. Got: {}".format(
(num_flare_circles_lower, num_flare_circles_upper)
msg = (
"Invalid combination of num_flare_circles_lower nad num_flare_circles_upper. "
f"Got: {(num_flare_circles_lower, num_flare_circles_upper)}"
)
raise ValueError(msg)

Expand Down Expand Up @@ -889,9 +891,8 @@ def __init__(
if not 0 <= shadow_lower_x <= shadow_upper_x <= 1 or not 0 <= shadow_lower_y <= shadow_upper_y <= 1:
raise ValueError(f"Invalid shadow_roi. Got: {shadow_roi}")
if not 0 <= num_shadows_lower <= num_shadows_upper:
msg = "Invalid combination of num_shadows_lower nad num_shadows_upper. Got: {}".format(
(num_shadows_lower, num_shadows_upper)
)
msg = "Invalid combination of num_shadows_lower nad num_shadows_upper. "
f"Got: {(num_shadows_lower, num_shadows_upper)}"
raise ValueError(msg)

self.shadow_roi = shadow_roi
Expand Down Expand Up @@ -1171,8 +1172,8 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A
def targets_as_params(self) -> List[str]:
return ["image", *list(self.mask_params)]

def get_transform_init_args_names(self) -> Tuple[str, str]:
return ("mode", "by_channels")
def get_transform_init_args_names(self) -> Tuple[str, ...]:
return ("mode", "by_channels", "mask", "mask_params")


class RGBShift(ImageOnlyTransform):
Expand Down Expand Up @@ -2325,9 +2326,8 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A
if get_num_channels(template) not in [1, get_num_channels(img)]:
msg = (
"Template must be a single channel or "
"has the same number of channels as input image ({}), got {}".format(
get_num_channels(img), get_num_channels(template)
)
"has the same number of channels as input "
f"image ({get_num_channels(img)}), got {get_num_channels(template)}"
)
raise ValueError(msg)

Expand Down
20 changes: 17 additions & 3 deletions albumentations/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import warnings
from abc import ABC, ABCMeta, abstractmethod
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional, TextIO, Tuple, Type, Union

Expand Down Expand Up @@ -98,9 +99,10 @@ def to_dict(self, on_not_implemented_error: str = "raise") -> Dict[str, Any]:

transform_dict = {}
warnings.warn(
"Got NotImplementedError while trying to serialize {obj}. Object arguments are not preserved. "
"Implement either '{cls_name}.get_transform_init_args_names' or '{cls_name}.get_transform_init_args' "
"method to make the transform serializable".format(obj=self, cls_name=self.__class__.__name__)
f"Got NotImplementedError while trying to serialize {self}. Object arguments are not preserved. "
f"Implement either '{self.__class__.__name__}.get_transform_init_args_names' "
f"or '{self.__class__.__name__}.get_transform_init_args' "
"method to make the transform serializable"
)
return {"__version__": __version__, "transform": transform_dict}

Expand Down Expand Up @@ -165,6 +167,17 @@ def check_data_format(data_format: str) -> None:
raise ValueError(f"Unknown data_format {data_format}. Supported formats are: 'json' and 'yaml'")


def serialize_enum(obj: Any) -> Any:
"""Recursively search for Enum objects and convert them to their value."""
if isinstance(obj, dict):
return {k: serialize_enum(v) for k, v in obj.items()}
if isinstance(obj, list):
return [serialize_enum(v) for v in obj]
if isinstance(obj, Enum):
return obj.value # Convert Enum to its value
return obj


def save(
transform: "Serializable",
filepath_or_buffer: Union[str, Path, TextIO],
Expand Down Expand Up @@ -192,6 +205,7 @@ def save(
"""
check_data_format(data_format)
transform_dict = transform.to_dict(on_not_implemented_error=on_not_implemented_error)
transform_dict = serialize_enum(transform_dict)

# Determine whether to write to a file or a file-like object
if isinstance(filepath_or_buffer, (str, Path)): # It's a filepath
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pre_commit>=3.5.0
pytest>=8.0.2
pytest_cov>=4.1.0
requests>=2.31.0
ruff>=0.3.0
ruff>=0.3.2
tomli>=2.0.1
types-pkg-resources
types-PyYAML
Expand Down
87 changes: 82 additions & 5 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest
from deepdiff import DeepDiff
import inspect

import albumentations as A
import albumentations.augmentations.geometric.functional as FGeometric
Expand Down Expand Up @@ -43,6 +44,13 @@
"mask_fill_value": 1,
"fill_value": 0,
},
A.PadIfNeeded: {
"min_height": 512,
"min_width": 512,
"border_mode": 0,
"value": [124, 116, 104],
"position": "top_left"
}
},
except_augmentations={
A.RandomCropNearBBox,
Expand Down Expand Up @@ -403,7 +411,14 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m
"mask_fill_value": 1,
"fill_value": 0,
},
]
],
[A.PadIfNeeded, {
"min_height": 512,
"min_width": 512,
"border_mode": 0,
"value": [124, 116, 104],
"position": "top_left"
}]
]

AUGMENTATION_CLS_EXCEPT = {
Expand Down Expand Up @@ -480,6 +495,13 @@ def test_augmentations_serialization_to_file_with_custom_parameters(
A.Resize: {"height": 10, "width": 10},
A.RandomSizedBBoxSafeCrop: {"height": 10, "width": 10},
A.BBoxSafeRandomCrop: {"erosion_rate": 0.6},
A.PadIfNeeded: {
"min_height": 512,
"min_width": 512,
"border_mode": 0,
"value": [124, 116, 104],
"position": "top_left"
}
},
except_augmentations={
A.RandomCropNearBBox,
Expand Down Expand Up @@ -538,6 +560,13 @@ def test_augmentations_for_bboxes_serialization(
"fill_value": 0,
"mask_fill_value": 1,
},
A.PadIfNeeded: {
"min_height": 512,
"min_width": 512,
"border_mode": 0,
"value": [124, 116, 104],
"position": "top_left"
}
},
except_augmentations={
A.RandomCropNearBBox,
Expand Down Expand Up @@ -861,8 +890,8 @@ def test_serialization_conversion_without_totensor(transform_file_name, data_for
buffer.close()

assert (
DeepDiff(transform.to_dict(), transform_from_buffer.to_dict()) == {}
), "The loaded transform is not equal to the original one"
DeepDiff(transform.to_dict(), transform_from_buffer.to_dict(), ignore_type_in_groups=[(tuple, list)]) == {}
), f"The loaded transform is not equal to the original one {DeepDiff(transform.to_dict(), transform_from_buffer.to_dict(), ignore_type_in_groups=[(tuple, list)])}"

set_seed(seed)
image1 = transform(image=image)["image"]
Expand Down Expand Up @@ -898,8 +927,8 @@ def test_serialization_conversion_with_totensor(transform_file_name, data_format
buffer.close() # Ensure the buffer is closed after use

assert (
DeepDiff(transform.to_dict(), transform_from_buffer.to_dict()) == {}
), "The loaded transform is not equal to the original one"
DeepDiff(transform.to_dict(), transform_from_buffer.to_dict(), ignore_type_in_groups=[(tuple, list)]) == {}
), f"The loaded transform is not equal to the original one {DeepDiff(transform.to_dict(), transform_from_buffer.to_dict(), ignore_type_in_groups=[(tuple, list)])}"

set_seed(seed)
image1 = transform(image=image)["image"]
Expand Down Expand Up @@ -959,3 +988,51 @@ def test_template_transform_serialization(image, template, seed, p):
deserialized_aug_data = deserialized_aug(image=image)

assert np.array_equal(aug_data["image"], deserialized_aug_data["image"])


@pytest.mark.parametrize( ["augmentation_cls", "params"], get_transforms(custom_arguments={
A.Crop: {"y_min": 0, "y_max": 10, "x_min": 0, "x_max": 10},
A.CenterCrop: {"height": 10, "width": 10},
A.CropNonEmptyMaskIfExists: {"height": 10, "width": 10},
A.RandomCrop: {"height": 10, "width": 10},
A.RandomResizedCrop: {"height": 10, "width": 10},
A.RandomSizedCrop: {"min_max_height": (4, 8), "height": 10, "width": 10},
A.CropAndPad: {"px": 10},
A.Resize: {"height": 10, "width": 10},
A.XYMasking: {
"num_masks_x": (1, 3),
"num_masks_y": 3,
"mask_x_length": (10, 20),
"mask_y_length": 10,
"fill_value": 0,
"mask_fill_value": 1,
},
A.PadIfNeeded: {
"min_height": 512,
"min_width": 512,
"border_mode": 0,
"value": [124, 116, 104],
"position": "top_left"
},
A.RandomSizedBBoxSafeCrop: {"height": 10, "width": 10}
}, except_augmentations={
A.FDA,
A.HistogramMatching,
A.PixelDistributionAdaptation,
A.Lambda,
A.TemplateTransform,
A.MixUp,
A.ShiftScaleRotate,
},) )
def test_augmentations_serialization(augmentation_cls, params):
instance = augmentation_cls(**params)

# Retrieve the constructor's parameters, except 'self', "always_apply"\
init_params = inspect.signature(augmentation_cls.__init__).parameters
expected_args = set(init_params.keys()) - {'self', "always_apply"}

# Retrieve the arguments reported by the instance's get_transform_init_args_names
reported_args = set(instance.to_dict()["transform"].keys()) - {'__class_fullname__', "always_apply"}

# Check if the reported arguments match the expected arguments
assert expected_args == reported_args, f"Mismatch in {augmentation_cls.__name__}: Expected {expected_args}, got {reported_args}"
Loading