Skip to content

Commit

Permalink
Added example to mixeup docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Mar 12, 2024
1 parent 95fac64 commit 5a7635b
Showing 1 changed file with 25 additions and 1 deletion.
26 changes: 25 additions & 1 deletion albumentations/augmentations/mixing/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class MixUp(ReferenceBasedTransform):
- The returned dictionary must include an 'image' key with a numpy array value.
- It may also include 'mask', 'global_label' each associated with numpy array values.
Defaults to a function that assumes input dictionary contains numpy arrays and directly returns it.
mix_coef_return_name (str): Name used for the applied alpha coefficient in the returned dictionary.
mix_coef_return_name (str): Name used for the applied alpha coefficient in the returned dictionary.
Defaults to "mix_coef".
alpha (float):
The alpha parameter for the Beta distribution, influencing the mix's balance. Must be ≥ 0.
Expand All @@ -58,6 +58,30 @@ class MixUp(ReferenceBasedTransform):
Notes:
- If no reference data is provided, a warning is issued, and the transform acts as a no-op.
- Notes if images are in float32 format, they should be within [0, 1] range.
Example Usage:
import albumentations as A
import numpy as np
from albumentations.core.types import ReferenceImage
# Prepare reference data
reference_data = [ReferenceImage(image=np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8),
mask=np.random.randint(0, 4, (100, 100, 1), dtype=np.uint8),
global_label=np.random.choice([0, 1], size=3)) for i in range(10)]
aug = A.Compose([A.RandomRotate90(), A.MixUp(p=1, reference_data=reference_data, read_fn=lambda x: x)])
# Apply augmentations
image = np.empty([100, 100, 3], dtype=np.uint8)
mask = np.empty([100, 100], dtype=np.uint8)
global_label = np.array([0, 1, 0])
data = aug(image=image, global_label=global_label, mask=mask)
transformed_image = data["image"]
transformed_mask = data["mask"]
transformed_global_label = data["global_label"]
# Print applied mix coefficient
print(data["mix_coef"]) # Output: e.g., 0.9991580344142427
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.GLOBAL_LABEL)
Expand Down

0 comments on commit 5a7635b

Please sign in to comment.