15
15
16
16
17
17
def recursive_find_python_class (name , folder = None ,
18
- current_module = "rhtorch.models" ):
18
+ current_module = "rhtorch.models" , exit_if_not_found = True ):
19
19
20
20
# Set default search path to root modules
21
21
if folder is None :
22
- folder = [os .path .join (rhtorch .__path__ [0 ], 'models' )]
22
+ folder = [os .path .join (rhtorch .__path__ [0 ], * current_module . split ( '.' )[ 1 :] )]
23
23
24
24
tr = None
25
25
for importer , modname , ispkg in pkgutil .iter_modules (folder ):
@@ -34,12 +34,12 @@ def recursive_find_python_class(name, folder=None,
34
34
if ispkg :
35
35
next_current_module = current_module + '.' + modname
36
36
tr = recursive_find_python_class (name , folder = [os .path .join (
37
- folder [0 ], modname )], current_module = next_current_module )
37
+ folder [0 ], modname )], current_module = next_current_module , exit_if_not_found = exit_if_not_found )
38
38
39
39
if tr is not None :
40
40
break
41
41
42
- if tr is None :
42
+ if tr is None and exit_if_not_found :
43
43
sys .exit (f"Could not find module { name } " )
44
44
45
45
return tr
@@ -82,3 +82,11 @@ def find_best_checkpoint(ckpt_dir: Union[str, Path],
82
82
best_score = val_loss
83
83
best_path = ckpt ['best_model_path' ]
84
84
return best_path
85
+
86
+
87
+ """ Calculate the overlap for torchio inference
88
+ If the patch is e.g. 192,192,16 and a patch spacing is 16,16,4
89
+ it will return 576 patches that overlaps with the center voxel
90
+ """
91
+ def calculate_patch_overlap (patch_shape , patch_spacing ):
92
+ return [p_shape - p_space for p_shape , p_space in zip (patch_shape , patch_spacing )]
0 commit comments