Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Mar 9, 2024
1 parent 6cf2641 commit 3497f95
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
19 changes: 17 additions & 2 deletions albumentations/augmentations/crops/transforms.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import random
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from warnings import warn

import cv2
import numpy as np
Expand Down Expand Up @@ -481,7 +482,8 @@ class RandomCropNearBBox(DualTransform):
to `cropping_bbox` dimension.
If max_part_shift is a single float, the range will be (max_part_shift, max_part_shift).
Default (0.3, 0.3).
cropping_bbox_key (str): Additional target key for cropping box. Default `cropping_bbox`
cropping_bbox_key (str): Additional target key for cropping box. Default `cropping_bbox`.
cropping_box_key (str): [Deprecated] Use `cropping_bbox_key` instead.
p (float): probability of applying the transform. Default: 1.
Targets:
Expand All @@ -491,7 +493,7 @@ class RandomCropNearBBox(DualTransform):
uint8, float32
Examples:
>>> aug = Compose([RandomCropNearBBox(max_part_shift=(0.1, 0.5), cropping_box_key='test_box')],
>>> aug = Compose([RandomCropNearBBox(max_part_shift=(0.1, 0.5), cropping_bbox_key='test_box')],
>>> bbox_params=BboxParams("pascal_voc"))
>>> result = aug(image=image, bboxes=bboxes, test_box=[0, 5, 10, 20])
Expand All @@ -503,11 +505,24 @@ def __init__(
self,
max_part_shift: ScaleFloatType = (0.3, 0.3),
cropping_bbox_key: str = "cropping_bbox",
cropping_box_key: Optional[str] = None, # Deprecated
always_apply: bool = False,
p: float = 1.0,
):
super().__init__(always_apply, p)
self.max_part_shift = to_tuple(max_part_shift, low=max_part_shift)

# Check for deprecated parameter and issue warning
if cropping_box_key is not None:
warn(
"The parameter 'cropping_box_key' is deprecated and will be removed in future versions. "
"Use 'cropping_bbox_key' instead.",
DeprecationWarning,
stacklevel=2,
)
# Ensure the new parameter is used even if the old one is passed
cropping_bbox_key = cropping_box_key

self.cropping_bbox_key = cropping_bbox_key

if min(self.max_part_shift) < 0 or max(self.max_part_shift) > 1:
Expand Down
9 changes: 6 additions & 3 deletions albumentations/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import json
import warnings
from abc import ABC, ABCMeta, abstractmethod
from collections.abc import Mapping, Sequence
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional, TextIO, Tuple, Type, Union
Expand Down Expand Up @@ -168,10 +169,12 @@ def check_data_format(data_format: str) -> None:


def serialize_enum(obj: Any) -> Any:
"""Recursively search for Enum objects and convert them to their value."""
if isinstance(obj, dict):
"""Recursively search for Enum objects and convert them to their value.
Also handle any Mapping or Sequence types.
"""
if isinstance(obj, Mapping):
return {k: serialize_enum(v) for k, v in obj.items()}
if isinstance(obj, list):
if isinstance(obj, Sequence) and not isinstance(obj, str): # exclude strings since they're also sequences
return [serialize_enum(v) for v in obj]
return obj.value if isinstance(obj, Enum) else obj

Expand Down

0 comments on commit 3497f95

Please sign in to comment.