diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 53044df60..56b165c6d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -67,7 +67,7 @@ repos: - id: codespell additional_dependencies: ["tomli"] - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.8.0 + rev: v1.9.0 hooks: - id: mypy files: ^albumentations/ diff --git a/albumentations/augmentations/mixing/transforms.py b/albumentations/augmentations/mixing/transforms.py index 0155721ac..8c39fd854 100644 --- a/albumentations/augmentations/mixing/transforms.py +++ b/albumentations/augmentations/mixing/transforms.py @@ -37,6 +37,8 @@ class MixUp(ReferenceBasedTransform): - The returned dictionary must include an 'image' key with a numpy array value. - It may also include 'mask', 'global_label' each associated with numpy array values. Defaults to a function that assumes input dictionary contains numpy arrays and directly returns it. + mix_coef_return_name (str): Name used for the applied alpha coefficient in the returned dictionary. + Defaults to "mix_coef". alpha (float): The alpha parameter for the Beta distribution, influencing the mix's balance. Must be ≥ 0. Higher values lead to more uniform mixing. Defaults to 0.4. @@ -65,10 +67,12 @@ def __init__( reference_data: Optional[Union[Generator[ReferenceImage, None, None], Sequence[Any]]] = None, read_fn: Callable[[ReferenceImage], Any] = lambda x: {"image": x, "mask": None, "class_label": None}, alpha: float = 0.4, + mix_coef_return_name: str = "mix_coef", always_apply: bool = False, p: float = 0.5, ): super().__init__(always_apply, p) + self.mix_coef_return_name = mix_coef_return_name if alpha < 0: msg = "Alpha must be >= 0." @@ -151,3 +155,9 @@ def get_params(self) -> Dict[str, Union[None, float, Dict[str, Any]]]: # If mix_data is not None, calculate mix_coef and apply read_fn mix_coef = beta(self.alpha, self.alpha) # Assuming beta is defined elsewhere return {"mix_data": self.read_fn(mix_data), "mix_coef": mix_coef} + + def apply_with_params(self, params: Dict[str, Any], *args: Any, **kwargs: Any) -> Dict[str, Any]: + res = super().apply_with_params(params, *args, **kwargs) + if self.mix_coef_return_name: + res[self.mix_coef_return_name] = params["mix_coef"] + return res diff --git a/requirements-dev.txt b/requirements-dev.txt index 8c464a74d..8a12d863d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,5 @@ deepdiff>=6.7.1 -mypy>=1.8.0 +mypy>=1.9.0 pre_commit>=3.5.0 pytest>=8.0.2 pytest_cov>=4.1.0 diff --git a/tests/test_mixing.py b/tests/test_mixing.py index f57b79433..7b9d3a459 100644 --- a/tests/test_mixing.py +++ b/tests/test_mixing.py @@ -40,7 +40,7 @@ def complex_read_fn_image(x): "reference_data": complex_image_generator(), "read_fn": complex_read_fn_image})] ) def test_image_only(augmentation_cls, params, image): - aug = augmentation_cls(p=1, **params) + aug = A.Compose([augmentation_cls(p=1, **params)], p=1) data = aug(image=image) assert data["image"].dtype == np.uint8 @@ -58,20 +58,25 @@ def test_image_only(augmentation_cls, params, image): ] ) def test_image_global_label(augmentation_cls, params, image, global_label): - aug = augmentation_cls(p=1, **params) + aug = A.Compose([augmentation_cls(p=1, **params)], p=1) data = aug(image=image, global_label=global_label) assert data["image"].dtype == np.uint8 - reference_item = params["read_fn"](aug.reference_data[0]) + reference_data = params["reference_data"][0] + + reference_item = params["read_fn"](reference_data) reference_image = reference_item["image"] reference_global_label = reference_item["global_label"] + mix_coef = data["mix_coef"] + mix_coeff_image = find_mix_coef(data["image"], image, reference_image) mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_global_label) + assert math.isclose(mix_coef, mix_coeff_image, abs_tol=0.01) assert math.isclose(mix_coeff_image, mix_coeff_label, abs_tol=0.01) assert 0 <= mix_coeff_image <= 1 @@ -85,16 +90,19 @@ def test_image_global_label(augmentation_cls, params, image, global_label): "read_fn": lambda x: x})] ) def test_image_mask_global_label(augmentation_cls, params, image, mask, global_label): - aug = augmentation_cls(p=1, **params) + aug = A.Compose([augmentation_cls(p=1, **params)], p=1) data = aug(image=image, global_label=global_label, mask=mask) - assert data["image"].dtype == np.uint8 + reference_data = params["reference_data"][0] - mix_coeff_image = find_mix_coef(data["image"], image, aug.reference_data[0]["image"]) - mix_coeff_mask = find_mix_coef(data["mask"], mask, aug.reference_data[0]["mask"]) - mix_coeff_label = find_mix_coef(data["global_label"], global_label, aug.reference_data[0]["global_label"]) + mix_coef = data["mix_coef"] + mix_coeff_image = find_mix_coef(data["image"], image, reference_data["image"]) + mix_coeff_mask = find_mix_coef(data["mask"], mask, reference_data["mask"]) + mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_data["global_label"]) + + assert math.isclose(mix_coef, mix_coeff_image, abs_tol=0.01) assert math.isclose(mix_coeff_image, mix_coeff_label, abs_tol=0.01) assert math.isclose(mix_coeff_image, mix_coeff_mask, abs_tol=0.01) assert 0 <= mix_coeff_image <= 1 @@ -115,6 +123,8 @@ def test_additional_targets(image, mask, global_label): data = aug(image=image, global_label=global_label, mask=mask, image1=image1, global_label1=global_label1, mask1=mask1) + mix_coef = data["mix_coef"] + assert data["image"].dtype == np.uint8 mix_coeff_image = find_mix_coef(data["image"], image, reference_data[0]["image"]) @@ -125,6 +135,7 @@ def test_additional_targets(image, mask, global_label): mix_coeff_mask1 = find_mix_coef(data["mask1"], mask1, reference_data[0]["mask"]) mix_coeff_label1 = find_mix_coef(data["global_label1"], global_label1, reference_data[0]["global_label"]) + assert math.isclose(mix_coef, mix_coeff_image, abs_tol=0.01) assert math.isclose(mix_coeff_image, mix_coeff_label, abs_tol=0.01) assert math.isclose(mix_coeff_image, mix_coeff_mask, abs_tol=0.01) @@ -176,6 +187,9 @@ def test_pipeline(augmentation_cls, params, image, mask, global_label): assert data["image"].dtype == np.uint8 + mix_coef = data["mix_coef"] + mix_coeff_label = find_mix_coef(data["global_label"], global_label, reference_data[0]["global_label"]) + assert math.isclose(mix_coef, mix_coeff_label, abs_tol=0.01) assert 0 <= mix_coeff_label <= 1