Skip to content
This repository was archived by the owner on Jul 10, 2025. It is now read-only.

Commit 95fac64

Browse files
Dipetternaus
andauthored
Return applied alpha for MixUp (#1572)
* Return applied alpha for MixUp * Added mix coef to tests --------- Co-authored-by: Vladimir Iglovikov <iglovikov@gmail.com> Co-authored-by: Vladimir Iglovikov <ternaus@users.noreply.github.com>
1 parent 597e02f commit 95fac64

File tree

4 files changed

+34
-10
lines changed

4 files changed

+34
-10
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ repos:
6767
- id: codespell
6868
additional_dependencies: ["tomli"]
6969
- repo: https://github.com/pre-commit/mirrors-mypy
70-
rev: v1.8.0
70+
rev: v1.9.0
7171
hooks:
7272
- id: mypy
7373
files: ^albumentations/

albumentations/augmentations/mixing/transforms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class MixUp(ReferenceBasedTransform):
3737
- The returned dictionary must include an 'image' key with a numpy array value.
3838
- It may also include 'mask', 'global_label' each associated with numpy array values.
3939
Defaults to a function that assumes input dictionary contains numpy arrays and directly returns it.
40+
mix_coef_return_name (str): Name used for the applied alpha coefficient in the returned dictionary.
41+
Defaults to "mix_coef".
4042
alpha (float):
4143
The alpha parameter for the Beta distribution, influencing the mix's balance. Must be ≥ 0.
4244
Higher values lead to more uniform mixing. Defaults to 0.4.
@@ -65,10 +67,12 @@ def __init__(
6567
reference_data: Optional[Union[Generator[ReferenceImage, None, None], Sequence[Any]]] = None,
6668
read_fn: Callable[[ReferenceImage], Any] = lambda x: {"image": x, "mask": None, "class_label": None},
6769
alpha: float = 0.4,
70+
mix_coef_return_name: str = "mix_coef",
6871
always_apply: bool = False,
6972
p: float = 0.5,
7073
):
7174
super().__init__(always_apply, p)
75+
self.mix_coef_return_name = mix_coef_return_name
7276

7377
if alpha < 0:
7478
msg = "Alpha must be >= 0."
@@ -151,3 +155,9 @@ def get_params(self) -> Dict[str, Union[None, float, Dict[str, Any]]]:
151155
# If mix_data is not None, calculate mix_coef and apply read_fn
152156
mix_coef = beta(self.alpha, self.alpha) # Assuming beta is defined elsewhere
153157
return {"mix_data": self.read_fn(mix_data), "mix_coef": mix_coef}
158+
159+
def apply_with_params(self, params: Dict[str, Any], *args: Any, **kwargs: Any) -> Dict[str, Any]:
160+
res = super().apply_with_params(params, *args, **kwargs)
161+
if self.mix_coef_return_name:
162+
res[self.mix_coef_return_name] = params["mix_coef"]
163+
return res

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
deepdiff>=6.7.1
2-
mypy>=1.8.0
2+
mypy>=1.9.0
33
pre_commit>=3.5.0
44
pytest>=8.0.2
55
pytest_cov>=4.1.0

tests/test_mixing.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def complex_read_fn_image(x):
4040
"reference_data": complex_image_generator(),
4141
"read_fn": complex_read_fn_image})] )
4242
def test_image_only(augmentation_cls, params, image):
43-
aug = augmentation_cls(p=1, **params)
43+
aug = A.Compose([augmentation_cls(p=1, **params)], p=1)
4444
data = aug(image=image)
4545
assert data["image"].dtype == np.uint8
4646

@@ -58,20 +58,25 @@ def test_image_only(augmentation_cls, params, image):
5858
]
5959
)
6060
def test_image_global_label(augmentation_cls, params, image, global_label):
61-
aug = augmentation_cls(p=1, **params)
61+
aug = A.Compose([augmentation_cls(p=1, **params)], p=1)
6262

