Skip to content

Commit 6362b3b

Browse files
committed
Updated with new patch_overlap function and error handling
1 parent e3ec1f5 commit 6362b3b

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

rhtorch/utilities/modules.py

+12-4
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515

1616

1717
def recursive_find_python_class(name, folder=None,
18-
current_module="rhtorch.models"):
18+
current_module="rhtorch.models", exit_if_not_found=True):
1919

2020
# Set default search path to root modules
2121
if folder is None:
22-
folder = [os.path.join(rhtorch.__path__[0], 'models')]
22+
folder = [os.path.join(rhtorch.__path__[0], *current_module.split('.')[1:])]
2323

2424
tr = None
2525
for importer, modname, ispkg in pkgutil.iter_modules(folder):
@@ -34,12 +34,12 @@ def recursive_find_python_class(name, folder=None,
3434
if ispkg:
3535
next_current_module = current_module + '.' + modname
3636
tr = recursive_find_python_class(name, folder=[os.path.join(
37-
folder[0], modname)], current_module=next_current_module)
37+
folder[0], modname)], current_module=next_current_module, exit_if_not_found=exit_if_not_found)
3838

3939
if tr is not None:
4040
break
4141

42-
if tr is None:
42+
if tr is None and exit_if_not_found:
4343
sys.exit(f"Could not find module {name}")
4444

4545
return tr
@@ -82,3 +82,11 @@ def find_best_checkpoint(ckpt_dir: Union[str, Path],
8282
best_score = val_loss
8383
best_path = ckpt['best_model_path']
8484
return best_path
85+
86+
87+
""" Calculate the overlap for torchio inference
88+
If the patch is e.g. 192,192,16 and a patch spacing is 16,16,4
89+
it will return 576 patches that overlaps with the center voxel
90+
"""
91+
def calculate_patch_overlap(patch_shape, patch_spacing):
92+
return [p_shape - p_space for p_shape, p_space in zip(patch_shape, patch_spacing)]

0 commit comments

Comments
 (0)