Skip to content

Commit

Permalink
Return applied alpha for MixUp (#1572)
Browse files Browse the repository at this point in the history
* Return applied alpha for MixUp

* Added mix coef to tests

---------

Co-authored-by: Vladimir Iglovikov <[email protected]>
Co-authored-by: Vladimir Iglovikov <[email protected]>
  • Loading branch information
3 people authored Mar 12, 2024
1 parent 597e02f commit 95fac64
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 10 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ repos:
- id: codespell
additional_dependencies: ["tomli"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
rev: v1.9.0
hooks:
- id: mypy
files: ^albumentations/
Expand Down
10 changes: 10 additions & 0 deletions albumentations/augmentations/mixing/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class MixUp(ReferenceBasedTransform):
- The returned dictionary must include an 'image' key with a numpy array value.
- It may also include 'mask', 'global_label' each associated with numpy array values.
Defaults to a function that assumes input dictionary contains numpy arrays and directly returns it.
mix_coef_return_name (str): Name used for the applied alpha coefficient in the returned dictionary.
Defaults to "mix_coef".
alpha (float):
The alpha parameter for the Beta distribution, influencing the mix's balance. Must be ≥ 0.
Higher values lead to more uniform mixing. Defaults to 0.4.
Expand Down Expand Up @@ -65,10 +67,12 @@ def __init__(
reference_data: Optional[Union[Generator[ReferenceImage, None, None], Sequence[Any]]] = None,
read_fn: Callable[[ReferenceImage], Any] = lambda x: {"image": x, "mask": None, "class_label": None},
alpha: float = 0.4,
mix_coef_return_name: str = "mix_coef",
always_apply: bool = False,
p: float = 0.5,
):
super().__init__(always_apply, p)
self.mix_coef_return_name = mix_coef_return_name

if alpha < 0:
msg = "Alpha must be >= 0."
Expand Down Expand Up @@ -151,3 +155,9 @@ def get_params(self) -> Dict[str, Union[None, float, Dict[str, Any]]]:
# If mix_data is not None, calculate mix_coef and apply read_fn
mix_coef = beta(self.alpha, self.alpha) # Assuming beta is defined elsewhere
return {"mix_data": self.read_fn(mix_data), "mix_coef": mix_coef}

def apply_with_params(self, params: Dict[str, Any], *args: Any, **kwargs: Any) -> Dict[str, Any]:
res = super().apply_with_params(params, *args, **kwargs)
if self.mix_coef_return_name:
res[self.mix_coef_return_name] = params["mix_coef"]
return res
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
deepdiff>=6.7.1
mypy>=1.8.0
mypy>=1.9.0
pre_commit>=3.5.0
pytest>=8.0.2
pytest_cov>=4.1.0
Expand Down
30 changes: 22 additions & 8 deletions tests/test_mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def complex_read_fn_image(x):
"reference_data": complex_image_generator(),
"read_fn": complex_read_fn_image})] )
def test_image_only(augmentation_cls, params, image):
aug = augmentation_cls(p=1, **params)
aug = A.Compose([augmentation_cls(p=1, **params)], p=1)
data = aug(image=image)
assert data["image"].dtype == np.uint8

Expand All @@ -58,20 +58,25 @@ def test_image_only(augmentation_cls, params, image):
]
)
def test_image_global_label(augmentation_cls, params, image, global_label):
aug = augmentation_cls(p=1, **params)
aug = A.Compose([augmentation_cls(p=1, **params)], p=1)

data = aug(image=image, global_label=global_label)

assert data["image"].dtype == np.uint8

reference_item = params["read_fn"](aug.reference_data[0])
reference_data = params["reference_data"][0]

reference_item = params["read_fn"](reference_data)

reference_image = reference_item["image"]
reference_global_label = reference_item["global_label"]

mix_coef = data["mix_coef"]

mix_coeff_image = find_mix_coef(data["image"], image, reference_image)
mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_global_label)

assert math.isclose(mix_coef, mix_coeff_image, abs_tol=0.01)
assert math.isclose(mix_coeff_image, mix_coeff_label, abs_tol=0.01)
assert 0 <= mix_coeff_image <= 1

Expand All @@ -85,16 +90,19 @@ def test_image_global_label(augmentation_cls, params, image, global_label):
"read_fn": lambda x: x})]
)
def test_image_mask_global_label(augmentation_cls, params, image, mask, global_label):
aug = augmentation_cls(p=1, **params)
aug = A.Compose([augmentation_cls(p=1, **params)], p=1)

data = aug(image=image, global_label=global_label, mask=mask)

assert data["image"].dtype == np.uint8
reference_data = params["reference_data"][0]

mix_coeff_image = find_mix_coef(data["image"], image, aug.reference_data[0]["image"])
mix_coeff_mask = find_mix_coef(data["mask"], mask, aug.reference_data[0]["mask"])
mix_coeff_label = find_mix_coef(data["global_label"], global_label, aug.reference_data[0]["global_label"])
mix_coef = data["mix_coef"]

mix_coeff_image = find_mix_coef(data["image"], image, reference_data["image"])
mix_coeff_mask = find_mix_coef(data["mask"], mask, reference_data["mask"])
mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_data["global_label"])

assert math.isclose(mix_coef, mix_coeff_image, abs_tol=0.01)
assert math.isclose(mix_coeff_image, mix_coeff_label, abs_tol=0.01)
assert math.isclose(mix_coeff_image, mix_coeff_mask, abs_tol=0.01)
assert 0 <= mix_coeff_image <= 1
Expand All @@ -115,6 +123,8 @@ def test_additional_targets(image, mask, global_label):

data = aug(image=image, global_label=global_label, mask=mask, image1=image1, global_label1=global_label1, mask1=mask1)

mix_coef = data["mix_coef"]

assert data["image"].dtype == np.uint8

mix_coeff_image = find_mix_coef(data["image"], image, reference_data[0]["image"])
Expand All @@ -125,6 +135,7 @@ def test_additional_targets(image, mask, global_label):
mix_coeff_mask1 = find_mix_coef(data["mask1"], mask1, reference_data[0]["mask"])
mix_coeff_label1 = find_mix_coef(data["global_label1"], global_label1, reference_data[0]["global_label"])

assert math.isclose(mix_coef, mix_coeff_image, abs_tol=0.01)
assert math.isclose(mix_coeff_image, mix_coeff_label, abs_tol=0.01)

assert math.isclose(mix_coeff_image, mix_coeff_mask, abs_tol=0.01)
Expand Down Expand Up @@ -176,6 +187,9 @@ def test_pipeline(augmentation_cls, params, image, mask, global_label):

assert data["image"].dtype == np.uint8

mix_coef = data["mix_coef"]

mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_data[0]["global_label"])

assert math.isclose(mix_coef, mix_coeff_label, abs_tol=0.01)
assert 0 <= mix_coeff_label <= 1

0 comments on commit 95fac64

Please sign in to comment.