Skip to content

Commit

Permalink
ENH: Assert bundle tractograms and data in bundling script
Browse files Browse the repository at this point in the history
Assert bundle tractograms and data in bundling script according to the
set of expected bundles:
- Make sure that all config files involved contain all the bundles
  existing in the anatomy file.
- Filter any other tractogram file in the atlas folder.

Add the necessary helper methods and definitions.

Add the corresponding tests.
  • Loading branch information
jhlegarreta committed May 6, 2023
1 parent 9130346 commit 4d6a6e5
Show file tree
Hide file tree
Showing 4 changed files with 362 additions and 15 deletions.
39 changes: 24 additions & 15 deletions scripts/ae_bundle_streamlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
from tractolearn.models.autoencoding_utils import encode_data
from tractolearn.models.model_pool import get_model
from tractolearn.tractoio.utils import (
assert_bundle_datum_exists,
assert_tractogram_exists,
filter_filenames,
load_ref_anat_image,
load_streamlines,
read_data_from_json_file,
Expand Down Expand Up @@ -331,15 +334,26 @@ def _build_arg_parser():
add_overwrite_arg(parser)
add_verbose_arg(parser)

return parser.parse_args()
return parser


def main():
args = _build_arg_parser()
device = torch.device(args.device)

parser = _build_arg_parser()
args = parser.parse_args()

print(args)

streamline_classes = read_data_from_json_file(args.anatomy_file)

# Get the bundles of interest
boi = list(streamline_classes.keys())

assert_tractogram_exists(parser, args.atlas_path, boi)

thresholds = read_data_from_json_file(args.thresholds_file)
assert_bundle_datum_exists(parser, thresholds, boi)

if exists(args.output):
if not args.overwrite:
print(
Expand All @@ -366,6 +380,8 @@ def main():
f"Please specify a number between 1 and 30. Got {args.num_neighbors}. "
)

device = torch.device(args.device)

logging.info(args)

_set_up_logger(pjoin(args.output, LoggerKeys.logger_file_basename.name))
Expand All @@ -379,35 +395,28 @@ def main():
model.load_state_dict(state_dict)
model.eval()

streamline_classes = read_data_from_json_file(args.anatomy_file)

thresholds = read_data_from_json_file(args.thresholds_file)

latent_atlas_all = np.empty((0, 32))
y_latent_atlas_all = np.empty((0,))

atlas_file = os.listdir(args.atlas_path)
# Filter the atlas filenames according to the bundles of interest
foi = filter_filenames(args.atlas_path, boi)

logger.info("Loading atlas files ...")

for f in tqdm(atlas_file):
for f in tqdm(foi):

key = f.split(".")[-2]

assert (
key in thresholds.keys()
), f"[!] Threshold: {key} not in threshold file"

X_a_not_flipped, y_a_not_flipped = load_streamlines(
pjoin(args.atlas_path, f),
f,
args.common_space_reference,
streamline_classes[key],
resample=True,
num_points=256,
)

X_a_flipped, y_a_flipped = load_streamlines(
pjoin(args.atlas_path, f),
f,
args.common_space_reference,
streamline_classes[key],
resample=True,
Expand Down
14 changes: 14 additions & 0 deletions tractolearn/tractoio/file_extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-

import enum

fname_sep = "."


class DictDataExtensions(enum.Enum):
JSON = "json"


class TractogramExtensions(enum.Enum):
TCK = "tck"
TRK = "trk"
152 changes: 152 additions & 0 deletions tractolearn/tractoio/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import tempfile

from tractolearn.tractoio.file_extensions import (
DictDataExtensions,
TractogramExtensions,
fname_sep,
)
from tractolearn.tractoio.utils import (
compose_filename,
filter_filenames,
identify_missing_bundle,
identify_missing_tractogram,
read_data_from_json_file,
save_data_to_json_file,
)


def test_identify_missing_bundle(tmp_path):

with tempfile.NamedTemporaryFile(
suffix=fname_sep + DictDataExtensions.JSON.value, dir=tmp_path
) as f:

# Target bundle names
bundle_names = ["CC_Fr_1", "CST_L", "AC"]

bundle_data = dict({"CC_Fr_1": 1.0, "CST_L": 2.0, "AC": 3.0})
expected = sorted(set(bundle_names).difference(bundle_data.keys()))

save_data_to_json_file(bundle_data, f.name)

data = read_data_from_json_file(
os.path.join(tmp_path, os.listdir(tmp_path)[0])
)

obtained = identify_missing_bundle(data, bundle_names)

assert obtained == expected

with tempfile.NamedTemporaryFile(
suffix=fname_sep + DictDataExtensions.JSON.value, dir=tmp_path
) as f:

# Target bundle names
bundle_names = ["Cu", "PrCu"]

bundle_data = dict({"Cu": 2.0})
expected = sorted(set(bundle_names).difference(bundle_data.keys()))

save_data_to_json_file(bundle_data, f.name)

data = read_data_from_json_file(
os.path.join(tmp_path, os.listdir(tmp_path)[0])
)

obtained = identify_missing_bundle(data, bundle_names)

assert obtained == expected


def test_identify_missing_tractogram(tmp_path):

# Target bundle names
bundle_names = ["CC_Fr_1", "CST_L", "AC"]

# Create some files in the temporary path
file_rootnames = ["CC_Fr_1", "CC_Fr_2", "AC"]
fnames = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in file_rootnames
]
[open(val, "w") for val in fnames]

expected = sorted(set(bundle_names).difference(file_rootnames))

obtained = identify_missing_tractogram(tmp_path, bundle_names)

assert obtained == expected

# Target bundle names
bundle_names = ["Cu"]
expected = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in bundle_names
]

# Create some files in the temporary path
file_rootnames = ["Cu", "PrCu"]
fnames = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in file_rootnames
]
[open(val, "w") for val in fnames]

expected = sorted(set(bundle_names).difference(file_rootnames))

obtained = identify_missing_tractogram(tmp_path, bundle_names)

assert obtained == expected


def test_filter_fnames(tmp_path):

# Target bundle names
bundle_names = ["CC_Fr_1", "CST_L", "AC"]

# Create some files in the temporary path
file_rootnames = ["CC_Fr_1", "CC_Fr_2", "AC"]
fnames = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in file_rootnames
]
[open(val, "w") for val in fnames]

expected_rootnames = ["AC", "CC_Fr_1"]
expected = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in expected_rootnames
]

obtained = filter_filenames(tmp_path, bundle_names)

assert obtained == expected

# Target bundle names
bundle_names = ["Cu"]
expected = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in bundle_names
]

# Create some files in the temporary path
file_rootnames = ["Cu", "PrCu"]
fnames = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in file_rootnames
]
[open(val, "w") for val in fnames]

expected_rootnames = ["Cu"]
expected = [
compose_filename(tmp_path, val, TractogramExtensions.TRK.value)
for val in expected_rootnames
]

obtained = filter_filenames(tmp_path, bundle_names)

assert obtained == expected
Loading

0 comments on commit 4d6a6e5

Please sign in to comment.