Skip to content

Commit 9f51893

Browse files
authored
fixes RandSpatialCropSamples random states (#1086)
Signed-off-by: Wenqi Li <[email protected]>
1 parent 06fb955 commit 9f51893

File tree

4 files changed

+114
-21
lines changed

4 files changed

+114
-21
lines changed

monai/transforms/croppad/array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,13 @@ def __init__(
347347
self.num_samples = num_samples
348348
self.cropper = RandSpatialCrop(roi_size, random_center, random_size)
349349

350+
def set_random_state(
351+
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
352+
) -> "Randomizable":
353+
super().set_random_state(seed=seed, state=state)
354+
self.cropper.set_random_state(state=self.R)
355+
return self
356+
350357
def randomize(self, data: Optional[Any] = None) -> None:
351358
pass
352359

monai/transforms/croppad/dictionary.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,13 @@ def __init__(
307307
self.num_samples = num_samples
308308
self.cropper = RandSpatialCropd(keys, roi_size, random_center, random_size)
309309

310+
def set_random_state(
311+
self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None
312+
) -> "Randomizable":
313+
super().set_random_state(seed=seed, state=state)
314+
self.cropper.set_random_state(state=self.R)
315+
return self
316+
310317
def randomize(self, data: Optional[Any] = None) -> None:
311318
pass
312319

tests/test_rand_spatial_crop_samples.py

Lines changed: 53 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,67 @@
1717
from monai.transforms import RandSpatialCropSamples
1818

1919
TEST_CASE_1 = [
20-
{"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True},
21-
np.random.randint(0, 2, size=[3, 3, 3, 3]),
22-
(3, 3, 3, 3),
20+
{"roi_size": [3, 3, 3], "num_samples": 4, "random_center": True, "random_size": False},
21+
np.arange(192).reshape(3, 4, 4, 4),
22+
[(3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)],
23+
np.array(
24+
[
25+
[
26+
[[21, 22, 23], [25, 26, 27], [29, 30, 31]],
27+
[[37, 38, 39], [41, 42, 43], [45, 46, 47]],
28+
[[53, 54, 55], [57, 58, 59], [61, 62, 63]],
29+
],
30+
[
31+
[[85, 86, 87], [89, 90, 91], [93, 94, 95]],
32+
[[101, 102, 103], [105, 106, 107], [109, 110, 111]],
33+
[[117, 118, 119], [121, 122, 123], [125, 126, 127]],
34+
],
35+
[
36+
[[149, 150, 151], [153, 154, 155], [157, 158, 159]],
37+
[[165, 166, 167], [169, 170, 171], [173, 174, 175]],
38+
[[181, 182, 183], [185, 186, 187], [189, 190, 191]],
39+
],
40+
]
41+
),
2342
]
2443

2544
TEST_CASE_2 = [
26-
{"roi_size": [3, 3, 3], "num_samples": 8, "random_center": False},
27-
np.random.randint(0, 2, size=[3, 3, 3, 3]),
28-
(3, 3, 3, 3),
45+
{"roi_size": [3, 3, 3], "num_samples": 8, "random_center": False, "random_size": True},
46+
np.arange(192).reshape(3, 4, 4, 4),
47+
[(3, 4, 4, 3), (3, 4, 3, 3), (3, 3, 4, 4), (3, 4, 4, 4), (3, 3, 3, 4), (3, 3, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3)],
48+
np.array(
49+
[
50+
[
51+
[[21, 22, 23], [25, 26, 27], [29, 30, 31]],
52+
[[37, 38, 39], [41, 42, 43], [45, 46, 47]],
53+
[[53, 54, 55], [57, 58, 59], [61, 62, 63]],
54+
],
55+
[
56+
[[85, 86, 87], [89, 90, 91], [93, 94, 95]],
57+
[[101, 102, 103], [105, 106, 107], [109, 110, 111]],
58+
[[117, 118, 119], [121, 122, 123], [125, 126, 127]],
59+
],
60+
[
61+
[[149, 150, 151], [153, 154, 155], [157, 158, 159]],
62+
[[165, 166, 167], [169, 170, 171], [173, 174, 175]],
63+
[[181, 182, 183], [185, 186, 187], [189, 190, 191]],
64+
],
65+
]
66+
),
2967
]
3068

3169

3270
class TestRandSpatialCropSamples(unittest.TestCase):
3371
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
34-
def test_shape(self, input_param, input_data, expected_shape):
35-
result = RandSpatialCropSamples(**input_param)(input_data)
36-
for item in result:
37-
self.assertTupleEqual(item.shape, expected_shape)
72+
def test_shape(self, input_param, input_data, expected_shape, expected_last_item):
73+
xform = RandSpatialCropSamples(**input_param)
74+
xform.set_random_state(1234)
75+
result = xform(input_data)
76+
77+
np.testing.assert_equal(len(result), input_param["num_samples"])
78+
for item, expected in zip(result, expected_shape):
79+
self.assertTupleEqual(item.shape, expected)
80+
np.testing.assert_allclose(result[-1], expected_last_item)
3881

3982

4083
if __name__ == "__main__":

tests/test_rand_spatial_crop_samplesd.py

Lines changed: 47 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,61 @@
1717
from monai.transforms import RandSpatialCropSamplesd
1818

1919
TEST_CASE_1 = [
20-
{"keys": ["img", "seg"], "num_samples": 4, "roi_size": [3, 3, 3], "random_center": True},
21-
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3]), "seg": np.random.randint(0, 2, size=[3, 3, 3, 3])},
22-
(3, 3, 3, 3),
20+
{"keys": ["img", "seg"], "num_samples": 4, "roi_size": [2, 2, 2], "random_center": True},
21+
{"img": np.arange(81).reshape(3, 3, 3, 3), "seg": np.arange(81, 0, -1).reshape(3, 3, 3, 3)},
22+
[(3, 3, 3, 2), (3, 2, 2, 2), (3, 3, 3, 2), (3, 3, 2, 2)],
23+
{
24+
"img": np.array(
25+
[
26+
[[[0, 1], [3, 4]], [[9, 10], [12, 13]], [[18, 19], [21, 22]]],
27+
[[[27, 28], [30, 31]], [[36, 37], [39, 40]], [[45, 46], [48, 49]]],
28+
[[[54, 55], [57, 58]], [[63, 64], [66, 67]], [[72, 73], [75, 76]]],
29+
]
30+
),
31+
"seg": np.array(
32+
[
33+
[[[81, 80], [78, 77]], [[72, 71], [69, 68]], [[63, 62], [60, 59]]],
34+
[[[54, 53], [51, 50]], [[45, 44], [42, 41]], [[36, 35], [33, 32]]],
35+
[[[27, 26], [24, 23]], [[18, 17], [15, 14]], [[9, 8], [6, 5]]],
36+
]
37+
),
38+
},
2339
]
2440

2541
TEST_CASE_2 = [
26-
{"keys": ["img", "seg"], "num_samples": 8, "roi_size": [3, 3, 3], "random_center": False},
27-
{"img": np.random.randint(0, 2, size=[3, 3, 3, 3]), "seg": np.random.randint(0, 2, size=[3, 3, 3, 3])},
28-
(3, 3, 3, 3),
42+
{"keys": ["img", "seg"], "num_samples": 8, "roi_size": [2, 2, 3], "random_center": False},
43+
{"img": np.arange(81).reshape(3, 3, 3, 3), "seg": np.arange(81, 0, -1).reshape(3, 3, 3, 3)},
44+
[(3, 3, 3, 3), (3, 2, 3, 3), (3, 2, 2, 3), (3, 2, 3, 3), (3, 3, 3, 3), (3, 3, 3, 3), (3, 2, 2, 3), (3, 3, 2, 3)],
45+
{
46+
"img": np.array(
47+
[
48+
[[[0, 1, 2], [3, 4, 5]], [[9, 10, 11], [12, 13, 14]], [[18, 19, 20], [21, 22, 23]]],
49+
[[[27, 28, 29], [30, 31, 32]], [[36, 37, 38], [39, 40, 41]], [[45, 46, 47], [48, 49, 50]]],
50+
[[[54, 55, 56], [57, 58, 59]], [[63, 64, 65], [66, 67, 68]], [[72, 73, 74], [75, 76, 77]]],
51+
]
52+
),
53+
"seg": np.array(
54+
[
55+
[[[81, 80, 79], [78, 77, 76]], [[72, 71, 70], [69, 68, 67]], [[63, 62, 61], [60, 59, 58]]],
56+
[[[54, 53, 52], [51, 50, 49]], [[45, 44, 43], [42, 41, 40]], [[36, 35, 34], [33, 32, 31]]],
57+
[[[27, 26, 25], [24, 23, 22]], [[18, 17, 16], [15, 14, 13]], [[9, 8, 7], [6, 5, 4]]],
58+
]
59+
),
60+
},
2961
]
3062

3163

3264
class TestRandSpatialCropSamplesd(unittest.TestCase):
3365
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
34-
def test_shape(self, input_param, input_data, expected_shape):
35-
result = RandSpatialCropSamplesd(**input_param)(input_data)
36-
for item in result:
37-
self.assertTupleEqual(item["img"].shape, expected_shape)
38-
self.assertTupleEqual(item["seg"].shape, expected_shape)
66+
def test_shape(self, input_param, input_data, expected_shape, expected_last):
67+
xform = RandSpatialCropSamplesd(**input_param)
68+
xform.set_random_state(1234)
69+
result = xform(input_data)
70+
for item, expected in zip(result, expected_shape):
71+
self.assertTupleEqual(item["img"].shape, expected)
72+
self.assertTupleEqual(item["seg"].shape, expected)
73+
np.testing.assert_allclose(item["img"], expected_last["img"])
74+
np.testing.assert_allclose(item["seg"], expected_last["seg"])
3975

4076

4177
if __name__ == "__main__":

0 commit comments

Comments
 (0)