Skip to content

Commit 35eb088

Browse files
authored
Merge pull request #8 from CAAI/raphael-dev
Torchio upgrade - retaining torch compatibility
2 parents 0bb9ef8 + 6931478 commit 35eb088

10 files changed

+597
-97
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,383 @@
1+
from torch.utils.data import Dataset
2+
from rhtorch.data.DataAugmentation3D import DataAugmentation3D
3+
from pathlib import Path
4+
import torch
5+
import json
6+
import pickle
7+
import glob
8+
import numpy as np
9+
import pytorch_lightning as pl
10+
from torch.utils.data import DataLoader
11+
from torchio import ScalarImage, Subject, SubjectsDataset, Queue
12+
from torchio.transforms import Lambda, RandomAffine, RandomFlip, Compose
13+
from torchio.data import UniformSampler
14+
from sklearn.model_selection import train_test_split
15+
16+
augmenter = DataAugmentation3D(rotation_range=[5, 5, 5],
17+
shift_range=[0.05, 0.05, 0.05],
18+
shear_range=[2, 2, 0],
19+
zoom_lower=[0.9, 0.9, 0.9],
20+
zoom_upper=[1.2, 1.2, 1.2],
21+
zoom_independent=True,
22+
data_format='channels_last', # relative to position of batch_size
23+
flip_axis=[0, 1, 2],
24+
fill_mode='reflect')
25+
26+
27+
def swap_axes(x):
28+
return np.swapaxes(np.swapaxes(x, 0, 1), 1, 2)
29+
30+
31+
def numpy_reader(path):
32+
return np.load(path), np.eye(4)
33+
34+
35+
def load_data_splitting(mode="train", pkl_file=None, datadir=None, fold=0, duplicate_list=0, quick_test=False):
36+
37+
# search for pkl/json file in data folder if none passed
38+
if pkl_file:
39+
pkl_file = datadir.joinpath(pkl_file)
40+
else:
41+
pkl_file = glob.glob(f"{datadir}/*_train_test_split_*_fold.json")[0]
42+
if not pkl_file.exists():
43+
raise FileNotFoundError(
44+
"Data train/test split info file not found. Add file to data folder or declare in config file.")
45+
46+
# use json or pickle loader depending on file extension
47+
load_module = json if pkl_file.name.endswith(".json") else pickle
48+
with open(pkl_file, 'r') as f:
49+
split_data_info = load_module.load(f)
50+
51+
patients = split_data_info[f"{mode}_{fold}"]
52+
if duplicate_list:
53+
patients = np.repeat(patients, duplicate_list)
54+
55+
# trim the test/train set if quick test
56+
if quick_test:
57+
keep = 2 if mode == 'test' else 10
58+
patients = patients[:keep]
59+
60+
return patients
61+
62+
## TORCHIO way of setting up data module
63+
64+
class TIODataModule(pl.LightningDataModule):
65+
def __init__(self, config, quick_test=False):
66+
super().__init__()
67+
self.config = config
68+
self.k = config['k_fold']
69+
self.quick_test = quick_test
70+
self.batch_size = self.config['batch_size']
71+
self.datadir = Path(self.config['data_folder'])
72+
self.augment = self.config['augment']
73+
self.num_workers = 4
74+
# normalization factor for PET data
75+
self.pet_norm = self.config['pet_normalization_constant']
76+
77+
# for the queue
78+
self.patch_size = self.config['patch_size'] # [16, 128, 128]
79+
patches_per_volume = int(
80+
np.max(self.patch_size) / np.min(self.patch_size))
81+
self.queue_length = patches_per_volume
82+
self.samples_per_volume = int(
83+
patches_per_volume / 2) if patches_per_volume > 1 else 1
84+
self.sampler = UniformSampler(self.patch_size)
85+
86+
# variables to be filled later
87+
self.subjects = None
88+
self.test_subjects = None
89+
self.train_set = None
90+
self.val_set = None
91+
self.test_set = None
92+
self.train_queue = None
93+
self.val_queue = None
94+
self.test_queue = None
95+
self.transform = None
96+
self.preprocess = None
97+
98+
# Normalization functions
99+
def get_normalization_transform(self, tr):
100+
if tr == 'pet_hard_normalization':
101+
return Lambda(lambda x: x / self.pet_norm)
102+
elif tr == 'ct_normalization':
103+
return Lambda(lambda x: (x + 1024.0) / 2000.0)
104+
else:
105+
return None
106+
107+
def prepare_patient_info(self, filename, preprocess_step=None):
108+
109+
# NUMPY files
110+
if filename.name.endswith('.npy'):
111+
rawdata = ScalarImage(filename, reader=numpy_reader)
112+
# NIFTY, MINC, NRRD, MHA files, or DICOM folder
113+
else:
114+
rawdata = ScalarImage(filename)
115+
116+
if preprocess_step:
117+
pp = self.get_normalization_transform(preprocess_step)
118+
return pp(rawdata)
119+
else:
120+
return rawdata
121+
122+
def prepare_patient_data(self, mode='train'):
123+
""" data is organized as
124+
data_dir
125+
├── patient1
126+
| ├── pet_highdose.nii.gz
127+
| ├── pet_lowdose.nii.gz
128+
| └── ct.nii.gz
129+
├── patient2
130+
| ├── pet_highdose.nii.gz
131+
| ├── pet_lowdose.nii.gz
132+
| └── ct.nii.gz
133+
├── ... etc
134+
"""
135+
136+
# load train/valid patient list in the json file
137+
patients = load_data_splitting(mode,
138+
self.config['data_split_pkl'],
139+
self.datadir,
140+
fold=self.k,
141+
duplicate_list=self.config['repeat_patient_list'],
142+
quick_test=self.quick_test)
143+
144+
# create Subject object for each patient
145+
subjects = []
146+
for p in patients:
147+
p_folder = self.datadir.joinpath(p)
148+
patient_dict = {'id': p}
149+
150+
for file_type in ['input', 'target']:
151+
file_info = self.config[file_type + '_files']
152+
for i in range(len(file_info['name'])):
153+
input_path = p_folder.joinpath(file_info['name'][i])
154+
transf = file_info['preprocess_step'][i]
155+
patient_dict[f"{file_type}{i}"] = self.prepare_patient_info(
156+
input_path, transf)
157+
158+
# Subject instantiation
159+
s = Subject(patient_dict)
160+
subjects.append(s)
161+
162+
return subjects
163+
164+
def prepare_data(self):
165+
166+
self.subjects = self.prepare_patient_data('train')
167+
self.test_subjects = self.prepare_patient_data('test')
168+
169+
def get_augmentation_transform(self):
170+
augment = Compose([
171+
RandomAffine(scales=(0.9, 1.2), # zoom
172+
degrees=5, # rotation
173+
translation=5, # shift
174+
isotropic=False, # wrt zoom
175+
center='image',
176+
image_interpolation='linear'),
177+
RandomFlip(axes=(0, 1, 2))
178+
])
179+
return augment
180+
181+
def setup(self, stage=None):
182+
# train/test split subjects
183+
train_subjects, val_subjects = train_test_split(
184+
self.subjects, test_size=.2, random_state=42)
185+
186+
# setup for trainer.fit()
187+
if stage in (None, 'fit'):
188+
self.transform = self.get_augmentation_transform() if self.augment else None
189+
190+
# datasets
191+
self.train_set = SubjectsDataset(train_subjects,
192+
transform=self.transform)
193+
self.val_set = SubjectsDataset(val_subjects)
194+
195+
# queues
196+
self.train_queue = Queue(self.train_set,
197+
self.queue_length,
198+
self.samples_per_volume,
199+
self.sampler,
200+
num_workers=self.num_workers)
201+
202+
self.val_queue = Queue(self.val_set,
203+
self.queue_length,
204+
self.samples_per_volume,
205+
self.sampler,
206+
num_workers=self.num_workers)
207+
208+
# setup for trainer.test()
209+
if stage in (None, 'test'):
210+
self.test_set = SubjectsDataset(self.test_subjects)
211+
self.test_queue = Queue(self.test_set,
212+
self.queue_length,
213+
self.samples_per_volume,
214+
self.sampler,
215+
num_workers=self.num_workers)
216+
217+
def train_dataloader(self):
218+
return DataLoader(self.train_queue, self.batch_size)
219+
220+
def val_dataloader(self):
221+
return DataLoader(self.val_queue, self.batch_size)
222+
223+
def test_dataloader(self):
224+
return DataLoader(self.test_queue, self.batch_size)
225+
226+
227+
## Generic DataModule without TORCHIO
228+
229+
class GenericDataModule(pl.LightningDataModule):
230+
def __init__(self, config, quick_test=False):
231+
super().__init__()
232+
self.config = config
233+
self.k = config['k_fold']
234+
self.quick_test = quick_test
235+
self.batch_size = self.config['batch_size']
236+
self.datadir = Path(self.config['data_folder'])
237+
self.augment = self.config['augment']
238+
self.num_workers = 4
239+
# normalization factor for PET data
240+
self.pet_norm = self.config['pet_normalization_constant']
241+
242+
# variables to be filled later
243+
self.subjects = None
244+
self.test_subjects = None
245+
self.train_set = None
246+
self.val_set = None
247+
self.test_set = None
248+
self.transform = None
249+
self.preprocess = None
250+
251+
def prepare_data(self):
252+
253+
""" data is organized as
254+
data_dir
255+
├── patient1
256+
| ├── pet_highdose.npy
257+
| ├── pet_lowdose.npy
258+
| └── ct.npy
259+
├── patient2
260+
| ├── pet_highdose.npy
261+
| ├── pet_lowdose.npy
262+
| └── ct.npy
263+
├── ... etc
264+
"""
265+
266+
# load train/valid patient list in the json file
267+
self.subjects = load_data_splitting('train',
268+
self.config['data_split_pkl'],
269+
self.datadir,
270+
fold=self.k,
271+
duplicate_list=self.config['repeat_patient_list'],
272+
quick_test=self.quick_test)
273+
274+
# load train/valid patient list in the json file
275+
self.test_subjects = load_data_splitting('test',
276+
self.config['data_split_pkl'],
277+
self.datadir,
278+
fold=self.k,
279+
duplicate_list=self.config['repeat_patient_list'],
280+
quick_test=self.quick_test)
281+
282+
def setup(self, stage=None):
283+
# train/test split subjects
284+
train_subjects, val_subjects = train_test_split(
285+
self.subjects, test_size=.2, random_state=42)
286+
287+
# setup for trainer.fit()
288+
if stage in (None, 'fit'):
289+
# datasets
290+
self.train_set = DatasetFullVolume(train_subjects, self.config, self.config['augment'])
291+
self.val_set = DatasetFullVolume(val_subjects, self.config)
292+
293+
# setup for trainer.test()
294+
if stage in (None, 'test'):
295+
self.test_set = DatasetFullVolume(self.test_subjects, self.config)
296+
297+
298+
def train_dataloader(self):
299+
return DataLoader(self.train_set, self.batch_size, num_workers=self.num_workers)
300+
301+
def val_dataloader(self):
302+
return DataLoader(self.val_set, self.batch_size, num_workers=self.num_workers)
303+
304+
def test_dataloader(self):
305+
return DataLoader(self.test_set, self.batch_size, num_workers=self.num_workers)
306+
307+
308+
class DatasetFullVolume(Dataset):
309+
""" Generates data for Keras """
310+
311+
def __init__(self, patients, conf=None, augment=False):
312+
""" Initialization """
313+
314+
self.patients = patients
315+
self.config = conf
316+
self.datadir = Path(self.config['data_folder'])
317+
self.augment = augment
318+
319+
# normalization factor for PET data
320+
self.pet_norm = conf['pet_normalization_constant']
321+
322+
# data shape
323+
self.full_data_shape = conf['data_shape']
324+
self.color_channels = conf['color_channels_in']
325+
326+
327+
def __len__(self):
328+
'Denotes the total number of samples in the dataset'
329+
return len(self.patients)
330+
331+
def __getitem__(self, index):
332+
'Generates one sample of data'
333+
# Select sample
334+
ID = self.patients[index]
335+
336+
# Load data and get label
337+
X, y = self.data_generation(ID)
338+
339+
return torch.from_numpy(X).float(), torch.from_numpy(y).float()
340+
341+
def pet_hard_normalization(self, data):
342+
return data / self.pet_norm
343+
344+
@staticmethod
345+
def ct_normalization(data):
346+
return (data + 1024.0) / 2000.0
347+
348+
def data_generation(self, patient_id):
349+
""" Generates data of batch_size samples """
350+
# X : (batch_size, n_channels, v_size, v_size, v_size)
351+
352+
# this loads and returns the full image size
353+
X, y = self.load_volume(patient_id)
354+
355+
return X.astype(np.float64), y.astype(np.float64)
356+
357+
def load_volume(self, patient_id):
358+
359+
# initialize input data with correct number of channels
360+
dat = np.zeros((self.color_channels, *self.full_data_shape))
361+
362+
# --- Load data and labels
363+
for i in range(self.color_channels):
364+
fname = self.datadir.joinpath(patient_id).joinpath(
365+
self.config['input_files']['name'][i])
366+
pet = np.memmap(fname, dtype='double', mode='r')
367+
dat[i, ...] = pet.reshape(self.full_data_shape)
368+
normalization_func = getattr(self, self.config['input_files']['preprocess_step'][i])
369+
dat[i, ...] = normalization_func(dat[i, ...])
370+
371+
# fulldose contains High Dose PET image
372+
fname2 = self.datadir.joinpath(patient_id).joinpath(
373+
self.config['target_files']['name'][0])
374+
target = np.memmap(fname2, dtype='double', mode='r')
375+
target = target.reshape(1, *self.full_data_shape)
376+
target = self.pet_hard_normalization(target)
377+
378+
# manual augmentation of the data half of the time
379+
if self.augment and np.random.random() < 0.5:
380+
dat, target = augmenter.random_transform_sample(dat, target)
381+
382+
return dat, target
383+

0 commit comments

Comments
 (0)