Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New feature: adding a parameter to control the number of processes used by the validation dataloader #2053

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions nnunetv2/configuration.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
import os

from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA

default_num_processes = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc'])

ANISO_THRESHOLD = 3 # determines when a sample is considered anisotropic (3 means that the spacing in the low
# resolution axis must be 3x as large as the next largest spacing)

default_n_proc_DA = get_allowed_n_proc_DA()
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from nnunetv2.preprocessing.normalization.map_channel_name_to_normalization import get_normalization_scheme
from nnunetv2.preprocessing.resampling.default_resampling import resample_data_or_seg_to_shape, compute_new_shape
from nnunetv2.utilities.dataset_name_id_conversion import maybe_convert_to_dataset_name
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
from nnunetv2.utilities.json_export import recursive_fix_for_json_export
from nnunetv2.utilities.utils import get_filenames_of_train_images_and_targets
Expand Down Expand Up @@ -100,14 +99,10 @@ def static_estimate_VRAM_usage(patch_size: Tuple[int],
"""
Works for PlainConvUNet, ResidualEncoderUNet
"""
a = torch.get_num_threads()
torch.set_num_threads(get_allowed_n_proc_DA())
# print(f'instantiating network, patch size {patch_size}, pool op: {arch_kwargs["strides"]}')
net = get_network_from_plans(arch_class_name, arch_kwargs, arch_kwargs_req_import, input_channels,
output_channels,
allow_init=False)
ret = net.compute_conv_feature_map_size(patch_size)
torch.set_num_threads(a)
return ret

def determine_resampling(self, *args, **kwargs):
Expand Down
19 changes: 13 additions & 6 deletions nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
from nnunetv2.training.lr_scheduler.polylr import PolyLRScheduler
from nnunetv2.utilities.collate_outputs import collate_outputs
from nnunetv2.utilities.crossval_split import generate_crossval_split
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA, get_allowed_n_proc_DA_val
from nnunetv2.utilities.file_path_utilities import check_workers_alive_and_busy
from nnunetv2.utilities.get_network_from_plans import get_network_from_plans
from nnunetv2.utilities.helpers import empty_cache, dummy_context
Expand Down Expand Up @@ -635,16 +635,23 @@ def get_dataloaders(self):

dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)

allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms)

def init_dataloaders(self, dl_tr, tr_transforms, dl_val, val_transforms):
num_processes_train = get_allowed_n_proc_DA()
if num_processes_train == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, data_loader=dl_tr, transform=tr_transforms,
num_processes=allowed_num_processes, num_cached=6, seeds=None,
num_processes=num_processes_train, num_cached=6, seeds=None,
pin_memory=self.device.type == 'cuda', wait_time=0.02)

num_processes_val = get_allowed_n_proc_DA_val()
if num_processes_val == 0:
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, data_loader=dl_val,
transform=val_transforms, num_processes=max(1, allowed_num_processes // 2),
transform=val_transforms, num_processes=num_processes_val,
num_cached=3, seeds=None, pin_memory=self.device.type == 'cuda',
wait_time=0.02)
return mt_gen_train, mt_gen_val
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import numpy as np
import torch
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter
from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose
from batchgenerators.transforms.color_transforms import BrightnessTransform, ContrastAugmentationTransform, \
GammaTransform
Expand All @@ -21,15 +20,12 @@
ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform
from nnunetv2.training.data_augmentation.custom_transforms.deep_supervision_donwsampling import \
DownsampleSegForDSTransform2
from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \
LimitedLenWrapper
from nnunetv2.training.data_augmentation.custom_transforms.masking import MaskTransform
from nnunetv2.training.data_augmentation.custom_transforms.region_based_training import \
ConvertSegmentationToRegionsTransform
from nnunetv2.training.data_augmentation.custom_transforms.transforms_for_dummy_2d import Convert3DTo2DTransform, \
Convert2DTo3DTransform
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA


class nnUNetTrainerDA5(nnUNetTrainer):
Expand Down Expand Up @@ -338,17 +334,7 @@ def get_dataloaders(self):

dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)

allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
allowed_num_processes, 6, None, True, 0.02)
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
max(1, allowed_num_processes // 2), 3, None, True, 0.02)

return mt_gen_train, mt_gen_val
return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms)


def _brightnessadditive_localgamma_transform_scale(x, y):
Expand Down Expand Up @@ -399,17 +385,7 @@ def get_dataloaders(self):

dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)

allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
allowed_num_processes, 6, None, True, 0.02)
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
max(1, allowed_num_processes // 2), 3, None, True, 0.02)

return mt_gen_train, mt_gen_val
return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms)


class nnUNetTrainerDA5_10epochs(nnUNetTrainerDA5):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
from batchgenerators.dataloading.single_threaded_augmenter import SingleThreadedAugmenter

from nnunetv2.training.data_augmentation.custom_transforms.limited_length_multithreaded_augmenter import \
LimitedLenWrapper
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from nnunetv2.utilities.default_n_proc_DA import get_allowed_n_proc_DA


class nnUNetTrainerDAOrd0(nnUNetTrainer):
Expand Down Expand Up @@ -42,17 +37,7 @@ def get_dataloaders(self):

dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)

allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
allowed_num_processes, 6, None, True, 0.02)
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
max(1, allowed_num_processes // 2), 3, None, True, 0.02)

return mt_gen_train, mt_gen_val
return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms)


class nnUNetTrainer_DASegOrd0(nnUNetTrainer):
Expand Down Expand Up @@ -91,17 +76,7 @@ def get_dataloaders(self):

dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)

allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
allowed_num_processes, 6, None, True, 0.02)
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
max(1, allowed_num_processes // 2), 3, None, True, 0.02)

return mt_gen_train, mt_gen_val
return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms)


class nnUNetTrainer_DASegOrd0_NoMirroring(nnUNetTrainer):
Expand Down Expand Up @@ -144,14 +119,4 @@ def get_dataloaders(self):

dl_tr, dl_val = self.get_plain_dataloaders(initial_patch_size, dim)

allowed_num_processes = get_allowed_n_proc_DA()
if allowed_num_processes == 0:
mt_gen_train = SingleThreadedAugmenter(dl_tr, tr_transforms)
mt_gen_val = SingleThreadedAugmenter(dl_val, val_transforms)
else:
mt_gen_train = LimitedLenWrapper(self.num_iterations_per_epoch, dl_tr, tr_transforms,
allowed_num_processes, 6, None, True, 0.02)
mt_gen_val = LimitedLenWrapper(self.num_val_iterations_per_epoch, dl_val, val_transforms,
max(1, allowed_num_processes // 2), 3, None, True, 0.02)

return mt_gen_train, mt_gen_val
return self.init_dataloaders(dl_tr, tr_transforms, dl_val, val_transforms)
10 changes: 10 additions & 0 deletions nnunetv2/utilities/default_n_proc_DA.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
import os


def get_allowed_n_proc_DA_val():
"""
This function is used to set the number of processes used for the validation data loader. When nnUNet_n_proc_DA_val
is 0, the validation data is loaded sequentially in the main process.
"""
if 'nnUNet_n_proc_DA_val' in os.environ.keys():
return int(os.environ['nnUNet_n_proc_DA_val'])
return get_allowed_n_proc_DA() // 2


def get_allowed_n_proc_DA():
"""
This function is used to set the number of processes used on different Systems. It is specific to our cluster
Expand Down
Loading