Skip to content

Commit 4d6a6e5

Browse files
committed
ENH: Assert bundle tractograms and data in bundling script
Assert bundle tractograms and data in bundling script according to the set of expected bundles: - Make sure that all config files involved contain all the bundles existing in the anatomy file. - Filter any other tractogram file in the atlas folder. Add the necessary helper methods and definitions. Add the corresponding tests.
1 parent 9130346 commit 4d6a6e5

File tree

4 files changed

+362
-15
lines changed

4 files changed

+362
-15
lines changed

scripts/ae_bundle_streamlines.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
from tractolearn.models.autoencoding_utils import encode_data
3030
from tractolearn.models.model_pool import get_model
3131
from tractolearn.tractoio.utils import (
32+
assert_bundle_datum_exists,
33+
assert_tractogram_exists,
34+
filter_filenames,
3235
load_ref_anat_image,
3336
load_streamlines,
3437
read_data_from_json_file,
@@ -331,15 +334,26 @@ def _build_arg_parser():
331334
add_overwrite_arg(parser)
332335
add_verbose_arg(parser)
333336

334-
return parser.parse_args()
337+
return parser
335338

336339

337340
def main():
338-
args = _build_arg_parser()
339-
device = torch.device(args.device)
341+
342+
parser = _build_arg_parser()
343+
args = parser.parse_args()
340344

341345
print(args)
342346

347+
streamline_classes = read_data_from_json_file(args.anatomy_file)
348+
349+
# Get the bundles of interest
350+
boi = list(streamline_classes.keys())
351+
352+
assert_tractogram_exists(parser, args.atlas_path, boi)
353+
354+
thresholds = read_data_from_json_file(args.thresholds_file)
355+
assert_bundle_datum_exists(parser, thresholds, boi)
356+
343357
if exists(args.output):
344358
if not args.overwrite:
345359
print(
@@ -366,6 +380,8 @@ def main():
366380
f"Please specify a number between 1 and 30. Got {args.num_neighbors}. "
367381
)
368382

383+
device = torch.device(args.device)
384+
369385
logging.info(args)
370386

371387
_set_up_logger(pjoin(args.output, LoggerKeys.logger_file_basename.name))
@@ -379,35 +395,28 @@ def main():
379395
model.load_state_dict(state_dict)
380396
model.eval()
381397

382-
streamline_classes = read_data_from_json_file(args.anatomy_file)
383-
384-
thresholds = read_data_from_json_file(args.thresholds_file)
385-
386398
latent_atlas_all = np.empty((0, 32))
387399
y_latent_atlas_all = np.empty((0,))
388400

389-
atlas_file = os.listdir(args.atlas_path)
401+
# Filter the atlas filenames according to the bundles of interest
402+
foi = filter_filenames(args.atlas_path, boi)
390403

391404
logger.info("Loading atlas files ...")
392405

393-
for f in tqdm(atlas_file):
406+
for f in tqdm(foi):
394407

395408
key = f.split(".")[-2]
396409

397-
assert (
398-
key in thresholds.keys()
399-
), f"[!] Threshold: {key} not in threshold file"
400-
401410
X_a_not_flipped, y_a_not_flipped = load_streamlines(
402-
pjoin(args.atlas_path, f),
411+
f,
403412
args.common_space_reference,
404413
streamline_classes[key],
405414
resample=True,
406415
num_points=256,
407416
)
408417

409418
X_a_flipped, y_a_flipped = load_streamlines(
410-
pjoin(args.atlas_path, f),
419+
f,
411420
args.common_space_reference,
412421
streamline_classes[key],
413422
resample=True,
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import enum
4+
5+
fname_sep = "."
6+
7+
8+
class DictDataExtensions(enum.Enum):
9+
JSON = "json"
10+
11+
12+
class TractogramExtensions(enum.Enum):
13+
TCK = "tck"
14+
TRK = "trk"
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
4+
import os
5+
import tempfile
6+
7+
from tractolearn.tractoio.file_extensions import (
8+
DictDataExtensions,
9+
TractogramExtensions,
10+
fname_sep,
11+
)
12+
from tractolearn.tractoio.utils import (
13+
compose_filename,
14+
filter_filenames,
15+
identify_missing_bundle,
16+
identify_missing_tractogram,
17+
read_data_from_json_file,
18+
save_data_to_json_file,
19+
)
20+
21+
22+
def test_identify_missing_bundle(tmp_path):
23+
24+
with tempfile.NamedTemporaryFile(
25+
suffix=fname_sep + DictDataExtensions.JSON.value, dir=tmp_path
26+
) as f:
27+
28+
# Target bundle names
29+
bundle_names = ["CC_Fr_1", "CST_L", "AC"]
30+
31+
bundle_data = dict({"CC_Fr_1": 1.0, "CST_L": 2.0, "AC": 3.0})
32+
expected = sorted(set(bundle_names).difference(bundle_data.keys()))
33+
34+
save_data_to_json_file(bundle_data, f.name)
35+
36+
data = read_data_from_json_file(
37+
os.path.join(tmp_path, os.listdir(tmp_path)[0])
38+
)
39+
40+
obtained = identify_missing_bundle(data, bundle_names)
41+
42+
assert obtained == expected
43+
44+
with tempfile.NamedTemporaryFile(
45+
suffix=fname_sep + DictDataExtensions.JSON.value, dir=tmp_path
46+
) as f:
47+
48+
# Target bundle names
49+
bundle_names = ["Cu", "PrCu"]
50+
51+
bundle_data = dict({"Cu": 2.0})
52+
expected = sorted(set(bundle_names).difference(bundle_data.keys()))
53+
54+
save_data_to_json_file(bundle_data, f.name)
55+
56+
data = read_data_from_json_file(
57+
os.path.join(tmp_path, os.listdir(tmp_path)[0])
58+
)
59+
60+
obtained = identify_missing_bundle(data, bundle_names)
61+
62+
assert obtained == expected
63+
64+
65+
def test_identify_missing_tractogram(tmp_path):
66+
67+
# Target bundle names
68+
bundle_names = ["CC_Fr_1", "CST_L", "AC"]
69+
70+
# Create some files in the temporary path
71+
file_rootnames = ["CC_Fr_1", "CC_Fr_2", "AC"]
72+
fnames = [
73+
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
74+
for val in file_rootnames
75+
]
76+
[open(val, "w") for val in fnames]
77+
78+
expected = sorted(set(bundle_names).difference(file_rootnames))
79+
80+
obtained = identify_missing_tractogram(tmp_path, bundle_names)
81+
82+
assert obtained == expected
83+
84+
# Target bundle names
85+
bundle_names = ["Cu"]
86+
expected = [
87+
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
88+
for val in bundle_names
89+
]
90+
91+
# Create some files in the temporary path
92+
file_rootnames = ["Cu", "PrCu"]
93+
fnames = [
94+
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
95+
for val in file_rootnames
96+
]
97+
[open(val, "w") for val in fnames]
98+
99+
expected = sorted(set(bundle_names).difference(file_rootnames))
100+
101+
obtained = identify_missing_tractogram(tmp_path, bundle_names)
102+
103+
assert obtained == expected
104+
105+
106+
def test_filter_fnames(tmp_path):
107+
108+
# Target bundle names
109+
bundle_names = ["CC_Fr_1", "CST_L", "AC"]
110+
111+
# Create some files in the temporary path
112+
file_rootnames = ["CC_Fr_1", "CC_Fr_2", "AC"]
113+
fnames = [
114+
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
115+
for val in file_rootnames
116+
]
117+
[open(val, "w") for val in fnames]
118+
119+
expected_rootnames = ["AC", "CC_Fr_1"]
120+
expected = [
121+
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
122+
for val in expected_rootnames
123+
]
124+
125+
obtained = filter_filenames(tmp_path, bundle_names)
126+
127+
assert obtained == expected
128+
129+
# Target bundle names
130+
bundle_names = ["Cu"]
131+
expected = [
132+
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
133+
for val in bundle_names
134+
]
135+
136+
# Create some files in the temporary path
137+
file_rootnames = ["Cu", "PrCu"]
138+
fnames = [
139+
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
140+
for val in file_rootnames
141+
]
142+
[open(val, "w") for val in fnames]
143+
144+
expected_rootnames = ["Cu"]
145+
expected = [
146+
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
147+
for val in expected_rootnames
148+
]
149+
150+
obtained = filter_filenames(tmp_path, bundle_names)
151+
152+
assert obtained == expected

0 commit comments

Comments
 (0)