Skip to content

Commit

Permalink
Added mix coef to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Mar 12, 2024
1 parent 19e2e78 commit 1073bf4
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ repos:
- id: codespell
additional_dependencies: ["tomli"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
rev: v1.9.0
hooks:
- id: mypy
files: ^albumentations/
Expand Down
3 changes: 2 additions & 1 deletion albumentations/augmentations/mixing/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class MixUp(ReferenceBasedTransform):
- The returned dictionary must include an 'image' key with a numpy array value.
- It may also include 'mask', 'global_label' each associated with numpy array values.
Defaults to a function that assumes input dictionary contains numpy arrays and directly returns it.
mix_coef_return_name (str): With this name will be returned applied alpha coefficient.
mix_coef_return_name (str): Name used for the applied alpha coefficient in the returned dictionary.
Defaults to "mix_coef".
alpha (float):
The alpha parameter for the Beta distribution, influencing the mix's balance. Must be ≥ 0.
Higher values lead to more uniform mixing. Defaults to 0.4.
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
deepdiff>=6.7.1
mypy>=1.8.0
mypy>=1.9.0
pre_commit>=3.5.0
pytest>=8.0.2
pytest_cov>=4.1.0
Expand Down
30 changes: 22 additions & 8 deletions tests/test_mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def complex_read_fn_image(x):
"reference_data": complex_image_generator(),
"read_fn": complex_read_fn_image})] )
def test_image_only(augmentation_cls, params, image):
aug = augmentation_cls(p=1, **params)
aug = A.Compose([augmentation_cls(p=1, **params)], p=1)
data = aug(image=image)
assert data["image"].dtype == np.uint8

Expand All @@ -58,20 +58,25 @@ def test_image_only(augmentation_cls, params, image):
]
)
def test_image_global_label(augmentation_cls, params, image, global_label):
aug = augmentation_cls(p=1, **params)
aug = A.Compose([augmentation_cls(p=1, **params)], p=1)

data = aug(image=image, global_label=global_label)

assert data["image"].dtype == np.uint8

reference_item = params["read_fn"](aug.reference_data[0])
reference_data = params["reference_data"][0]

reference_item = params["read_fn"](reference_data)

reference_image = reference_item["image"]
reference_global_label = reference_item["global_label"]

mix_coef = data["mix_coef"]

mix_coeff_image = find_mix_coef(data["image"], image, reference_image)
mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_global_label)

assert math.isclose(mix_coef, mix_coeff_image, abs_tol=0.01)
assert math.isclose(mix_coeff_image, mix_coeff_label, abs_tol=0.01)
assert 0 <= mix_coeff_image <= 1

Expand All @@ -85,16 +90,19 @@ def test_image_global_label(augmentation_cls, params, image, global_label):
"read_fn": lambda x: x})]
)
def test_image_mask_global_label(augmentation_cls, params, image, mask, global_label):
aug = augmentation_cls(p=1, **params)
aug = A.Compose([augmentation_cls(p=1, **params)], p=1)

data = aug(image=image, global_label=global_label, mask=mask)

assert data["image"].dtype == np.uint8
reference_data = params["reference_data"][0]

mix_coeff_image = find_mix_coef(data["image"], image, aug.reference_data[0]["image"])
mix_coeff_mask = find_mix_coef(data["mask"], mask, aug.reference_data[0]["mask"])
mix_coeff_label = find_mix_coef(data["global_label"], global_label, aug.reference_data[0]["global_label"])
mix_coef = data["mix_coef"]

mix_coeff_image = find_mix_coef(data["image"], image, reference_data["image"])
mix_coeff_mask = find_mix_coef(data["mask"], mask, reference_data["mask"])
mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_data["global_label"])

assert math.isclose(mix_coef, mix_coeff_image, abs_tol=0.01)
assert math.isclose(mix_coeff_image, mix_coeff_label, abs_tol=0.01)
assert math.isclose(mix_coeff_image, mix_coeff_mask, abs_tol=0.01)
assert 0 <= mix_coeff_image <= 1
Expand All @@ -115,6 +123,8 @@ def test_additional_targets(image, mask, global_label):

data = aug(image=image, global_label=global_label, mask=mask, image1=image1, global_label1=global_label1, mask1=mask1)

mix_coef = data["mix_coef"]

assert data["image"].dtype == np.uint8

mix_coeff_image = find_mix_coef(data["image"], image, reference_data[0]["image"])
Expand All @@ -125,6 +135,7 @@ def test_additional_targets(image, mask, global_label):
mix_coeff_mask1 = find_mix_coef(data["mask1"], mask1, reference_data[0]["mask"])
mix_coeff_label1 = find_mix_coef(data["global_label1"], global_label1, reference_data[0]["global_label"])

assert math.isclose(mix_coef, mix_coeff_image, abs_tol=0.01)
assert math.isclose(mix_coeff_image, mix_coeff_label, abs_tol=0.01)

assert math.isclose(mix_coeff_image, mix_coeff_mask, abs_tol=0.01)
Expand Down Expand Up @@ -176,6 +187,9 @@ def test_pipeline(augmentation_cls, params, image, mask, global_label):

assert data["image"].dtype == np.uint8

mix_coef = data["mix_coef"]

mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_data[0]["global_label"])

assert math.isclose(mix_coef, mix_coeff_label, abs_tol=0.01)
assert 0 <= mix_coeff_label <= 1

0 comments on commit 1073bf4

Please sign in to comment.