Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New feature: adding a parameter to control the number of pseudo-validation cases #2061

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions nnunetv2/run/run_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ def get_trainer_from_args(dataset_name_or_id: Union[int, str],
trainer_name: str = 'nnUNetTrainer',
plans_identifier: str = 'nnUNetPlans',
use_compressed: bool = False,
device: torch.device = torch.device('cuda')):
device: torch.device = torch.device('cuda'),
val_iters: int = 50):
# load nnunet class and do sanity checks
nnunet_trainer = recursive_find_python_class(join(nnunetv2.__path__[0], "training", "nnUNetTrainer"),
trainer_name, 'nnunetv2.training.nnUNetTrainer')
Expand Down Expand Up @@ -63,7 +64,8 @@ def get_trainer_from_args(dataset_name_or_id: Union[int, str],
plans = load_json(plans_file)
dataset_json = load_json(join(preprocessed_dataset_folder_base, 'dataset.json'))
nnunet_trainer = nnunet_trainer(plans=plans, configuration=configuration, fold=fold,
dataset_json=dataset_json, unpack_dataset=not use_compressed, device=device)
dataset_json=dataset_json, unpack_dataset=not use_compressed, device=device,
val_iters=val_iters)
return nnunet_trainer


Expand Down Expand Up @@ -108,12 +110,12 @@ def cleanup_ddp():


def run_ddp(rank, dataset_name_or_id, configuration, fold, tr, p, use_compressed, disable_checkpointing, c, val,
pretrained_weights, npz, val_with_best, world_size):
pretrained_weights, npz, val_with_best, world_size, val_iters):
setup_ddp(rank, world_size)
torch.cuda.set_device(torch.device('cuda', dist.get_rank()))

nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, tr, p,
use_compressed)
use_compressed, val_iters=val_iters)

if disable_checkpointing:
nnunet_trainer.disable_checkpointing = disable_checkpointing
Expand Down Expand Up @@ -147,7 +149,8 @@ def run_training(dataset_name_or_id: Union[str, int],
only_run_validation: bool = False,
disable_checkpointing: bool = False,
val_with_best: bool = False,
device: torch.device = torch.device('cuda')):
device: torch.device = torch.device('cuda'),
val_iters: int = 50):
if isinstance(fold, str):
if fold != 'all':
try:
Expand Down Expand Up @@ -182,12 +185,15 @@ def run_training(dataset_name_or_id: Union[str, int],
pretrained_weights,
export_validation_probabilities,
val_with_best,
num_gpus),
num_gpus,
val_iters
),
nprocs=num_gpus,
join=True)
else:
nnunet_trainer = get_trainer_from_args(dataset_name_or_id, configuration, fold, trainer_class_name,
plans_identifier, use_compressed_data, device=device)
plans_identifier, use_compressed_data, device=device,
val_iters=val_iters)

if disable_checkpointing:
nnunet_trainer.disable_checkpointing = disable_checkpointing
Expand Down Expand Up @@ -249,6 +255,9 @@ def run_training_entry():
help="Use this to set the device the training should run with. Available options are 'cuda' "
"(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! "
"Use CUDA_VISIBLE_DEVICES=X nnUNetv2_train [...] instead!")
parser.add_argument('-val_iters', type=int, default=50, required=False,
help='Use this to adjust the number of pseudo-validation steps. Using a lower value increases '
'the validation speed but decreases the confidence in the pseudo-validation dice.')
args = parser.parse_args()

assert args.device in ['cpu', 'cuda', 'mps'], f'-device must be either cpu, mps or cuda. Other devices are not tested/supported. Got: {args.device}.'
Expand All @@ -266,8 +275,8 @@ def run_training_entry():
device = torch.device('mps')

run_training(args.dataset_name_or_id, args.configuration, args.fold, args.tr, args.p, args.pretrained_weights,
args.num_gpus, args.use_compressed, args.npz, args.c, args.val, args.disable_checkpointing, args.val_best,
device=device)
args.num_gpus, args.use_compressed, args.npz, args.c, args.val, args.disable_checkpointing,
args.val_best, device=device, val_iters=args.val_iters)


