From 69fd9a761d8b24efb71c3ae1258aea4923e6639e Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Thu, 21 Nov 2024 08:15:37 -0500 Subject: [PATCH 1/6] Non-mp predict support [WIP] --- nnunetv2/inference/data_iterators.py | 265 ++++++++++++-------- nnunetv2/inference/predict_from_raw_data.py | 116 ++++++--- 2 files changed, 244 insertions(+), 137 deletions(-) diff --git a/nnunetv2/inference/data_iterators.py b/nnunetv2/inference/data_iterators.py index 2486bf6df..25dc0dc38 100644 --- a/nnunetv2/inference/data_iterators.py +++ b/nnunetv2/inference/data_iterators.py @@ -58,6 +58,29 @@ def preprocess_fromfiles_save_to_queue(list_of_lists: List[List[str]], raise e +def preprocess_fromfiles_noqueue(list_of_lists: List[List[str]], + list_of_segs_from_prev_stage_files: Union[None, List[str]], + output_filenames_truncated: Union[None, List[str]], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + verbose: bool = False): + print("Running preprocessing in non-multiprocessing mode") + data_iterator = [] + for i in range(len(list_of_lists)): + input_files = list_of_lists[i] + seg_file = list_of_segs_from_prev_stage_files[i] if list_of_segs_from_prev_stage_files is not None else None + output_file = output_filenames_truncated[i] if output_filenames_truncated is not None else None + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) + data, seg, data_properties = preprocessor.run_case(input_files, seg_file, plans_manager, configuration_manager, dataset_json) + preprocessed_data = { + 'data': data, + 'data_properties': data_properties, + 'ofile': output_file + } + data_iterator.append(preprocessed_data) + return data_iterator + def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]], list_of_segs_from_prev_stage_files: Union[None, List[str]], output_filenames_truncated: Union[None, List[str]], @@ -67,56 +90,63 @@ def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]], num_processes: int, pin_memory: bool = False, verbose: bool = False): - context = multiprocessing.get_context('spawn') - manager = Manager() - num_processes = min(len(list_of_lists), num_processes) - assert num_processes >= 1 - processes = [] - done_events = [] - target_queues = [] - abort_event = manager.Event() - for i in range(num_processes): - event = manager.Event() - queue = Manager().Queue(maxsize=1) - pr = context.Process(target=preprocess_fromfiles_save_to_queue, - args=( - list_of_lists[i::num_processes], - list_of_segs_from_prev_stage_files[ - i::num_processes] if list_of_segs_from_prev_stage_files is not None else None, - output_filenames_truncated[ - i::num_processes] if output_filenames_truncated is not None else None, - plans_manager, - dataset_json, - configuration_manager, - queue, - event, - abort_event, - verbose - ), daemon=True) - pr.start() - target_queues.append(queue) - done_events.append(event) - processes.append(pr) - - worker_ctr = 0 - while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): - # import IPython;IPython.embed() - if not target_queues[worker_ctr].empty(): - item = target_queues[worker_ctr].get() - worker_ctr = (worker_ctr + 1) % num_processes - else: - all_ok = all( - [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() - if not all_ok: - raise RuntimeError('Background workers died. Look for the error message further up! If there is ' - 'none then your RAM was full and the worker was killed by the OS. Use fewer ' - 'workers or get more RAM in that case!') - sleep(0.01) - continue - if pin_memory: - [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] - yield item - [p.join() for p in processes] + if num_processes > 1: + context = multiprocessing.get_context('spawn') + manager = Manager() + num_processes = min(len(list_of_lists), num_processes) + assert num_processes >= 1 + processes = [] + done_events = [] + target_queues = [] + abort_event = manager.Event() + for i in range(num_processes): + event = manager.Event() + queue = Manager().Queue(maxsize=1) + pr = context.Process(target=preprocess_fromfiles_save_to_queue, + args=( + list_of_lists[i::num_processes], + list_of_segs_from_prev_stage_files[ + i::num_processes] if list_of_segs_from_prev_stage_files is not None else None, + output_filenames_truncated[ + i::num_processes] if output_filenames_truncated is not None else None, + plans_manager, + dataset_json, + configuration_manager, + queue, + event, + abort_event, + verbose + ), daemon=True) + pr.start() + target_queues.append(queue) + done_events.append(event) + processes.append(pr) + + worker_ctr = 0 + while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): + if not target_queues[worker_ctr].empty(): + item = target_queues[worker_ctr].get() + worker_ctr = (worker_ctr + 1) % num_processes + else: + all_ok = all( + [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() + if not all_ok: + raise RuntimeError('Background workers died. Look for the error message further up! If there is ' + 'none then your RAM was full and the worker was killed by the OS. Use fewer ' + 'workers or get more RAM in that case!') + sleep(0.01) + continue + if pin_memory: + [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] + yield item + [p.join() for p in processes] + else: + print("Running preprocessing in non-multiprocessing mode") + data_iterator = preprocess_fromfiles_noqueue(list_of_lists, list_of_segs_from_prev_stage_files, output_filenames_truncated, plans_manager, dataset_json, configuration_manager, verbose=verbose) + for item in data_iterator: + if pin_memory: + [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] + yield item class PreprocessAdapter(DataLoader): @@ -253,6 +283,35 @@ def preprocess_fromnpy_save_to_queue(list_of_images: List[np.ndarray], raise e +def preprocess_fromnpy_noqueue(list_of_images: List[np.ndarray], + list_of_segs_from_prev_stage: Union[List[np.ndarray], None], + list_of_image_properties: List[dict], + truncated_ofnames: Union[List[str], None], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + verbose: bool = False): + print("Running preprocessing in non-multiprocessing mode") + data_iterator = [] + for i in range(len(list_of_images)): + image = list_of_images[i] + seg_prev_stage = list_of_segs_from_prev_stage[i] if list_of_segs_from_prev_stage is not None else None + props = list_of_image_properties[i] + ofname = truncated_ofnames[i] if truncated_ofnames is not None else None + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) + data, seg = preprocessor.run_case_npy(image, seg_prev_stage, props, plans_manager, configuration_manager, dataset_json) + if seg_prev_stage is not None: + seg_onehot = convert_labelmap_to_one_hot(seg[0], plans_manager.get_label_manager(dataset_json).foreground_labels, data.dtype) + data = np.vstack((data, seg_onehot)) + preprocessed_data = { + 'data': data, + 'data_properties': props, + 'ofile': ofname if ofname is not None else None + } + data_iterator.append(preprocessed_data) + return data_iterator + + def preprocessing_iterator_fromnpy(list_of_images: List[np.ndarray], list_of_segs_from_prev_stage: Union[List[np.ndarray], None], list_of_image_properties: List[dict], @@ -263,52 +322,60 @@ def preprocessing_iterator_fromnpy(list_of_images: List[np.ndarray], num_processes: int, pin_memory: bool = False, verbose: bool = False): - context = multiprocessing.get_context('spawn') - manager = Manager() - num_processes = min(len(list_of_images), num_processes) - assert num_processes >= 1 - target_queues = [] - processes = [] - done_events = [] - abort_event = manager.Event() - for i in range(num_processes): - event = manager.Event() - queue = manager.Queue(maxsize=1) - pr = context.Process(target=preprocess_fromnpy_save_to_queue, - args=( - list_of_images[i::num_processes], - list_of_segs_from_prev_stage[ - i::num_processes] if list_of_segs_from_prev_stage is not None else None, - list_of_image_properties[i::num_processes], - truncated_ofnames[i::num_processes] if truncated_ofnames is not None else None, - plans_manager, - dataset_json, - configuration_manager, - queue, - event, - abort_event, - verbose - ), daemon=True) - pr.start() - done_events.append(event) - processes.append(pr) - target_queues.append(queue) - - worker_ctr = 0 - while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): - if not target_queues[worker_ctr].empty(): - item = target_queues[worker_ctr].get() - worker_ctr = (worker_ctr + 1) % num_processes - else: - all_ok = all( - [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() - if not all_ok: - raise RuntimeError('Background workers died. Look for the error message further up! If there is ' - 'none then your RAM was full and the worker was killed by the OS. Use fewer ' - 'workers or get more RAM in that case!') - sleep(0.01) - continue - if pin_memory: - [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] - yield item - [p.join() for p in processes] + if num_processes > 1: + context = multiprocessing.get_context('spawn') + manager = Manager() + num_processes = min(len(list_of_images), num_processes) + assert num_processes >= 1 + target_queues = [] + processes = [] + done_events = [] + abort_event = manager.Event() + for i in range(num_processes): + event = manager.Event() + queue = manager.Queue(maxsize=1) + pr = context.Process(target=preprocess_fromnpy_save_to_queue, + args=( + list_of_images[i::num_processes], + list_of_segs_from_prev_stage[ + i::num_processes] if list_of_segs_from_prev_stage is not None else None, + list_of_image_properties[i::num_processes], + truncated_ofnames[i::num_processes] if truncated_ofnames is not None else None, + plans_manager, + dataset_json, + configuration_manager, + queue, + event, + abort_event, + verbose + ), daemon=True) + pr.start() + done_events.append(event) + processes.append(pr) + target_queues.append(queue) + + worker_ctr = 0 + while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): + if not target_queues[worker_ctr].empty(): + item = target_queues[worker_ctr].get() + worker_ctr = (worker_ctr + 1) % num_processes + else: + all_ok = all( + [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() + if not all_ok: + raise RuntimeError('Background workers died. Look for the error message further up! If there is ' + 'none then your RAM was full and the worker was killed by the OS. Use fewer ' + 'workers or get more RAM in that case!') + sleep(0.01) + continue + if pin_memory: + [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] + yield item + [p.join() for p in processes] + else: + print("Running preprocessing in non-multiprocessing mode") + data_iterator = preprocess_fromnpy_noqueue(list_of_images, list_of_segs_from_prev_stage, list_of_image_properties, truncated_ofnames, plans_manager, dataset_json, configuration_manager, verbose=verbose) + for item in data_iterator: + if pin_memory: + [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] + yield item diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 1f5ede64f..f5e2247cc 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -345,9 +345,76 @@ def predict_from_data_iterator(self, each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys! If 'ofile' is None, the result will be returned instead of written to a file """ - with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool: - worker_list = [i for i in export_pool._pool] - r = [] + use_multiprocessing = num_processes_segmentation_export > 0 + + if use_multiprocessing: + with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool: + worker_list = [i for i in export_pool._pool] + r = [] + for preprocessed in data_iterator: + data = preprocessed['data'] + if isinstance(data, str): + delfile = data + data = torch.from_numpy(np.load(data)) + os.remove(delfile) + + ofile = preprocessed['ofile'] + if ofile is not None: + print(f'\nPredicting {os.path.basename(ofile)}:') + else: + print(f'\nPredicting image of shape {data.shape}:') + + print(f'perform_everything_on_device: {self.perform_everything_on_device}') + + properties = preprocessed['data_properties'] + + # let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with + # npy files + proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) + while not proceed: + sleep(0.1) + proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) + + prediction = self.predict_logits_from_preprocessed_data(data).cpu() + + if ofile is not None: + # this needs to go into background processes + # export_prediction_from_logits(prediction, properties, self.configuration_manager, self.plans_manager, + # self.dataset_json, ofile, save_probabilities) + print('sending off prediction to background worker for resampling and export') + r.append( + export_pool.starmap_async( + export_prediction_from_logits, + ((prediction, properties, self.configuration_manager, self.plans_manager, + self.dataset_json, ofile, save_probabilities),) + ) + ) + else: + # convert_predicted_logits_to_segmentation_with_correct_shape( + # prediction, self.plans_manager, + # self.configuration_manager, self.label_manager, + # properties, + # save_probabilities) + + print('sending off prediction to background worker for resampling') + r.append( + export_pool.starmap_async( + convert_predicted_logits_to_segmentation_with_correct_shape, ( + (prediction, self.plans_manager, + self.configuration_manager, self.label_manager, + properties, + save_probabilities),) + ) + ) + if ofile is not None: + print(f'done with {os.path.basename(ofile)}') + else: + print(f'\nDone with image of shape {data.shape}:') + ret = [i.get()[0] for i in r] + + else: + print("Running in non-multiprocessing mode") + ret = [] for preprocessed in data_iterator: data = preprocessed['data'] if isinstance(data, str): @@ -365,49 +432,22 @@ def predict_from_data_iterator(self, properties = preprocessed['data_properties'] - # let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with - # npy files - proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) - while not proceed: - sleep(0.1) - proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) - prediction = self.predict_logits_from_preprocessed_data(data).cpu() if ofile is not None: - # this needs to go into background processes - # export_prediction_from_logits(prediction, properties, self.configuration_manager, self.plans_manager, - # self.dataset_json, ofile, save_probabilities) - print('sending off prediction to background worker for resampling and export') - r.append( - export_pool.starmap_async( - export_prediction_from_logits, - ((prediction, properties, self.configuration_manager, self.plans_manager, - self.dataset_json, ofile, save_probabilities),) - ) - ) + export_prediction_from_logits(prediction, properties, self.configuration_manager, self.plans_manager, + self.dataset_json, ofile, save_probabilities) else: - # convert_predicted_logits_to_segmentation_with_correct_shape( - # prediction, self.plans_manager, - # self.configuration_manager, self.label_manager, - # properties, - # save_probabilities) - - print('sending off prediction to background worker for resampling') - r.append( - export_pool.starmap_async( - convert_predicted_logits_to_segmentation_with_correct_shape, ( - (prediction, self.plans_manager, - self.configuration_manager, self.label_manager, - properties, - save_probabilities),) - ) - ) + ret.append(convert_predicted_logits_to_segmentation_with_correct_shape(prediction, self.plans_manager, + self.configuration_manager, + self.label_manager, + properties, + save_probabilities)) + if ofile is not None: print(f'done with {os.path.basename(ofile)}') else: print(f'\nDone with image of shape {data.shape}:') - ret = [i.get()[0] for i in r] if isinstance(data_iterator, MultiThreadedAugmenter): data_iterator._finish() From a595f6463a245490056e46edba3957ed02b6655c Mon Sep 17 00:00:00 2001 From: LennyN95 Date: Thu, 21 Nov 2024 16:35:26 +0100 Subject: [PATCH 2/6] add torchification --- nnunetv2/inference/data_iterators.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nnunetv2/inference/data_iterators.py b/nnunetv2/inference/data_iterators.py index 25dc0dc38..7db4e731a 100644 --- a/nnunetv2/inference/data_iterators.py +++ b/nnunetv2/inference/data_iterators.py @@ -73,6 +73,7 @@ def preprocess_fromfiles_noqueue(list_of_lists: List[List[str]], output_file = output_filenames_truncated[i] if output_filenames_truncated is not None else None preprocessor = configuration_manager.preprocessor_class(verbose=verbose) data, seg, data_properties = preprocessor.run_case(input_files, seg_file, plans_manager, configuration_manager, dataset_json) + data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format) preprocessed_data = { 'data': data, 'data_properties': data_properties, From 6cb4e7addaf50b648e7ecb34087842c502164f2f Mon Sep 17 00:00:00 2001 From: LennyN95 Date: Thu, 21 Nov 2024 17:00:53 +0100 Subject: [PATCH 3/6] update preprocess_fromfiles_noqueue --- nnunetv2/inference/data_iterators.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/nnunetv2/inference/data_iterators.py b/nnunetv2/inference/data_iterators.py index 7db4e731a..f8b8a02c2 100644 --- a/nnunetv2/inference/data_iterators.py +++ b/nnunetv2/inference/data_iterators.py @@ -65,21 +65,35 @@ def preprocess_fromfiles_noqueue(list_of_lists: List[List[str]], dataset_json: dict, configuration_manager: ConfigurationManager, verbose: bool = False): + print("Running preprocessing in non-multiprocessing mode") + data_iterator = [] - for i in range(len(list_of_lists)): - input_files = list_of_lists[i] - seg_file = list_of_segs_from_prev_stage_files[i] if list_of_segs_from_prev_stage_files is not None else None - output_file = output_filenames_truncated[i] if output_filenames_truncated is not None else None + label_manager = plans_manager.get_label_manager(dataset_json) + + for idx in range(len(list_of_lists)): + + input_files = list_of_lists[idx] + seg_file = list_of_segs_from_prev_stage_files[idx] if list_of_segs_from_prev_stage_files is not None else None + output_file = output_filenames_truncated[idx] if output_filenames_truncated is not None else None + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) data, seg, data_properties = preprocessor.run_case(input_files, seg_file, plans_manager, configuration_manager, dataset_json) + + if list_of_segs_from_prev_stage_files is not None and list_of_segs_from_prev_stage_files[idx] is not None: + seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype) + data = np.vstack((data, seg_onehot)) + data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format) + preprocessed_data = { 'data': data, 'data_properties': data_properties, 'ofile': output_file } + data_iterator.append(preprocessed_data) + return data_iterator def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]], @@ -91,6 +105,9 @@ def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]], num_processes: int, pin_memory: bool = False, verbose: bool = False): + + num_processes = 1 + if num_processes > 1: context = multiprocessing.get_context('spawn') manager = Manager() From b570e8ae5c830cbcd1c339dda06df470e41caeb0 Mon Sep 17 00:00:00 2001 From: LennyN95 Date: Thu, 21 Nov 2024 18:34:19 +0100 Subject: [PATCH 4/6] dev-only: remove enforced num_processes = 1 --- nnunetv2/inference/data_iterators.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/nnunetv2/inference/data_iterators.py b/nnunetv2/inference/data_iterators.py index f8b8a02c2..f7e2ace53 100644 --- a/nnunetv2/inference/data_iterators.py +++ b/nnunetv2/inference/data_iterators.py @@ -105,9 +105,7 @@ def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]], num_processes: int, pin_memory: bool = False, verbose: bool = False): - - num_processes = 1 - + if num_processes > 1: context = multiprocessing.get_context('spawn') manager = Manager() From 3628055c9a2b16cae0e9b3b8f6f646cf92ff916d Mon Sep 17 00:00:00 2001 From: LennyN95 Date: Thu, 21 Nov 2024 19:28:25 +0100 Subject: [PATCH 5/6] add env to overwrite npp and nps New environment variables: - nnUNet_npp - nnUNet_nps Default values remain unchanged, cli parameter -npp and -nps overwrite environment variables if set. --- nnunetv2/inference/predict_from_raw_data.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index f5e2247cc..ac8dfd8f1 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -698,6 +698,12 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])] return predicted_logits +def _getDefaultValue(env: str, dtype: type, default: any,) -> any: + try: + val = dtype(os.environ.get(env) or default) + except: + val = default + return val def predict_entry_point_modelfolder(): import argparse @@ -732,10 +738,10 @@ def predict_entry_point_modelfolder(): help='Continue an aborted previous prediction (will not overwrite existing files)') parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') - parser.add_argument('-npp', type=int, required=False, default=3, + parser.add_argument('-npp', type=int, required=False, default=_getDefaultValue('nnUNet_npp', int, 3), help='Number of processes used for preprocessing. More is not always better. Beware of ' 'out-of-RAM issues. Default: 3') - parser.add_argument('-nps', type=int, required=False, default=3, + parser.add_argument('-nps', type=int, required=False, default=_getDefaultValue('nnUNet_nps', int, 3), help='Number of processes used for segmentation export. More is not always better. Beware of ' 'out-of-RAM issues. Default: 3') parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None, From 6bcec2199a3e72db83e9ed18e270b7a531818db0 Mon Sep 17 00:00:00 2001 From: Suraj Pai Date: Fri, 22 Nov 2024 06:49:51 -0500 Subject: [PATCH 6/6] Fix preprocessor initialization order --- nnunetv2/inference/data_iterators.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/nnunetv2/inference/data_iterators.py b/nnunetv2/inference/data_iterators.py index f7e2ace53..960b648f3 100644 --- a/nnunetv2/inference/data_iterators.py +++ b/nnunetv2/inference/data_iterators.py @@ -70,6 +70,7 @@ def preprocess_fromfiles_noqueue(list_of_lists: List[List[str]], data_iterator = [] label_manager = plans_manager.get_label_manager(dataset_json) + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) for idx in range(len(list_of_lists)): @@ -77,7 +78,6 @@ def preprocess_fromfiles_noqueue(list_of_lists: List[List[str]], seg_file = list_of_segs_from_prev_stage_files[idx] if list_of_segs_from_prev_stage_files is not None else None output_file = output_filenames_truncated[idx] if output_filenames_truncated is not None else None - preprocessor = configuration_manager.preprocessor_class(verbose=verbose) data, seg, data_properties = preprocessor.run_case(input_files, seg_file, plans_manager, configuration_manager, dataset_json) if list_of_segs_from_prev_stage_files is not None and list_of_segs_from_prev_stage_files[idx] is not None: @@ -309,15 +309,17 @@ def preprocess_fromnpy_noqueue(list_of_images: List[np.ndarray], verbose: bool = False): print("Running preprocessing in non-multiprocessing mode") data_iterator = [] + label_manager = plans_manager.get_label_manager(dataset_json) + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) + for i in range(len(list_of_images)): image = list_of_images[i] seg_prev_stage = list_of_segs_from_prev_stage[i] if list_of_segs_from_prev_stage is not None else None props = list_of_image_properties[i] ofname = truncated_ofnames[i] if truncated_ofnames is not None else None - preprocessor = configuration_manager.preprocessor_class(verbose=verbose) data, seg = preprocessor.run_case_npy(image, seg_prev_stage, props, plans_manager, configuration_manager, dataset_json) if seg_prev_stage is not None: - seg_onehot = convert_labelmap_to_one_hot(seg[0], plans_manager.get_label_manager(dataset_json).foreground_labels, data.dtype) + seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype) data = np.vstack((data, seg_onehot)) preprocessed_data = { 'data': data,