Skip to content

Commit

Permalink
Added test for multiple targets
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Feb 29, 2024
1 parent abc1824 commit b08d8b2
Showing 1 changed file with 29 additions and 26 deletions.
55 changes: 29 additions & 26 deletions albumentations/augmentations/mixing/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,48 +21,51 @@ class ReferenceImage(TypedDict):


class MixUp(DualTransform):
"""MixUp data augmentation for images, masks, and class labels.
"""Performs MixUp data augmentation, blending images, masks, and class labels with reference data.
This transformation performs the MixUp augmentation by blending an input image, mask, and class label
with another set from a predefined reference dataset. The blending is controlled by a parameter lambda,
sampled from a Beta distribution, dictating the proportion of the mix between the original and reference data.
The MixUp augmentation is known for improving model generalization by encouraging linear behavior between
classes and smoothing the decision boundaries. It can be applied not only to the images but also to the
segmentation masks and class labels, providing a comprehensive data augmentation strategy.
MixUp augmentation linearly combines an input (image, mask, and class label) with another set from a predefined
reference dataset. The mixing degree is controlled by a parameter λ (lambda), sampled from a Beta distribution.
This method is known for improving model generalization by promoting linear behavior between classes and
smoothing decision boundaries.
Reference:
Zhang, H., Cisse, M., Dauphin, Y.N., and Lopez-Paz, D., mixup: Beyond Empirical Risk Minimization,
ICLR 2018. https://arxiv.org/abs/1710.09412
Zhang, H., Cisse, M., Dauphin, Y.N., and Lopez-Paz, D. (2018). mixup: Beyond Empirical Risk Minimization.
In International Conference on Learning Representations. https://arxiv.org/abs/1710.09412
Args:
----
reference_data Sequence[ReferenceImage]: A sequence of dictionaries containing the reference
images, masks, and class labels for mixing. Each dictionary should have keys 'image', and optionally 'mask',
and 'class_label'. Defaults to an empty list, resulting in no operation if not provided.
read_fn Callable[[Any], Dict[str, np.ndarray]]: A function to load and process the data
from the reference_data dictionaries. It should accept one argument (one of the dictionaries) and
return a processed dictionary containing numpy arrays for keys 'image', and optionally 'mask'
and 'class_label'. The 'class_label' should be one-hot encoded if provided.
Defaults to a lambda function that acts as a no-op for simplification.
alpha (float): The alpha parameter of the Beta distribution used to sample the lambda value.
Must be greater than or equal to 0. Higher values make the distribution closer to uniform,
resulting in more balanced mixing. Defaults to 0.4.
p (float): Probability that the transform will be applied. Defaults to 0.5.
reference_data (Optional[Union[Generator[ReferenceImage, None, None], Sequence[ReferenceImage]]]):
A sequence or generator of dictionaries containing the reference data for mixing. Each dictionary
should contain:
- 'image': Mandatory key with an image array.
- 'mask': Optional key with a mask array.
- 'global_label': Optional key with a class label array.
If None or an empty sequence is provided, no operation is performed and a warning is issued.
read_fn (Callable[[ReferenceImage], Dict[str, Any]]):
A function to process items from reference_data. It should accept a dictionary from reference_data
and return a processed dictionary containing 'image', and optionally 'mask' and 'global_label',
each as numpy arrays. Defaults to a no-op lambda function.
alpha (float):
The alpha parameter for the Beta distribution, influencing the mix's balance. Must be ≥ 0.
Higher values lead to more uniform mixing. Defaults to 0.4.
p (float):
The probability of applying the transformation. Defaults to 0.5.
Targets:
image, mask, global_label
- image: The input image to augment.
- mask: An optional segmentation mask corresponding to the input image.
- global_label: An optional global label associated with the input image.
Image types:
uint8, float32
- uint8, float32
Raises:
------
ValueError: If the alpha parameter is negative.
- ValueError: If the alpha parameter is negative.
Notes:
-----
- If no reference data is provided, this transform will issue a warning and act as a no-op.
- If no reference data is provided, a warning is issued, and the transform acts as a no-op.
"""

Expand Down

0 comments on commit b08d8b2

Please sign in to comment.