diff --git a/nnunetv2/run/run_training.py b/nnunetv2/run/run_training.py index 93dd7598b..301f2c89a 100644 --- a/nnunetv2/run/run_training.py +++ b/nnunetv2/run/run_training.py @@ -34,7 +34,8 @@ def get_trainer_from_args(dataset_name_or_id: Union[int, str], trainer_name: str = 'nnUNetTrainer', plans_identifier: str = 'nnUNetPlans', use_compressed: bool = False, - device: torch.device = torch.device('cuda')): + device: torch.device = torch.device('cuda'), + val_iters: int = 50): # load nnunet class and do sanity checks nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"), trainer_name, 'nnunetv2.training.nnUNetTrainer') @@ -63,7 +64,8 @@ def get_trainer_from_args(dataset_name_or_id: Union[int, str], plans = load_json(plans_file) dataset_json = load_json(join(preprocessed_dataset_folder_base, 'dataset.json')) nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold, - dataset_json=dataset_json, unpack_dataset=not use_compressed, device=device) + dataset_json=dataset_json, unpack_dataset=not use_compressed, device=device, + val_iters=val_iters) return nnunet_trainer @@ -108,12 +110,12 @@ def cleanup_ddp(): def run_ddp(rank, dataset_name_or_id, configuration, fold, tr, p, use_compressed, disable_checkpointing, c, val, - pretrained_weights, npz, val_with_best, world_size): + pretrained_weights, npz, val_with_best, world_size, val_iters): setup_ddp(rank, world_size) torch.cuda.set_device(torch.device('cuda', dist.get_rank())) nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p, - use_compressed) + use_compressed, val_iters=val_iters) if disable_checkpointing: nnunet_trainer.disable_checkpointing = disable_checkpointing @@ -147,7 +149,8 @@ def run_training(dataset_name_or_id: Union[str, int], only_run_validation: bool = False, disable_checkpointing: bool = False, val_with_best: bool = False, - device: torch.device = torch.device('cuda')): + device: torch.device = torch.device('cuda'), + val_iters: int = 50): if isinstance(fold, str): if fold != 'all': try: @@ -182,12 +185,15 @@ def run_training(dataset_name_or_id: Union[str, int], pretrained_weights, export_validation_probabilities, val_with_best, - num_gpus), + num_gpus, + val_iters + ), nprocs=num_gpus, join=True) else: nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name, - plans_identifier, use_compressed_data, device=device) + plans_identifier, use_compressed_data, device=device, + val_iters=val_iters) if disable_checkpointing: nnunet_trainer.disable_checkpointing = disable_checkpointing @@ -249,6 +255,9 @@ def run_training_entry(): help="Use this to set the device the training should run with. Available options are 'cuda' " "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!") + parser.add_argument('-val_iters', type=int, default=50, required=False, + help='Use this to adjust the number of pseudo-validation steps. Using a lower value increases ' + 'the validation speed but decreases the confidence in the pseudo-validation dice.') args = parser.parse_args() assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.' @@ -266,8 +275,8 @@ def run_training_entry(): device = torch.device('mps') run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights, - args.num_gpus, args.use_compressed, args.npz, args.c, args.val, args.disable_checkpointing, args.val_best, - device=device) + args.num_gpus, args.use_compressed, args.npz, args.c, args.val, args.disable_checkpointing, + args.val_best, device=device, val_iters=args.val_iters) if __name__ == '__main__': diff --git a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py index 426bbf047..bbb758e1a 100644 --- a/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py +++ b/nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py @@ -66,9 +66,9 @@ from torch.nn.parallel import DistributedDataParallel as DDP -class nnUNetTrainer(object): +class nnUNetTrainer: def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): + device: torch.device = torch.device('cuda'), val_iters: int = 50): # From https://grugbrain.dev/. Worth a read ya big brains ;-) # apex predator of grug is complexity @@ -145,7 +145,7 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic self.weight_decay = 3e-5 self.oversample_foreground_percent = 0.33 self.num_iterations_per_epoch = 250 - self.num_val_iterations_per_epoch = 50 + self.num_val_iterations_per_epoch = val_iters self.num_epochs = 1000 self.current_epoch = 0 self.enable_deep_supervision = True diff --git a/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py b/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py index fad1fff99..921a795c7 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py +++ b/nnunetv2/training/nnUNetTrainer/variants/benchmarking/nnUNetTrainerBenchmark_5epochs.py @@ -7,8 +7,8 @@ class nnUNetTrainerBenchmark_5epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + device: torch.device = torch.device('cuda'), val_iters: int = 50): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) assert self.fold == 0, "It makes absolutely no sense to specify a certain fold. Stick with 0 so that we can parse the results." self.disable_checkpointing = True self.num_epochs = 5 diff --git a/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py b/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py index 1152fbeb4..478a4bbb2 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py +++ b/nnunetv2/training/nnUNetTrainer/variants/network_architecture/nnUNetTrainerNoDeepSupervision.py @@ -11,6 +11,7 @@ def __init__( dataset_json: dict, unpack_dataset: bool = True, device: torch.device = torch.device("cuda"), + val_iters: int = 50, ): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.enable_deep_supervision = False diff --git a/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py b/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py index 467a6fd04..ac7f998c3 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py +++ b/nnunetv2/training/nnUNetTrainer/variants/sampling/nnUNetTrainer_probabilisticOversampling.py @@ -18,8 +18,8 @@ class nnUNetTrainer_probabilisticOversampling(nnUNetTrainer): If we switch to this oversampling then we can keep it at a constant 0.33 or whatever. """ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + device: torch.device = torch.device('cuda'), val_iters: int = 50): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.oversample_foreground_percent = float(np.mean( [not sample_idx < round(self.configuration_manager.batch_size * (1 - self.oversample_foreground_percent)) for sample_idx in range(self.configuration_manager.batch_size)])) diff --git a/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py b/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py index e3a71a000..f16094c4f 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py +++ b/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs.py @@ -5,72 +5,72 @@ class nnUNetTrainer_5epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): + device: torch.device = torch.device('cuda'), val_iters: int = 50): """used for debugging plans etc""" - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 5 class nnUNetTrainer_1epoch(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): + device: torch.device = torch.device('cuda'), val_iters: int = 50): """used for debugging plans etc""" - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 1 class nnUNetTrainer_10epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): + device: torch.device = torch.device('cuda'), val_iters: int = 50): """used for debugging plans etc""" - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 10 class nnUNetTrainer_20epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + device: torch.device = torch.device('cuda'), val_iters: int = 50): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 20 class nnUNetTrainer_50epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + device: torch.device = torch.device('cuda'), val_iters: int = 50): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 50 class nnUNetTrainer_100epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + device: torch.device = torch.device('cuda'), val_iters: int = 50): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 100 class nnUNetTrainer_250epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + device: torch.device = torch.device('cuda'), val_iters: int = 50): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 250 class nnUNetTrainer_2000epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + device: torch.device = torch.device('cuda'), val_iters: int = 50): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 2000 class nnUNetTrainer_4000epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + device: torch.device = torch.device('cuda'), val_iters: int = 50): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 4000 class nnUNetTrainer_8000epochs(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + device: torch.device = torch.device('cuda'), val_iters: int = 50): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 8000 diff --git a/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py b/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py index c16b88508..bf9706620 100644 --- a/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py +++ b/nnunetv2/training/nnUNetTrainer/variants/training_length/nnUNetTrainer_Xepochs_NoMirroring.py @@ -5,8 +5,8 @@ class nnUNetTrainer_250epochs_NoMirroring(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + device: torch.device = torch.device('cuda'), val_iters: int = 50): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 250 def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): @@ -19,8 +19,8 @@ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): class nnUNetTrainer_2000epochs_NoMirroring(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + device: torch.device = torch.device('cuda'), val_iters: int = 50): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 2000 def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): @@ -33,8 +33,8 @@ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): class nnUNetTrainer_4000epochs_NoMirroring(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + device: torch.device = torch.device('cuda'), val_iters: int = 50): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 4000 def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): @@ -47,8 +47,8 @@ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self): class nnUNetTrainer_8000epochs_NoMirroring(nnUNetTrainer): def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True, - device: torch.device = torch.device('cuda')): - super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device) + device: torch.device = torch.device('cuda'), val_iters: int = 50): + super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters) self.num_epochs = 8000 def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):