diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 276dcd167..8c79045fc 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -114,6 +114,54 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str, print('Using torch.compile') self.network = torch.compile(self.network) + def load_from_checkpoint(self, checkpoint_path: str): + """Load model from single checkpoint""" + + # load the full checkpoint + model_checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) + + # load the dataset and plans + dataset_json = model_checkpoint['dataset'] + plans_json = model_checkpoint['plans'] + plans_manager = PlansManager(plans_json) + + # load the model parameters + parameters = [] + checkpoint_name = "final" # always use final checkpoint for now + for i, k in enumerate(sorted(model_checkpoint['folds'])): + + checkpoint = model_checkpoint['folds'][k][checkpoint_name] + + if i == 0: # use first fold to get trainer and configuration name + trainer_name = checkpoint['trainer_name'] + configuration_name = checkpoint['init_args']['configuration'] + inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \ + 'inference_allowed_mirroring_axes' in checkpoint.keys() else None + + parameters.append(checkpoint['network_weights']) + + configuration_manager = plans_manager.get_configuration(configuration_name) + + # restore network + num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json) + trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), + trainer_name, 'nnunetv2.training.nnUNetTrainer') + network = trainer_class.build_network_architecture(plans_manager, dataset_json, configuration_manager, + num_input_channels, enable_deep_supervision=False) + self.plans_manager = plans_manager + self.configuration_manager = configuration_manager + self.list_of_parameters = parameters + self.network = network + self.dataset_json = dataset_json + self.trainer_name = trainer_name + self.allowed_mirroring_axes = inference_allowed_mirroring_axes + self.label_manager = plans_manager.get_label_manager(dataset_json) + if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \ + and not isinstance(self.network, OptimizedModule): + print('Using torch.compile') + self.network = torch.compile(self.network) + + def manual_initialization(self, network: nn.Module, plans_manager: PlansManager, configuration_manager: ConfigurationManager, parameters: Optional[List[dict]], dataset_json: dict, trainer_name: str, diff --git a/nnunetv2/model_sharing/entry_points.py b/nnunetv2/model_sharing/entry_points.py index 1ab7c9351..6843cdf30 100644 --- a/nnunetv2/model_sharing/entry_points.py +++ b/nnunetv2/model_sharing/entry_points.py @@ -1,5 +1,5 @@ from nnunetv2.model_sharing.model_download import download_and_install_from_url -from nnunetv2.model_sharing.model_export import export_pretrained_model +from nnunetv2.model_sharing.model_export import export_pretrained_model, export_model_checkpoint from nnunetv2.model_sharing.model_import import install_model_from_zip_file @@ -59,3 +59,19 @@ def export_pretrained_model_entry(): export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr, plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk, export_crossval_predictions=args.exp_cv_preds) + + +def export_model_to_checkpoint(): + import argparse + parser = argparse.ArgumentParser(description="Export nnunet model checkpoint as a single .pth file") + parser.add_argument("--path", type=str, help="path to nnunet model directory") + parser.add_argument("--checkpoint_path", type=str, help="path to save the checkpoint", required=False, default=None) + parser.add_argument("--checkpoint_name", type=str, help="name of the checkpoint", required=False, default="model_checkpoint.pth",) + + args = parser.parse_args() + + export_model_checkpoint( + path=args.path, + checkpoint_path=args.checkpoint_path, + checkpoint_name=args.checkpoint_name, + ) \ No newline at end of file diff --git a/nnunetv2/model_sharing/model_export.py b/nnunetv2/model_sharing/model_export.py index 51eb455f2..8fb45bac1 100644 --- a/nnunetv2/model_sharing/model_export.py +++ b/nnunetv2/model_sharing/model_export.py @@ -1,5 +1,7 @@ +import glob import zipfile +import torch from nnunetv2.utilities.file_path_utilities import * @@ -119,6 +121,87 @@ def export_pretrained_model(dataset_name_or_id: Union[int, str], output_file: st zipf.write(inference_information_txt_file, os.path.relpath(inference_information_txt_file, nnUNet_results)) print('Done') +def export_model_checkpoint( + path: str, + checkpoint_path: str = None, + checkpoint_name: str = "model_checkpoint.pth", +) -> None: + """Save NNUNet model checkpoint as a single .pth file + args: + path: path to the nnunet model directory + + """ + # nnunet model directory structure for ensemble: + # model + # dataset.json + # plans.json + # fold_n: + # checkpoint_best.pth + # checkpoint_final.pth + + # we want to convert it to a single .pth file with the following structure: + # model_checkpoint.pth + # dataset: dataset.json + # plans: plans.json + # fold_n: + # best: checkpoint_best.pth + # final: checkpoint_final.pth + + # this makes it more portable and easier to load + + def load_json(path: str): + with open(path, "r") as f: + return json.load(f) + + # confirm that the path is a nnunet model directory + if not os.path.isdir(path): + raise ValueError(f"{path} is not a directory") + if not os.path.exists(os.path.join(path, "dataset.json")): + raise ValueError(f"{path} does not contain a dataset.json file") + if not os.path.exists(os.path.join(path, "plans.json")): + raise ValueError(f"{path} does not contain a plans.json file") + + print(f"Exporting model checkpoint from {path}...") + + model_checkpoint = {} + + # paths + dataset_json_path = os.path.join(path, "dataset.json") + plan_json_path = os.path.join(path, "plans.json") + + # load the dataset and plans + print("Loading dataset and plans configurations...") + model_checkpoint["dataset"] = load_json(dataset_json_path) + model_checkpoint["plans"] = load_json(plan_json_path) + + # load the folds + model_checkpoint["folds"] = {} + + # get all the fold directories, + fold_dirs = sorted(glob.glob(os.path.join(path, "fold_*"))) + print(f"Found {len(fold_dirs)} folds...") + for fold_dir in fold_dirs: + fold_name = os.path.basename(fold_dir) + print(f"Processing fold {fold_name}...") + + # load the best/ final checkpoint + best_checkpoint_path = os.path.join(fold_dir, "checkpoint_best.pth") + final_checkpoint_path = os.path.join(fold_dir, "checkpoint_final.pth") + + model_checkpoint["folds"][fold_name] = { + "best": torch.load(best_checkpoint_path, map_location=torch.device("cpu")), + "final": torch.load( + final_checkpoint_path, map_location=torch.device("cpu") + ), + } + + # save as single torch checkpoint + if checkpoint_path is None: + checkpoint_path = os.path.join(path, checkpoint_name) + torch.save(model_checkpoint, checkpoint_path) + print(f"Exported model checkpoint to {checkpoint_path}") + + if __name__ == '__main__': export_pretrained_model(2, '/home/fabian/temp/dataset2.zip', strict=False, export_crossval_predictions=True, folds=(0, )) diff --git a/pyproject.toml b/pyproject.toml index 91bc31563..2b1514c3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ nnUNetv2_plot_overlay_pngs = "nnunetv2.utilities.overlay_plots:entry_point_gener nnUNetv2_download_pretrained_model_by_url = "nnunetv2.model_sharing.entry_points:download_by_url" nnUNetv2_install_pretrained_model_from_zip = "nnunetv2.model_sharing.entry_points:install_from_zip_entry_point" nnUNetv2_export_model_to_zip = "nnunetv2.model_sharing.entry_points:export_pretrained_model_entry" +nnUNetv2_export_model_to_checkpoint = "nnunetv2.model_sharing.entry_points:export_model_to_checkpoint" nnUNetv2_move_plans_between_datasets = "nnunetv2.experiment_planning.plans_for_pretraining.move_plans_between_datasets:entry_point_move_plans_between_datasets" nnUNetv2_evaluate_folder = "nnunetv2.evaluation.evaluate_predictions:evaluate_folder_entry_point" nnUNetv2_evaluate_simple = "nnunetv2.evaluation.evaluate_predictions:evaluate_simple_entry_point"