if __name__ == '__main__':
Expand Down
6 changes: 3 additions & 3 deletions nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@
from torch.nn.parallel import DistributedDataParallel as DDP


class nnUNetTrainer(object):
class nnUNetTrainer:
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
device: torch.device = torch.device('cuda'), val_iters: int = 50):
# From https://grugbrain.dev/. Worth a read ya big brains ;-)

# apex predator of grug is complexity
Expand Down Expand Up @@ -145,7 +145,7 @@ def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dic
self.weight_decay = 3e-5
self.oversample_foreground_percent = 0.33
self.num_iterations_per_epoch = 250
self.num_val_iterations_per_epoch = 50
self.num_val_iterations_per_epoch = val_iters
self.num_epochs = 1000
self.current_epoch = 0
self.enable_deep_supervision = True
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

class nnUNetTrainerBenchmark_5epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
device: torch.device = torch.device('cuda'), val_iters: int = 50):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
assert self.fold == 0, "It makes absolutely no sense to specify a certain fold. Stick with 0 so that we can parse the results."
self.disable_checkpointing = True
self.num_epochs = 5
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(
dataset_json: dict,
unpack_dataset: bool = True,
device: torch.device = torch.device("cuda"),
val_iters: int = 50,
):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.enable_deep_supervision = False
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ class nnUNetTrainer_probabilisticOversampling(nnUNetTrainer):
If we switch to this oversampling then we can keep it at a constant 0.33 or whatever.
"""
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
device: torch.device = torch.device('cuda'), val_iters: int = 50):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.oversample_foreground_percent = float(np.mean(
[not sample_idx < round(self.configuration_manager.batch_size * (1 - self.oversample_foreground_percent))
for sample_idx in range(self.configuration_manager.batch_size)]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,72 +5,72 @@

class nnUNetTrainer_5epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
device: torch.device = torch.device('cuda'), val_iters: int = 50):
"""used for debugging plans etc"""
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 5


class nnUNetTrainer_1epoch(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
device: torch.device = torch.device('cuda'), val_iters: int = 50):
"""used for debugging plans etc"""
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 1


class nnUNetTrainer_10epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
device: torch.device = torch.device('cuda'), val_iters: int = 50):
"""used for debugging plans etc"""
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 10


class nnUNetTrainer_20epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
device: torch.device = torch.device('cuda'), val_iters: int = 50):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 20


class nnUNetTrainer_50epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
device: torch.device = torch.device('cuda'), val_iters: int = 50):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 50


class nnUNetTrainer_100epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
device: torch.device = torch.device('cuda'), val_iters: int = 50):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 100


class nnUNetTrainer_250epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
device: torch.device = torch.device('cuda'), val_iters: int = 50):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 250


class nnUNetTrainer_2000epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
device: torch.device = torch.device('cuda'), val_iters: int = 50):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 2000


class nnUNetTrainer_4000epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
device: torch.device = torch.device('cuda'), val_iters: int = 50):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 4000


class nnUNetTrainer_8000epochs(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
device: torch.device = torch.device('cuda'), val_iters: int = 50):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 8000
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

class nnUNetTrainer_250epochs_NoMirroring(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
device: torch.device = torch.device('cuda'), val_iters: int = 50):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 250

def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
Expand All @@ -19,8 +19,8 @@ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):

class nnUNetTrainer_2000epochs_NoMirroring(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
device: torch.device = torch.device('cuda'), val_iters: int = 50):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 2000

def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
Expand All @@ -33,8 +33,8 @@ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):

class nnUNetTrainer_4000epochs_NoMirroring(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
device: torch.device = torch.device('cuda'), val_iters: int = 50):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 4000

def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
Expand All @@ -47,8 +47,8 @@ def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):

class nnUNetTrainer_8000epochs_NoMirroring(nnUNetTrainer):
def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, unpack_dataset: bool = True,
device: torch.device = torch.device('cuda')):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device)
device: torch.device = torch.device('cuda'), val_iters: int = 50):
super().__init__(plans, configuration, fold, dataset_json, unpack_dataset, device, val_iters)
self.num_epochs = 8000

def configure_rotation_dummyDA_mirroring_and_inital_patch_size(self):
Expand Down
Loading