6363
data = aug(image=image, global_label=global_label)
6464

6565
assert data["image"].dtype == np.uint8
6666

67-
reference_item = params["read_fn"](aug.reference_data[0])
67+
reference_data = params["reference_data"][0]
68+
69+
reference_item = params["read_fn"](reference_data)
6870

6971
reference_image = reference_item["image"]
7072
reference_global_label = reference_item["global_label"]
7173

74+
mix_coef = data["mix_coef"]
75+
7276
mix_coeff_image = find_mix_coef(data["image"], image, reference_image)
7377
mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_global_label)
7478

79+
assert math.isclose(mix_coef, mix_coeff_image, abs_tol=0.01)
7580
assert math.isclose(mix_coeff_image, mix_coeff_label, abs_tol=0.01)
7681
assert 0 <= mix_coeff_image <= 1
7782

@@ -85,16 +90,19 @@ def test_image_global_label(augmentation_cls, params, image, global_label):
8590
"read_fn": lambda x: x})]
8691
)
8792
def test_image_mask_global_label(augmentation_cls, params, image, mask, global_label):
88-
aug = augmentation_cls(p=1, **params)
93+
aug = A.Compose([augmentation_cls(p=1, **params)], p=1)
8994

9095
data = aug(image=image, global_label=global_label, mask=mask)
9196

92-
assert data["image"].dtype == np.uint8
97+
reference_data = params["reference_data"][0]
9398

94-
mix_coeff_image = find_mix_coef(data["image"], image, aug.reference_data[0]["image"])
95-
mix_coeff_mask = find_mix_coef(data["mask"], mask, aug.reference_data[0]["mask"])
96-
mix_coeff_label = find_mix_coef(data["global_label"], global_label, aug.reference_data[0]["global_label"])
99+
mix_coef = data["mix_coef"]
97100

101+
mix_coeff_image = find_mix_coef(data["image"], image, reference_data["image"])
102+
mix_coeff_mask = find_mix_coef(data["mask"], mask, reference_data["mask"])
103+
mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_data["global_label"])
104+
105+
assert math.isclose(mix_coef, mix_coeff_image, abs_tol=0.01)
98106
assert math.isclose(mix_coeff_image, mix_coeff_label, abs_tol=0.01)
99107
assert math.isclose(mix_coeff_image, mix_coeff_mask, abs_tol=0.01)
100108
assert 0 <= mix_coeff_image <= 1
@@ -115,6 +123,8 @@ def test_additional_targets(image, mask, global_label):
115123

116124
data = aug(image=image, global_label=global_label, mask=mask, image1=image1, global_label1=global_label1, mask1=mask1)
117125

126+
mix_coef = data["mix_coef"]
127+
118128
assert data["image"].dtype == np.uint8
119129

120130
mix_coeff_image = find_mix_coef(data["image"], image, reference_data[0]["image"])
@@ -125,6 +135,7 @@ def test_additional_targets(image, mask, global_label):
125135
mix_coeff_mask1 = find_mix_coef(data["mask1"], mask1, reference_data[0]["mask"])
126136
mix_coeff_label1 = find_mix_coef(data["global_label1"], global_label1, reference_data[0]["global_label"])
127137

138+
assert math.isclose(mix_coef, mix_coeff_image, abs_tol=0.01)
128139
assert math.isclose(mix_coeff_image, mix_coeff_label, abs_tol=0.01)
129140

130141
assert math.isclose(mix_coeff_image, mix_coeff_mask, abs_tol=0.01)
@@ -176,6 +187,9 @@ def test_pipeline(augmentation_cls, params, image, mask, global_label):
176187

177188
assert data["image"].dtype == np.uint8
178189

190+
mix_coef = data["mix_coef"]
191+
179192
mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_data[0]["global_label"])
180193

194+
assert math.isclose(mix_coef, mix_coeff_label, abs_tol=0.01)
181195
assert 0 <= mix_coeff_label <= 1

0 commit comments

Comments
 (0)