Skip to content

Commit

Permalink
Return applied alpha for MixUp
Browse files Browse the repository at this point in the history
  • Loading branch information
Dipet committed Mar 10, 2024
1 parent b1a79c2 commit 19e2e78
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions albumentations/augmentations/mixing/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class MixUp(ReferenceBasedTransform):
- The returned dictionary must include an 'image' key with a numpy array value.
- It may also include 'mask', 'global_label' each associated with numpy array values.
Defaults to a function that assumes input dictionary contains numpy arrays and directly returns it.
mix_coef_return_name (str): With this name will be returned applied alpha coefficient.
alpha (float):
The alpha parameter for the Beta distribution, influencing the mix's balance. Must be ≥ 0.
Higher values lead to more uniform mixing. Defaults to 0.4.
Expand Down Expand Up @@ -65,10 +66,12 @@ def __init__(
reference_data: Optional[Union[Generator[ReferenceImage, None, None], Sequence[Any]]] = None,
read_fn: Callable[[ReferenceImage], Any] = lambda x: {"image": x, "mask": None, "class_label": None},
alpha: float = 0.4,
mix_coef_return_name: str = "mix_coef",
always_apply: bool = False,
p: float = 0.5,
):
super().__init__(always_apply, p)
self.mix_coef_return_name = mix_coef_return_name

if alpha < 0:
msg = "Alpha must be >= 0."
Expand Down Expand Up @@ -151,3 +154,9 @@ def get_params(self) -> Dict[str, Union[None, float, Dict[str, Any]]]:
# If mix_data is not None, calculate mix_coef and apply read_fn
mix_coef = beta(self.alpha, self.alpha) # Assuming beta is defined elsewhere
return {"mix_data": self.read_fn(mix_data), "mix_coef": mix_coef}

def apply_with_params(self, params: Dict[str, Any], *args: Any, **kwargs: Any) -> Dict[str, Any]:
res = super().apply_with_params(params, *args, **kwargs)
if self.mix_coef_return_name:
res[self.mix_coef_return_name] = params["mix_coef"]
return res

0 comments on commit 19e2e78

Please sign in to comment.