diff --git a/truenet/true_net/truenet_test_function.py b/truenet/true_net/truenet_test_function.py index daac9e4..f1ee69e 100644 --- a/truenet/true_net/truenet_test_function.py +++ b/truenet/true_net/truenet_test_function.py @@ -32,10 +32,10 @@ def main(sub_name_dicts, eval_params, intermediate=False, model_dir=None, use_cpu = eval_params['Use_CPU'] if use_cpu is True: device = torch.device("cpu") - print('testfunction:device used:' + device) + print('testfunction:device used:' + str(device)) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - print('testfunction:device used:' + device) + print('testfunction:device used:' + str(device)) nclass = eval_params['Nclass'] num_channels = eval_params['Numchannels'] @@ -82,64 +82,64 @@ def main(sub_name_dicts, eval_params, intermediate=False, model_dir=None, for sub in range(len(sub_name_dicts)): if verbose: print('Predicting output for subject ' + str(sub+1) + '...', flush=True) - + test_sub_dict = [sub_name_dicts[sub]] basename = test_sub_dict[0]['basename'] - + probs_combined = [] flair_path = test_sub_dict[0]['flair_path'] flair_hdr = nib.load(flair_path).header - probs_axial = truenet_evaluate.evaluate_truenet(test_sub_dict, model_axial, eval_params, device, + probs_axial = truenet_evaluate.evaluate_truenet(test_sub_dict, model_axial, eval_params, device, mode='axial', verbose=verbose) - probs_axial = truenet_data_postprocessing.resize_to_original_size(probs_axial, test_sub_dict, + probs_axial = truenet_data_postprocessing.resize_to_original_size(probs_axial, test_sub_dict, plane='axial') probs_combined.append(probs_axial) - + if intermediate: save_path = os.path.join(output_dir,'Predicted_probmap_truenet_' + basename + '_axial.nii.gz') preds_axial = truenet_data_postprocessing.get_final_3dvolumes(probs_axial, test_sub_dict) if verbose: print('Saving the intermediate Axial prediction ...', flush=True) - + newhdr = flair_hdr.copy() newobj = nib.nifti1.Nifti1Image(preds_axial, None, header=newhdr) - nib.save(newobj, save_path) - - probs_sagittal = truenet_evaluate.evaluate_truenet(test_sub_dict, model_sagittal, eval_params, device, + nib.save(newobj, save_path) + + probs_sagittal = truenet_evaluate.evaluate_truenet(test_sub_dict, model_sagittal, eval_params, device, mode='sagittal', verbose=verbose) - probs_sagittal = truenet_data_postprocessing.resize_to_original_size(probs_sagittal, test_sub_dict, + probs_sagittal = truenet_data_postprocessing.resize_to_original_size(probs_sagittal, test_sub_dict, plane='sagittal') probs_combined.append(probs_sagittal) - + if intermediate: save_path = os.path.join(output_dir,'Predicted_probmap_truenet_' + basename + '_sagittal.nii.gz') preds_sagittal = truenet_data_postprocessing.get_final_3dvolumes(probs_sagittal, test_sub_dict) if verbose: print('Saving the intermediate Sagittal prediction ...', flush=True) - + newhdr = flair_hdr.copy() newobj = nib.nifti1.Nifti1Image(preds_sagittal, None, header=newhdr) - nib.save(newobj, save_path) - - probs_coronal = truenet_evaluate.evaluate_truenet(test_sub_dict, model_coronal, eval_params, device, - mode='coronal', verbose=verbose) - probs_coronal = truenet_data_postprocessing.resize_to_original_size(probs_coronal, test_sub_dict, + nib.save(newobj, save_path) + + probs_coronal = truenet_evaluate.evaluate_truenet(test_sub_dict, model_coronal, eval_params, device, + mode='coronal', verbose=verbose) + probs_coronal = truenet_data_postprocessing.resize_to_original_size(probs_coronal, test_sub_dict, plane='coronal') probs_combined.append(probs_coronal) - + if intermediate: save_path = os.path.join(output_dir,'Predicted_probmap_truenet_' + basename + '_coronal.nii.gz') preds_coronal = truenet_data_postprocessing.get_final_3dvolumes(probs_coronal, test_sub_dict) if verbose: print('Saving the intermediate Coronal prediction ...', flush=True) - + newhdr = flair_hdr.copy() newobj = nib.nifti1.Nifti1Image(preds_coronal, None, header=newhdr) - nib.save(newobj, save_path) - + nib.save(newobj, save_path) + probs_combined = np.array(probs_combined) prob_mean = np.mean(probs_combined,axis=0) - + save_path = os.path.join(output_dir,'Predicted_probmap_truenet_' + basename + '.nii.gz') pred_mean = truenet_data_postprocessing.get_final_3dvolumes(prob_mean, test_sub_dict) if verbose: @@ -147,7 +147,7 @@ def main(sub_name_dicts, eval_params, intermediate=False, model_dir=None, newhdr = flair_hdr.copy() newobj = nib.nifti1.Nifti1Image(pred_mean, None, header=newhdr) - nib.save(newobj, save_path) - + nib.save(newobj, save_path) + if verbose: print('Testing complete for all subjects!', flush=True) diff --git a/truenet/utils/truenet_utils.py b/truenet/utils/truenet_utils.py index 85cfe73..355ea91 100644 --- a/truenet/utils/truenet_utils.py +++ b/truenet/utils/truenet_utils.py @@ -22,7 +22,7 @@ def select_train_val_names(data_path,val_numbers): :return: ''' val_ids = random.choices(list(np.arange(len(data_path))),k=val_numbers) - train_ids = np.setdiff1d(np.arange(len(data_path)),val_ids) + train_ids = np.setdiff1d(np.arange(len(data_path)),val_ids) data_path_train = [data_path[ind] for ind in train_ids] data_path_val = [data_path[ind] for ind in val_ids] return data_path_train,data_path_val,val_ids @@ -40,7 +40,7 @@ def freeze_layer_for_finetuning(model, layer_to_ft, verbose=False): model_layers_tobe_ftd = [] for layer_id in layer_to_ft: model_layers_tobe_ftd.append(model_layer_names[layer_id-1]) - + for name, child in model.module.named_children(): if name in model_layers_tobe_ftd: if verbose: @@ -54,24 +54,24 @@ def freeze_layer_for_finetuning(model, layer_to_ft, verbose=False): print(name + ' is frozen', flush=True) for param in child.parameters(): param.requires_grad = False - + return model def loading_model(model_name, model, device, mode='weights'): if mode == 'weights': if device == 'cpu': - print('utils:device used:' + device) + print('utils:device used:' + str(device)) axial_state_dict = torch.load(model_name, map_location='cpu') else: - print('utils:device used:' + device) + print('utils:device used:' + str(device)) axial_state_dict = torch.load(model_name) else: if device == 'cpu': - print('utils:device used:' + device) + print('utils:device used:' + str(device)) ckpt = torch.load(model_name, map_location='cpu') else: - print('utils:device used:' + device) + print('utils:device used:' + str(device)) ckpt = torch.load(model_name) axial_state_dict = ckpt['model_state_dict'] @@ -175,10 +175,3 @@ def save_checkpoint(self, val_loss, val_acc, best_val_acc, model, epoch, optimiz else: if self.verbose: print('Validation loss increased; Exiting without saving the model ...') - - - - - - -