@@ -40,7 +40,7 @@ def complex_read_fn_image(x):
4040 "reference_data" : complex_image_generator (),
4141 "read_fn" : complex_read_fn_image })] )
4242def test_image_only (augmentation_cls , params , image ):
43- aug = augmentation_cls (p = 1 , ** params )
43+ aug = A . Compose ([ augmentation_cls (p = 1 , ** params )], p = 1 )
4444 data = aug (image = image )
4545 assert data ["image" ].dtype == np .uint8
4646
@@ -58,20 +58,25 @@ def test_image_only(augmentation_cls, params, image):
5858 ]
5959)
6060def test_image_global_label (augmentation_cls , params , image , global_label ):
61- aug = augmentation_cls (p = 1 , ** params )
61+ aug = A . Compose ([ augmentation_cls (p = 1 , ** params )], p = 1 )
6262
6363 data = aug (image = image , global_label = global_label )
6464
6565 assert data ["image" ].dtype == np .uint8
6666
67- reference_item = params ["read_fn" ](aug .reference_data [0 ])
67+ reference_data = params ["reference_data" ][0 ]
68+
69+ reference_item = params ["read_fn" ](reference_data )
6870
6971 reference_image = reference_item ["image" ]
7072 reference_global_label = reference_item ["global_label" ]
7173
74+ mix_coef = data ["mix_coef" ]
75+
7276 mix_coeff_image = find_mix_coef (data ["image" ], image , reference_image )
7377 mix_coeff_label = find_mix_coef (data ["global_label" ], global_label , reference_global_label )
7478
79+ assert math .isclose (mix_coef , mix_coeff_image , abs_tol = 0.01 )
7580 assert math .isclose (mix_coeff_image , mix_coeff_label , abs_tol = 0.01 )
7681 assert 0 <= mix_coeff_image <= 1
7782
@@ -85,16 +90,19 @@ def test_image_global_label(augmentation_cls, params, image, global_label):
8590 "read_fn" : lambda x : x })]
8691)
8792def test_image_mask_global_label (augmentation_cls , params , image , mask , global_label ):
88- aug = augmentation_cls (p = 1 , ** params )
93+ aug = A . Compose ([ augmentation_cls (p = 1 , ** params )], p = 1 )
8994
9095 data = aug (image = image , global_label = global_label , mask = mask )
9196
92- assert data [ "image" ]. dtype == np . uint8
97+ reference_data = params [ "reference_data" ][ 0 ]
9398
94- mix_coeff_image = find_mix_coef (data ["image" ], image , aug .reference_data [0 ]["image" ])
95- mix_coeff_mask = find_mix_coef (data ["mask" ], mask , aug .reference_data [0 ]["mask" ])
96- mix_coeff_label = find_mix_coef (data ["global_label" ], global_label , aug .reference_data [0 ]["global_label" ])
99+ mix_coef = data ["mix_coef" ]
97100
101+ mix_coeff_image = find_mix_coef (data ["image" ], image , reference_data ["image" ])
102+ mix_coeff_mask = find_mix_coef (data ["mask" ], mask , reference_data ["mask" ])
103+ mix_coeff_label = find_mix_coef (data ["global_label" ], global_label , reference_data ["global_label" ])
104+
105+ assert math .isclose (mix_coef , mix_coeff_image , abs_tol = 0.01 )
98106 assert math .isclose (mix_coeff_image , mix_coeff_label , abs_tol = 0.01 )
99107 assert math .isclose (mix_coeff_image , mix_coeff_mask , abs_tol = 0.01 )
100108 assert 0 <= mix_coeff_image <= 1
@@ -115,6 +123,8 @@ def test_additional_targets(image, mask, global_label):
115123
116124 data = aug (image = image , global_label = global_label , mask = mask , image1 = image1 , global_label1 = global_label1 , mask1 = mask1 )
117125
126+ mix_coef = data ["mix_coef" ]
127+
118128 assert data ["image" ].dtype == np .uint8
119129
120130 mix_coeff_image = find_mix_coef (data ["image" ], image , reference_data [0 ]["image" ])
@@ -125,6 +135,7 @@ def test_additional_targets(image, mask, global_label):
125135 mix_coeff_mask1 = find_mix_coef (data ["mask1" ], mask1 , reference_data [0 ]["mask" ])
126136 mix_coeff_label1 = find_mix_coef (data ["global_label1" ], global_label1 , reference_data [0 ]["global_label" ])
127137
138+ assert math .isclose (mix_coef , mix_coeff_image , abs_tol = 0.01 )
128139 assert math .isclose (mix_coeff_image , mix_coeff_label , abs_tol = 0.01 )
129140
130141 assert math .isclose (mix_coeff_image , mix_coeff_mask , abs_tol = 0.01 )
@@ -176,6 +187,9 @@ def test_pipeline(augmentation_cls, params, image, mask, global_label):
176187
177188 assert data ["image" ].dtype == np .uint8
178189
190+ mix_coef = data ["mix_coef" ]
191+
179192 mix_coeff_label = find_mix_coef (data ["global_label" ], global_label , reference_data [0 ]["global_label" ])
180193
194+ assert math .isclose (mix_coef , mix_coeff_label , abs_tol = 0.01 )
181195 assert 0 <= mix_coeff_label <= 1
0 commit comments