diff --git a/nnunetv2/inference/data_iterators.py b/nnunetv2/inference/data_iterators.py index 2486bf6df..960b648f3 100644 --- a/nnunetv2/inference/data_iterators.py +++ b/nnunetv2/inference/data_iterators.py @@ -58,6 +58,44 @@ def preprocess_fromfiles_save_to_queue(list_of_lists: List[List[str]], raise e +def preprocess_fromfiles_noqueue(list_of_lists: List[List[str]], + list_of_segs_from_prev_stage_files: Union[None, List[str]], + output_filenames_truncated: Union[None, List[str]], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + verbose: bool = False): + + print("Running preprocessing in non-multiprocessing mode") + + data_iterator = [] + label_manager = plans_manager.get_label_manager(dataset_json) + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) + + for idx in range(len(list_of_lists)): + + input_files = list_of_lists[idx] + seg_file = list_of_segs_from_prev_stage_files[idx] if list_of_segs_from_prev_stage_files is not None else None + output_file = output_filenames_truncated[idx] if output_filenames_truncated is not None else None + + data, seg, data_properties = preprocessor.run_case(input_files, seg_file, plans_manager, configuration_manager, dataset_json) + + if list_of_segs_from_prev_stage_files is not None and list_of_segs_from_prev_stage_files[idx] is not None: + seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype) + data = np.vstack((data, seg_onehot)) + + data = torch.from_numpy(data).to(dtype=torch.float32, memory_format=torch.contiguous_format) + + preprocessed_data = { + 'data': data, + 'data_properties': data_properties, + 'ofile': output_file + } + + data_iterator.append(preprocessed_data) + + return data_iterator + def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]], list_of_segs_from_prev_stage_files: Union[None, List[str]], output_filenames_truncated: Union[None, List[str]], @@ -67,56 +105,64 @@ def preprocessing_iterator_fromfiles(list_of_lists: List[List[str]], num_processes: int, pin_memory: bool = False, verbose: bool = False): - context = multiprocessing.get_context('spawn') - manager = Manager() - num_processes = min(len(list_of_lists), num_processes) - assert num_processes >= 1 - processes = [] - done_events = [] - target_queues = [] - abort_event = manager.Event() - for i in range(num_processes): - event = manager.Event() - queue = Manager().Queue(maxsize=1) - pr = context.Process(target=preprocess_fromfiles_save_to_queue, - args=( - list_of_lists[i::num_processes], - list_of_segs_from_prev_stage_files[ - i::num_processes] if list_of_segs_from_prev_stage_files is not None else None, - output_filenames_truncated[ - i::num_processes] if output_filenames_truncated is not None else None, - plans_manager, - dataset_json, - configuration_manager, - queue, - event, - abort_event, - verbose - ), daemon=True) - pr.start() - target_queues.append(queue) - done_events.append(event) - processes.append(pr) - - worker_ctr = 0 - while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): - # import IPython;IPython.embed() - if not target_queues[worker_ctr].empty(): - item = target_queues[worker_ctr].get() - worker_ctr = (worker_ctr + 1) % num_processes - else: - all_ok = all( - [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() - if not all_ok: - raise RuntimeError('Background workers died. Look for the error message further up! If there is ' - 'none then your RAM was full and the worker was killed by the OS. Use fewer ' - 'workers or get more RAM in that case!') - sleep(0.01) - continue - if pin_memory: - [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] - yield item - [p.join() for p in processes] + + if num_processes > 1: + context = multiprocessing.get_context('spawn') + manager = Manager() + num_processes = min(len(list_of_lists), num_processes) + assert num_processes >= 1 + processes = [] + done_events = [] + target_queues = [] + abort_event = manager.Event() + for i in range(num_processes): + event = manager.Event() + queue = Manager().Queue(maxsize=1) + pr = context.Process(target=preprocess_fromfiles_save_to_queue, + args=( + list_of_lists[i::num_processes], + list_of_segs_from_prev_stage_files[ + i::num_processes] if list_of_segs_from_prev_stage_files is not None else None, + output_filenames_truncated[ + i::num_processes] if output_filenames_truncated is not None else None, + plans_manager, + dataset_json, + configuration_manager, + queue, + event, + abort_event, + verbose + ), daemon=True) + pr.start() + target_queues.append(queue) + done_events.append(event) + processes.append(pr) + + worker_ctr = 0 + while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): + if not target_queues[worker_ctr].empty(): + item = target_queues[worker_ctr].get() + worker_ctr = (worker_ctr + 1) % num_processes + else: + all_ok = all( + [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() + if not all_ok: + raise RuntimeError('Background workers died. Look for the error message further up! If there is ' + 'none then your RAM was full and the worker was killed by the OS. Use fewer ' + 'workers or get more RAM in that case!') + sleep(0.01) + continue + if pin_memory: + [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] + yield item + [p.join() for p in processes] + else: + print("Running preprocessing in non-multiprocessing mode") + data_iterator = preprocess_fromfiles_noqueue(list_of_lists, list_of_segs_from_prev_stage_files, output_filenames_truncated, plans_manager, dataset_json, configuration_manager, verbose=verbose) + for item in data_iterator: + if pin_memory: + [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] + yield item class PreprocessAdapter(DataLoader): @@ -253,6 +299,37 @@ def preprocess_fromnpy_save_to_queue(list_of_images: List[np.ndarray], raise e +def preprocess_fromnpy_noqueue(list_of_images: List[np.ndarray], + list_of_segs_from_prev_stage: Union[List[np.ndarray], None], + list_of_image_properties: List[dict], + truncated_ofnames: Union[List[str], None], + plans_manager: PlansManager, + dataset_json: dict, + configuration_manager: ConfigurationManager, + verbose: bool = False): + print("Running preprocessing in non-multiprocessing mode") + data_iterator = [] + label_manager = plans_manager.get_label_manager(dataset_json) + preprocessor = configuration_manager.preprocessor_class(verbose=verbose) + + for i in range(len(list_of_images)): + image = list_of_images[i] + seg_prev_stage = list_of_segs_from_prev_stage[i] if list_of_segs_from_prev_stage is not None else None + props = list_of_image_properties[i] + ofname = truncated_ofnames[i] if truncated_ofnames is not None else None + data, seg = preprocessor.run_case_npy(image, seg_prev_stage, props, plans_manager, configuration_manager, dataset_json) + if seg_prev_stage is not None: + seg_onehot = convert_labelmap_to_one_hot(seg[0], label_manager.foreground_labels, data.dtype) + data = np.vstack((data, seg_onehot)) + preprocessed_data = { + 'data': data, + 'data_properties': props, + 'ofile': ofname if ofname is not None else None + } + data_iterator.append(preprocessed_data) + return data_iterator + + def preprocessing_iterator_fromnpy(list_of_images: List[np.ndarray], list_of_segs_from_prev_stage: Union[List[np.ndarray], None], list_of_image_properties: List[dict], @@ -263,52 +340,60 @@ def preprocessing_iterator_fromnpy(list_of_images: List[np.ndarray], num_processes: int, pin_memory: bool = False, verbose: bool = False): - context = multiprocessing.get_context('spawn') - manager = Manager() - num_processes = min(len(list_of_images), num_processes) - assert num_processes >= 1 - target_queues = [] - processes = [] - done_events = [] - abort_event = manager.Event() - for i in range(num_processes): - event = manager.Event() - queue = manager.Queue(maxsize=1) - pr = context.Process(target=preprocess_fromnpy_save_to_queue, - args=( - list_of_images[i::num_processes], - list_of_segs_from_prev_stage[ - i::num_processes] if list_of_segs_from_prev_stage is not None else None, - list_of_image_properties[i::num_processes], - truncated_ofnames[i::num_processes] if truncated_ofnames is not None else None, - plans_manager, - dataset_json, - configuration_manager, - queue, - event, - abort_event, - verbose - ), daemon=True) - pr.start() - done_events.append(event) - processes.append(pr) - target_queues.append(queue) - - worker_ctr = 0 - while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): - if not target_queues[worker_ctr].empty(): - item = target_queues[worker_ctr].get() - worker_ctr = (worker_ctr + 1) % num_processes - else: - all_ok = all( - [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() - if not all_ok: - raise RuntimeError('Background workers died. Look for the error message further up! If there is ' - 'none then your RAM was full and the worker was killed by the OS. Use fewer ' - 'workers or get more RAM in that case!') - sleep(0.01) - continue - if pin_memory: - [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] - yield item - [p.join() for p in processes] + if num_processes > 1: + context = multiprocessing.get_context('spawn') + manager = Manager() + num_processes = min(len(list_of_images), num_processes) + assert num_processes >= 1 + target_queues = [] + processes = [] + done_events = [] + abort_event = manager.Event() + for i in range(num_processes): + event = manager.Event() + queue = manager.Queue(maxsize=1) + pr = context.Process(target=preprocess_fromnpy_save_to_queue, + args=( + list_of_images[i::num_processes], + list_of_segs_from_prev_stage[ + i::num_processes] if list_of_segs_from_prev_stage is not None else None, + list_of_image_properties[i::num_processes], + truncated_ofnames[i::num_processes] if truncated_ofnames is not None else None, + plans_manager, + dataset_json, + configuration_manager, + queue, + event, + abort_event, + verbose + ), daemon=True) + pr.start() + done_events.append(event) + processes.append(pr) + target_queues.append(queue) + + worker_ctr = 0 + while (not done_events[worker_ctr].is_set()) or (not target_queues[worker_ctr].empty()): + if not target_queues[worker_ctr].empty(): + item = target_queues[worker_ctr].get() + worker_ctr = (worker_ctr + 1) % num_processes + else: + all_ok = all( + [i.is_alive() or j.is_set() for i, j in zip(processes, done_events)]) and not abort_event.is_set() + if not all_ok: + raise RuntimeError('Background workers died. Look for the error message further up! If there is ' + 'none then your RAM was full and the worker was killed by the OS. Use fewer ' + 'workers or get more RAM in that case!') + sleep(0.01) + continue + if pin_memory: + [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] + yield item + [p.join() for p in processes] + else: + print("Running preprocessing in non-multiprocessing mode") + data_iterator = preprocess_fromnpy_noqueue(list_of_images, list_of_segs_from_prev_stage, list_of_image_properties, truncated_ofnames, plans_manager, dataset_json, configuration_manager, verbose=verbose) + for item in data_iterator: + if pin_memory: + [i.pin_memory() for i in item.values() if isinstance(i, torch.Tensor)] + yield item diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index 1f5ede64f..ac8dfd8f1 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -345,9 +345,76 @@ def predict_from_data_iterator(self, each element returned by data_iterator must be a dict with 'data', 'ofile' and 'data_properties' keys! If 'ofile' is None, the result will be returned instead of written to a file """ - with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool: - worker_list = [i for i in export_pool._pool] - r = [] + use_multiprocessing = num_processes_segmentation_export > 0 + + if use_multiprocessing: + with multiprocessing.get_context("spawn").Pool(num_processes_segmentation_export) as export_pool: + worker_list = [i for i in export_pool._pool] + r = [] + for preprocessed in data_iterator: + data = preprocessed['data'] + if isinstance(data, str): + delfile = data + data = torch.from_numpy(np.load(data)) + os.remove(delfile) + + ofile = preprocessed['ofile'] + if ofile is not None: + print(f'\nPredicting {os.path.basename(ofile)}:') + else: + print(f'\nPredicting image of shape {data.shape}:') + + print(f'perform_everything_on_device: {self.perform_everything_on_device}') + + properties = preprocessed['data_properties'] + + # let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with + # npy files + proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) + while not proceed: + sleep(0.1) + proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) + + prediction = self.predict_logits_from_preprocessed_data(data).cpu() + + if ofile is not None: + # this needs to go into background processes + # export_prediction_from_logits(prediction, properties, self.configuration_manager, self.plans_manager, + # self.dataset_json, ofile, save_probabilities) + print('sending off prediction to background worker for resampling and export') + r.append( + export_pool.starmap_async( + export_prediction_from_logits, + ((prediction, properties, self.configuration_manager, self.plans_manager, + self.dataset_json, ofile, save_probabilities),) + ) + ) + else: + # convert_predicted_logits_to_segmentation_with_correct_shape( + # prediction, self.plans_manager, + # self.configuration_manager, self.label_manager, + # properties, + # save_probabilities) + + print('sending off prediction to background worker for resampling') + r.append( + export_pool.starmap_async( + convert_predicted_logits_to_segmentation_with_correct_shape, ( + (prediction, self.plans_manager, + self.configuration_manager, self.label_manager, + properties, + save_probabilities),) + ) + ) + if ofile is not None: + print(f'done with {os.path.basename(ofile)}') + else: + print(f'\nDone with image of shape {data.shape}:') + ret = [i.get()[0] for i in r] + + else: + print("Running in non-multiprocessing mode") + ret = [] for preprocessed in data_iterator: data = preprocessed['data'] if isinstance(data, str): @@ -365,49 +432,22 @@ def predict_from_data_iterator(self, properties = preprocessed['data_properties'] - # let's not get into a runaway situation where the GPU predicts so fast that the disk has to b swamped with - # npy files - proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) - while not proceed: - sleep(0.1) - proceed = not check_workers_alive_and_busy(export_pool, worker_list, r, allowed_num_queued=2) - prediction = self.predict_logits_from_preprocessed_data(data).cpu() if ofile is not None: - # this needs to go into background processes - # export_prediction_from_logits(prediction, properties, self.configuration_manager, self.plans_manager, - # self.dataset_json, ofile, save_probabilities) - print('sending off prediction to background worker for resampling and export') - r.append( - export_pool.starmap_async( - export_prediction_from_logits, - ((prediction, properties, self.configuration_manager, self.plans_manager, - self.dataset_json, ofile, save_probabilities),) - ) - ) + export_prediction_from_logits(prediction, properties, self.configuration_manager, self.plans_manager, + self.dataset_json, ofile, save_probabilities) else: - # convert_predicted_logits_to_segmentation_with_correct_shape( - # prediction, self.plans_manager, - # self.configuration_manager, self.label_manager, - # properties, - # save_probabilities) - - print('sending off prediction to background worker for resampling') - r.append( - export_pool.starmap_async( - convert_predicted_logits_to_segmentation_with_correct_shape, ( - (prediction, self.plans_manager, - self.configuration_manager, self.label_manager, - properties, - save_probabilities),) - ) - ) + ret.append(convert_predicted_logits_to_segmentation_with_correct_shape(prediction, self.plans_manager, + self.configuration_manager, + self.label_manager, + properties, + save_probabilities)) + if ofile is not None: print(f'done with {os.path.basename(ofile)}') else: print(f'\nDone with image of shape {data.shape}:') - ret = [i.get()[0] for i in r] if isinstance(data_iterator, MultiThreadedAugmenter): data_iterator._finish() @@ -658,6 +698,12 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor) \ predicted_logits = predicted_logits[(slice(None), *slicer_revert_padding[1:])] return predicted_logits +def _getDefaultValue(env: str, dtype: type, default: any,) -> any: + try: + val = dtype(os.environ.get(env) or default) + except: + val = default + return val def predict_entry_point_modelfolder(): import argparse @@ -692,10 +738,10 @@ def predict_entry_point_modelfolder(): help='Continue an aborted previous prediction (will not overwrite existing files)') parser.add_argument('-chk', type=str, required=False, default='checkpoint_final.pth', help='Name of the checkpoint you want to use. Default: checkpoint_final.pth') - parser.add_argument('-npp', type=int, required=False, default=3, + parser.add_argument('-npp', type=int, required=False, default=_getDefaultValue('nnUNet_npp', int, 3), help='Number of processes used for preprocessing. More is not always better. Beware of ' 'out-of-RAM issues. Default: 3') - parser.add_argument('-nps', type=int, required=False, default=3, + parser.add_argument('-nps', type=int, required=False, default=_getDefaultValue('nnUNet_nps', int, 3), help='Number of processes used for segmentation export. More is not always better. Beware of ' 'out-of-RAM issues. Default: 3') parser.add_argument('-prev_stage_predictions', type=str, required=False, default=None,