Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add PyTorch Checkpoint Export and Load #1865

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions nnunetv2/inference/predict_from_raw_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,54 @@ def initialize_from_trained_model_folder(self, model_training_output_dir: str,
print('Using torch.compile')
self.network = torch.compile(self.network)

def load_from_checkpoint(self, checkpoint_path: str):
"""Load model from single checkpoint"""

# load the full checkpoint
model_checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))

# load the dataset and plans
dataset_json = model_checkpoint['dataset']
plans_json = model_checkpoint['plans']
plans_manager = PlansManager(plans_json)

# load the model parameters
parameters = []
checkpoint_name = "final" # always use final checkpoint for now
for i, k in enumerate(sorted(model_checkpoint['folds'])):

checkpoint = model_checkpoint['folds'][k][checkpoint_name]

if i == 0: # use first fold to get trainer and configuration name
trainer_name = checkpoint['trainer_name']
configuration_name = checkpoint['init_args']['configuration']
inference_allowed_mirroring_axes = checkpoint['inference_allowed_mirroring_axes'] if \
'inference_allowed_mirroring_axes' in checkpoint.keys() else None

parameters.append(checkpoint['network_weights'])

configuration_manager = plans_manager.get_configuration(configuration_name)

# restore network
num_input_channels = determine_num_input_channels(plans_manager, configuration_manager, dataset_json)
trainer_class = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
network = trainer_class.build_network_architecture(plans_manager, dataset_json, configuration_manager,
num_input_channels, enable_deep_supervision=False)
self.plans_manager = plans_manager
self.configuration_manager = configuration_manager
self.list_of_parameters = parameters
self.network = network
self.dataset_json = dataset_json
self.trainer_name = trainer_name
self.allowed_mirroring_axes = inference_allowed_mirroring_axes
self.label_manager = plans_manager.get_label_manager(dataset_json)
if ('nnUNet_compile' in os.environ.keys()) and (os.environ['nnUNet_compile'].lower() in ('true', '1', 't')) \
and not isinstance(self.network, OptimizedModule):
print('Using torch.compile')
self.network = torch.compile(self.network)


