Skip to content

Commit

Permalink
issue1587
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus authored and zetyquickly committed Apr 1, 2024
1 parent 191b68f commit 644f937
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 9 deletions.
80 changes: 71 additions & 9 deletions albumentations/augmentations/crops/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,12 +334,11 @@ def apply_to_keypoint(


class RandomSizedCrop(_BaseRandomSizedCrop):
"""Crop a random part of the input and rescale it to some size.
"""Crop a random portion of the input and rescale it to a specific size.
Args:
min_max_height ((int, int)): crop size limits.
height (int): height after crop and resize.
width (int): width after crop and resize.
size (tuple[int]): target size for the output image, i.e. (height, width) after crop and resize
w2h_ratio (float): aspect ratio of crop.
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
Expand All @@ -355,17 +354,49 @@ class RandomSizedCrop(_BaseRandomSizedCrop):
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)
_size_len = 2

def __init__(
self,
min_max_height: Tuple[int, int],
height: int,
width: int,
# NOTE @zetyquickly: when (width, height) are deprecated, make 'size' non optional
size: Optional[Union[int, Tuple[int, int]]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
*,
w2h_ratio: float = 1.0,
interpolation: int = cv2.INTER_LINEAR,
always_apply: bool = False,
p: float = 1.0,
):
if isinstance(size, tuple):
if len(size) != self._size_len:
message = "Size must be a tuple of two integers (height, width)."
raise ValueError(message)
height, width = size
elif size is None:
if height is None or width is None:
message = "If 'size' is not provided, both 'height' and 'width' must be specified."
raise ValueError(message)
size = (height, width)
warn(
"Initializing with 'height' and 'width' is deprecated. "
"Please use a tuple (height, width) for the 'size' argument.",
DeprecationWarning,
stacklevel=2,
)
else:
if width is None:
message = "When 'size' is an integer, 'width' must be provided."
raise ValueError(message)
height = size
warn(
"Initializing with 'size' as an integer and a separate 'width' is deprecated. "
"Please use a tuple (height, width) for the 'size' argument.",
DeprecationWarning,
stacklevel=2,
)

super().__init__(height=height, width=width, interpolation=interpolation, always_apply=always_apply, p=p)

self.min_max_height = min_max_height
Expand All @@ -388,8 +419,7 @@ class RandomResizedCrop(_BaseRandomSizedCrop):
"""Torchvision's variant of crop a random part of the input and rescale it to some size.
Args:
height (int): height after crop and resize.
width (int): width after crop and resize.
size (tuple[int]): target size for the output image, i.e. (height, width) after crop and resize
scale ((float, float)): range of size of the origin size cropped
ratio ((float, float)): range of aspect ratio of the origin aspect ratio cropped
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
Expand All @@ -406,17 +436,49 @@ class RandomResizedCrop(_BaseRandomSizedCrop):
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.BBOXES, Targets.KEYPOINTS)
_size_len = 2

def __init__(
self,
height: int,
width: int,
# NOTE @zetyquickly: when (width, height) are deprecated, make 'size' non optional
size: Optional[Union[int, Tuple[int, int]]] = None,
width: Optional[int] = None,
height: Optional[int] = None,
*,
scale: Tuple[float, float] = (0.08, 1.0),
ratio: Tuple[float, float] = (0.75, 1.3333333333333333),
interpolation: int = cv2.INTER_LINEAR,
always_apply: bool = False,
p: float = 1.0,
):
if isinstance(size, tuple):
if len(size) != self._size_len:
message = "Size must be a tuple of two integers (height, width)."
raise ValueError(message)
height, width = size
elif size is None:
if height is None or width is None:
message = "If 'size' is not provided, both 'height' and 'width' must be specified."
raise ValueError(message)
size = (height, width)
warn(
"Initializing with 'height' and 'width' is deprecated. "
"Please use a tuple (height, width) for the 'size' argument.",
DeprecationWarning,
stacklevel=2,
)
else:
if width is None:
message = "When 'size' is an integer, 'width' must be provided."
raise ValueError(message)
height = size
warn(
"Initializing with 'size' as an integer and a separate 'width' is deprecated. "
"Please use a tuple (height, width) for the 'size' argument.",
DeprecationWarning,
stacklevel=2,
)

super().__init__(height=height, width=width, interpolation=interpolation, always_apply=always_apply, p=p)
self.scale = scale
self.ratio = ratio
Expand Down
61 changes: 61 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import cv2
import numpy as np
import pytest
import warnings
from torchvision import transforms as torch_transforms

import albumentations as A
import albumentations.augmentations.functional as F
Expand Down Expand Up @@ -1353,3 +1355,62 @@ def test_spatter_incorrect_color(unsupported_color, mode, message):
A.Spatter(mode=mode, color=unsupported_color)

assert str(exc_info.value).startswith(message)

@pytest.mark.parametrize("height, width", [(100, 200), (200, 100)])
@pytest.mark.parametrize("scale", [(0.08, 1.0), (0.5, 1.0)])
@pytest.mark.parametrize("ratio", [(0.75, 1.33), (1.0, 1.0)])
def test_random_crop_interfaces_vs_torchvision(height, width, scale, ratio):
# NOTE: below will fail when height, width is no longer expected as first two positional arguments
transform_albu = A.RandomResizedCrop(height, width, scale=scale, ratio=ratio, p=1)
transform_albu_new = A.RandomResizedCrop(size=(height, width), scale=scale, ratio=ratio, p=1)

image = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
transformed_image_albu = transform_albu(image=image)['image']
transformed_image_albu_new = transform_albu_new(image=image)['image']

# PyTorch equivalent operation
transform_pt = torch_transforms.RandomResizedCrop(size=(height, width), scale=scale, ratio=ratio)
image_pil = torch_transforms.functional.to_pil_image(image)
transformed_image_pt = transform_pt(image_pil)

transformed_image_pt_np = np.array(transformed_image_pt)
assert transformed_image_albu.shape == transformed_image_pt_np.shape
assert transformed_image_albu_new.shape == transformed_image_pt_np.shape

# NOTE: below will fail when height, width is no longer expected as second and third positional arguments
transform_albu = A.RandomSizedCrop((128, 224), height, width, p=1.0)
transform_albu_new = A.RandomSizedCrop(min_max_height=(128, 224), size=(height, width), p=1.0)
transformed_image_albu = transform_albu(image=image)['image']
transformed_image_albu_new = transform_albu_new(image=image)['image']
assert transformed_image_albu.shape == transformed_image_pt_np.shape
assert transformed_image_albu_new.shape == transformed_image_pt_np.shape

# NOTE: below will fail when height, width is no longer expected as first two positional arguments
transform_albu = A.RandomResizedCrop(height, width, scale=scale, ratio=ratio, p=1)
transform_albu_height_is_size = A.RandomResizedCrop(size=height, width=width, scale=scale, ratio=ratio, p=1)

image = np.random.randint(0, 256, (224, 224, 3), dtype=np.uint8)
transformed_image_albu = transform_albu(image=image)['image']
transform_albu_height_is_size = transform_albu_new(image=image)['image']
assert transformed_image_albu.shape == transformed_image_pt_np.shape
assert transform_albu_height_is_size.shape == transformed_image_pt_np.shape

@pytest.mark.parametrize("size, width, height, expected_warning", [
((100, 200), None, None, None),
(None, 200, 100, DeprecationWarning),
(100, None, None, ValueError),
])
def test_deprecation_warnings(size, width, height, expected_warning):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
if expected_warning == ValueError:
with pytest.raises(ValueError):
A.RandomResizedCrop(size=size, width=width, height=height)
else:
A.RandomResizedCrop(size=size, width=width, height=height)
if expected_warning is DeprecationWarning:
assert len(w) == 1
assert issubclass(w[-1].category, expected_warning)
else:
assert not w
warnings.resetwarnings()

0 comments on commit 644f937

Please sign in to comment.