Skip to content

Commit

Permalink
updated version
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Mar 4, 2024
1 parent c2f2449 commit bf49148
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 9 deletions.
2 changes: 1 addition & 1 deletion albumentations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.4.0"
__version__ = "1.4.1"

from .augmentations import *
from .core.composition import *
Expand Down
23 changes: 15 additions & 8 deletions albumentations/augmentations/mixing/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,21 +126,28 @@ def get_transform_init_args_names(self) -> Tuple[str, ...]:
return "reference_data", "alpha"

def get_params(self) -> Dict[str, Union[None, float, Dict[str, Any]]]:
if self.reference_data and isinstance(self.reference_data, Sequence):
mix_idx = random.randint(0, len(self.reference_data) - 1)
mix_data = self.reference_data[mix_idx]
elif self.reference_data and isinstance(self.reference_data, Iterator):
mix_data = None
# Check if reference_data is not empty and is a sequence (list, tuple, np.array)
if isinstance(self.reference_data, Sequence) and not isinstance(self.reference_data, (str, bytes)):
if len(self.reference_data) > 0: # Additional check to ensure it's not empty
mix_idx = random.randint(0, len(self.reference_data) - 1)
mix_data = self.reference_data[mix_idx]
# Check if reference_data is an iterator or generator
elif isinstance(self.reference_data, Iterator):
try:
mix_data = next(self.reference_data) # Get the next item from the iterator
mix_data = next(self.reference_data) # Attempt to get the next item
except StopIteration:
warn(
"Reference data iterator/generator has been exhausted. "
"Further mixing augmentations will not be applied.",
RuntimeWarning,
)
return {"mix_data": {}, "mix_coef": 1}
else:

# If mix_data is None or empty after the above checks, return default values
if mix_data is None:
return {"mix_data": {}, "mix_coef": 1}
mix_coef = beta(self.alpha, self.alpha) if mix_data else 1

return {"mix_data": self.read_fn(mix_data) if mix_data else None, "mix_coef": mix_coef}
# If mix_data is not None, calculate mix_coef and apply read_fn
mix_coef = beta(self.alpha, self.alpha) # Assuming beta is defined elsewhere
return {"mix_data": self.read_fn(mix_data), "mix_coef": mix_coef}
4 changes: 4 additions & 0 deletions tests/test_mixing.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def complex_read_fn_image(x):
(A.MixUp, {
"reference_data": [1],
"read_fn": lambda x: {"image": np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)}},
),
(A.MixUp, {
"reference_data": np.array([1]),
"read_fn": lambda x: {"image": np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)}},
),
(A.MixUp, {
"reference_data": None,
Expand Down

0 comments on commit bf49148

Please sign in to comment.