def manual_initialization(self, network: nn.Module, plans_manager: PlansManager,
configuration_manager: ConfigurationManager, parameters: Optional[List[dict]],
dataset_json: dict, trainer_name: str,
Expand Down
18 changes: 17 additions & 1 deletion nnunetv2/model_sharing/entry_points.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from nnunetv2.model_sharing.model_download import download_and_install_from_url
from nnunetv2.model_sharing.model_export import export_pretrained_model
from nnunetv2.model_sharing.model_export import export_pretrained_model, export_model_checkpoint
from nnunetv2.model_sharing.model_import import install_model_from_zip_file


Expand Down Expand Up @@ -59,3 +59,19 @@ def export_pretrained_model_entry():
export_pretrained_model(dataset_name_or_id=args.d, output_file=args.o, configurations=args.c, trainer=args.tr,
plans_identifier=args.p, folds=args.f, strict=not args.not_strict, save_checkpoints=args.chk,
export_crossval_predictions=args.exp_cv_preds)


def export_model_to_checkpoint():
import argparse
parser = argparse.ArgumentParser(description="Export nnunet model checkpoint as a single .pth file")
parser.add_argument("--path", type=str, help="path to nnunet model directory")
parser.add_argument("--checkpoint_path", type=str, help="path to save the checkpoint", required=False, default=None)
parser.add_argument("--checkpoint_name", type=str, help="name of the checkpoint", required=False, default="model_checkpoint.pth",)

args = parser.parse_args()

export_model_checkpoint(
path=args.path,
checkpoint_path=args.checkpoint_path,
checkpoint_name=args.checkpoint_name,
)
83 changes: 83 additions & 0 deletions nnunetv2/model_sharing/model_export.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import glob
import zipfile

import torch
from nnunetv2.utilities.file_path_utilities import *


Expand Down Expand Up @@ -119,6 +121,87 @@ def export_pretrained_model(dataset_name_or_id: Union[int, str], output_file: st
zipf.write(inference_information_txt_file, os.path.relpath(inference_information_txt_file, nnUNet_results))
print('Done')

def export_model_checkpoint(
path: str,
checkpoint_path: str = None,
checkpoint_name: str = "model_checkpoint.pth",
) -> None:
"""Save NNUNet model checkpoint as a single .pth file
args:
path: path to the nnunet model directory

"""
# nnunet model directory structure for ensemble:
# model
# dataset.json
# plans.json
# fold_n:
# checkpoint_best.pth
# checkpoint_final.pth

# we want to convert it to a single .pth file with the following structure:
# model_checkpoint.pth
# dataset: dataset.json
# plans: plans.json
# fold_n:
# best: checkpoint_best.pth
# final: checkpoint_final.pth

# this makes it more portable and easier to load

def load_json(path: str):
with open(path, "r") as f:
return json.load(f)

# confirm that the path is a nnunet model directory
if not os.path.isdir(path):
raise ValueError(f"{path} is not a directory")
if not os.path.exists(os.path.join(path, "dataset.json")):
raise ValueError(f"{path} does not contain a dataset.json file")
if not os.path.exists(os.path.join(path, "plans.json")):
raise ValueError(f"{path} does not contain a plans.json file")

print(f"Exporting model checkpoint from {path}...")

model_checkpoint = {}

# paths
dataset_json_path = os.path.join(path, "dataset.json")
plan_json_path = os.path.join(path, "plans.json")

# load the dataset and plans
print("Loading dataset and plans configurations...")
model_checkpoint["dataset"] = load_json(dataset_json_path)
model_checkpoint["plans"] = load_json(plan_json_path)

# load the folds
model_checkpoint["folds"] = {}

# get all the fold directories,
fold_dirs = sorted(glob.glob(os.path.join(path, "fold_*")))
print(f"Found {len(fold_dirs)} folds...")
for fold_dir in fold_dirs:
fold_name = os.path.basename(fold_dir)
print(f"Processing fold {fold_name}...")

# load the best/ final checkpoint
best_checkpoint_path = os.path.join(fold_dir, "checkpoint_best.pth")
final_checkpoint_path = os.path.join(fold_dir, "checkpoint_final.pth")

model_checkpoint["folds"][fold_name] = {
"best": torch.load(best_checkpoint_path, map_location=torch.device("cpu")),
"final": torch.load(
final_checkpoint_path, map_location=torch.device("cpu")
),
}

# save as single torch checkpoint
if checkpoint_path is None:
checkpoint_path = os.path.join(path, checkpoint_name)
torch.save(model_checkpoint, checkpoint_path)
print(f"Exported model checkpoint to {checkpoint_path}")



if __name__ == '__main__':
export_pretrained_model(2, '/home/fabian/temp/dataset2.zip', strict=False, export_crossval_predictions=True, folds=(0, ))
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ nnUNetv2_plot_overlay_pngs = "nnunetv2.utilities.overlay_plots:entry_point_gener
nnUNetv2_download_pretrained_model_by_url = "nnunetv2.model_sharing.entry_points:download_by_url"
nnUNetv2_install_pretrained_model_from_zip = "nnunetv2.model_sharing.entry_points:install_from_zip_entry_point"
nnUNetv2_export_model_to_zip = "nnunetv2.model_sharing.entry_points:export_pretrained_model_entry"
nnUNetv2_export_model_to_checkpoint = "nnunetv2.model_sharing.entry_points:export_model_to_checkpoint"
nnUNetv2_move_plans_between_datasets = "nnunetv2.experiment_planning.plans_for_pretraining.move_plans_between_datasets:entry_point_move_plans_between_datasets"
nnUNetv2_evaluate_folder = "nnunetv2.evaluation.evaluate_predictions:evaluate_folder_entry_point"
nnUNetv2_evaluate_simple = "nnunetv2.evaluation.evaluate_predictions:evaluate_simple_entry_point"
Expand Down