Skip to content

Commit

Permalink
Added example to mixeup docstring (#1576)
Browse files Browse the repository at this point in the history
* Added example to mixeup docstring

* Update albumentations/augmentations/mixing/transforms.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Update albumentations/augmentations/mixing/transforms.py

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

* Added example to mixeup docstring

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
  • Loading branch information
ternaus and sourcery-ai[bot] authored Mar 12, 2024
1 parent 95fac64 commit a4b9606
Showing 1 changed file with 46 additions and 1 deletion.
47 changes: 46 additions & 1 deletion albumentations/augmentations/mixing/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class MixUp(ReferenceBasedTransform):
- The returned dictionary must include an 'image' key with a numpy array value.
- It may also include 'mask', 'global_label' each associated with numpy array values.
Defaults to a function that assumes input dictionary contains numpy arrays and directly returns it.
mix_coef_return_name (str): Name used for the applied alpha coefficient in the returned dictionary.
mix_coef_return_name (str): Name used for the applied alpha coefficient in the returned dictionary.
Defaults to "mix_coef".
alpha (float):
The alpha parameter for the Beta distribution, influencing the mix's balance. Must be ≥ 0.
Expand All @@ -58,6 +58,51 @@ class MixUp(ReferenceBasedTransform):
Notes:
- If no reference data is provided, a warning is issued, and the transform acts as a no-op.
- Notes if images are in float32 format, they should be within [0, 1] range.
Example Usage:
import albumentations as A
import numpy as np
from albumentations.core.types import ReferenceImage
# Prepare reference data
# Note: This code generates random reference data for demonstration purposes only.
# In real-world applications, it's crucial to use meaningful and representative data.
# The quality and relevance of your input data significantly impact the effectiveness
# of the augmentation process. Ensure your data closely aligns with your specific
# use case and application requirements.
reference_data = [ReferenceImage(image=np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8),
mask=np.random.randint(0, 4, (100, 100, 1), dtype=np.uint8),
global_label=np.random.choice([0, 1], size=3)) for i in range(10)]
# In this example, the lambda function simply returns its input, which works well for
# data already in the expected format. For more complex scenarios, where the data might not be in
# the required format or additional processing is needed, a more sophisticated function can be implemented.
# Below is a hypothetical example where the input data is a file path, # and the function reads the image
# file, converts it to a specific format, and possibly performs other preprocessing steps.
# Example of a more complex read_fn that reads an image from a file path, converts it to RGB, and resizes it.
# def custom_read_fn(file_path):
# from PIL import Image
# image = Image.open(file_path).convert('RGB')
# image = image.resize((100, 100)) # Example resize, adjust as needed.
# return np.array(image)
# aug = A.Compose([A.RandomRotate90(), A.MixUp(p=1, reference_data=reference_data, read_fn=lambda x: x)])
# For simplicity, the original lambda function is used in this example.
# Replace `lambda x: x` with `custom_read_fn`if you need to process the data more extensively.
# Apply augmentations
image = np.empty([100, 100, 3], dtype=np.uint8)
mask = np.empty([100, 100], dtype=np.uint8)
global_label = np.array([0, 1, 0])
data = aug(image=image, global_label=global_label, mask=mask)
transformed_image = data["image"]
transformed_mask = data["mask"]
transformed_global_label = data["global_label"]
# Print applied mix coefficient
print(data["mix_coef"]) # Output: e.g., 0.9991580344142427
"""

_targets = (Targets.IMAGE, Targets.MASK, Targets.GLOBAL_LABEL)
Expand Down

0 comments on commit a4b9606

Please sign in to comment.