diff --git a/__pycache__/yaml.cpython-38.pyc b/__pycache__/yaml.cpython-38.pyc new file mode 100644 index 0000000..b3bc89b Binary files /dev/null and b/__pycache__/yaml.cpython-38.pyc differ diff --git a/get_data/.DS_Store b/get_data/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/get_data/.DS_Store differ diff --git a/get_data/__ini__.py b/get_data/__ini__.py new file mode 100644 index 0000000..e69de29 diff --git a/get_data/__pycache__/data_gen_flow.cpython-38.pyc b/get_data/__pycache__/data_gen_flow.cpython-38.pyc new file mode 100644 index 0000000..7cb6af2 Binary files /dev/null and b/get_data/__pycache__/data_gen_flow.cpython-38.pyc differ diff --git a/get_data/__pycache__/data_respacing_reg.cpython-38.pyc b/get_data/__pycache__/data_respacing_reg.cpython-38.pyc new file mode 100644 index 0000000..4c92a5f Binary files /dev/null and b/get_data/__pycache__/data_respacing_reg.cpython-38.pyc differ diff --git a/get_data/__pycache__/exval_dataset.cpython-38.pyc b/get_data/__pycache__/exval_dataset.cpython-38.pyc new file mode 100644 index 0000000..d0f7cc1 Binary files /dev/null and b/get_data/__pycache__/exval_dataset.cpython-38.pyc differ diff --git a/get_data/__pycache__/get_data_arr.cpython-38.pyc b/get_data/__pycache__/get_data_arr.cpython-38.pyc new file mode 100644 index 0000000..2b31ab9 Binary files /dev/null and b/get_data/__pycache__/get_data_arr.cpython-38.pyc differ diff --git a/get_data/__pycache__/get_data_arr_df.cpython-38.pyc b/get_data/__pycache__/get_data_arr_df.cpython-38.pyc new file mode 100644 index 0000000..7006db7 Binary files /dev/null and b/get_data/__pycache__/get_data_arr_df.cpython-38.pyc differ diff --git a/get_data/__pycache__/get_img_dataset.cpython-38.pyc b/get_data/__pycache__/get_img_dataset.cpython-38.pyc new file mode 100644 index 0000000..e34370e Binary files /dev/null and b/get_data/__pycache__/get_img_dataset.cpython-38.pyc differ diff --git a/get_data/__pycache__/get_pat_dataset.cpython-38.pyc b/get_data/__pycache__/get_pat_dataset.cpython-38.pyc new file mode 100644 index 0000000..abf210c Binary files /dev/null and b/get_data/__pycache__/get_pat_dataset.cpython-38.pyc differ diff --git a/get_data/__pycache__/pred_dataset.cpython-38.pyc b/get_data/__pycache__/pred_dataset.cpython-38.pyc new file mode 100644 index 0000000..edb8dee Binary files /dev/null and b/get_data/__pycache__/pred_dataset.cpython-38.pyc differ diff --git a/get_data/__pycache__/preprocess_data.cpython-38.pyc b/get_data/__pycache__/preprocess_data.cpython-38.pyc new file mode 100644 index 0000000..af0dc6a Binary files /dev/null and b/get_data/__pycache__/preprocess_data.cpython-38.pyc differ diff --git a/get_data/__pycache__/respacing_reg_crop.cpython-38.pyc b/get_data/__pycache__/respacing_reg_crop.cpython-38.pyc new file mode 100644 index 0000000..3bd723a Binary files /dev/null and b/get_data/__pycache__/respacing_reg_crop.cpython-38.pyc differ diff --git a/get_data/__pycache__/test_dataset.cpython-38.pyc b/get_data/__pycache__/test_dataset.cpython-38.pyc new file mode 100644 index 0000000..891068f Binary files /dev/null and b/get_data/__pycache__/test_dataset.cpython-38.pyc differ diff --git a/get_data/__pycache__/train_dataset.cpython-38.pyc b/get_data/__pycache__/train_dataset.cpython-38.pyc new file mode 100644 index 0000000..02085b7 Binary files /dev/null and b/get_data/__pycache__/train_dataset.cpython-38.pyc differ diff --git a/get_data/__pycache__/train_val_split.cpython-38.pyc b/get_data/__pycache__/train_val_split.cpython-38.pyc new file mode 100644 index 0000000..f4aeb8d Binary files /dev/null and b/get_data/__pycache__/train_val_split.cpython-38.pyc differ diff --git a/get_data/__pycache__/val_dataset.cpython-38.pyc b/get_data/__pycache__/val_dataset.cpython-38.pyc new file mode 100644 index 0000000..5de74e2 Binary files /dev/null and b/get_data/__pycache__/val_dataset.cpython-38.pyc differ diff --git a/get_data/data_gen_flow.py b/get_data/data_gen_flow.py new file mode 100644 index 0000000..23d3fd2 --- /dev/null +++ b/get_data/data_gen_flow.py @@ -0,0 +1,157 @@ + +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib.pyplot as plt +from PIL import Image +import glob +import tensorflow as tf +from tensorflow.keras.models import Model +from tensorflow.keras.preprocessing.image import ImageDataGenerator + + + +def train_generator(proj_dir, batch_size, input_channel=3): + + """ + create data generator for training dataset; + + Arguments: + out_dir {path} -- path to output results; + batch_size {int} -- batch size for data generator; + input_channel {int} -- input channel for image; + + Return: + Keras data generator; + + """ + + pro_data_dir = os.path.join(proj_dir, 'pro_data') + if not os.path.exists(pro_data_dir): + os.mkdir(pro_data_dir) + + ### load train data based on input channels + if input_channel == 1: + fn = 'train_arr_1ch.npy' + elif input_channel == 3: + #fn = 'train_arr_3ch_crop.npy' + fn = 'train_arr_3ch.npy' + x_train = np.load(os.path.join(pro_data_dir, fn)) + + ### load val labels + train_df = pd.read_csv(os.path.join(pro_data_dir, 'train_img_df.csv')) + y_train = np.asarray(train_df['label']).astype('int').reshape((-1, 1)) + + ## data generator + datagen = ImageDataGenerator( + featurewise_center=False, + samplewise_center=False, + featurewise_std_normalization=False, + samplewise_std_normalization=False, + zca_whitening=False, + zca_epsilon=1e-06, + rotation_range=5, + width_shift_range=0.1, + height_shift_range=0.1, + brightness_range=None, + shear_range=0.0, + zoom_range=0, + channel_shift_range=0.0, + fill_mode="nearest", + cval=0.0, + horizontal_flip=False, + vertical_flip=False, + rescale=None, + preprocessing_function=None, + data_format=None, + validation_split=0.0, + dtype=None, + ) + + ### Train generator + train_gen = datagen.flow( + x=x_train, + y=y_train, + subset=None, + batch_size=batch_size, + seed=42, + shuffle=True, + ) + print('Train generator created') + + return train_gen + + + + +def val_generator(proj_dir, batch_size, input_channel=3): + + """ + create data generator for validation dataset; + + Arguments: + out_dir {path} -- path to output results; + batch_size {int} -- batch size for data generator; + input_channel {int} -- input channel for image; + + Return: + Keras data generator; + + """ + + pro_data_dir = os.path.join(proj_dir, 'pro_data') + if not os.path.exists(pro_data_dir): + os.mkdir(pro_data_dir) + + ### load val data based on input channels + if input_channel == 1: + fn = 'val_arr_1ch.npy' + elif input_channel == 3: + fn = 'val_arr_3ch.npy' + x_val = np.load(os.path.join(pro_data_dir, fn)) + + ### load val labels + val_df = pd.read_csv(os.path.join(pro_data_dir, 'val_img_df.csv')) + y_val = np.asarray(val_df['label']).astype('int').reshape((-1, 1)) + + datagen = ImageDataGenerator( + featurewise_center=False, + samplewise_center=False, + featurewise_std_normalization=False, + samplewise_std_normalization=False, + zca_whitening=False, + zca_epsilon=1e-06, + rotation_range=0, + width_shift_range=0.0, + height_shift_range=0.0, + brightness_range=None, + shear_range=0.0, + zoom_range=0, + channel_shift_range=0.0, + fill_mode="nearest", + cval=0.0, + horizontal_flip=False, + vertical_flip=False, + rescale=None, + preprocessing_function=None, + data_format=None, + validation_split=0.0, + dtype=None, + ) + + datagen = ImageDataGenerator() + val_gen = datagen.flow( + x=x_val, + y=y_val, + subset=None, + batch_size=batch_size, + seed=42, + shuffle=True, + ) + print('val generator created') + + return x_val, y_val, val_gen + + + diff --git a/get_data/exval_dataset.py b/get_data/exval_dataset.py new file mode 100644 index 0000000..d1fbdfd --- /dev/null +++ b/get_data/exval_dataset.py @@ -0,0 +1,214 @@ +import glob +import shutil +import os +import pandas as pd +import numpy as np +import nrrd +import re +import matplotlib +import matplotlib.pyplot as plt +import pickle +from time import gmtime, strftime +from datetime import datetime +import timeit +from sklearn.model_selection import train_test_split +from tensorflow.keras.utils import to_categorical +from utils.resize_3d import resize_3d +from utils.crop_image import crop_image +from utils.respacing import respacing +from utils.nrrd_reg import nrrd_reg_rigid_ref +from get_data.get_img_dataset import img_dataset + + + +def exval_pat_dataset(out_dir, proj_dir, crop_shape=[192, 192, 140], + interp_type='linear', input_channel=3, + norm_type='np_clip', data_exclude=None, new_spacing=[1, 1, 3]): + + """ + Preprocess data (respacing, registration, cropping) for chest CT dataset; + + Arguments: + proj_dir {path} -- path to main project folder; + out_dir {path} -- path to result outputs; + + Keyword arguments: + new_spacing {tuple} -- respacing size, defaul [1, 1, 3]; + return_type {str} -- image data format after preprocessing, default: 'nrrd'; + data_exclude {str} -- exclude patient data due to data issue, default: None; + crop_shape {np.array} -- numpy array size afer cropping; + interp_type {str} -- interpolation type for respacing, default: 'linear'; + + Return: + save nrrd image data; + """ + + NSCLC_data_dir = '/mnt/aertslab/DATA/Lung/TOPCODER/nrrd_data' + NSCLC_reg_dir = os.path.join(out_dir, 'data/NSCLC_data_reg') + exval1_dir = os.path.join(out_dir, 'exval1') + pro_data_dir = os.path.join(proj_dir, 'pro_data') + + if not os.path.exists(NSCLC_reg_dir): + os.mkdir(NSCLC_reg_dir) + if not os.path.exists(exval1_dir): + os.mkdir(exval1_dir) + if not os.path.exists(pro_data_dir): + os.mkdir(pro_data_dir) + + reg_temp_img = os.path.join(exval1_dir, 'NSCLC001.nrrd') + + df_label = pd.read_csv(os.path.join(pro_data_dir, 'label_NSCLC.csv')) + df_label.dropna(subset=['ctdose_contrast', 'top_coder_id'], how='any', inplace=True) + df_id = pd.read_csv(os.path.join(pro_data_dir, 'harvard_rt.csv')) + + ## create df for dir, ID and labels on patient level + fns = [] + IDs = [] + labels = [] + list_fn = [fn for fn in sorted(glob.glob(NSCLC_data_dir + '/*nrrd'))] + for fn in list_fn: + ID = fn.split('/')[-1].split('_')[2][0:5].strip() + for label, top_coder_id in zip(df_label['ctdose_contrast'], df_label['top_coder_id']): + tc_id = top_coder_id.split('_')[2].strip() + if tc_id == ID: + IDs.append(ID) + labels.append(label) + fns.append(fn) + ## exclude scans with certain conditions + print('ID:', len(IDs)) + print('file:', len(fns)) + print('label:', len(labels)) + print('contrast scan in ex val:', labels.count(1)) + print('non-contrast scan in ex val:', labels.count(0)) + df = pd.DataFrame({'ID': IDs, 'file': fns, 'label': labels}) + df.to_csv(os.path.join(pro_data_dir, 'exval_pat_df.csv')) + print('total test scan:', df.shape[0]) + + ## delete excluded scans and repeated scans + if data_exclude != None: + df_exclude = df[df['ID'].isin(data_exclude)] + print('exclude scans:', df_exclude) + df.drop(df[df['ID'].isin(test_exclude)].index, inplace=True) + print('total test scans:', df.shape[0]) + pd.options.display.max_columns = 100 + pd.set_option('display.max_rows', 500) + #print(df[0:50]) + + ### registration, respacing, cropping + for fn, ID in zip(df['file'], df['ID']): + print(ID) + ## respacing + img_nrrd = respacing( + nrrd_dir=fn, + interp_type=interp_type, + new_spacing=new_spacing, + patient_id=ID, + return_type='nrrd', + save_dir=None + ) + ## registration + img_reg = nrrd_reg_rigid_ref( + img_nrrd=img_nrrd, + fixed_img_dir=reg_temp_img, + patient_id=ID, + save_dir=None + ) + ## crop image from (500, 500, 116) to (180, 180, 60) + img_crop = crop_image( + nrrd_file=img_reg, + patient_id=ID, + crop_shape=crop_shape, + return_type='nrrd', + save_dir=NSCLC_reg_dir + ) + + +def exval_img_dataset(proj_dir, slice_range=range(50, 120), input_channel=3, + norm_type='np_clip', split=True, fn_arr_1ch=None): + + """ + get stacked image slices from scan level CT and corresponding labels and IDs; + + Args: + run_type {str} -- train, val, test, external val, pred; + pro_data_dir {path} -- path to processed data; + nrrds {list} -- list of paths for CT scan files in nrrd format; + IDs {list} -- list of patient ID; + labels {list} -- list of patient labels; + slice_range {np.array} -- image slice range in z direction for cropping; + run_type {str} -- train, val, test, or external val; + pro_data_dir {path} -- path to processed data; + fn_arr_1ch {str} -- filename for 1 d numpy array for stacked image slices; + fn_arr_3ch {str} -- filename for 3 d numpy array for stacked image slices; + fn_df {str} -- filename for dataframe contains image path, image labels and image ID; + + Keyword args: + input_channel {str} -- image channel, default: 3; + norm_type {str} -- image normalization type: 'np_clip' or 'np_linear'; + + Returns: + img_df {pd.df} -- dataframe contains preprocessed image paths, label, ID (image level); + + """ + + pro_data_dir = os.path.join(proj_dir, 'pro_data') + df = pd.read_csv(os.path.join(pro_data_dir, 'exval_pat_df.csv')) + fns = df['file'] + labels = df['label'] + IDs = df['ID'] + + ## split dataset for fine-tuning model and test model + if split == True: + data_exval1, data_exval2, label_exval1, label_exval2, ID_exval1, ID_exval2 = train_test_split( + fns, + labels, + IDs, + stratify=labels, + shuffle=True, + test_size=0.2, + random_state=42 + ) + nrrds = [data_exval1, data_exval2] + labels = [label_exval1, label_exval2] + IDs = [ID_exval1, ID_exval2] + fn_arrs = ['exval1_arr1.npy', 'exval1_arr2.npy'] + fn_dfs = ['exval1_img_df1.csv', 'exval1_img_df2.csv'] + + ## creat numpy array for image slices + for nrrd, label, ID, fn_arr, fn_df in zip(nrrds, labels, IDs, fn_arrs, fn_dfs): + img_dataset( + pro_data_dir=pro_data_dir, + run_type='exval', + nrrds=nrrds, + IDs=IDs, + labels=labels, + fn_arr_1ch=None, + fn_arr_3ch=fn_arr_3ch, + fn_df=fn_df, + slice_range=slice_range, + input_channel=3, + norm_type=norm_type, + ) + print('train and test datasets created!') + + ## use entire exval data to test model + elif split == False: + nrrds = fns + labels = labels + IDs = IDs + img_dataset( + pro_data_dir=pro_data_dir, + run_type='exval', + nrrds=nrrds, + IDs=IDs, + labels=labels, + fn_arr_1ch=None, + fn_arr_3ch='exval1_arr.npy', + fn_df='exval1_img_df.csv', + slice_range=slice_range, + input_channel=3, + norm_type=norm_type, + ) + print('total patient:', len(IDs)) + print('exval datasets created!') + diff --git a/get_data/get_img_dataset.py b/get_data/get_img_dataset.py new file mode 100644 index 0000000..88aec52 --- /dev/null +++ b/get_data/get_img_dataset.py @@ -0,0 +1,167 @@ + +import glob +import shutil +import os +import pandas as pd +import nrrd +import re +import matplotlib +import matplotlib.pyplot as plt +import pickle +import numpy as np +from tensorflow.keras.utils import to_categorical +from utils.resize_3d import resize_3d +from utils.crop_image import crop_image +import SimpleITK as sitk +import h5py + + + + +def img_dataset(pro_data_dir, run_type, nrrds, IDs, labels, fn_arr_1ch, fn_arr_3ch, fn_df, + slice_range, input_channel=3, norm_type='np_clip'): + + """ + get stacked image slices from scan level CT and corresponding labels and IDs; + + Args: + run_type {str} -- train, val, test, external val, pred; + pro_data_dir {path} -- path to processed data; + nrrds {list} -- list of paths for CT scan files in nrrd format; + IDs {list} -- list of patient ID; + labels {list} -- list of patient labels; + slice_range {np.array} -- image slice range in z direction for cropping; + run_type {str} -- train, val, test, or external val; + pro_data_dir {path} -- path to processed data; + fn_arr_1ch {str} -- filename for 1 d numpy array for stacked image slices; + fn_arr_3ch {str} -- filename for 3 d numpy array for stacked image slices; + fn_df {str} -- filename for dataframe contains image path, image labels and image ID; + + Keyword args: + input_channel {str} -- image channel, default: 3; + norm_type {str} -- image normalization type: 'np_clip' or 'np_linear'; + + Returns: + img_df {pd.df} -- dataframe contains preprocessed image paths, label, ID (image level); + + """ + + # get image slice and save them as numpy array + count = 0 + slice_numbers = [] + list_fn = [] + arr = np.empty([0, 192, 192]) + + for nrrd, patient_id in zip(nrrds, IDs): + count += 1 + print(count) + nrrd = sitk.ReadImage(nrrd, sitk.sitkFloat32) + img_arr = sitk.GetArrayFromImage(nrrd) + #data = img_arr[30:78, :, :] + #data = img_arr[17:83, :, :] + data = img_arr[slice_range, :, :] + ### clear signals lower than -1024 + data[data <= -1024] = -1024 + ### strip skull, skull UHI = ~700 + data[data > 700] = 0 + ### normalize UHI to 0 - 1, all signlas outside of [0, 1] will be 0; + if norm_type == 'np_interp': + data = np.interp(data, [-200, 200], [0, 1]) + elif norm_type == 'np_clip': + data = np.clip(data, a_min=-200, a_max=200) + MAX, MIN = data.max(), data.min() + data = (data - MIN) / (MAX - MIN) + ## stack all image arrays to one array for CNN input + arr = np.concatenate([arr, data], 0) + + ### create patient ID and slice index for img + slice_numbers.append(data.shape[0]) + for i in range(data.shape[0]): + img = data[i, :, :] + fn = patient_id + '_' + 'slice%s'%(f'{i:03d}') + list_fn.append(fn) + + ### covert 1 channel input to 3 channel inputs for CNN + if input_channel == 1: + img_arr = arr.reshape(arr.shape[0], arr.shape[1], arr.shape[2], 1) + print('img_arr shape:', img_arr.shape) + np.save(os.path.join(pro_data_dir, fn_arr_1ch), img_arr) + elif input_channel == 3: + img_arr = np.broadcast_to(arr, (3, arr.shape[0], arr.shape[1], arr.shape[2])) + img_arr = np.transpose(img_arr, (1, 2, 3, 0)) + print('img_arr shape:', img_arr.shape) + np.save(os.path.join(pro_data_dir, fn_arr_3ch), img_arr) + #fn = os.path.join(pro_data_dir, 'exval_arr_3ch.h5') + #h5f = h5py.File(fn, 'w') + #h5f.create_dataset('dataset_exval_arr_3ch', data=img_arr) + + ### generate labels for CT slices + if run_type == 'pred': + ### makeing dataframe containing img dir and labels + img_df = pd.DataFrame({'fn': list_fn}) + img_df.to_csv(os.path.join(pro_data_dir, fn_df)) + print('data size:', img_df.shape[0]) + else: + list_label = [] + list_img = [] + for label, slice_number in zip(labels, slice_numbers): + list_1 = [label] * slice_number + list_label.extend(list_1) + ### makeing dataframe containing img dir and labels + img_df = pd.DataFrame({'fn': list_fn, 'label': list_label}) + pd.options.display.max_columns = 100 + pd.set_option('display.max_rows', 500) + print(img_df[0:100]) + img_df.to_csv(os.path.join(pro_data_dir, fn_df)) + print('data size:', img_df.shape[0]) + + + +def get_img_dataset(proj_dir, run_type, data_tot, ID_tot, label_tot, slice_range): + + """ + Get np arrays for stacked images slices, labels and IDs for train, val, test dataset; + + Args: + run_type {str} -- train, val, test, external val, pred; + pro_data_dir {path} -- path to processed data; + data_tot {list} -- list of data paths: ['data_train', 'data_val', 'data_test']; + ID_tot {list} -- list of image IDs: ['ID_train', 'ID_val', 'ID_test']; + label_tot {list} -- list of image labels: ['label_train', 'label_val', 'label_test']; + slice_range {np.array} -- image slice range in z direction for cropping; + + Keyword args: + input_channel {str} -- image channel, default: 3; + norm_type {str} -- image normalization type: 'np_clip' or 'np_linear'; + + """ + + pro_data_dir = ps.path.join(proj_dir, 'pro_data') + if not os.path.exists(pro_data_dir): + os.mkdir(pro_data_dir) + + fns_arr_1ch = ['train_arr_1ch.npy', 'val_arr_1ch.npy', 'test_arr_1ch.npy'] + fns_arr_3ch = ['train_arr_3ch.npy', 'val_arr_3ch.npy', 'test_arr_3ch.npy'] + fns_df = ['train_img_df.csv', 'val_img_df.csv', 'test_img_df.csv'] + + for nrrds, IDs, labels, fn_arr_1ch, fn_arr_3ch, fn_df in zip( + data_tot, ID_tot, label_tot, fns_arr_1ch, fns_arr_3ch, fns_df): + + img_dataset( + pro_data_dir=pro_data_dir, + run_type=run_type, + nrrds=nrrds, + IDs=IDs, + labels=labels, + fn_arr_1ch=fn_arr_1ch, + fn_arr_3ch=fn_arr_3ch, + fn_df=fn_df, + slice_range=slice_range, + input_channel=input_channel, + norm_type=norm_type, + ) + + + + + diff --git a/get_data/get_pat_dataset.py b/get_data/get_pat_dataset.py new file mode 100644 index 0000000..5072dbf --- /dev/null +++ b/get_data/get_pat_dataset.py @@ -0,0 +1,174 @@ + + +import glob +import shutil +import os +import pandas as pd +import nrrd +import re +from sklearn.model_selection import train_test_split +import pickle +import numpy as np +from time import gmtime, strftime +from datetime import datetime +import timeit + + + + +def pat_df(label_dir, label_file, cohort, data_reg_dir, MDACC_data_dir): + + """ + create dataframe to contain data path, patient ID and label on the + patient level; + + Arguments: + label_dir {path} -- path for label csv file; + label_file {csv} -- csv file contain lable info; + cohort {str} -- patient cohort name (PMH, CHUM, MDACC, CHUS); + MDACC_data_dir {patyh} -- path to MDACC patient data; + + Return: + panda dataframe for patient data; + + """ + + ## labels + labels = [] + df_label = pd.read_csv(os.path.join(label_dir, label_file)) + df_label['Contrast'] = df_label['Contrast'].map({'Yes': 1, 'No': 0}) + if cohort == 'CHUM': + for file_ID, label in zip(df_label['File ID'], df_label['Contrast']): + scan = file_ID.split('_')[2].strip() + if scan == 'CT-SIM': + labels.append(label) + elif scan == 'CT-PET': + continue + elif cohort == 'CHUS': + labels = df_label['Contrast'].to_list() + elif cohort == 'PMH': + labels = df_label['Contrast'].to_list() + elif cohort == 'MDACC': + fns = [fn for fn in sorted(glob.glob(MDACC_data_dir + '/*nrrd'))] + IDs = [] + for fn in fns: + ID = 'MDACC' + fn.split('/')[-1].split('-')[2][1:4].strip() + IDs.append(ID) + labels = df_label['Contrast'].to_list() + print('MDACC label:', len(labels)) + print('MDACC ID:', len(IDs)) + ## check if labels and data are matched + for fn in fns: + fn = fn.split('/')[-1] + if fn not in df_label['File ID'].to_list(): + print(fn) + ## make df and delete duplicate patient scans + df = pd.DataFrame({'ID': IDs, 'labels': labels}) + df.drop_duplicates(subset=['ID'], keep='last', inplace=True) + labels = df['labels'].to_list() + #print('MDACC label:', len(labels)) + + ## data + fns = [fn for fn in sorted(glob.glob(data_reg_dir + '/*nrrd'))] + + ## patient ID + IDs = [] + for fn in fns: + ID = fn.split('/')[-1].split('.')[0].strip() + IDs.append(ID) + ## check id and labels + if cohort == 'MDACC': + list1 = list(set(IDs) - set(df['ID'].to_list())) + print(list1) + ## create dataframe + print('cohort:', cohort) + print('ID:', len(IDs)) + print('file:', len(fns)) + print('label:', len(labels)) + df = pd.DataFrame({'ID': IDs, 'file': fns, 'label': labels}) + + return df + + + +def get_pat_dataset(data_dir, out_dir, proj_dir): + + """ + get data path, patient ID and label for all the cohorts; + + Arguments: + data_dir {path} -- path to the CT data; + lab_drive_dir {path} -- path to outputs; + proj_dir {path} -- path to processed data; + CHUM_label_csv {csv} -- label file for CHUM cohort; + CHUS_label_csv {csv} -- label file for CHUS cohort; + PMH_label_csv {csv} -- label file for PMH cohort; + MDACC_label_csv {csv} -- label file for MDACC cohort; + + Return: + lists for patient data, labels and IDs; + + """ + + MDACC_data_dir = os.path.join(data_dir, '0_image_raw_MDACC') + CHUM_reg_dir = os.path.join(lab_drive_dir, 'data/CHUM_data_reg') + CHUS_reg_dir = os.path.join(lab_drive_dir, 'data/CHUS_data_reg') + PMH_reg_dir = os.path.join(lab_drive_dir, 'data/PMH_data_reg') + MDACC_reg_dir = os.path.join(lab_drive_dir, 'data/MDACC_data_reg') + label_dir = os.path.join(lab_drive_dir, 'data_pro') + pro_data_dir = os.path.join(proj_dir, 'pro_data') + + cohorts = ['CHUM', 'CHUS', 'PMH', 'MDACC'] + label_files = ['label_CHUM.csv', 'label_CHUS.csv', 'label_PMH.csv', 'label_MDACC.csv'] + data_reg_dirs = [CHUM_reg_dir, CHUS_reg_dir, PMH_reg_dir, MDACC_reg_dir] + df_tot = [] + for cohort, label_file, data_reg_dir in zip(cohorts, label_files, data_reg_dirs): + df = pat_df( + label_dir=label_dir, + label_file=label_file, + cohort=cohort, + data_reg_dir=data_reg_dir, + MDACC_data_dir=MDACC_data_dir + ) + df_tot.append(df) + + ## get df for different cohorts + df_CHUM = df_tot[0] + df_CHUS = df_tot[1] + df_PMH = df_tot[2] + df_MDACC = df_tot[3] + + ## train-val split + df = pd.concat([df_PMH, df_CHUM, df_CHUS], ignore_index=True) + data = df['file'] + label = df['label'] + ID = df['ID'] + data_train, data_val, label_train, label_val, ID_train, ID_val = train_test_split( + data, + label, + ID, + stratify=label, + test_size=0.3, + random_state=42 + ) + + ## test patient data + data_test = df_MDACC['file'] + label_test = df_MDACC['label'] + ID_test = df_MDACC['ID'] + + ## save train, val, test df on patient level + train_pat_df = pd.DataFrame({'ID': ID_train, 'file': data_train, 'label': label_train}) + val_pat_df = pd.DataFrame({'ID': ID_val, 'file': data_val, 'label': label_val}) + test_pat_df = pd.DataFrame({'ID': ID_test, 'file': data_test, 'label': label_test}) + train_pat_df.to_csv(os.path.join(pro_data_dir, 'train_pat_df.csv')) + val_pat_df.to_csv(os.path.join(pro_data_dir, 'val_pat_df.csv')) + test_pat_df.to_csv(os.path.join(pro_data_dir, 'test_pat_df.csv')) + + ## save data, label and ID as list + data_tot = [data_train, data_val, data_test] + label_tot = [label_train, label_val, label_test] + ID_tot = [ID_train, ID_val, ID_test] + + return data_tot, label_tot, ID_tot + diff --git a/get_data/preprocess_data.py b/get_data/preprocess_data.py new file mode 100644 index 0000000..2120010 --- /dev/null +++ b/get_data/preprocess_data.py @@ -0,0 +1,187 @@ +import glob +import shutil +import os +import pandas as pd +import nrrd +import re +from sklearn.model_selection import train_test_split +import pickle +import numpy as np +from time import gmtime, strftime +from datetime import datetime +import timeit +from utils.respacing import respacing +from utils.nrrd_reg import nrrd_reg_rigid_ref +from utils.crop_image import crop_image + + + +def preprocess_data(data_dir, out_dir, new_spacing=(1, 1, 3), data_exclude=None, + crop_shape=[192, 192, 10], interp_type='linear'): + + """ + Preprocess data including: respacing, registration, cropping; + + Arguments: + data_dir {path} -- path to CT data; + out_dir {path} -- path to result outputs; + + Keyword arguments: + new_spacing {tuple} -- respacing size; + return_type {str} -- image data format after preprocessing, default: 'nrrd'; + data_exclude {str} -- exclude patient data due to data issue, default: None; + crop_shape {np.array} -- numpy array size afer cropping; + interp_type {str} -- interpolation type for respacing, default: 'linear'; + + Return: + save nrrd image data; + """ + + CHUM_data_dir = os.path.join(data_dir, '0_image_raw_CHUM') + CHUS_data_dir = os.path.join(data_dir, '0_image_raw_CHUS') + PMH_data_dir = os.path.join(data_dir, '0_image_raw_PMH') + MDACC_data_dir = os.path.join(data_dir, '0_image_raw_MDACC') + CHUM_reg_dir = os.path.join(out_dir, 'data/CHUM_data_reg') + CHUS_reg_dir = os.path.join(out_dir, 'data/CHUS_data_reg') + PMH_reg_dir = os.path.join(out_dir, 'data/PMH_data_reg') + MDACC_reg_dir = os.path.join(out_dir, 'data/MDACC_data_reg') + + if not os.path.exists(CHUM_reg_dir): + os.mkdir(CHUM_reg_dir) + if not os.path.exists(CHUS_reg_dir): + os.mkdir(CHUS_reg_dir) + if not os.path.exists(PMH_reg_dir): + os.mkdir(PMH_reg_dir) + if not os.path.exists(MDACC_reg_dir): + os.mkdir(MDACC_reg_dir) + + reg_temp_img = os.path.join(PMH_reg_dir, 'PMH050.nrrd') + + # get PMH data + #------------------ + fns = [fn for fn in sorted(glob.glob(PMH_data_dir + '/*nrrd'))] + ## PMH patient ID + IDs = [] + for fn in fns: + ID = 'PMH' + fn.split('/')[-1].split('-')[1][2:5].strip() + IDs.append(ID) + ## PMH dataframe + df_PMH = pd.DataFrame({'ID': IDs, 'file': fns}) + pd.options.display.max_colwidth = 100 + #print(df_pmh) + file = df_PMH['file'][0] + data, header = nrrd.read(file) + print(data.shape) + print('PMH data:', len(IDs)) + + # get CHUM data + #----------------- + fns = [] + for fn in sorted(glob.glob(CHUM_data_dir + '/*nrrd')): + scan_ = fn.split('/')[-1].split('_')[2].strip() + if scan_ == 'CT-SIM': + fns.append(fn) + else: + continue + ## CHUM patient ID + IDs = [] + for fn in fns: + ID = 'CHUM' + fn.split('/')[-1].split('_')[1].split('-')[2].strip() + IDs.append(ID) + ## CHUM dataframe + df_CHUM = pd.DataFrame({'ID': IDs, 'file': fns}) + #print(df_chum) + pd.options.display.max_colwidth = 100 + file = df_CHUM['file'][0] + data, header = nrrd.read(file) + print(data.shape) + print('CHUM data:', len(IDs)) + + # get CHUS data + #--------------- + fns = [] + for fn in sorted(glob.glob(CHUS_data_dir + '/*nrrd')): + scan = fn.split('/')[-1].split('_')[2].strip() + if scan == 'CT-SIMPET': + fns.append(fn) + else: + continue + ## CHUS patient ID + IDs = [] + for fn in fns: + ID = 'CHUS' + fn.split('/')[-1].split('_')[1].split('-')[2].strip() + IDs.append(ID) + ## CHUS dataframe + df_CHUS = pd.DataFrame({'ID': IDs, 'file': fns}) + pd.options.display.max_colwidth = 100 + #print(df_chus) + file = df_CHUS['file'][0] + data, header = nrrd.read(file) + print(data.shape) + print('CHUS data:', len(IDs)) + + # get MDACC dataset + #------------------ + fns = [fn for fn in sorted(glob.glob(MDACC_data_dir + '/*nrrd'))] + ## MDACC patient ID + IDs = [] + for fn in fns: + ID = 'MDACC' + fn.split('/')[-1].split('-')[2][1:4].strip() + IDs.append(ID) + ## MDACC dataframe + df_MDACC = pd.DataFrame({'ID': IDs, 'file': fns}) + df_MDACC.drop_duplicates(subset=['ID'], keep='last', inplace=True) + print('MDACC data:', df_MDACC.shape[0]) + + # combine dataset for train-val split + #------------------------------------- + df = pd.concat([df_PMH, df_CHUM, df_CHUS, df_MDACC], ignore_index=True) + print('total patients:', df.shape[0]) + #print(df[700:]) + ## exclude data with certain conditions + if data_exclude != None: + df_exclude = df[df['ID'].isin(data_exclude)] + print(df_exclude) + df.drop(df[df['ID'].isin(data_exclude)].index, inplace=True) + print(df.shape[0]) + + for fn, ID in zip(df['file'], df['ID']): + + print(ID) + + ## set up save dir + if ID[:-3] == 'PMH': + save_dir = PMH_reg_dir + elif ID[:-3] == 'CHUM': + save_dir = CHUM_reg_dir + elif ID[:-3] == 'CHUS': + save_dir = CHUS_reg_dir + elif ID[:-3] == 'MDACC': + save_dir = MDACC_reg_dir + + ## respacing + img_nrrd = respacing( + nrrd_dir=fn, + interp_type=interp_type, + new_spacing=new_spacing, + patient_id=ID, + return_type='nrrd', + save_dir=None + ) + + ## registration + img_reg = nrrd_reg_rigid_ref( + img_nrrd=img_nrrd, + fixed_img_dir=reg_temp_img, + patient_id=ID, + save_dir=None + ) + + ## crop image from (500, 500, 116) to (180, 180, 60) + img_crop = crop_image( + nrrd_file=img_reg, + patient_id=ID, + crop_shape=crop_shape, + return_type='nrrd', + save_dir=save_dir + ) diff --git a/go_model/.DS_Store b/go_model/.DS_Store new file mode 100644 index 0000000..5008ddf Binary files /dev/null and b/go_model/.DS_Store differ diff --git a/go_model/__init__.py b/go_model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/go_model/__pycache__/__init__.cpython-38.pyc b/go_model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..e3e6546 Binary files /dev/null and b/go_model/__pycache__/__init__.cpython-38.pyc differ diff --git a/go_model/__pycache__/callbacks.cpython-38.pyc b/go_model/__pycache__/callbacks.cpython-38.pyc new file mode 100644 index 0000000..d4eca2e Binary files /dev/null and b/go_model/__pycache__/callbacks.cpython-38.pyc differ diff --git a/go_model/__pycache__/evaluate_model.cpython-38.pyc b/go_model/__pycache__/evaluate_model.cpython-38.pyc new file mode 100644 index 0000000..979c170 Binary files /dev/null and b/go_model/__pycache__/evaluate_model.cpython-38.pyc differ diff --git a/go_model/__pycache__/exval_model.cpython-38.pyc b/go_model/__pycache__/exval_model.cpython-38.pyc new file mode 100644 index 0000000..034c043 Binary files /dev/null and b/go_model/__pycache__/exval_model.cpython-38.pyc differ diff --git a/go_model/__pycache__/finetune_model.cpython-38.pyc b/go_model/__pycache__/finetune_model.cpython-38.pyc new file mode 100644 index 0000000..5a105a5 Binary files /dev/null and b/go_model/__pycache__/finetune_model.cpython-38.pyc differ diff --git a/go_model/__pycache__/get_model.cpython-38.pyc b/go_model/__pycache__/get_model.cpython-38.pyc new file mode 100644 index 0000000..7bb7684 Binary files /dev/null and b/go_model/__pycache__/get_model.cpython-38.pyc differ diff --git a/go_model/__pycache__/pred_model.cpython-38.pyc b/go_model/__pycache__/pred_model.cpython-38.pyc new file mode 100644 index 0000000..04ea2ed Binary files /dev/null and b/go_model/__pycache__/pred_model.cpython-38.pyc differ diff --git a/go_model/__pycache__/test_model.cpython-38.pyc b/go_model/__pycache__/test_model.cpython-38.pyc new file mode 100644 index 0000000..0794f40 Binary files /dev/null and b/go_model/__pycache__/test_model.cpython-38.pyc differ diff --git a/go_model/__pycache__/train_model.cpython-38.pyc b/go_model/__pycache__/train_model.cpython-38.pyc new file mode 100644 index 0000000..b537393 Binary files /dev/null and b/go_model/__pycache__/train_model.cpython-38.pyc differ diff --git a/go_model/__pycache__/val_model.cpython-38.pyc b/go_model/__pycache__/val_model.cpython-38.pyc new file mode 100644 index 0000000..1b8b800 Binary files /dev/null and b/go_model/__pycache__/val_model.cpython-38.pyc differ diff --git a/go_model/__pycache__/write_text.cpython-38.pyc b/go_model/__pycache__/write_text.cpython-38.pyc new file mode 100644 index 0000000..d0e36a0 Binary files /dev/null and b/go_model/__pycache__/write_text.cpython-38.pyc differ diff --git a/go_model/callbacks.py b/go_model/callbacks.py new file mode 100644 index 0000000..f1c65c9 --- /dev/null +++ b/go_model/callbacks.py @@ -0,0 +1,74 @@ +#---------------------------------------------------------------------- +# Deep learning for classification for contrast CT; +# Transfer learning using Google Inception V3; +#------------------------------------------------------------------------------------------- + +import os +import numpy as np +import pandas as pd +import seaborn as sn +import tensorflow as tf +from tensorflow.keras.callbacks import EarlyStopping +from tensorflow.keras.callbacks import LearningRateScheduler +from tensorflow.keras.callbacks import ModelCheckpoint +from tensorflow.keras.callbacks import TensorBoard + + +# ---------------------------------------------------------------------------------- +# scheduler +# ---------------------------------------------------------------------------------- +def scheduler(epoch, lr): + if epoch < 10: + return lr + else: + return lr * tf.math.exp(-0.1) + +# ---------------------------------------------------------------------------------- +# scheduler +# ---------------------------------------------------------------------------------- +def callbacks(log_dir): + + check_point = ModelCheckpoint( + filepath=os.path.join(log_dir, 'model.{epoch:02d}-{val_loss:.2f}.h5'), + monitor='val_acc', + verbose=1, + save_best_model_only=True, + save_weights_only=True, + mode='max' + ) + + tensor_board = TensorBoard( + log_dir=log_dir, + histogram_freq=0, + write_graph=True, + write_images=False, + update_freq='epoch', + profile_batch=2, + embeddings_freq=0, + embeddings_metadata=None + ) + + early_stopping = EarlyStopping( + monitor='val_loss', + min_delta=0, + patience=20, + verbose=0, + mode='auto', + baseline=None, + restore_best_weights=False + ) + + my_callbacks = [ + #ModelSave(), + early_stopping, + #LearningRateScheduler(shcheduler), + #check_point, + tensor_board + ] + + return my_callbacks + + + + + diff --git a/go_model/evaluate_model.py b/go_model/evaluate_model.py new file mode 100644 index 0000000..7be0b9f --- /dev/null +++ b/go_model/evaluate_model.py @@ -0,0 +1,122 @@ + +import os +import timeit +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib.pyplot as plt +import nrrd +import scipy.stats as ss +import SimpleITK as stik +import glob +from PIL import Image +from collections import Counter +import skimage.transform as st +from datetime import datetime +from time import gmtime, strftime +import pickle +import tensorflow +from tensorflow.keras.models import Model +from tensorflow.keras.models import load_model +from sklearn.metrics import classification_report +from sklearn.metrics import confusion_matrix + + + + +def evaluate_model(run_type, out_dir, proj_dir, saved_model, + threshold=0.5, activation='sigmoid'): + + """ + Evaluate model for validation/test/external validation data; + + Args: + out_dir {path} -- path to main output folder; + proj_dir {path} -- path to main project folder; + saved_model {str} -- saved model name; + tuned_model {Keras model} -- finetuned model for chest CT; + + Keyword args: + threshold {float} -- threshold to determine postive class; + activation {str or function} -- activation function, default: 'sigmoid'; + + Returns: + training accuracy, loss, model + + """ + + # check folder + #--------------- + model_dir = os.path.join(out_dir, 'model') + pro_data_dir = os.path.join(proj_dir, 'pro_data') + + if not os.path.exists(model_dir): + os.mkdir(model_dir) + if not os.path.exists(pro_data_dir): + os.mkdir(pro_data_dir) + + # load data and label based on run type + #--------------------------------------- + if run_type == 'val': + fn_data = 'val_arr_3ch.npy' + fn_label = 'val_img_df.csv' + fn_pred = 'val_img_pred.csv' + elif run_type == 'test': + fn_data = 'test_arr_3ch.npy' + fn_label = 'test_img_df.csv' + fn_pred = 'test_img_pred.csv' + elif run_type == 'exval1': + fn_data = 'exval1_arr2.npy' + fn_label = 'exval1_img_df2.csv' + fn_pred = 'exval1_img_pred.csv' + elif run_type == 'exval2': + fn_data = 'rtog_0617_arr.npy' + fn_label = 'rtog_img_df.csv' + fn_pred = 'exval2_img_pred.csv' + + x_data = np.load(os.path.join(pro_data_dir, fn_data)) + df = pd.read_csv(os.path.join(pro_data_dir, fn_label)) + y_label = np.asarray(df['label']).astype('int').reshape((-1, 1)) + + ## load saved model and evaluate + #------------------------------- + model = load_model(os.path.join(model_dir, saved_model)) + y_pred = model.predict(x_data) + score = model.evaluate(x_data, y_label) + loss = np.around(score[0], 3) + acc = np.around(score[1], 3) + print('loss:', loss) + print('acc:', acc) + + if activation == 'sigmoid': + y_pred = model.predict(x_data) + y_pred_class = [1 * (x[0] >= threshold) for x in y_pred] + elif activation == 'softmax': + y_pred_prob = model.predict(x_data) + y_pred = y_pred_prob[:, 1] + y_pred_class = np.argmax(y_pred_prob, axis=1) + + # save a dataframe for prediction + #---------------------------------- + ID = [] + for file in df['fn']: + if run_type in ['val', 'test', 'exval1']: + id = file.split('\\')[-1].split('_')[0].strip() + elif run_type == 'exval2': + id = file.split('\\')[-1].split('_s')[0].strip() + ID.append(id) + df['ID'] = ID + df['y_pred'] = y_pred + df['y_pred_class'] = y_pred_class + df_test_pred = df[['ID', 'fn', 'label', 'y_pred', 'y_pred_class']] + df_test_pred.to_csv(os.path.join(pro_data_dir, fn_pred)) + + return loss, acc + + + + + + + + diff --git a/go_model/finetune_model.py b/go_model/finetune_model.py new file mode 100644 index 0000000..6718a9d --- /dev/null +++ b/go_model/finetune_model.py @@ -0,0 +1,106 @@ +import os +import numpy as np +import pandas as pd +import seaborn as sn +import glob +from collections import Counter +from datetime import datetime +from time import localtime, strftime +import tensorflow as tf +from tensorflow.keras.models import Model +from tensorflow.keras.models import load_model + + + +def finetune_model(out_dir, proj_dir, HN_model, batch_size, epoch, + freeze_layer, input_channel=3): + + """ + Fine tune head anc neck model using chest CT data; + + Args: + out_dir {path} -- path to main output folder; + proj_dir {path} -- path to main project folder; + saved_model {str} -- saved model name; + batch_size {int} -- batch size to load the data; + epoch {int} -- running epoch to fine tune model, 10 or 20; + freeeze_layer {int} -- number of layers in HN model to freeze durding fine tuning; + i + Keyword args: + input_channel {int} -- image channel: 1 or 3; + + Returns: + Finetuned model for chest CT. + + """ + + model_dir = os.path.join(out_dir, 'model') + pro_data_dir = os.path.join(proj_dir, 'pro_data') + if not os.path.exists(model_dir): + os.mkdir(model_dir) + if not os.path.exists(pro_data_dir): + os.mkdir(pro_data_dir) + + ### load train data + if input_channel == 1: + fn = 'exval1_arr1.npy' + elif input_channel == 3: + fn = 'exval1_arr1.npy' + x_train = np.load(os.path.join(pro_data_dir, fn)) + ### load train labels + train_df = pd.read_csv(os.path.join(pro_data_dir, 'exval1_img_df1.csv')) + y_train = np.asarray(train_df['label']).astype('int').reshape((-1, 1)) + print("sucessfully load data!") + + ## load saved model + model = load_model(os.path.join(model_dir, HN_model)) + model.summary() + + ### freeze specific number of layers + if freeze_layer != None: + for layer in model.layers[0:freeze_layer]: + layer.trainable = False + for layer in model.layers: + print(layer, layer.trainable) + else: + for layer in model.layers: + layer.trainable = True + model.summary() + + ### fit data into dnn models + history = model.fit( + x=x_train, + y=y_train, + batch_size=batch_size, + epochs=epoch, + validation_data=None, + verbose=1, + callbacks=None, + validation_split=0.3, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0 + ) + +# ### valudation acc and loss +# score = model.evaluate(x_val, y_val) +# loss = np.around(score[0], 3) +# acc = np.around(score[1], 3) +# print('val loss:', loss) +# print('val acc:', acc) + + #### save final model + run_model = saved_model.split('_')[0].strip() + model_fn = 'Tuned' + '_' + str(run_model) + '_' + \ + str(strftime('%Y_%m_%d_%H_%M_%S', localtime())) + model.save(os.path.join(model_dir, model_fn)) + tuned_model = model + print('fine tuning model complete!!') + print('saved fine-tuned model as:', model_fn) + + return tuned_model, model_fn + + + + diff --git a/go_model/get_model.py b/go_model/get_model.py new file mode 100644 index 0000000..11f97e6 --- /dev/null +++ b/go_model/get_model.py @@ -0,0 +1,91 @@ +import os +import numpy as np +import pydot +import pydotplus +import graphviz +from tensorflow.keras.utils import plot_model +from models.simple_cnn import simple_cnn +from models.EfficientNet import EfficientNet +from models.ResNet import ResNet +from models.Inception import Inception +from models.VGGNet import VGGNet +from models.TLNet import TLNet + + + + +def get_model(out_dir, run_model, activation, input_shape=(192, 192, 3), freeze_layer=None, transfer=False): + + """ + generate cnn models + + Args: + run_model {str} -- choose specific CNN model type; + activation {str or function} -- activation function in last layer: 'sigmoid', 'softmax', etc; + + Keyword args:i + input_shape {np.array} -- input data shape; + transfer {boolean} -- decide if transfer learning; + freeze_layer {int} -- number of layers to freeze; + + Returns: + deep learning model; + + """ + + + train_dir = os.path.join(out_dir, 'train') + + if run_model == 'cnn': + my_model = simple_cnn( + input_shape=input_shape, + activation=activation, + ) + elif run_model == 'ResNet101V2': + my_model = ResNet( + resnet='ResNet101V2', #'ResNet50V2', + transfer=transfer, + freeze_layer=freeze_layer, + input_shape=input_shape, + activation=activation, + ) + elif run_model == 'EffNetB4': + my_model = EfficientNet( + effnet='EffNetB4', + transfer=transfer, + freeze_layer=freeze_layer, + input_shape=input_shape, + activation=activation + ) + elif run_model == 'TLNet': + my_model = TLNet( + resnet='ResNet101V2', + input_shape=input_shape, + activation=activation + ) + elif run_model == 'InceptionV3': + my_model = Inception( + inception='InceptionV3', + transfer=transfer, + freeze_layer=freeze_layer, + input_shape=input_shape, + activation=activation + ) + + print(my_model) + + # plot cnn architectures and save png + fn = os.path.join(train_dir, str(run_model) + '.png') + plot_model( + model=my_model, + to_file=fn, + show_shapes=True, + show_layer_names=True + ) + + + return my_model + + + + diff --git a/go_model/pred_model.py b/go_model/pred_model.py new file mode 100644 index 0000000..40a87a8 --- /dev/null +++ b/go_model/pred_model.py @@ -0,0 +1,100 @@ + +""" + ---------------------------------------------- + DeepContrast - run DeepContrast pipeline step5 + ---------------------------------------------- + ---------------------------------------------- + Author: AIM Harvard + + Python Version: 3.8.5 + ---------------------------------------------- + +""" + +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib.pyplot as plt +import nrrd +import scipy.stats as ss +import SimpleITK as stik +import glob +from PIL import Image +from collections import Counter +import skimage.transform as st +from datetime import datetime +from time import gmtime, strftime +import pickle +import tensorflow +from tensorflow.keras.models import Model +from tensorflow.keras.models import load_model +from sklearn.metrics import classification_report +from sklearn.metrics import confusion_matrix + + + + +def pred_model(model_dir, pred_data_dir, saved_model, thr_img, thr_prob, fns_pat_pred, + fns_img_pred, fns_arr, fns_img_df): + + """ + model prediction + + @params: + model_dir - required : path to load CNN model + pro_data_dir - required : path to folder that saves all processed data + saved_model - required : CNN model name from training step + input_channel - required : 1 or 3, usually 3 + threshold - required : threshold to decide predicted label + + """ + + for fn_arr, fn_img_df, fn_img_pred in zip(fns_arr, fns_img_df, fns_img_pred): + + ### load numpy array + x_exval = np.load(os.path.join(pred_data_dir, fn_arr)) + ### load ID + df = pd.read_csv(os.path.join(pred_data_dir, fn_img_df)) + ### load saved model + model = load_model(os.path.join(model_dir, saved_model)) + ## prediction + y_pred = model.predict(x_exval) + y_pred_class = [1 * (x[0] >= thr_img) for x in y_pred] + + ### save a dataframe for test and prediction + ID = [] + for file in df['fn']: + id = file.split('\\')[-1].split('_s')[0].strip() + ID.append(id) + df['ID'] = ID + df['y_pred'] = y_pred + df['y_pred_class'] = y_pred_class + df_img_pred = df[['ID', 'fn', 'y_pred', 'y_pred_class']] + df_img_pred.to_csv(os.path.join(pred_data_dir, fn_img_pred), index=False) + + ## calcualte patient level prediction + for fn_img_pred, fn_pat_pred in zip(fns_img_pred, fns_pat_pred): + df = pd.read_csv(os.path.join(pred_data_dir, fn_img_pred)) + df.drop(['fn'], axis=1, inplace=True) + df_mean = df.groupby(['ID']).mean() + preds = df_mean['y_pred'] + y_pred = [] + for pred in preds: + if pred > thr_prob: + pred = 1 + else: + pred = 0 + y_pred.append(pred) + df_mean['predictions'] = y_pred + df_mean.drop(['y_pred', 'y_pred_class'], axis=1, inplace=True) + df_mean.to_csv(os.path.join(pred_data_dir, fn_pat_pred)) + print(str(fn_pat_pred.split('_')[0]) + str(' cohort')) + print('Total scan:', df_mean.shape[0]) + print(df_mean['predictions'].value_counts()) + + + + + + diff --git a/go_model/train_model.py b/go_model/train_model.py new file mode 100644 index 0000000..063dca4 --- /dev/null +++ b/go_model/train_model.py @@ -0,0 +1,115 @@ + +import os +import numpy as np +import pandas as pd +import seaborn as sn +import glob +from collections import Counter +from datetime import datetime +from time import localtime, strftime +import tensorflow as tf +from tensorflow.keras.models import Model +from go_model.callbacks import callbacks +from utils.plot_train_curve import plot_train_curve +from tensorflow.keras.optimizers import Adam +from utils.write_txt import write_txt + + + +def train_model(out_dir, model, run_model, train_gen, val_gen, x_val, y_val, batch_size, epoch, + opt, loss_func, lr): + + """ + train model + + Args: + model {cnn model} -- cnn model; + run_model {str} -- cnn model name; + train_gen {Keras data generator} -- training data generator with data augmentation; + val_gen {Keras data generator} -- val data generator; + x_val {np.array} -- np array for validation data; + y_val {np.array} -- np array for validation label; + batch_size {int} -- batch size for data loading; + epoch {int} -- training epoch; + out_dir {path} -- path for output files; + opt {str or function} -- optimized function: 'adam'; + loss_func {str or function} -- loss function: 'binary_crossentropy'; + lr {float} -- learning rate; + + Returns: + training accuracy, loss, model + + """ + + model_dir = os.path.join(out_dir, 'model') + log_dir = os.path.join(out_dir, 'log') + if not os.path.exists(model_dir): + os.mkdir(model_dir) + if not os.path.exists(log_dir): + os.mkdir(log_dir) + + ## compile model + print('complie model') + model.compile( + optimizer=opt, + loss=loss_func, + metrics=['acc'] + ) + + ## call back functions + my_callbacks = callbacks(log_dir) + + ## fit models + history = model.fit( + train_gen, + steps_per_epoch=train_gen.n//batch_size, + epochs=epoch, + validation_data=val_gen, + #validation_data=(x_val, y_val), + validation_steps=val_gen.n//batch_size, + #validation_steps=y_val.shape[0]//batch_size, + verbose=2, + callbacks=my_callbacks, + validation_split=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0 + ) + + ## valudation acc and loss + score = model.evaluate(x_val, y_val) + loss = np.around(score[0], 3) + acc = np.around(score[1], 3) + print('val loss:', loss) + print('val acc:', acc) + + ## save final model + saved_model = str(run_model) + '_' + str(strftime('%Y_%m_%d_%H_%M_%S', localtime())) + model.save(os.path.join(model_dir, saved_model)) + print(saved_model) + + ## save validation results to txt file + write_txt( + run_type='train', + out_dir=out_dir, + loss=1, + acc=1, + cms=None, + cm_norms=None, + reports=None, + prc_aucs=None, + roc_stats=None, + run_model=run_model, + saved_model=saved_model, + epoch=epoch, + batch_size=batch_size, + lr=lr + ) + + + + + + + diff --git a/models/.DS_Store b/models/.DS_Store new file mode 100644 index 0000000..c1fab06 Binary files /dev/null and b/models/.DS_Store differ diff --git a/models/EfficientNet.py b/models/EfficientNet.py new file mode 100644 index 0000000..b43317e --- /dev/null +++ b/models/EfficientNet.py @@ -0,0 +1,109 @@ + +import os +import numpy as np +import pandas as pd +import glob +import tensorflow +from tensorflow.keras.models import Model +from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout, BatchNormalization +from tensorflow.keras.models import Sequential +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.losses import BinaryCrossentropy +from tensorflow.keras.metrics import BinaryAccuracy +from tensorflow.keras.applications import ( + EfficientNetB3, + EfficientNetB4, + EfficientNetB5 + ) + + +def EfficientNet(effnet, input_shape, transfer=False, freeze_layer=None, activation='sigmoid'): + + + """ + EfficientNets: B3, B4, B5; + + Args: + effnet {str} -- EfficientNets with different layers; + input_shape {np.array} -- input data shape; + + Keyword args: + transfer {boolean} -- decide if transfer learning; + freeze_layer {int} -- number of layers to freeze; + activation {str or function} -- activation function in last layer: 'sigmoid', 'softmax'; + + Returns: + EfficientNet model; + + """ + + + # determine if use transfer learnong or not + if transfer == True: + weights = 'imagenet' + elif transfer == False: + weights = None + + ### determine input shape + default_shape = (224, 224, 3) + if input_shape == default_shape: + include_top = True + else: + include_top = False + + ### determine ResNet base model + if effnet == 'EffNetB3': + base_model = EfficientNetB3( + weights=weights, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + elif effnet == 'EffNetB4': + base_model = EfficientNetB4( + weights=weights, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + elif effnet == 'EffNetB5': + base_model = EfficientNetB5( + weights=weights, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + base_model.trainable = True + + ### create top model + inputs = base_model.input + x = base_model.output + x = GlobalAveragePooling2D()(x) + x = Dropout(0.3)(x) + x = Dense(1000, activation='relu')(x) + x = Dropout(0.3)(x) + outputs = Dense(1, activation=activation)(x) + model = Model(inputs=inputs, outputs=outputs) + + ### freeze specific number of layers + if freeze_layer == 1: + for layer in base_model.layers[0:5]: + layer.trainable = False + for layer in base_model.layers: + print(layer, layer.trainable) + if freeze_layer == 5: + for layer in base_model.layers[0:16]: + layer.trainable = False + for layer in base_model.layers: + print(layer, layer.trainable) + else: + for layer in base_model.layers: + layer.trainable = True + + model.summary() + + + return model + + + diff --git a/models/GoogleNet.py b/models/GoogleNet.py new file mode 100644 index 0000000..67087e2 --- /dev/null +++ b/models/GoogleNet.py @@ -0,0 +1,119 @@ +#---------------------------------------------------------------------- +# Deep learning for classification for contrast CT; +# Transfer learning using Google Inception V3; +#------------------------------------------------------------------------------------------- + +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib.pyplot as plt +import glob +import tensorflow +from tensorflow import keras +from tensorflow.keras.models import Model +from tensorflow.keras.preprocessing.image import img_to_array, load_img, ImageDataGenerator +from tensorflow.keras.layers import GlobalAveragePooling2D +from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization +from tensorflow.keras.models import Sequential +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.applications.inception_v3 import InceptionV3 +from tensorflow.keras.applications.vgg16 import VGG16 +from tensorflow.keras.applications import ResNet50 +from tensorflow.keras.applications import ResNet152 +from tensorflow.keras.applications import ResNet101 +from tensorflow.keras.applications import ResNet50V2 +from tensorflow.keras.applications import ResNet101V2 +from tensorflow.keras.applications import ResNet152V2 + +# ---------------------------------------------------------------------------------- +# transfer learning CNN model +# ---------------------------------------------------------------------------------- +def GoogleNet(resnet, transfer_learning, input_shape, batch_momentum, activation, + activation_out, loss_function, optimizer, dropout_rate): + + ### determine if use transfer learnong or not + if transfer_learning == True: + weights = 'imagenet' + elif transfer_learning == False: + weights = None + + ### dermine input shape + default_shape = (224, 224, 3) + if input_shape == default_shape: + include_top = True + else: + include_top = False + + ### determine ResNet base model + if resnet == 'ResNet50V2': + base_model = ResNet50V2( + weights=weights, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + elif resnet == 'ResNet101V2': + base_model = ResNet101V2( + weights=weights, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + elif resnet == 'ResNet152V2': + base_model = ResNet152V2( + weights=weights, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) +# base_model.trainable = True +# +# inputs = keras.Input(shape=input_shape) +# x = inputs +# x = base_model(x, training=True) +# x = BatchNormalization(momentum=batch_momentum)(x) +# x = GlobalAveragePooling2D()(x) +# x = Dropout(0.2)(x) +# outputs = Dense(1)(x) +# model = Model(inputs, outputs) +# model.summary() + ### create top model + out = base_model.output + out = BatchNormalization(momentum=batch_momentum)(out) + out = GlobalAveragePooling2D()(out) + out = Dropout(dropout_rate)(out) + ### layer 3 +# out = BatchNormalization(momentum=batch_momentum)(out) +# out = Dense(512, activation=activation)(out) +# out = Dropout(dropout_rate)(out) +# ### lyaer 2 +# out = BatchNormalization(momentum=batch_momentum)(out) +# out = Dense(128, activation=activation)(out) +# out = Dropout(dropout_rate)(out) +# ### layer 1 +# out = BatchNormalization(momentum=batch_momentum)(out) + predictions = Dense(1, activation=activation_out)(out) + model = Model(inputs=base_model.input, outputs=predictions) + + ### only if we want to freeze layers +# for layer in base_model.layers: +# layer.trainable = True + + print('complie model') + model.compile( + loss=loss_function, + optimizer=optimizer, + metrics=['accuracy'] + ) + model.summary() + + return model + + + + + + + + diff --git a/models/Inception.py b/models/Inception.py new file mode 100644 index 0000000..182ec5d --- /dev/null +++ b/models/Inception.py @@ -0,0 +1,113 @@ + +import os +import numpy as np +import pandas as pd +import tensorflow +from tensorflow.keras.models import Model +from tensorflow.keras.layers import Dense, Dropout, BatchNormalization +from tensorflow.keras.models import Sequential +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.losses import BinaryCrossentropy +from tensorflow.keras.metrics import BinaryAccuracy +from tensorflow.keras.applications import Xception +from tensorflow.keras.applications import InceptionV3 +from tensorflow.keras.applications import InceptionResNetV2 + + + +def Inception(inception, input_shape, transfer=False, freeze_layer=None, activation='sigmoid'): + + """ + Google Inception Net: Xception, InceptionV3, InceptionResNetV2; + Keras CNN models for use: https://keras.io/api/applications/ + InceptionV3(top1 acc 0.779) + InceptionResnetV2(top1 acc 0.803), + ResNet152V2(top1 acc 0.780) + + Args: + effnet {str} -- EfficientNets with different layers; + input_shape {np.array} -- input data shape; + + Keyword args: + inception {boolean} -- decide if transfer learning; + freeze_layer {int} -- number of layers to freeze; + activation {str or function} -- activation function in last layer: 'sigmoid', 'softmax'; + + Returns: + Inception model; + + """ + + # determine if use transfer learnong or not + if transfer == True: + weights = 'imagenet' + elif transfer == False: + weights = None + + ### determine input shape + default_shape = (224, 224, 3) + if input_shape == default_shape: + include_top = True + else: + include_top = False + + ## determine n_output + if activation == 'softmax': + n_output = 2 + elif activation == 'sigmoid': + n_output = 1 + + ### determine ResNet base model + if inception == 'Xception': + base_model = Xception( + weights=weights, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + elif inception == 'InceptionV3': + base_model = InceptionV3( + weights=weights, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + elif inception == 'InceptionResNetV2': + base_model = InceptionResNetV2( + weights=weights, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + + ## create top model + inputs = base_model.input + x = base_model.output + x = GlobalAveragePooling2D()(x) + x = Dropout(0.3)(x) + x = Dense(1000, activation='relu')(x) + x = Dropout(0.3)(x) + outputs = Dense(n_output, activation=activation)(x) + model = Model(inputs=inputs, outputs=outputs) + + ### freeze specific number of layers + if freeze_layer == 1: + for layer in base_model.layers[0:5]: + layer.trainable = False + for layer in base_model.layers: + print(layer, layer.trainable) + if freeze_layer == 5: + for layer in base_model.layers[0:16]: + layer.trainable = False + for layer in base_model.layers: + print(layer, layer.trainable) + else: + for layer in base_model.layers: + layer.trainable = True + + model.summary() + + return model + + + diff --git a/models/ResNet.py b/models/ResNet.py new file mode 100644 index 0000000..3d11c0f --- /dev/null +++ b/models/ResNet.py @@ -0,0 +1,140 @@ + +import os +import numpy as np +import pandas as pd +import tensorflow +from tensorflow import keras +from tensorflow.keras.models import Model +from tensorflow.keras.layers import GlobalAveragePooling2D +from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization +from tensorflow.keras.models import Sequential +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.losses import BinaryCrossentropy +from tensorflow.keras.applications import ( + ResNet50, + ResNet101, + ResNet152, + ResNet50V2, + ResNet101V2, + ResNet152V2 + ) + + + +def ResNet(resnet, input_shape, transfer=False, freeze_layer=None, activation='sigmoid'): + + """ + ResNet: 50, 101, 152 + + Args: + resnet {str} -- resnets with different layers, i.e. 'ResNet101'; + input_shape {np.array} -- input data shape; + + Keyword args: + transfer {boolean} -- decide if transfer learning; + freeze_layer {int} -- number of layers to freeze; + activation {str or function} -- activation function in last layer: 'sigmoid', 'softmax'; + + Returns: + ResNet model; + + """ + + ### determine if use transfer learnong or not + if transfer == True: + weights = 'imagenet' + elif transfer == False: + weights = None + + ### determine input shape + default_shape = (224, 224, 3) + if input_shape == default_shape: + include_top = True + else: + include_top = False + + ## determine n_output + if activation == 'softmax': + n_output = 2 + elif activation == 'sigmoid': + n_output = 1 + + ### determine ResNet base model + if resnet == 'ResNet50V2': + base_model = ResNet50V2( + weights=None, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + elif resnet == 'ResNet101V2': + base_model = ResNet101V2( + weights=None, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + elif resnet == 'ResNet152V2': + base_model = ResNet152V2( + weights=None, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + elif resnet == 'ResNet50': + base_model = ResNet50( + weights=None, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + elif resnet == 'ResNet101': + base_model = ResNet101( + weights=None, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + elif resnet == 'ResNet152': + base_model = ResNet152( + weights=None, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + + ### create top model + inputs = base_model.input + x = base_model.output + x = GlobalAveragePooling2D()(x) + x = Dense(1000, activation='relu')(x) + outputs = Dense(n_output, activation=activation)(x) + model = Model(inputs=inputs, outputs=outputs) + + ### freeze specific number of layers + if freeze_layer == 1: + for layer in base_model.layers[0:5]: + layer.trainable = False + for layer in base_model.layers: + print(layer, layer.trainable) + if freeze_layer == 5: + for layer in base_model.layers[0:16]: + layer.trainable = False + for layer in base_model.layers: + print(layer, layer.trainable) + else: + for layer in base_model.layers: + layer.trainable = True + + model.summary() + + + return model + + + + + + + + diff --git a/models/TLNet.py b/models/TLNet.py new file mode 100644 index 0000000..8586cdd --- /dev/null +++ b/models/TLNet.py @@ -0,0 +1,91 @@ + +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib.pyplot as plt +import glob +import tensorflow +from tensorflow import keras +from tensorflow.keras import Input +from tensorflow.keras.models import Model +from tensorflow.keras.preprocessing.image import img_to_array, load_img, ImageDataGenerator +from tensorflow.keras.layers import GlobalAveragePooling2D +from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization +from tensorflow.keras.models import Sequential +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.losses import BinaryCrossentropy +from tensorflow.keras.applications.inception_v3 import InceptionV3 +from tensorflow.keras.applications.vgg16 import VGG16 +from tensorflow.keras.applications import ResNet50 +from tensorflow.keras.applications import ResNet152 +from tensorflow.keras.applications import ResNet101 +from tensorflow.keras.applications import ResNet50V2 +from tensorflow.keras.applications import ResNet101V2 +from tensorflow.keras.applications import ResNet152V2 + + + + +def TLNet(resnet, input_shape, activation='sigmoid'): + + """ + Transfer learning based on ResNet + + Args: + resnet {str} -- resnets with different layers, i.e. 'ResNet101'; + input_shape {np.array} -- input data shape; + + Keyword args: + activation {str or function} -- activation function in last layer: 'sigmoid', 'softmax'; + + Returns: + Transfer learning model; + + + """ + + ## determine ResNet base model + if resnet == 'ResNet50V2': + base_model = ResNet50V2( + weights='imagenet', + include_top=False, + input_shape=input_shape, + pooling=None + ) + elif resnet == 'ResNet101V2': + base_model = ResNet101V2( + weights='imagenet', + include_top=False, + input_shape=input_shape, + pooling=None + ) + elif resnet == 'ResNet152V2': + base_model = ResNet152V2( + weights='imagenet', + include_top=False, + input_shape=input_shape, + pooling=None + ) + + base_model.trainable = False + + ### create top model + inputs = Input(shape=input_shape) + x = base_model(inputs, training=False) + x = GlobalAveragePooling2D()(x) + x = Dense(1000, activation='relu')(x) + #x = Dense(1024, activation='relu')(x) + #x = Dense(512, activation='relu')(x) + outputs = Dense(1, activation=activation)(x) + model = Model(inputs, outputs) + + return model + + + + + + + + diff --git a/models/VGGNet.py b/models/VGGNet.py new file mode 100644 index 0000000..3fcfd31 --- /dev/null +++ b/models/VGGNet.py @@ -0,0 +1,107 @@ +#---------------------------------------------------------------------- +# Deep learning for classification for contrast CT; +# Transfer learning using Google Inception V3; +#------------------------------------------------------------------------------------------- + +import os +import numpy as np +import pandas as pd +import seaborn as s1n +import matplotlib.pyplot as plt +import glob +import tensorflow +from tensorflow.keras.models import Model +from tensorflow.keras.preprocessing.image import img_to_array, load_img, ImageDataGenerator +from tensorflow.keras.layers import GlobalAveragePooling2D +from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization +from tensorflow.keras.models import Sequential +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.losses import BinaryCrossentropy +from tensorflow.keras.metrics import BinaryAccuracy +from tensorflow.keras.applications import EfficientNetB5 +from tensorflow.keras.applications import EfficientNetB4 +from tensorflow.keras.applications import EfficientNetB3 +from tensorflow.keras.applications import DenseNet121 +from tensorflow.keras.applications import Xception +from tensorflow.keras.applications import InceptionV3 +from tensorflow.keras.applications import InceptionResNetV2 +from tensorflow.keras.applications import VGG16 +from tensorflow.keras.applications import VGG19 + + +# ---------------------------------------------------------------------------------- +# transfer learning CNN model +# ---------------------------------------------------------------------------------- +def VGGNet(VGG, transfer, freeze_layer, input_shape, activation): + + ### Keras CNN models for use: https://keras.io/api/applications/ + ### InceptionV3(top1 acc 0.779) + ### InceptionResnetV2(top1 acc 0.803), + ### ResNet152V2(top1 acc 0.780) + ### EficientNetB4 + + ### determine if use transfer learnong or not + if transfer == True: + weights = 'imagenet' + elif transfer == False: + weights = None + + ### determine input shape + default_shape = (224, 224, 3) + if input_shape == default_shape: + include_top = True + else: + include_top = False + + ## determine n_output + if activation == 'softmax': + n_output = 2 + elif activation == 'sigmoid': + n_output = 1 + + ### determine ResNet base model + if VGG == 'VGG16': + base_model = VGG16( + weights=weights, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + elif VGG == 'VGG19': + base_model = VGG19( + weights=weights, + include_top=include_top, + input_shape=input_shape, + pooling=None + ) + + ### create top model + inputs = base_model.input + x = base_model.output + x = GlobalAveragePooling2D()(x) + x = Dropout(0.3)(x) + x = Dense(1000, activation='relu')(x) + x = Dropout(0.3)(x) + outputs = Dense(n_output, activation=activation)(x) + model = Model(inputs=inputs, outputs=outputs) + + ### freeze specific number of layers + if freeze_layer == 1: + for layer in base_model.layers[0:5]: + layer.trainable = False + for layer in base_model.layers: + print(layer, layer.trainable) + if freeze_layer == 5: + for layer in base_model.layers[0:16]: + layer.trainable = False + for layer in base_model.layers: + print(layer, layer.trainable) + else: + for layer in base_model.layers: + layer.trainable = True + model.summary() + + return model + + + diff --git a/models/ViT.py b/models/ViT.py new file mode 100644 index 0000000..5474f84 --- /dev/null +++ b/models/ViT.py @@ -0,0 +1,289 @@ +#---------------------------------------------------------------------- +# Deep learning for classification for contrast CT; +# Transfer learning using Google Inception V3; +#------------------------------------------------------------------------------------------- + +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib.pyplot as plt +import glob +import tensorflow +from tensorflow import keras +from tensorflow.keras import layers +from tensorflow.keras.models import Model +from tensorflow.keras.preprocessing.image import img_to_array, load_img, ImageDataGenerator +from tensorflow.keras.layers import GlobalAveragePooling2D +from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization +from tensorflow.keras.models import Sequential +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.losses import BinaryCrossentropy +from tensorflow.keras.applications.inception_v3 import InceptionV3 +from tensorflow.keras.applications.vgg16 import VGG16 +from tensorflow.keras.applications import ResNet50 +from tensorflow.keras.applications import ResNet152 +from tensorflow.keras.applications import ResNet101 +from tensorflow.keras.applications import ResNet50V2 +from tensorflow.keras.applications import ResNet101V2 +from tensorflow.keras.applications import ResNet152V2 +import tensorflow_addons as tfa +import matplotlib.pyplot as plt + +#--------------------------------------------------------------------- +# Vision Transformer +#--------------------------------------------------------------------- + +""" +Title: Image classification with Vision Transformer +Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/) +Date created: 2021/01/18 +Last modified: 2021/01/18 +Description: Implementing the Vision Transformer (ViT) model for image classification. +""" + +""" +## Introduction +This example implements the [Vision Transformer (ViT)](https://arxiv.org/abs/2010.11929) +model by Alexey Dosovitskiy et al. for image classification, +and demonstrates it on the CIFAR-100 dataset. +The ViT model applies the Transformer architecture with self-attention to sequences of +image patches, without using convolution layers. +This example requires TensorFlow 2.4 or higher, as well as +[TensorFlow Addons](https://www.tensorflow.org/addons/overview), +which can be installed using the following command: +```python +pip install -U tensorflow-addons +``` +""" + +""" +## Setup +""" + + +num_classes = 100 +input_shape = (32, 32, 3) +learning_rate = 0.001 +weight_decay = 0.0001 +batch_size = 256 +num_epochs = 100 +image_size = 72 # We'll resize input images to this size +patch_size = 6 # Size of the patches to be extract from the input images +num_patches = (image_size // patch_size) ** 2 +projection_dim = 64 +num_heads = 4 +transformer_units = [projection_dim * 2, projection_dim] # Size of the transformer layers +transformer_layers = 8 +mlp_head_units = [2048, 1024] # Size of the dense layers of the final classifier + + +## Use data augmentation +data_augmentation = keras.Sequential( + [layers.Normalization(), + layers.Resizing(image_size, image_size), + layers.RandomFlip("horizontal"), + layers.RandomRotation(factor=0.02), + layers.RandomZoom(height_factor=0.2, width_factor=0.2)], + name="data_augmentation", + ) +# Compute the mean and the variance of the training data for normalization. +data_augmentation.layers[0].adapt(x_train) + +## Implement multilayer perceptron (MLP) +def mlp(x, hidden_units, dropout_rate): + for units in hidden_units: + x = layers.Dense(units, activation=tf.nn.gelu)(x) + x = layers.Dropout(dropout_rate)(x) + return x + +## Implement patch creation as a layer +class Patches(layers.Layer): + def __init__(self, patch_size): + super(Patches, self).__init__() + self.patch_size = patch_size + + def call(self, images): + batch_size = tf.shape(images)[0] + patches = tf.image.extract_patches( + images=images, + sizes=[1, self.patch_size, self.patch_size, 1], + strides=[1, self.patch_size, self.patch_size, 1], + rates=[1, 1, 1, 1], + padding="VALID", + ) + patch_dims = patches.shape[-1] + patches = tf.reshape(patches, [batch_size, -1, patch_dims]) + return patches + +## Let's display patches for a sample image +plt.figure(figsize=(4, 4)) +image = x_train[np.random.choice(range(x_train.shape[0]))] +plt.imshow(image.astype("uint8")) +plt.axis("off") + +resized_image = tf.image.resize( + tf.convert_to_tensor([image]), size=(image_size, image_size) +) +patches = Patches(patch_size)(resized_image) +print(f"Image size: {image_size} X {image_size}") +print(f"Patch size: {patch_size} X {patch_size}") +print(f"Patches per image: {patches.shape[1]}") +print(f"Elements per patch: {patches.shape[-1]}") + +n = int(np.sqrt(patches.shape[1])) +plt.figure(figsize=(4, 4)) +for i, patch in enumerate(patches[0]): + ax = plt.subplot(n, n, i + 1) + patch_img = tf.reshape(patch, (patch_size, patch_size, 3)) + plt.imshow(patch_img.numpy().astype("uint8")) + plt.axis("off") + +""" +## Implement the patch encoding layer +The `PatchEncoder` layer will linearly transform a patch by projecting it into a +vector of size `projection_dim`. In addition, it adds a learnable position +embedding to the projected vector. +""" + +class PatchEncoder(layers.Layer): + + def __init__(self, num_patches, projection_dim): + super(PatchEncoder, self).__init__() + self.num_patches = num_patches + self.projection = layers.Dense(units=projection_dim) + self.position_embedding = layers.Embedding( + input_dim=num_patches, + output_dim=projection_dim + ) + + def call(self, patch): + positions = tf.range(start=0, limit=self.num_patches, delta=1) + encoded = self.projection(patch) + self.position_embedding(positions) + return encoded + + +""" +## Build the ViT model +The ViT model consists of multiple Transformer blocks, +which use the `layers.MultiHeadAttention` layer as a self-attention mechanism +applied to the sequence of patches. The Transformer blocks produce a +`[batch_size, num_patches, projection_dim]` tensor, which is processed via an +classifier head with softmax to produce the final class probabilities output. +Unlike the technique described in the [paper](https://arxiv.org/abs/2010.11929), +which prepends a learnable embedding to the sequence of encoded patches to serve +as the image representation, all the outputs of the final Transformer block are +reshaped with `layers.Flatten()` and used as the image +representation input to the classifier head. +Note that the `layers.GlobalAveragePooling1D` layer +could also be used instead to aggregate the outputs of the Transformer block, +especially when the number of patches and the projection dimensions are large. +""" + +def create_vit_classifier(): + inputs = layers.Input(shape=input_shape) + # Augment data. + augmented = data_augmentation(inputs) + # Create patches. + patches = Patches(patch_size)(augmented) + # Encode patches. + encoded_patches = PatchEncoder(num_patches, projection_dim)(patches) + + # Create multiple layers of the Transformer block. + for _ in range(transformer_layers): + # Layer normalization 1. + x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) + # Create a multi-head attention layer. + attention_output = layers.MultiHeadAttention( + num_heads=num_heads, + key_dim=projection_dim, + dropout=0.1 + )(x1, x1) + # Skip connection 1. + x2 = layers.Add()([attention_output, encoded_patches]) + # Layer normalization 2. + x3 = layers.LayerNormalization(epsilon=1e-6)(x2) + # MLP. + x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1) + # Skip connection 2. + encoded_patches = layers.Add()([x3, x2]) + + # Create a [batch_size, projection_dim] tensor. + representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches) + representation = layers.Flatten()(representation) + representation = layers.Dropout(0.5)(representation) + # Add MLP. + features = mlp( + representation, + hidden_units=mlp_head_units, + dropout_rate=0.5 + ) + # Classify outputs. + logits = layers.Dense(num_classes)(features) + # Create the Keras model. + model = keras.Model(inputs=inputs, outputs=logits) + + return model + + +## Compile, train, and evaluate the mode +def run_experiment(model): + + optimizer = tfa.optimizers.AdamW( + learning_rate=learning_rate, + weight_decay=weight_decay + ) + + model.compile( + optimizer=optimizer, + loss=BinaryCrossentropy(from_logits=True), + metrics=['acc'] + ) + + checkpoint_filepath = log_dir + checkpoint_callback = keras.callbacks.ModelCheckpoint( + checkpoint_filepath, + monitor="val_accuracy", + save_best_only=True, + save_weights_only=True, + ) + + history = model.fit( + x=x_train, + y=y_train, + batch_size=batch_size, + epochs=num_epochs, + validation_split=0.1, + callbacks=[checkpoint_callback], + ) + + model.load_weights(checkpoint_filepath) + acc = model.evaluate(x_test, y_test) + print(f"Test accuracy: {round(accuracy * 100, 2)}%") + + return history + + +vit_classifier = create_vit_classifier() +history = run_experiment(vit_classifier) + + +""" +After 100 epochs, the ViT model achieves around 55% accuracy and +82% top-5 accuracy on the test data. These are not competitive results on the CIFAR-100 dataset, +as a ResNet50V2 trained from scratch on the same data can achieve 67% accuracy. +Note that the state of the art results reported in the +[paper](https://arxiv.org/abs/2010.11929) are achieved by pre-training the ViT model using +the JFT-300M dataset, then fine-tuning it on the target dataset. To improve the model quality +without pre-training, you can try to train the model for more epochs, use a larger number of +Transformer layers, resize the input images, change the patch size, or increase the projection dimensions. +Besides, as mentioned in the paper, the quality of the model is affected not only by architecture choices, +but also by parameters such as the learning rate schedule, optimizer, weight decay, etc. +In practice, it's recommended to fine-tune a ViT model +that was pre-trained using a large, high-resolution dataset. +""" + + + + + diff --git a/models/__pycache__/EffNet.cpython-38.pyc b/models/__pycache__/EffNet.cpython-38.pyc new file mode 100644 index 0000000..5f0feb8 Binary files /dev/null and b/models/__pycache__/EffNet.cpython-38.pyc differ diff --git a/models/__pycache__/EffNetB5_model.cpython-38.pyc b/models/__pycache__/EffNetB5_model.cpython-38.pyc new file mode 100644 index 0000000..a9083e8 Binary files /dev/null and b/models/__pycache__/EffNetB5_model.cpython-38.pyc differ diff --git a/models/__pycache__/EffNet_model.cpython-38.pyc b/models/__pycache__/EffNet_model.cpython-38.pyc new file mode 100644 index 0000000..f838301 Binary files /dev/null and b/models/__pycache__/EffNet_model.cpython-38.pyc differ diff --git a/models/__pycache__/EfficientNet.cpython-38.pyc b/models/__pycache__/EfficientNet.cpython-38.pyc new file mode 100644 index 0000000..a622f96 Binary files /dev/null and b/models/__pycache__/EfficientNet.cpython-38.pyc differ diff --git a/models/__pycache__/Inception.cpython-38.pyc b/models/__pycache__/Inception.cpython-38.pyc new file mode 100644 index 0000000..bef2d26 Binary files /dev/null and b/models/__pycache__/Inception.cpython-38.pyc differ diff --git a/models/__pycache__/ResNet.cpython-38.pyc b/models/__pycache__/ResNet.cpython-38.pyc new file mode 100644 index 0000000..18b7a11 Binary files /dev/null and b/models/__pycache__/ResNet.cpython-38.pyc differ diff --git a/models/__pycache__/ResNet50.cpython-38.pyc b/models/__pycache__/ResNet50.cpython-38.pyc new file mode 100644 index 0000000..72a250d Binary files /dev/null and b/models/__pycache__/ResNet50.cpython-38.pyc differ diff --git a/models/__pycache__/ResNet50_model.cpython-38.pyc b/models/__pycache__/ResNet50_model.cpython-38.pyc new file mode 100644 index 0000000..d4da8fb Binary files /dev/null and b/models/__pycache__/ResNet50_model.cpython-38.pyc differ diff --git a/models/__pycache__/ResNet_model.cpython-38.pyc b/models/__pycache__/ResNet_model.cpython-38.pyc new file mode 100644 index 0000000..5b0612c Binary files /dev/null and b/models/__pycache__/ResNet_model.cpython-38.pyc differ diff --git a/models/__pycache__/TLNet.cpython-38.pyc b/models/__pycache__/TLNet.cpython-38.pyc new file mode 100644 index 0000000..d885b34 Binary files /dev/null and b/models/__pycache__/TLNet.cpython-38.pyc differ diff --git a/models/__pycache__/TL_ResNet50.cpython-38.pyc b/models/__pycache__/TL_ResNet50.cpython-38.pyc new file mode 100644 index 0000000..016914a Binary files /dev/null and b/models/__pycache__/TL_ResNet50.cpython-38.pyc differ diff --git a/models/__pycache__/TL_ResNet50_model.cpython-38.pyc b/models/__pycache__/TL_ResNet50_model.cpython-38.pyc new file mode 100644 index 0000000..d261179 Binary files /dev/null and b/models/__pycache__/TL_ResNet50_model.cpython-38.pyc differ diff --git a/models/__pycache__/VGG16.cpython-38.pyc b/models/__pycache__/VGG16.cpython-38.pyc new file mode 100644 index 0000000..2c7dd96 Binary files /dev/null and b/models/__pycache__/VGG16.cpython-38.pyc differ diff --git a/models/__pycache__/VGGNet.cpython-38.pyc b/models/__pycache__/VGGNet.cpython-38.pyc new file mode 100644 index 0000000..bc9b22f Binary files /dev/null and b/models/__pycache__/VGGNet.cpython-38.pyc differ diff --git a/models/__pycache__/cnn_model.cpython-38.pyc b/models/__pycache__/cnn_model.cpython-38.pyc new file mode 100644 index 0000000..4d8ce9c Binary files /dev/null and b/models/__pycache__/cnn_model.cpython-38.pyc differ diff --git a/models/__pycache__/models.cpython-38.pyc b/models/__pycache__/models.cpython-38.pyc new file mode 100644 index 0000000..fe8e725 Binary files /dev/null and b/models/__pycache__/models.cpython-38.pyc differ diff --git a/models/__pycache__/simple_cnn.cpython-38.pyc b/models/__pycache__/simple_cnn.cpython-38.pyc new file mode 100644 index 0000000..44d82fe Binary files /dev/null and b/models/__pycache__/simple_cnn.cpython-38.pyc differ diff --git a/models/simple_cnn.py b/models/simple_cnn.py new file mode 100644 index 0000000..cca7a3f --- /dev/null +++ b/models/simple_cnn.py @@ -0,0 +1,82 @@ + +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib.pyplot as plt +import glob +import tensorflow as tf +from tensorflow.keras.models import Model +from tensorflow.keras.layers import GlobalAveragePooling2D +from tensorflow.keras.layers import Conv2D +from tensorflow.keras.layers import MaxPooling2D +from tensorflow.keras.layers import Flatten +from tensorflow.keras.layers import Dense +from tensorflow.keras.layers import Dropout +from tensorflow.keras.layers import BatchNormalization +from tensorflow.keras.models import Sequential +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.losses import BinaryCrossentropy + + + + +def simple_cnn(input_shape, activation): + + """ + simple CNN model + + Args: + activation {str or function} -- activation function in last layer: 'softmax' or 'sigmoid'; + + Returns: + simple cnn model + + """ + + ## determine n_output + if activation == 'softmax': + n_output = 2 + elif activation == 'sigmoid': + n_output = 1 + + model = Sequential() + + model.add(Conv2D(16, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) + model.add(BatchNormalization(momentum=0.95)) + model.add(MaxPooling2D(pool_size=(2, 2))) + #model.add(Dropout(0.3)) + + model.add(BatchNormalization(momentum=0.95)) + model.add(Conv2D(64, kernel_size=(3, 3), activation='relu')) + model.add(MaxPooling2D(pool_size=(2, 2))) + #model.add(Dropout(0.3)) + + model.add(BatchNormalization(momentum=0.95)) + model.add(Conv2D(128, kernel_size=(3, 3), activation='relu')) + model.add(MaxPooling2D(pool_size=(2, 2))) + #model.add(Dropout(0.3)) + + model.add(BatchNormalization(momentum=0.95)) + model.add(Conv2D(128, kernel_size=(3, 3), activation='relu')) + model.add(MaxPooling2D(pool_size=(2, 2))) + #model.add(Dropout(0.3)) + + model.add(Flatten()) + model.add(BatchNormalization(momentum=0.95)) + model.add(Dense(256, activation='relu')) + model.add(Dropout(0.3)) + + model.add(BatchNormalization(momentum=0.95)) + model.add(Dense(256, activation='relu')) + model.add(Dropout(0.3)) + model.add(Dense(n_output, activation=activation)) + + model.summary() + + return model + + + + + diff --git a/prediction/__pycache__/data_prepro.cpython-38.pyc b/prediction/__pycache__/data_prepro.cpython-38.pyc new file mode 100644 index 0000000..2fd42d1 Binary files /dev/null and b/prediction/__pycache__/data_prepro.cpython-38.pyc differ diff --git a/prediction/__pycache__/model_pred.cpython-38.pyc b/prediction/__pycache__/model_pred.cpython-38.pyc new file mode 100644 index 0000000..184b4fa Binary files /dev/null and b/prediction/__pycache__/model_pred.cpython-38.pyc differ diff --git a/prediction/data_prepro.py b/prediction/data_prepro.py new file mode 100644 index 0000000..bbf00c0 --- /dev/null +++ b/prediction/data_prepro.py @@ -0,0 +1,133 @@ +import glob +import shutil +import os +import pandas as pd +import numpy as np +import nrrd +import re +import matplotlib +import matplotlib.pyplot as plt +import pickle +from time import gmtime, strftime +from datetime import datetime +import timeit +from sklearn.model_selection import train_test_split +from tensorflow.keras.utils import to_categorical +from utils.resize_3d import resize_3d +from utils.crop_image import crop_image +from utils.respacing import respacing +from utils.nrrd_reg import nrrd_reg_rigid_ref +from get_data.get_img_dataset import img_dataset + + + +def data_prepro(body_part, out_dir, reg_temp_img, new_spacing=[1, 1, 3], + input_channel=3, norm_type='np_clip'): + + """ + data preprocrssing: respacing, registration, crop + + Arguments: + crop_shape {np.array} -- array shape for cropping image. + fixed_img_dir {str} -- dir for registered template iamge. + data_dir {str} -- data dir. + slice_range {np.array} -- slice range to extract axial slices of scans. + + Keyword arguments: + input_channel {int} -- input channel 1 or 3. + new_spacing {np.array} -- respacing size, default [1, 1, 3]. + norm_type {'str'} -- normalization methods for image, 'np_clip' or 'np_interp' + + return: + df_img {pd.df} -- dataframe with image ID and patient ID. + img_arr {np.array} -- stacked numpy array from all slices of all scans. + + + """ + + data_dir = os.path.join(out_dir, 'pred_data') + os.mkdir(data_dir) if not os.path.isdir(data_dir) else None + reg_template = os.path.join(data_dir, reg_temp_img) + + if body_part == 'head_and_neck': + crop_shape = [192, 192, 100] + slice_range = range(17, 83) + elif body_part == 'chest': + crop_shape = [192, 192, 140] + slice_range = range(50, 120) + + # registration, respacing, cropping + img_ids = [] + pat_ids = [] + slice_numbers = [] + arr = np.empty([0, 192, 192]) + for fn in sorted(glob.glob(data_dir + '/*nrrd')): + pat_id = fn.split('/')[-1].split('.')[0].strip() + print(pat_id) + ## respacing + img_nrrd = respacing( + nrrd_dir=fn, + interp_type='linear', + new_spacing=new_spacing, + patient_id=pat_id, + return_type='nrrd', + save_dir=None + ) + ## registration + img_reg = nrrd_reg_rigid_ref( + img_nrrd=img_nrrd, + fixed_img_dir=reg_template, + patient_id=pat_id, + save_dir=None + ) + ## crop image from (500, 500, 116) to (180, 180, 60) + img_crop = crop_image( + nrrd_file=img_reg, + patient_id=pat_id, + crop_shape=crop_shape, + return_type='npy', + save_dir=None + ) + + ## choose slice range to cover body part + if slice_range == None: + data = img_crop + else: + data = img_crop[slice_range, :, :] + ## clear signals lower than -1024 + data[data <= -1024] = -1024 + ## strip skull, skull UHI = ~700 + data[data > 700] = 0 + ## normalize UHI to 0 - 1, all signlas outside of [0, 1] will be 0; + if norm_type == 'np_interp': + data = np.interp(data, [-200, 200], [0, 1]) + elif norm_type == 'np_clip': + data = np.clip(data, a_min=-200, a_max=200) + MAX, MIN = data.max(), data.min() + data = (data - MIN) / (MAX - MIN) + ## stack all image arrays to one array for CNN input + arr = np.concatenate([arr, data], 0) + + ## create image ID and slice index for img + slice_numbers.append(data.shape[0]) + for i in range(data.shape[0]): + img = data[i, :, :] + img_id = pat_id + '_' + 'slice%s'%(f'{i:03d}') + img_ids.append(img_id) + pat_ids.append(pat_id) + + # generate patient and slice ID + df_img = pd.DataFrame({'pat_id': pat_ids, 'img_id': img_ids}) + #print('data size:\n', df_img) + + # covert 1 channel input to 3 channel inputs for CNN + if input_channel == 1: + img_arr = arr.reshape(arr.shape[0], arr.shape[1], arr.shape[2], 1) + #print('img_arr shape:', img_arr.shape) + #np.save(os.path.join(pro_data_dir, fn_arr_1ch), img_arr) + elif input_channel == 3: + img_arr = np.broadcast_to(arr, (3, arr.shape[0], arr.shape[1], arr.shape[2])) + img_arr = np.transpose(img_arr, (1, 2, 3, 0)) + + return df_img, img_arr + diff --git a/prediction/model_pred.py b/prediction/model_pred.py new file mode 100644 index 0000000..d11c2ba --- /dev/null +++ b/prediction/model_pred.py @@ -0,0 +1,93 @@ + +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib.pyplot as plt +import nrrd +import scipy.stats as ss +import SimpleITK as stik +import glob +from PIL import Image +from collections import Counter +import skimage.transform as st +from datetime import datetime +from time import gmtime, strftime +import pickle +import tensorflow +from tensorflow.keras.models import Model +from tensorflow.keras.models import load_model +from sklearn.metrics import classification_report +from sklearn.metrics import confusion_matrix + + + +def model_pred(body_part, df_img, img_arr, out_dir, thr_img=0.5, thr_pat=0.5): + + """ + model prediction for IV contrast + + Arguments: + df_img {pd.df} -- dataframe with scan and axial slice ID. + img_arr {np.array} -- numpy array stacked with axial image slices. + model_dir {str} -- directory for saved model. + saved_model {str} -- saved model name. + pred_dir {str} -- directory for results output. + + Keyword arguments: + thr_img {float} -- threshold to determine prediction class on image level. + thr_pat {float} -- threshold to determine prediction class on patient level. + + return: + dataframes of model predictions on image level and patient level + """ + + model_dir = os.path.join(out_dir, 'model') + pred_dir = os.path.join(out_dir, 'pred') + os.mkdir(model_dir) if not os.path.isdir(model_dir) else None + os.mkdir(pred_dir) if not os.path.isdir(pred_dir) else None + + if body_part == 'head_and_neck': + saved_model = 'EffNet_2021_08_24_09_57_13' + elif body_part == 'chest': + saved_model = 'Tuned_EfficientNetB4_2021_08_27_20_26_55' + + ## load saved model + print(saved_model) + model = load_model(os.path.join(model_dir, saved_model)) + ## prediction + y_pred = model.predict(img_arr, batch_size=32) + y_pred_class = [1 * (x[0] >= thr_img) for x in y_pred] + #print(y_pred) + #print(y_pred_class) + + ## save a dataframe for prediction on image level + df_img['y_pred'] = np.around(y_pred, 3) + df_img['y_pred_class'] = y_pred_class + df_img_pred = df_img[['pat_id', 'img_id', 'y_pred', 'y_pred_class']] + fn = 'df_img_pred' + '_' + str(saved_model) + '.csv' + df_img_pred.to_csv(os.path.join(out_dir, fn), index=False) + + ## calcualte patient level prediction + df_img_pred.drop(['img_id'], axis=1, inplace=True) + df_mean = df_img_pred.groupby(['pat_id']).mean() + preds = df_mean['y_pred'] + y_pred = [] + for pred in preds: + if pred > thr_pat: + pred = 1 + else: + pred = 0 + y_pred.append(pred) + df_mean['predictions'] = y_pred + df_mean.drop(['y_pred', 'y_pred_class'], axis=1, inplace=True) + df_pat_pred = df_mean + fn = 'df_pat_pred' + '_' + str(saved_model) + '.csv' + df_pat_pred.to_csv(os.path.join(out_dir, fn)) + print('image level pred:\n', df_img_pred) + print('patient level pred:\n', df_pat_pred) + + + + + diff --git a/run_step1_data.py b/run_step1_data.py new file mode 100644 index 0000000..a2c74b2 --- /dev/null +++ b/run_step1_data.py @@ -0,0 +1,49 @@ +import os +import numpy as np +import pandas as pd +import matplotlib +import matplotlib.pyplot as plt +import glob +from time import gmtime, strftime +from datetime import datetime +import timeit +import yaml +import argparse +from get_data.get_img_dataset import get_img_dataset +from get_data.get_pat_dataset import get_pat_dataset +from get_data.preprocess_data import preprocess_data + + + + +if __name__ == '__main__': + + proj_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project' + data_dir = '/media/bhkann/HN_RES1/HN_CONTRAST' + out_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection' + data_exclude = None + + print('\n--- STEP 1 - GET DATA ---\n') + + preprocess_data( + data_dir=data_dir, + out_dir=out_dir, + new_spacing=(1, 1, 3), + data_exclude=None, + crop_shape=[192, 192, 100] + ) + + data_tot, label_tot, ID_tot = get_pat_dataset( + data_dir=data_dir, + out_dir=out_dir, + proj_dir=proj_dir, + ) + + get_img_dataset( + proj_dir=proj_dir, + run_type=run_type, + data_tot=data_tot, + ID_tot=ID_tot, + label_tot=label_tot, + slice_range=range(17, 83), + ) diff --git a/run_step2_train.py b/run_step2_train.py new file mode 100644 index 0000000..c954a4e --- /dev/null +++ b/run_step2_train.py @@ -0,0 +1,76 @@ + +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib +import matplotlib.pyplot as plt +import glob +from time import gmtime, strftime +from datetime import datetime +import timeit +import yaml +import argparse +import pydot +import pydotplus +import graphviz +#from pydotplus import graphviz +from tensorflow.keras.optimizers import Adam +from tensorflow.keras.optimizers import SGD +from tensorflow.keras.losses import BinaryCrossentropy +from tensorflow.keras.utils import plot_model +from get_data.data_gen_flow import train_generator +from get_data.data_gen_flow import val_generator +from go_model.get_model import get_model +from go_model.train_model import train_model +from go_model.train_model import callbacks + + +if __name__ == '__main__': + + out_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection' + proj_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project' + batch_size = 32 + lr = 1e-5 + epoch = 1 + activation = 'sigmoid' # 'softmax' + loss_func = BinaryCrossentropy(from_logits=True) + opt = Adam(learning_rate=lr) + run_model = 'EffNetB4' + + # data generator for train and val data + train_gen = train_generator( + proj_dir=proj_dir, + batch_size=batch_size, + ) + + x_val, y_val, val_gen = val_generator( + proj_dir=proj_dir, + batch_size=batch_size, + ) + + my_model = get_model( + out_dir=out_dir, + run_model=run_model, + activation=activation, + input_shape=(192, 192, 3), + freeze_layer=None, + transfer=False + ) + + ### train model + train_model( + out_dir=out_dir, + model=my_model, + run_model=run_model, + train_gen=train_gen, + val_gen=val_gen, + x_val=x_val, + y_val=y_val, + batch_size=batch_size, + epoch=epoch, + opt=opt, + loss_func=loss_func, + lr=lr + ) + diff --git a/run_step3_test.py b/run_step3_test.py new file mode 100644 index 0000000..5ff59a6 --- /dev/null +++ b/run_step3_test.py @@ -0,0 +1,64 @@ + + +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib +import matplotlib.pyplot as plt +import glob +from time import gmtime, strftime +from datetime import datetime +import timeit +import yaml +import argparse +from tensorflow.keras.optimizers import Adam +from go_model.evaluate_model import evaluate_model +from utils.write_txt import write_txt +from utils.get_stats_plots import get_stats_plots + + + + +if __name__ == '__main__': + + + proj_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project' + out_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection' + saved_model = 'EffNet_2021_08_21_22_41_34' + # 'ResNet_2021_07_18_06_28_40', 'cnn_2021_07_19_21_56_34', 'inception_2021_08_21_15_16_12' + batch_size = 32 + epoch = 500 + lr = 1e-5 + run_model = 'EffNet' + + + print('\n--- STEP 3 - MODEL EVALUATION ---\n') + + for run_type in ['val', 'test']: + # evalute model + loss, acc = evaluate_model( + run_type=run_type, + out_dir=out_dir, + proj_dir=proj_dir, + saved_model=saved_model, + threshold=0.5, + activation='sigmoid' + ) + # get statistic and plots + get_stats_plots( + out_dir=out_dir, + proj_dir=proj_dir, + run_type=run_type, + run_model=run_model, + loss=None, + acc=None, + saved_model=saved_model, + epoch=epoch, + batch_size=batch_size, + lr=lr, + thr_img=0.5, + thr_prob=0.5, + thr_pos=0.5, + bootstrap=1000, + ) diff --git a/run_step4_exval.py b/run_step4_exval.py new file mode 100644 index 0000000..a818403 --- /dev/null +++ b/run_step4_exval.py @@ -0,0 +1,88 @@ + +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib +import matplotlib.pyplot as plt +import glob +from time import gmtime, strftime +from datetime import datetime +import timeit +import yaml +import argparse +from tensorflow.keras.optimizers import Adam +from go_model.finetune_model import finetune_model +from utils.write_txt import write_txt +from utils.get_stats_plots import get_stats_plots +from go_model.evaluate_model import evaluate_model +from get_data.exval_dataset import exval_pat_dataset +from get_data.exval_dataset import exval_img_dataset + + + +if __name__ == '__main__': + + + out_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection' + proj_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project' + batch_size = 32 + lr = 1e-5 + activation = 'sigmoid' + freeze_layer = None + epoch = 1 + get_data = False + fine_tune = False + HN_modiel = 'EffNet_2021_08_21_22_41_34' + saved_model = 'Tuned_EfficientNetB4_2021_08_27_20_26_55' + run_model = 'EffNetB4' + + print('\n--- STEP 4 - MODEL EX VAL ---\n') + + # get chest CT data anddo preprocessing + if get_data == True: + exval_pat_dataset( + out_dir=out_dir, + proj_dir=proj_dir, + crop_shape=[192, 192, 140], + ) + + exval_img_dataset( + proj_dir=proj_dir, + slice_range=range(50, 120), + ) + + ## fine tune models by freezing some layers + if fine_tune == True: + tuned_model, model_fn = finetune_model( + out_dir=out_dir, + proj_dir=proj_dir, + HN_model=HN_model, + batch_size=batch_size, + epoch=epoch, + freeze_layer=freeze_layer, + ) + + ## evaluate finetuned model on chest CT data + for run_type in ['exval1', 'exval2']: + loss, acc = evaluate_model( + run_type=run_type, + out_dir=out_dir, + proj_dir=proj_dir, + saved_model=saved_model, + ) + get_stats_plots( + out_dir=out_dir, + proj_dir=proj_dir, + run_type=run_type, + run_model=run_model, + loss=loss, + acc=acc, + saved_model=saved_model, + epoch=epoch, + batch_size=batch_size, + lr=lr + ) + + + diff --git a/run_step5_pred.py b/run_step5_pred.py new file mode 100644 index 0000000..cd3db34 --- /dev/null +++ b/run_step5_pred.py @@ -0,0 +1,39 @@ + +import os +import numpy as np +import pandas as pd +import glob +import yaml +import argparse +from time import gmtime, strftime +from datetime import datetime +import timeit +from prediction.data_prepro import data_prepro +from prediction.model_pred import model_pred + + + + +if __name__ == '__main__': + + body_part = 'head_and_neck' + out_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection' + reg_temp_img = 'PMH049.nrrd' + + print('\n--- STEP 5 - MODEL PREDICTION ---\n') + + # data preprocessing + df_img, img_arr = data_prepro( + body_part=body_part, + out_dir=out_dir, + reg_temp_img=reg_temp_img + ) + + # model prediction + model_pred( + body_part=body_part, + df_img=df_img, + img_arr=img_arr, + out_dir=out_dir, + ) + diff --git a/utils/.crop_roi.py.swo b/utils/.crop_roi.py.swo new file mode 100644 index 0000000..6932cd8 Binary files /dev/null and b/utils/.crop_roi.py.swo differ diff --git a/utils/.crop_roi.py.swp b/utils/.crop_roi.py.swp new file mode 100644 index 0000000..c743cfb Binary files /dev/null and b/utils/.crop_roi.py.swp differ diff --git a/utils/HN_patient_meta.py b/utils/HN_patient_meta.py new file mode 100644 index 0000000..776fbf7 --- /dev/null +++ b/utils/HN_patient_meta.py @@ -0,0 +1,446 @@ +import os +import numpy as np +import pandas as pd +from sklearn.metrics import confusion_matrix +from sklearn.metrics import accuracy_score, classification_report +from plot_cm import plot_cm + + +def pat_meta_info(pro_data_dir): + + """ + + Get patient and scan metadata for chest CT + + @params: + data_sitk - required : SimpleITK image, resulting from sitk.ImageFileReader().Execute() + new_spacing - required : desired spacing (equal for all the axes), in mm, of the output data + method - required : SimpleITK interpolation method (e.g., sitk.sitkLinear) + + FIXME: change this into something like downsample_sitk (also, data_sitk to img_sitk for homog.) + (as this function is especially used for downsampling, right?) + + """ + + + #----------------------------------------- + # CT scan artifacts + #----------------------------------------- + df = pd.read_csv(os.path.join(pro_data_dir, 'ContrastAnnotation_HN.csv')) + df.drop_duplicates(subset=['Patient ID'], keep='last', inplace=True) + df.dropna(subset=['Artifact-OP'], inplace=True) + print('annotation data with no duplicates:', df.shape[0]) + + train_data = pd.read_csv(os.path.join(pro_data_dir, 'train_pat_df.csv')) + val_data = pd.read_csv(os.path.join(pro_data_dir, 'val_pat_df.csv')) + test_data = pd.read_csv(os.path.join(pro_data_dir, 'test_pat_df.csv')) + tot_data = pd.concat([train_data, val_data, test_data]) + datas = [tot_data, train_data, val_data, test_data] + names = ['tot', 'train', 'val', 'test'] + + ## artifacts for train, val, test and tot dataset + for name, data in zip(names, datas): + pat_ids = [] + artifacts = [] + for patientid, artifact, note in zip(df['Patient ID'], df['Artifact-OP'], df['Notes']): + ## use consistent ID + if patientid[3:7] == 'CHUM': + pat_id = 'CHUM' + patientid[-3:] + elif patientid[3:7] == 'CHUS': + pat_id = 'CHUS' + patientid[-3:] + elif patientid[:3] == 'OPC': + pat_id = 'PMH' + patientid[-3:] + elif patientid[:5] == 'HNSCC': + pat_id = 'MDACC' + patientid[-3:] + ## find very severe artifacts + if note == 'really bad artifact': + artifact = 'very bad' + else: + artifact = artifact + ## append artifacts + if pat_id in data['ID'].to_list(): + pat_ids.append(pat_id) + artifacts.append(artifact) + + df_af = pd.DataFrame({'ID': pat_ids, 'Artifact-OP': artifacts}) + print('----------------------------') + print(name) + print('----------------------------') + print('data size:', df_af.shape[0]) + print('data with artifact:', df_af.loc[df_af['Artifact-OP'].isin(['Bad', 'Yes', 'Minimal'])].shape[0]) + print(df_af['Artifact-OP'].value_counts()) + print(df_af['Artifact-OP'].value_counts(normalize=True).round(3)) + + #----------------------------------------- + # clean and group metadata + #----------------------------------------- + df = pd.read_csv(os.path.join(pro_data_dir, 'clinical_meta_data.csv')) + print('\nmeta data size:', df.shape[0]) + df.drop_duplicates(subset=['patientid'], keep='last', inplace=True) + print('meta data with no duplicates:', df.shape[0]) + + ## combine HPV info from tow cols + hpvs = [] + df['hpv'] = df.iloc[:, 8].astype(str) + df.iloc[:, 9].astype(str) + for hpv in df['hpv']: + if hpv in ['nannan', 'Unknownnan', 'Nnan', 'Not testednan', 'no tissuenan']: + hpv = 'unknown' + elif hpv in [' positivenan', 'Pnan', '+nan', 'nanpositive', 'Positivenan', + 'Positive -Strongnan', 'Positive -focalnan']: + hpv = 'positive' + elif hpv in [' Negativenan', 'Negativenan', '-nan', 'nannegative']: + hpv = 'negative' + hpvs.append(hpv) + df['hpv'] = hpvs + + ## overall stage + stages = [] + for stage in df['ajccstage']: + if stage in ['I', 'Stade I']: + stage = 'I' + elif stage in ['II', 'Stade II', 'StageII']: + stage = 'II' + elif stage in ['III', 'Stade III', 'Stage III']: + stage = 'III' + elif stage in ['IVA', 'IV', 'IVB', 'Stade IVA', 'Stage IV', 'Stade IVB']: + stage = 'IV' + stages.append(stage) + df['ajccstage'] = stages + + ## primary cancer sites + sites = [] + for site in df['diseasesite']: + if site in ['Oropharynx']: + site = site + elif site in ['Larynx', 'Hypopharynx', 'Nasopharynx']: + site = 'Larynx/Hypopharynx/Nasopharynx' + elif site in ['Oral cavity']: + site = site + else: + site = 'Unknown/Other' + sites.append(site) + df['diseasesite'] = sites + + ## sex + df['gender'].replace(['F'], 'Female', inplace=True) + df['gender'].replace(['M'], 'Male', inplace=True) + + #----------------------------------------- + # patient meta data + #----------------------------------------- + ## actual patient data with images + train_data = pd.read_csv(os.path.join(pro_data_dir, 'train_pat_df.csv')) + val_data = pd.read_csv(os.path.join(pro_data_dir, 'val_pat_df.csv')) + test_data = pd.read_csv(os.path.join(pro_data_dir, 'test_pat_df.csv')) + print('train data:', train_data.shape[0]) + print('val data:', val_data.shape[0]) + print('test data:', test_data.shape[0]) + + ## print contrast info in train, val, test sets + datas = [train_data, val_data, test_data] + names = ['train', 'val', 'test'] + for data, name in zip(datas, names): + print('\n') + print('----------------------------') + print(name) + print('----------------------------') + print(data['label'].value_counts()) + print(data['label'].value_counts(normalize=True).round(3)) + + ## find patient metadata + datas = [train_data, val_data, test_data] + metas = [] + for data in datas: + ids = [] + genders = [] + ages = [] + tcats = [] + stages = [] + sites = [] + ncats = [] + hpvs = [] + ## find meta info + for patientid, gender, age, t_cat, ajccstage, site, n_cat, hpv in zip( + df['patientid'], df['gender'], df['ageatdiag'], df['t-category'], + df['ajccstage'], df['diseasesite'], df['n-category'], df['hpv']): + ## 4 datasets + if patientid[3:7] == 'CHUM': + pat_id = 'CHUM' + patientid[-3:] + elif patientid[3:7] == 'CHUS': + pat_id = 'CHUS' + patientid[-3:] + elif patientid[:3] == 'OPC': + pat_id = 'PMH' + patientid[-3:] + elif patientid[:5] == 'HNSCC': + pat_id = 'MDACC' + patientid[-3:] + if pat_id in data['ID'].to_list(): + #print(pat_id) + ids.append(patientid) + genders.append(gender) + ages.append(age) + tcats.append(t_cat) + stages.append(ajccstage) + sites.append(site) + ncats.append(n_cat) + hpvs.append(hpv) + ## create new df for train, val, test meta info + meta = pd.DataFrame( + {'id': ids, + 'gender': genders, + 'age': ages, + 't_stage': tcats, + 'stage': stages, + 'site': sites, + 'n_stage': ncats, + 'hpv': hpvs} + ) + metas.append(meta) + ## concat 3 datasets to 1 big dataset + all_meta = pd.concat([metas[0], metas[1], metas[2]]) + metas.append(all_meta) + ## print meta info + datasets = ['train', 'val', 'test', 'all'] + for df, dataset in zip(metas, datasets): + print('\n') + print('----------------------------') + print(dataset) + print('----------------------------') + print('patient number:', df.shape[0]) + print('\n') + print(df['gender'].value_counts()) + print(df['gender'].value_counts(normalize=True).round(3)) + print('\n') + print(df['t_stage'].value_counts()) + print(df['t_stage'].value_counts(normalize=True).round(3)) + print('\n') + print(df['stage'].value_counts()) + print(df['stage'].value_counts(normalize=True).round(3)) + print('\n') + print(df['site'].value_counts()) + print(df['site'].value_counts(normalize=True).round(3)) + print('\n') + print(df['n_stage'].value_counts()) + print(df['n_stage'].value_counts(normalize=True).round(3)) + print('\n') + print(df['hpv'].value_counts()) + print(df['hpv'].value_counts(normalize=True).round(3)) + print('\n') + print('mediam age:', df['age'].median()) + print('age max:', df['age'].max()) + print('age min:', df['age'].min()) + print('---------------------------------------------') + + #------------------------------------------------------------ + # CT meata data + #------------------------------------------------------------ + df = pd.read_csv(os.path.join(pro_data_dir, 'clinical_meta_data.csv')) + df.drop_duplicates(subset=['patientid'], keep='last', inplace=True) + print(df.shape[0]) + print(all_meta.shape[0]) + df0 = df[~df['patientid'].isin(all_meta['id'].to_list())] + df = df[~df['patientid'].isin(df0['patientid'].to_list())] + print('patient not in list:', df.shape[0]) + + ## combine CT scanner and model names + IDs = [] + for manufacturer, model in zip(df['manufacturer'], df['manufacturermodelname']): + ID = str(manufacturer) + ' ' + str(model) + IDs.append(ID) + df['ID'] = IDs + #print(df['manufacturer'].value_counts()) + print('-------------------') + print('CT scanner') + print('-------------------') + #print(df['manufacturermodelname'].value_counts()) + #print(df['manufacturermodelname'].value_counts(normalize=True).round(3)) + print(df['ID'].value_counts()) + print(df['ID'].value_counts(normalize=True).round(3)) + print(df.shape[0]) + + ## KVP + print('\n') + print('-------------------') + print('KVP') + print('-------------------') + print('kvp mean:', df['kvp'].mean().round(3)) + print('kvp median:', df['kvp'].median()) + print('kvp mode:', df['kvp'].mode()) + print('kvp std:', df['kvp'].std().round(3)) + print('kvp min:', df['kvp'].min()) + print('kvp max:', df['kvp'].max()) + + ## slice thickness + print('\n') + print('-------------------') + print('slice thickness') + print('-------------------') + print('thk mean:', df['slicethickness'].mean().round(3)) + print('thk median:', df['slicethickness'].median()) + print('thk mode:', df['slicethickness'].mode()) + print('thk std:', df['slicethickness'].std().round(3)) + print('thk min:', df['slicethickness'].min()) + print('thk max:', df['slicethickness'].max()) + print(df['slicethickness'].value_counts()) + print(df['slicethickness'].shape[0]) + + ## spatial resolution + print('\n') + print(df['rows'].value_counts()) + + ## pixel spacing + pixels = [] + for pixel in df['pixelspacing']: + pixel = pixel.split("'")[1] + pixel = float(pixel) + pixels.append(pixel) + df['pixel'] = pixels + df['pixel'].round(3) + print('\n') + print('-------------------') + print('pixel size') + print('-------------------') + print('pixel mean:', df['pixel'].mean().round(3)) + print('pixel median:', df['pixel'].median().round(3)) + print('pixel mode:', df['pixel'].mode().round(3)) + print('pixel std:', df['pixel'].std().round(3)) + print('pixel min:', df['pixel'].min().round(3)) + print('pixel max:', df['pixel'].max().round(3)) + + data = pd.concat([train_data, val_data, test_data]) + + #----------------------------------------------------------------- + # contrast information from mata data + #---------------------------------------------------------------- + df = pd.read_csv(os.path.join(pro_data_dir, 'clinical_meta_data.csv')) + print('\n') + print('-----------------------------------') + print('contrast information from meta dta') + print('-----------------------------------') + print(df['contrastbolusagent'].value_counts()) + print(df['contrastbolusagent'].value_counts(normalize=True).round(3)) + list_contrast = set(df['contrastbolusagent'].to_list()) + print('contrast agents bolus number:', len(list_contrast)) + print(list_contrast) + df['contrastbolusagent'] = df['contrastbolusagent'].fillna(2) + + pat_ids = [] + contrasts = [] + for patientid, contrast in zip(df['patientid'], df['contrastbolusagent']): + if patientid[3:7] == 'CHUM': + pat_id = 'CHUM' + patientid[-3:] + elif patientid[3:7] == 'CHUS': + pat_id = 'CHUS' + patientid[-3:] + elif patientid[:3] == 'OPC': + pat_id = 'PMH' + patientid[-3:] + elif patientid[:5] == 'HNSCC': + pat_id = 'MDACC' + patientid[-3:] + if pat_id in data['ID'].to_list(): + pat_ids.append(pat_id) + ## change contrast annotation in meta data + if contrast in ['N', 'n', 'NO']: + contrast = 0 + elif contrast == 2: + contrast = contrast + else: + contrast = 1 + contrasts.append(contrast) + df = pd.DataFrame({'ID': pat_ids, 'contrast': contrasts}) + + ## match metadata annotations with clinical expert + ids = [] + contrasts = [] + labels = [] + for ID, label in zip(data['ID'], data['label']): + for pat_id, contrast in zip(df['ID'], df['contrast']): + if pat_id == ID and contrast != 2 and contrast != label: + ids.append(pat_id) + contrasts.append(contrast) + labels.append(label) + print('\n') + print('-----------------------------------') + print('contrast information from meta dta') + print('-----------------------------------') + print('mismatch ID:', ids) + print('mismatch label:', labels) + print('mismatch label:', contrasts) + print('mismatch number:', len(contrasts)) + print('total patient:', df['contrast'].shape[0]) + print(df['contrast'].value_counts()) + print(df['contrast'].value_counts(normalize=True).round(3)) + + ## print contrast info in train, val, test sets + datas = [train_data, val_data, test_data] + names = ['train', 'val', 'test'] + conss = [] + for data, name in zip(datas, names): + cons = [] + IDs = [] + labels = [] + for ID, label in zip(data['ID'], data['label']): + for pat_id, con in zip(df['ID'], df['contrast']): + if pat_id == ID: + cons.append(con) + labels.append(label) + IDs.append(pat_id) + df_con = pd.DataFrame({'ID': IDs, 'label': labels, 'contrast': cons}) + conss.append(df_con) + names = ['train', 'val', 'test'] + for name, con in zip(names, conss): + print('\n') + print('----------------------------') + print(name) + print('----------------------------') + print(con['contrast'].value_counts()) + print(con['contrast'].value_counts(normalize=True).round(3)) + #print(con['label']) + + #-------------------------------------------------------------------- + # calculate confusion matrix, accuracy and AUC for contrast metadata + #-------------------------------------------------------------------- + for name, con in zip(['val', 'test'], [conss[1], conss[2]]): + contrasts = [] + for contrast, label in zip(con['contrast'], con['label']): + if contrast == 2 and label == 0: + contrast = 1 + elif contrast == 2 and label == 1: + contrast = 0 + else: + contrast = contrast + contrasts.append(contrast) + con['contrast'] = contrasts + cm = confusion_matrix(con['label'], con['contrast']) + cm_norm = cm.astype('float')/cm.sum(axis=1)[:, np.newaxis] + cm_norm = np.around(cm_norm, 2) + print('\n') + print(name) + print(cm_norm) + print(cm) + FP = cm.sum(axis=0) - np.diag(cm) + FN = cm.sum(axis=1) - np.diag(cm) + TP = np.diag(cm) + TN = cm.sum() - (FP + FN + TP) + ACC = (TP + TN)/(TP + FP + FN + TN) + TPR = TP/(TP + FN) + TNR = TN/(TN + FP) + AUC = (TPR + TNR)/2 + report = classification_report(con['label'], con['contrast']) + print('AUC:', np.around(AUC[1], 3)) + print('ACC:', np.around(ACC[1], 3)) + print('report:', report) + + # plot confusion matrix + save_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/metadata' + for cm0, cm_type in zip([cm, cm_norm], ['raw', 'norm']): + plot_cm( + cm0=cm0, + cm_type=cm_type, + level=name, + save_dir=save_dir + ) + +if __name__ == '__main__': + + pro_data_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project/pro_data' + + pat_meta_info(pro_data_dir) + + diff --git a/utils/] b/utils/] new file mode 100644 index 0000000..592ace3 --- /dev/null +++ b/utils/] @@ -0,0 +1,147 @@ +import os +import numpy as np +import pandas as pd +import pickle +from time import gmtime, strftime +from datetime import datetime +import timeit +from utils.cm_all import cm_all +from utils.roc_all import roc_all +from utils.prc_all import prc_all +from utils.acc_loss import acc_loss +from utils.write_txt import write_txt + + + +def get_stats_plots(out_dir, proj_dir, run_type, run_model, loss, acc, + saved_model, epoch, batch_size, lr, thr_img=0.5, + thr_prob=0.5, thr_pos=0.5, bootstrap=1000): + + """ + generate model val/test statistics and plot curves; + + Args: + loss {float} -- validation loss; + acc {float} -- validation accuracy; + run_model {str} -- cnn model name; + batch_size {int} -- batch size for data loading; + epoch {int} -- training epoch; + out_dir {path} -- path for output files; + opt {str or function} -- optimized function: 'adam'; + lr {float} -- learning rate; + + Keyword args: + bootstrap {int} -- number of bootstrap to calculate 95% CI for AUC; + thr_img {float} -- threshold to determine positive class on image level; + thr_prob {float} -- threshold to determine positive class on patient + level (mean prob score); + thr_pos {float} -- threshold to determine positive class on patient + level (positive class percentage); + Returns: + Model prediction statistics and plots: ROC, PRC, confusion matrix, etc. + + """ + + pro_data_dir = os.path.join(proj_dir, 'pro_data') + train_dir = os.path.join(out_dir, 'train') + val_dir = os.path.join(out_dir, 'val') + test_dir = os.path.join(out_dir, 'test') + exval1_dir = os.path.join(out_dir, 'exval1') + exval2_dir = os.path.join(out_dir, 'exval2') + + if not os.path.exist(train_dir): + os.mkdir(train_dir) + if not os.path.exist(val_dir): + os.mkdir(val_dir) + if not os.path.exist(test_dir): + os.mkdir(test_dir) + if not os.path.exist(exval1_dir): + os.mkdir(exval1_dir) + if not os.path.exist(exval2_dir): + os.mkdir(exval2_dir) + + ### determine if this is train or test + if run_type == 'val': + fn_df_pred = 'val_img_pred.csv' + save_dir = val_dir + elif run_type == 'test': + fn_df_pred = 'test_img_pred.csv' + save_dir = test_dir + elif run_type == 'exval1': + fn_df_pred = 'exval1_img_pred.csv' + save_dir = exval1_dir + elif run_type == 'exval2': + fn_df_pred = 'exval2_img_pred.csv' + save_dir = exval2_dir + + cms = [] + cm_norms = [] + reports = [] + roc_stats = [] + prc_aucs = [] + levels = ['img', 'patient_mean_prob', 'patient_mean_pos'] + + for level in levels: + + ## confusion matrix + cm, cm_norm, report = cm_all( + run_type=run_type, + level=level, + thr_img=thr_img, + thr_prob=thr_prob, + thr_pos=thr_pos, + pro_data_dir=pro_data_dir, + save_dir=save_dir, + fn_df_pred=fn_df_pred + ) + cms.append(cm) + cm_norms.append(cm_norm) + reports.append(report) + + ## ROC curves + roc_stat = roc_all( + run_type=run_type, + level=level, + thr_prob=thr_prob, + thr_pos=thr_pos, + bootstrap=bootstrap, + color='blue', + pro_data_dir=pro_data_dir, + save_dir=save_dir, + fn_df_pred=fn_df_pred + ) + roc_stats.append(roc_stat) + + ## PRC curves + prc_auc = prc_all( + run_type=run_type, + level=level, + thr_prob=thr_prob, + thr_pos=thr_pos, + color='red', + pro_data_dir=pro_data_dir, + save_dir=save_dir, + fn_df_pred=fn_df_pred + ) + prc_aucs.append(prc_auc) + + ### save validation results to txt + write_txt( + run_type=run_type, + out_dir=out_dir, + loss=loss, + acc=acc, + cms=cms, + cm_norms=cm_norms, + reports=reports, + prc_aucs=prc_aucs, + roc_stats=roc_stats, + run_model=run_model, + saved_model=saved_model, + epoch=epoch, + batch_size=batch_size, + lr=lr + ) + + print('saved model as:', saved_model) + diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/__pycache__/__init__.cpython-38.pyc b/utils/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..d4bb908 Binary files /dev/null and b/utils/__pycache__/__init__.cpython-38.pyc differ diff --git a/utils/__pycache__/acc_loss.cpython-38.pyc b/utils/__pycache__/acc_loss.cpython-38.pyc new file mode 100644 index 0000000..d184801 Binary files /dev/null and b/utils/__pycache__/acc_loss.cpython-38.pyc differ diff --git a/utils/__pycache__/cm_all.cpython-38.pyc b/utils/__pycache__/cm_all.cpython-38.pyc new file mode 100644 index 0000000..851c266 Binary files /dev/null and b/utils/__pycache__/cm_all.cpython-38.pyc differ diff --git a/utils/__pycache__/cm_patient_pos.cpython-38.pyc b/utils/__pycache__/cm_patient_pos.cpython-38.pyc new file mode 100644 index 0000000..21892db Binary files /dev/null and b/utils/__pycache__/cm_patient_pos.cpython-38.pyc differ diff --git a/utils/__pycache__/cm_patient_prob.cpython-38.pyc b/utils/__pycache__/cm_patient_prob.cpython-38.pyc new file mode 100644 index 0000000..77ad0a4 Binary files /dev/null and b/utils/__pycache__/cm_patient_prob.cpython-38.pyc differ diff --git a/utils/__pycache__/crop_image.cpython-38.pyc b/utils/__pycache__/crop_image.cpython-38.pyc new file mode 100644 index 0000000..8f84dff Binary files /dev/null and b/utils/__pycache__/crop_image.cpython-38.pyc differ diff --git a/utils/__pycache__/get_stats_plots.cpython-38.pyc b/utils/__pycache__/get_stats_plots.cpython-38.pyc new file mode 100644 index 0000000..281473c Binary files /dev/null and b/utils/__pycache__/get_stats_plots.cpython-38.pyc differ diff --git a/utils/__pycache__/make_plots.cpython-38.pyc b/utils/__pycache__/make_plots.cpython-38.pyc new file mode 100644 index 0000000..f8f0924 Binary files /dev/null and b/utils/__pycache__/make_plots.cpython-38.pyc differ diff --git a/utils/__pycache__/mean_CI.cpython-38.pyc b/utils/__pycache__/mean_CI.cpython-38.pyc new file mode 100644 index 0000000..5b623d0 Binary files /dev/null and b/utils/__pycache__/mean_CI.cpython-38.pyc differ diff --git a/utils/__pycache__/nrrd_reg.cpython-38.pyc b/utils/__pycache__/nrrd_reg.cpython-38.pyc new file mode 100644 index 0000000..81bebb4 Binary files /dev/null and b/utils/__pycache__/nrrd_reg.cpython-38.pyc differ diff --git a/utils/__pycache__/plot_cm.cpython-38.pyc b/utils/__pycache__/plot_cm.cpython-38.pyc new file mode 100644 index 0000000..d16ca0e Binary files /dev/null and b/utils/__pycache__/plot_cm.cpython-38.pyc differ diff --git a/utils/__pycache__/plot_prc.cpython-38.pyc b/utils/__pycache__/plot_prc.cpython-38.pyc new file mode 100644 index 0000000..cc489ca Binary files /dev/null and b/utils/__pycache__/plot_prc.cpython-38.pyc differ diff --git a/utils/__pycache__/plot_roc.cpython-38.pyc b/utils/__pycache__/plot_roc.cpython-38.pyc new file mode 100644 index 0000000..18c9fef Binary files /dev/null and b/utils/__pycache__/plot_roc.cpython-38.pyc differ diff --git a/utils/__pycache__/plot_train_curve.cpython-38.pyc b/utils/__pycache__/plot_train_curve.cpython-38.pyc new file mode 100644 index 0000000..b2e29f1 Binary files /dev/null and b/utils/__pycache__/plot_train_curve.cpython-38.pyc differ diff --git a/utils/__pycache__/prc_all.cpython-38.pyc b/utils/__pycache__/prc_all.cpython-38.pyc new file mode 100644 index 0000000..3ecd4ce Binary files /dev/null and b/utils/__pycache__/prc_all.cpython-38.pyc differ diff --git a/utils/__pycache__/prc_img.cpython-38.pyc b/utils/__pycache__/prc_img.cpython-38.pyc new file mode 100644 index 0000000..b633c45 Binary files /dev/null and b/utils/__pycache__/prc_img.cpython-38.pyc differ diff --git a/utils/__pycache__/prc_patient_mean_prob.cpython-38.pyc b/utils/__pycache__/prc_patient_mean_prob.cpython-38.pyc new file mode 100644 index 0000000..d7510ab Binary files /dev/null and b/utils/__pycache__/prc_patient_mean_prob.cpython-38.pyc differ diff --git a/utils/__pycache__/resize_3d.cpython-38.pyc b/utils/__pycache__/resize_3d.cpython-38.pyc new file mode 100644 index 0000000..59e772d Binary files /dev/null and b/utils/__pycache__/resize_3d.cpython-38.pyc differ diff --git a/utils/__pycache__/respacing.cpython-38.pyc b/utils/__pycache__/respacing.cpython-38.pyc new file mode 100644 index 0000000..e84851c Binary files /dev/null and b/utils/__pycache__/respacing.cpython-38.pyc differ diff --git a/utils/__pycache__/roc_all.cpython-38.pyc b/utils/__pycache__/roc_all.cpython-38.pyc new file mode 100644 index 0000000..3a9af76 Binary files /dev/null and b/utils/__pycache__/roc_all.cpython-38.pyc differ diff --git a/utils/__pycache__/roc_bootstrap.cpython-38.pyc b/utils/__pycache__/roc_bootstrap.cpython-38.pyc new file mode 100644 index 0000000..f8e92ab Binary files /dev/null and b/utils/__pycache__/roc_bootstrap.cpython-38.pyc differ diff --git a/utils/__pycache__/roc_img.cpython-38.pyc b/utils/__pycache__/roc_img.cpython-38.pyc new file mode 100644 index 0000000..9ccadba Binary files /dev/null and b/utils/__pycache__/roc_img.cpython-38.pyc differ diff --git a/utils/__pycache__/roc_patient_mean_prob.cpython-38.pyc b/utils/__pycache__/roc_patient_mean_prob.cpython-38.pyc new file mode 100644 index 0000000..ea3dfd8 Binary files /dev/null and b/utils/__pycache__/roc_patient_mean_prob.cpython-38.pyc differ diff --git a/utils/__pycache__/roc_patient_median_prob.cpython-38.pyc b/utils/__pycache__/roc_patient_median_prob.cpython-38.pyc new file mode 100644 index 0000000..0c5b498 Binary files /dev/null and b/utils/__pycache__/roc_patient_median_prob.cpython-38.pyc differ diff --git a/utils/__pycache__/roc_patient_pos_rate.cpython-38.pyc b/utils/__pycache__/roc_patient_pos_rate.cpython-38.pyc new file mode 100644 index 0000000..295bbfb Binary files /dev/null and b/utils/__pycache__/roc_patient_pos_rate.cpython-38.pyc differ diff --git a/utils/__pycache__/roc_patient_thr.cpython-38.pyc b/utils/__pycache__/roc_patient_thr.cpython-38.pyc new file mode 100644 index 0000000..c9b9e6b Binary files /dev/null and b/utils/__pycache__/roc_patient_thr.cpython-38.pyc differ diff --git a/utils/__pycache__/tensorboard.cpython-38.pyc b/utils/__pycache__/tensorboard.cpython-38.pyc new file mode 100644 index 0000000..c315363 Binary files /dev/null and b/utils/__pycache__/tensorboard.cpython-38.pyc differ diff --git a/utils/__pycache__/write_txt.cpython-38.pyc b/utils/__pycache__/write_txt.cpython-38.pyc new file mode 100644 index 0000000..77b6509 Binary files /dev/null and b/utils/__pycache__/write_txt.cpython-38.pyc differ diff --git a/utils/acc_loss.py b/utils/acc_loss.py new file mode 100644 index 0000000..eb082ed --- /dev/null +++ b/utils/acc_loss.py @@ -0,0 +1,33 @@ +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib.pyplot as plt +import glob +from time import gmtime, strftime +from datetime import datetime +import timeit +from sklearn.metrics import accuracy_score +from tensorflow.keras.models import load_model + +def acc_loss(run_type, saved_model, output_dir, save_dir): + + ### determine if this is train or test + if run_type == 'val': + fn = 'df_val_pred.p' + if run_type == 'test': + fn = 'df_test_pred.p' + + df_sum = pd.read_pickle(os.path.join(save_dir, fn)) + y_true = df_sum['label'].to_numpy() + y_pred = df_sum['y_pred_class'].to_numpy() + + ### acc and loss + model = load_model(os.path.join(output_dir, saved_model)) + score = model.evaluate(x_val, y_val) + loss = np.around(score[0], 3) + acc = np.around(score[1], 3) + print('loss:', loss) + print('acc:', acc) + + return acc, loss diff --git a/utils/chest_patient_meta.py b/utils/chest_patient_meta.py new file mode 100644 index 0000000..1e7346a --- /dev/null +++ b/utils/chest_patient_meta.py @@ -0,0 +1,343 @@ +""" + ---------------------------------------- + get patient and CT metadata for chest CT + ---------------------------------------- + ---------------------------------------- + Author: AIM Harvard + + Python Version: 3.8.8 + ---------------------------------------- + After the data (CT-mask pair, or just CT) is processed by the first script, + export downsampled versions to be used for heart-localisation purposes. + During this downsampling step, resample and crop/pad images - log all the + information needed for upsampling (and thus obtain a rough segmentation that + will be used for the localisation). + +""" + + +import os +import numpy as np +import pandas as pd +import glob +from sklearn.model_selection import train_test_split + +#---------------------------------------------------------------------------------------- +# external val dataset using lung CT +#---------------------------------------------------------------------------------------- +def chest_metadata(harvard_rt_dir, data_exclude, pro_data_dir, data_pro_dir, split): + + """ + + Get patient and scan metadata for chest CT + + @params: + data_sitk - required : SimpleITK image, resulting from sitk.ImageFileReader().Execute() + new_spacing - required : desired spacing (equal for all the axes), in mm, of the output data + method - required : SimpleITK interpolation method (e.g., sitk.sitkLinear) + + FIXME: change this into something like downsample_sitk (also, data_sitk to img_sitk for homog.) + (as this function is especially used for downsampling, right?) + + """ + + #---------------------------- + ## rotg_0617 test dataset + #----------------------------- + df_pat = pd.read_csv(os.path.join(pro_data_dir, 'rtog_pat_df.csv')) + df = pd.read_csv(os.path.join(pro_data_dir, 'rtog_final_curation.csv')) + df['gender'].replace([1], 0, inplace=True) + df['gender'].replace([2], 1, inplace=True) + + IDs = [] + ages = [] + genders = [] + stages = [] + thks = [] + histologys = [] + sizes = [] + spacings = [] + print(df['patid']) + print(df_pat['ID']) + for patid, age, gender, stage, thk, histology, size, spacing in zip( + df['patid'], + df['age'], + df['gender'], + df['ajcc_stage_grp'], + df['spacing_Z'], + df['histology'], + df['size_X'], + df['spacing_X'], + ): + if patid in df_pat['ID'].to_list(): + IDs.append(patid) + ages.append(age) + genders.append(gender) + stages.append(stage) + thks.append(thk) + histologys.append(histology) + sizes.append(size) + spacings.append(spacing) + + ## patient meta - test set + df_test = pd.DataFrame({ + 'ID': IDs, + 'gender': genders, + 'age': ages, + 'stage': stages, + 'histology': histologys + }) + print('df_test:', df_test.shape[0]) + + ## CT scan meta data + df_scan1 = pd.DataFrame({ + 'ID': IDs, + 'thk': thks, + 'size': sizes, + 'spacing': spacings + }) + print('df_scan1:', df_scan1.shape[0]) + + #--------------------------------------- + ## harvard-rt train and val dataset + #---------------------------------------- + df1 = pd.read_csv(os.path.join(data_pro_dir, 'harvard_rt_meta.csv')) + df1.dropna(subset=['ctdose_contrast', 'top_coder_id'], how='any', inplace=True) + df2 = pd.read_csv(os.path.join(data_pro_dir, 'harvard_rt.csv')) + + ## all scan ID to list + IDs = [] + list_fn = [fn for fn in sorted(glob.glob(harvard_rt_dir + '/*nrrd'))] + for fn in list_fn: + ID = fn.split('/')[-1].split('.')[0].strip() + IDs.append(ID) + print('IDs:', len(IDs)) + print('top coder ID:', df1['top_coder_id'].shape[0]) + + #------------------------------ + # meta file 1 - harvard_rt_meta + #------------------------------ + genders = [] + scanners = [] + kvps = [] + thks = [] + tstages = [] + nstages = [] + mstages = [] + stages = [] + labels = [] + for top_coder_id, label, gender, scanner, \ + kvp, thk, tstage, nstage, mstage, stage in zip( + df1['top_coder_id'], + df1['ctdose_contrast'], + df1['gender'], + df1['scanner_type'], + df1['kvp_value'], + df1['slice_thickness'], + df1['clin_tstage'], + df1['clin_nstage'], + df1['clin_mstage'], + df1['clin_stage'] + ): + tc_id = top_coder_id.split('_')[2].strip() + if tc_id in IDs: + labels.append(label) + genders.append(gender) + scanners.append(scanner) + kvps.append(kvp) + thks.append(thk) + tstages.append(tstage) + nstages.append(nstage) + mstages.append(mstage) + stages.append(stage) + + #------------------------- + # meta file 2 - harvard_rt + #------------------------- + ages = [] + histologys = [] + sizes = [] + spacings = [] + for topcoder_id, age, histology, size, spacing in zip( + df2['topcoder_id'], + df2['age'], + df2['histology'], + df2['raw_size_x'], + df2['raw_spacing_x'], + ): + if topcoder_id in IDs: + ages.append(age) + histologys.append(histology) + sizes.append(size) + spacings.append(spacing) + + ## delete excluded scans and repeated scans + if data_exclude != None: + df_exclude = df[df['ID'].isin(data_exclude)] + print('exclude scans:', df_exclude) + df.drop(df[df['ID'].isin(test_exclude)].index, inplace=True) + print('total scans:', df.shape[0]) + pd.options.display.max_columns = 100 + pd.set_option('display.max_rows', 500) + #print(df[0:50]) + + #--------------------------------------------------- + # split dataset for fine-tuning model and test model + #--------------------------------------------------- + if split == True: + ID1, ID2, gender1, gender2, age1, age2, tstage1, tstage2, nstage1, \ + nstage2, mstage1, mstage2, stage1, stage2, histo1, histo2 = train_test_split( + IDs, + genders, + ages, + tstages, + nstages, + mstages, + stages, + histologys, + stratify=labels, + shuffle=True, + test_size=0.2, + random_state=42 + ) + + ## patient meta - train df + df_train = pd.DataFrame({ + 'ID': ID1, + 'gender': gender1, + 'age': age1, + 'stage': stage1, + 'histology': histo1, + }) + #df.to_csv(os.path.join(pro_data_dir, 'exval_pat_df.csv')) + print('train set:', df_train.shape[0]) + + ## patient meta - val df + df_val = pd.DataFrame({ + 'ID': ID2, + 'gender': gender2, + 'age': age2, + 'stage': stage2, + 'histology': histo2, + }) + print('val set:', df_val.shape[0]) + + ## patient meta - train + val - test + df_tot = pd.concat([df_train, df_val, df_test]) + + ## print patient meta + dfs = [df_train, df_val, df_test, df_tot] + datasets = ['train', 'val', 'test', 'all'] + + for df, dataset in zip(dfs, datasets): + print('\n') + print('----------------------------') + print(dataset) + print('----------------------------') + print('patient number:', df.shape[0]) + print('median age:', df['age'].median().round(3)) + print('age max:', df['age'].max().round(3)) + print('age min:', df['age'].min().round(3)) + print('\n') + print(df['gender'].value_counts()) + print(df['gender'].value_counts(normalize=True).round(3)) + print('\n') + print(df['stage'].value_counts()) + print(df['stage'].value_counts(normalize=True).round(3)) + print('\n') + print(df['histology'].value_counts()) + print(df['histology'].value_counts(normalize=True).round(3)) + + #----------------------------------------- + ## print scan meta data + #----------------------------------------- + ## CT scan meta data + df_scan2 = pd.DataFrame({ + 'ID': IDs, + 'thk': thks, + 'size': sizes, + 'spacing': spacings + }) + #print('scanner:', scanners) + #print('kvp:', kvps) + ## scan parameters + df_scan3 = pd.DataFrame({ + 'scanner': scanners, + 'kvp': kvps, + }) + + ## print scan metadata + df_scan = pd.concat([df_scan1, df_scan2]) + df = df_scan + print('\n') + print('-------------------') + print('thickness and size') + print('-------------------') + print('print scan metadata:') + print('patient number:', df.shape[0]) + print(df['thk'].value_counts()) + print(df['thk'].value_counts(normalize=True).round(3)) + print(df['size'].value_counts()) + print(df['size'].value_counts(normalize=True).round(3)) + ## slice thickness + print('\n') + print('-------------------') + print('slice thickness') + print('-------------------') + print('thk mean:', df['thk'].mean().round(3)) + print('thk median:', df['thk'].median()) + print('thk mode:', df['thk'].mode()) + print('thk std:', df['thk'].std().round(3)) + print('thk min:', df['thk'].min()) + print('thk max:', df['thk'].max()) + ## voxel spacing + print('\n') + print('-------------------') + print('spacing info') + print('-------------------') + print('spacing mean:', df['spacing'].mean().round(3)) + print('spacing median:', df['spacing'].median()) + print('spacing mode:', df['spacing'].mode()) + print('spacing std:', df['spacing'].std().round(3)) + print('spacing min:', df['spacing'].min()) + print('spacing max:', df['spacing'].max()) + + df = df_scan3 + print('\n') + print('-------------------') + print('scanner info') + print('-------------------') + print('patient number:', df.shape[0]) + ## scanner type + print(df['scanner'].value_counts()) + print(df['scanner'].value_counts(normalize=True).round(3)) + ## tued voltage (kvp) + print('\n') + print('-------------------') + print('KVP') + print('-------------------') + print('kvp mean:', df['kvp'].mean().round(3)) + print('kvp median:', df['kvp'].median()) + print('kvp mode:', df['kvp'].mode()) + print('kvp std:', df['kvp'].std().round(3)) + print('kvp min:', df['kvp'].min()) + print('kvp max:', df['kvp'].max()) + +#----------------------------------------------------------------------------------- +# run funtions +#----------------------------------------------------------------------------------- +if __name__ == '__main__': + + pro_data_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project/pro_data' + data_pro_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data_pro' + harvard_rt_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data/NSCLC_data_reg' + + chest_metadata( + harvard_rt_dir=harvard_rt_dir, + data_exclude=None, + pro_data_dir=pro_data_dir, + data_pro_dir=data_pro_dir, + split=True, + ) + + diff --git a/utils/cm_all.py b/utils/cm_all.py new file mode 100644 index 0000000..d0f6504 --- /dev/null +++ b/utils/cm_all.py @@ -0,0 +1,111 @@ +#---------------------------------------------------------------------- +# Deep learning for classification for contrast CT; +# Transfer learning using Google Inception V3; +#----------------------------------------------------------------------------------------- +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib.pyplot as plt +import glob +from time import gmtime, strftime +from datetime import datetime +import timeit +from sklearn.model_selection import train_test_split, GroupShuffleSplit +from sklearn.metrics import classification_report, confusion_matrix +from sklearn.metrics import accuracy_score +from sklearn.metrics import roc_curve, auc, precision_recall_curve +from utils.plot_cm import plot_cm + +# ---------------------------------------------------------------------------------- +# plot ROI +# ---------------------------------------------------------------------------------- +def cm_all(run_type, level, thr_img, thr_prob, thr_pos, pro_data_dir, save_dir, fn_df_pred): + + df_sum = pd.read_csv(os.path.join(pro_data_dir, fn_df_pred)) + + if level == 'img': + y_true = df_sum['label'].to_numpy() + preds = df_sum['y_pred'].to_numpy() + y_pred = [] + for pred in preds: + if pred > thr_img: + pred = 1 + else: + pred = 0 + y_pred.append(pred) + y_pred = np.asarray(y_pred) + print_info = 'cm image:' + + elif level == 'patient_mean_prob': + df_mean = df_sum.groupby(['ID']).mean() + y_true = df_mean['label'].to_numpy() + preds = df_mean['y_pred'].to_numpy() + y_pred = [] + for pred in preds: + if pred > thr_prob: + pred = 1 + else: + pred = 0 + y_pred.append(pred) + y_pred = np.asarray(y_pred) + print_info = 'cm patient prob:' + + elif level == 'patient_mean_pos': + df_mean = df_sum.groupby(['ID']).mean() + y_true = df_mean['label'].to_numpy() + pos_rates = df_mean['y_pred_class'].to_list() + y_pred = [] + for pos_rate in pos_rates: + if pos_rate > thr_pos: + pred = 1 + else: + pred = 0 + y_pred.append(pred) + y_pred = np.asarray(y_pred) + print_info = 'cm patient pos:' + + ### using confusion matrix to calculate AUC + cm = confusion_matrix(y_true, y_pred) + cm_norm = cm.astype('float64') / cm.sum(axis=1)[:, np.newaxis] + cm_norm = np.around(cm_norm, 2) + + ## classification report + report = classification_report(y_true, y_pred, digits=3) + + # statistics + fp = cm[0][1] + fn = cm[1][0] + tp = cm[1][1] + tn = cm[0][0] + acc = (tp + tn)/(tp + fp + fn + tn) + tpr = tp/(tp + fn) + tnr = tn/(tn + fp) + tpr = np.around(tpr, 3) + tnr = np.around(tnr, 3) + auc5 = (tpr + tnr)/2 + auc5 = np.around(auc5, 3) + + print(print_info) + print(cm) + print(cm_norm) + print(report) + + ## plot cm + for cm0, cm_type in zip([cm, cm_norm], ['raw', 'norm']): + plot_cm( + cm0=cm0, + cm_type=cm_type, + level=level, + save_dir=save_dir + ) + + return cm, cm_norm, report + + + + + + + + diff --git a/utils/contrast_metadata.py b/utils/contrast_metadata.py new file mode 100644 index 0000000..27d79da --- /dev/null +++ b/utils/contrast_metadata.py @@ -0,0 +1,60 @@ +import os +import pandas as pd +import numpy as np +import glob + + +pro_data_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project/pro_data' + +## actual data from df +train_data = pd.read_csv(os.path.join(pro_data_dir, 'train_pat_df.csv')) +val_data = pd.read_csv(os.path.join(pro_data_dir, 'val_pat_df.csv')) +test_data = pd.read_csv(os.path.join(pro_data_dir, 'test_pat_df.csv')) +data = pd.concat([train_data, val_data, test_data]) + +## clinical meta data +df = pd.read_csv(os.path.join(pro_data_dir, 'clinical_meta_data.csv')) +df['contrastbolusagent'] = df['contrastbolusagent'].fillna(2) + +pat_ids = [] +contrasts = [] +for patientid, contrast in zip(df['patientid'], df['contrastbolusagent']): + if patientid[3:7] == 'CHUM': + pat_id = 'CHUM' + patientid[-3:] + elif patientid[3:7] == 'CHUS': + pat_id = 'CHUS' + patientid[-3:] + elif patientid[:3] == 'OPC': + pat_id = 'PMH' + patientid[-3:] + elif patientid[:5] == 'HNSCC': + pat_id = 'MDACC' + patientid[-3:] + if pat_id in data['ID'].to_list(): + pat_ids.append(pat_id) + ## change contrast annotation in meta data + if contrast in ['N', 'n', 'NO']: + contrast = 0 + elif contrast == 2: + contrast = contrast + else: + contrast = 1 + contrasts.append(contrast) +df = pd.DataFrame({'ID': pat_ids, 'contrast': contrasts}) + +## match metadata annotations with clinical expert +ids = [] +contrasts = [] +labels = [] +for ID, label in zip(data['ID'], data['label']): + for pat_id, contrast in zip(df['ID'], df['contrast']): + if pat_id == ID and contrast != 2 and contrast != label: + ids.append(pat_id) + contrasts.append(contrast) + labels.append(label) +print('mismatch ID:', ids) +print('mismatch label:', labels) +print('mismatch label:', contrasts) +print('mismatch number:', len(contrasts)) + +print('total patient:', df['contrast'].shape[0]) +print(df['contrast'].value_counts()) +print(df['contrast'].value_counts(normalize=True)) + diff --git a/utils/crop_image.py b/utils/crop_image.py new file mode 100644 index 0000000..edd1f53 --- /dev/null +++ b/utils/crop_image.py @@ -0,0 +1,125 @@ +import os +import itertools +import operator +import numpy as np +import SimpleITK as sitk +from scipy import ndimage +from tensorflow import keras +import tensorflow as tf +import matplotlib.pyplot as plt +import matplotlib.cm as cm +import cv2 +import nrrd +from PIL import Image + + +#-------------------------------------------------------------------------------------- +# crop image +#------------------------------------------------------------------------------------- +def crop_image(nrrd_file, patient_id, crop_shape, return_type, save_dir): + + ## load stik and arr + img_arr = sitk.GetArrayFromImage(nrrd_file) + ## Return top 25 rows of 3D volume, centered in x-y space / start at anterior (y=0)? +# img_arr = np.transpose(img_arr, (2, 1, 0)) +# print("image_arr shape: ", img_arr.shape) + c, y, x = img_arr.shape +# x, y, c = image_arr.shape +# print('c:', c) +# print('y:', y) +# print('x:', x) + + ## Get center of mass to center the crop in Y plane + mask_arr = np.copy(img_arr) + mask_arr[mask_arr > -500] = 1 + mask_arr[mask_arr <= -500] = 0 + mask_arr[mask_arr >= -500] = 1 + #print("mask_arr min and max:", np.amin(mask_arr), np.amax(mask_arr)) + centermass = ndimage.measurements.center_of_mass(mask_arr) # z,x,y + cpoint = c - crop_shape[2]//2 + #print("cpoint, ", cpoint) + centermass = ndimage.measurements.center_of_mass(mask_arr[cpoint, :, :]) + #print("center of mass: ", centermass) + startx = int(centermass[0] - crop_shape[0]//2) + starty = int(centermass[1] - crop_shape[1]//2) + #startx = x//2 - crop_shape[0]//2 + #starty = y//2 - crop_shape[1]//2 + startz = int(c - crop_shape[2]) + #print("start X, Y, Z: ", startx, starty, startz) + + ## crop image using crop shape + if startz < 0: + img_arr = np.pad( + img_arr, + ((abs(startz)//2, abs(startz)//2), (0, 0), (0, 0)), + 'constant', + constant_values=-1024 + ) + img_crop_arr = img_arr[ + 0:crop_shape[2], + starty:starty + crop_shape[1], + startx:startx + crop_shape[0] + ] + else: + img_crop_arr = img_arr[ +# 0:crop_shape[2], + startz:startz + crop_shape[2], + starty:starty + crop_shape[1], + startx:startx + crop_shape[0] + ] + if img_crop_arr.shape[0] < crop_shape[2]: + print('initial cropped image shape too small:', img_arr.shape) + print(crop_shape[2], img_crop_arr.shape[0]) + img_crop_arr = np.pad( + img_crop_arr, + ((int(crop_shape[2] - img_crop_arr.shape[0]), 0), (0, 0), (0, 0)), + 'constant', + constant_values=-1024 + ) + print("padded size: ", img_crop_arr.shape) + #print(img_crop_arr.shape) + ## get nrrd from numpy array + img_crop_nrrd = sitk.GetImageFromArray(img_crop_arr) + img_crop_nrrd.SetSpacing(nrrd_file.GetSpacing()) + img_crop_nrrd.SetOrigin(nrrd_file.GetOrigin()) + + if save_dir != None: + fn = str(patient_id) + '.nrrd' + writer = sitk.ImageFileWriter() + writer.SetFileName(os.path.join(save_dir, fn)) + writer.SetUseCompression(True) + writer.Execute(img_crop_nrrd) + + if return_type == 'nrrd': + return img_crop_nrrd + + elif return_type == 'npy': + return img_crop_arr + +#----------------------------------------------------------------------- +# run codes to test +#----------------------------------------------------------------------- +if __name__ == '__main__': + +# output_dir = '/media/bhkann/HN_RES1/HN_CONTRAST/output' + output_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/test' + file_dir = '/media/bhkann/HN_RES1/HN_CONTRAST/0_image_raw_PMH' + test_file = 'PMH_OPC-00050_CT-SIM_raw_raw_raw_xx.nrrd' + return_type = 'sitk' + crop_shape = [192, 192, 110] + + train_file = os.path.join(file_dir, test_file) + + img_crop = crop_image( + nrrd_file=train_file, + crop_shape=crop_shape, + return_type=return_type, + output_dir=output_dir + ) + print('crop arr shape:', img_crop) +# arr_crop = image_arr_crop[0, :, :] +# img_dir = os.path.join(output_dir, 'arr_crop.jpg') +# plt.imsave(img_dir, arr_crop, cmap='gray') + print('successfully save image!!!') + + diff --git a/utils/delete_row.py b/utils/delete_row.py new file mode 100644 index 0000000..bf0c6b8 --- /dev/null +++ b/utils/delete_row.py @@ -0,0 +1,14 @@ +import pandas as pd +import os + + +#data_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project/ahmed_data' +#df = pd.read_csv(os.path.join(data_dir, 'rtog-0617_pat.csv')) +#df.drop(df.index[[343]], inplace=True) +#df.to_csv(os.path.join(data_dir, 'rtog-0617_pat.csv')) + + +data_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project/pro_data' +df = pd.read_csv(os.path.join(data_dir, 'rtog_img_df.csv')) +df.drop(df.index[range(24010, 24080)], axis=0, inplace=True) +df.to_csv(os.path.join(data_dir, 'rtog_img_df.csv')) diff --git a/utils/error_analysis.py b/utils/error_analysis.py new file mode 100644 index 0000000..920b28e --- /dev/null +++ b/utils/error_analysis.py @@ -0,0 +1,286 @@ + +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib as mpl +import matplotlib.pyplot as plt +import glob +from collections import Counter +from datetime import datetime +from time import localtime, strftime +import matplotlib.pyplot as plt +import tensorflow as tf +from tensorflow.keras.models import load_model +from crop_image import crop_image +from resize_3d import resize_3d +import SimpleITK as sitk + +#---------------------------------------------------------------- +# error analysis on image level +#--------------------------------------------------------------- +def error_img(run_type, val_save_dir, test_save_dir, input_channel, crop, pro_data_dir): + + ### load train data based on input channels + if run_type == 'val': + file_dir = val_save_dir + input_fn = 'df_val_pred.csv' + output_fn = 'val_error_img.csv' + col = 'y_val' + elif run_type == 'test': + file_dir = test_save_dir + input_fn = 'df_test_pred.csv' + output_fn = 'test_error_img.csv' + col = 'y_test' + elif run_type == 'exval2': + file_dir = pro_data_dir + input_fn = 'exval2_img_pred.csv' + output_fn = 'exval2_error_img.csv' + col = 'y_exval2' + + ### dataframe with label and predictions + df = pd.read_csv(os.path.join(file_dir, input_fn)) + print(df[0:10]) + #df.drop([col, 'ID'], axis=1, inplace=True) + df = df.drop(df[df['y_pred_class'] == df['label']].index) + df[['y_pred']] = df[['y_pred']].round(3) + pd.options.display.max_columns = 100 + pd.set_option('display.max_rows', 500) + print(df[0:200]) + df.to_csv(os.path.join(file_dir, output_fn)) + df_error_img = df + + return df_error_img +#---------------------------------------------------- +# generate images for error check +#---------------------------------------------------- +def save_error_img(df_error_img, run_type, n_img, val_img_dir, test_img_dir, + val_error_dir, test_error_dir, pro_data_dir): + + ### store all indices that has wrong predictions + indices = [] + df = df_error_img + for i in range(df.shape[0]): + indices.append(i) + print(x_val.shape) + arr = x_val.take(indices=indices, axis=0) + print(arr.shape) + arr = arr[:, :, :, 0] + print(arr.shape) + arr = arr.reshape((arr.shape[0], 192, 192)) + print(arr.shape) + + ## load img data + if run_type == 'val': + if input_channel == 1: + fn = 'val_arr_1ch.npy' + elif input_channel == 3: + fn = 'val_arr_3ch.npy' + x_val = np.load(os.path.join(data_pro_dir, fn)) + file_dir = val_error_dir + elif run_type == 'test': + if input_channel == 1: + fn = 'test_arr_1ch.npy' + elif input_channel == 3: + fn = 'test_arr_3ch.npy' + x_val = np.load(os.path.join(data_pro_dir, fn)) + file_dir = test_error_dir + elif run_type == 'exval2': + if input_channel == 1: + fn = 'exval2_arr_1ch.npy' + elif input_channel == 3: + fn = 'exval2_arr.npy' + x_val = np.load(os.path.join(pro_data_dir, fn)) + file_dir = pro_data_dir + + ### display images for error checks + count = 0 + for i in range(n_img): + # for i in range(arr.shape[0]): + #count += 1 + #print(count) + img = arr[i, :, :] + fn = str(i) + '.jpeg' + img_fn = os.path.join(file_dir, fn) + mpl.image.imsave(img_fn, img, cmap='gray') + + print('save images complete!!!') + +#---------------------------------------------------------------------- +# error analysis on patient level +#---------------------------------------------------------------------- +def error_pat(run_type, val_dir, test_dir, exval2_dir, threshold, pro_data_dir): + + if run_type == 'val': + input_fn = 'val_img_pred.csv' + output_fn = 'val_pat_error.csv' + drop_col = 'y_val' + save_dir = val_dir + elif run_type == 'test': + input_fn = 'test_img_pred.csv' + output_fn = 'test_pat_error.csv' + drop_col = 'y_test' + save_dir = test_dir + elif run_type == 'exval2': + input_fn = 'rtog_img_pred.csv' + output_fn = 'rtog_pat_error.csv' + drop_col = 'y_exval2' + save_dir = exval2_dir + + df_sum = pd.read_csv(os.path.join(pro_data_dir, input_fn)) + df_mean = df_sum.groupby(['ID']).mean() + y_true = df_mean['label'] + y_pred = df_mean['y_pred'] + y_pred_classes = [] + for pred in y_pred: + if pred < threshold: + y_pred_class = 0 + elif pred >= threshold: + y_pred_class = 1 + y_pred_classes.append(y_pred_class) + df_mean['y_pred_class_thr'] = y_pred_classes + df = df_mean + df[['y_pred', 'y_pred_class']] = df[['y_pred', 'y_pred_class']].round(3) + df['label'] = df['label'].astype(int) + df = df.drop(df[df['y_pred_class_thr'] == df['label']].index) + #df.drop(['y_exval'], inplace=True, axis=1) + df.drop([drop_col], inplace=True, axis=1) + pd.options.display.max_columns = 100 + pd.set_option('display.max_rows', 500) + print(df) + #print("miss predicted scan:", df.shape[0]) + df_error_patient = df + df.to_csv(os.path.join(save_dir, output_fn)) + +#---------------------------------------------------------------------- +# save error scan +#---------------------------------------------------------------------- +def save_error_pat(run_type, val_save_dir, test_save_dir, val_error_dir, test_error_dir, norm_type, + interp_type, output_size, PMH_reg_dir, CHUM_reg_dir, CHUS_reg_dir): + count = 0 + dirs = [] + + if run_type == 'val': + save_dir = val_error_dir + df = pd.read_csv(os.path.join(val_save_dir, 'val_error_patient_new.csv')) + IDs = df['ID'].to_list() + for ID in IDs: + print('error scan:', ID) + fn = str(ID) + '.nrrd' + if ID[:-3] == 'PMH': + dir = os.path.join(PMH_reg_dir, fn) + elif ID[:-3] == 'CHUS': + dir = os.path.join(CHUS_reg_dir, fn) + elif ID[:-3] == 'CHUM': + dir = os.path.join(CHUM_reg_dir, fn) + dirs.append(dir) + + elif run_type == 'test': + save_dir = test_error_dir + df = pd.read_csv(os.path.join(test_save_dir, 'test_error_patient_new.csv')) + IDs = df['ID'].to_list() + for ID in IDs: + print('error scan:', ID) + fn = str(ID) + '.nrrd' + dir = os.path.join(MDACC_reg_dir, fn) + dirs.append(dir) + + for file_dir, patient_id in zip(dirs, IDs): + count += 1 + print(count) + nrrd = sitk.ReadImage(file_dir, sitk.sitkFloat32) + img_arr = sitk.GetArrayFromImage(nrrd) + data = img_arr[30:78, :, :] + data[data <= -1024] = -1024 + data[data > 700] = 0 + if norm_type == 'np_interp': + arr_img = np.interp(data, [-200, 200], [0, 1]) + elif norm_type == 'np_clip': + arr_img = np.clip(data, a_min=-200, a_max=200) + MAX, MIN = arr_img.max(), arr_img.min() + arr_img = (arr_img - MIN) / (MAX - MIN) + ## save npy array to image + img = sitk.GetImageFromArray(arr_img) + fn = str(patient_id) + '.nrrd' + sitk.WriteImage(img, os.path.join(save_dir, fn)) + print('save error scans!') + + ## save error scan as nrrd file + +#---------------------------------------------------------------------------- +# main funtion +#--------------------------------------------------------------------------- +if __name__ == '__main__': + + val_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/val' + test_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/test' + exval2_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/exval2' + val_error_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/val/error' + test_error_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/test/error' + mdacc_data_dir = '/media/bhkann/HN_RES1/HN_CONTRAST/0_image_raw_mdacc' + CHUM_reg_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data/CHUM_data_reg' + CHUS_reg_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data/CHUS_data_reg' + PMH_reg_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data/PMH_data_reg' + MDACC_reg_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data/MDACC_data_reg' + pro_data_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project/pro_data' + + input_channel = 3 + crop = True + n_img = 20 + save_img = False + save_pat = False + error_image = False + thr_prob = 0.5 + run_type = 'val' + norm_type = 'np_clip' + crop_shape = [192, 192, 110] + return_type1 = 'nrrd' + return_type2 = 'npy' + interp_type = 'linear' + output_size = (96, 96, 36) + + error_pat( + run_type=run_type, + val_dir=val_dir, + test_dir=test_dir, + exval2_dir=exval2_dir, + threshold=thr_prob, + pro_data_dir=pro_data_dir + ) + + if error_image == True: + df = error_img( + run_type=run_type, + val_save_dir=val_save_dir, + test_save_dir=test_save_dir, + input_channel=input_channel, + crop=crop, + pro_data_dir=pro_data_dir + ) + + if save_pat == True: + save_error_pat( + run_type=run_type, + val_save_dir=val_save_dir, + test_save_dir=test_save_dir, + val_error_dir=val_error_dir, + test_error_dir=test_error_dir, + norm_type=norm_type, + interp_type=interp_type, + output_size=output_size, + PMH_reg_dir=PMH_reg_dir, + CHUM_reg_dir=CHUM_reg_dir, + CHUS_reg_dir=CHUS_reg_dir + ) + + if save_img == True: + save_error_img( + df=df, + n_img=n_img, + run_type=run_type, + val_img_dir=val_img_dir, + test_img_dir=test_img_dir, + val_error_dir=vel_error_dir, + test_error_dir=test_error_dir, + pro_data_dir=pro_data_dir + ) diff --git a/utils/get_pat_img_df.py b/utils/get_pat_img_df.py new file mode 100644 index 0000000..a6b7462 --- /dev/null +++ b/utils/get_pat_img_df.py @@ -0,0 +1,94 @@ +import pandas as pd +import numpy as np +import os +import glob + + +#------------------------------------------------------------------------- +# create patient df with ID, label and dir +#------------------------------------------------------------------------- +def get_pat_df(pro_data_dir, reg_data_dir, label_file, fn_pat_df): + + ## create df for dir, ID and labels on patient level + df_label = pd.read_csv(os.path.join(pro_data_dir, label_file)) + df_label['Contrast'] = df_label['Contrast'].map({'Yes': 1, 'No': 0}) + labels = df_label['Contrast'].to_list() + fns = [fn for fn in sorted(glob.glob(reg_data_dir + '/*nrrd'))] + IDs = [] + for fn in fns: + ID = fn.split('/')[-1].split('_')[1].split('.')[0].strip() + IDs.append(ID) + pat_ids = [] + labels = [] + for pat_id, pat_label in zip(df_label['Patient ID'], df_label['Contrast']): + if pat_id in IDs: + pat_id = 'rtog' + '_' + str(pat_id) + pat_ids.append(pat_id) + labels.append(pat_label) + print("ID:", len(pat_ids)) + print("dir:", len(fns)) + print("label:", len(labels)) + print('contrast scan in ex val:', labels.count(1)) + print('non-contrast scan in ex val:', labels.count(0)) + df = pd.DataFrame({'ID': IDs, 'file': fns, 'label': labels}) + df.to_csv(os.path.join(pro_data_dir, fn_pat_df)) + print('total scan:', df.shape[0]) + +#------------------------------------------------------------------------- +# create img df with ID, label +#------------------------------------------------------------------------- +def get_img_df(pro_data_dir, fn_pat_df, fn_img_df, slice_range): + + pat_df = pd.read_csv(os.path.join(pro_data_dir, fn_pat_df)) + ## img ID + slice_number = len(slice_range) + img_ids = [] + for pat_id in pat_df['ID']: + for i in range(slice_number): + img_id = 'rtog' + '_' + pat_id + '_' + 'slice%s'%(f'{i:03d}') + img_ids.append(img_id) + #print(img_ids) + print(len(img_ids)) + ## img label + pat_label = pat_df['label'].to_list() + img_label = [] + for label in pat_label: + list2 = [label] * slice_number + img_label.extend(list2) + #print(img_label) + print(len(img_label)) + ### makeing dataframe containing img IDs and labels + df = pd.DataFrame({'fn': img_ids, 'label': img_label}) + print(df[0:10]) + df.to_csv(os.path.join(pro_data_dir, fn_img_df)) + print('total img:', df.shape[0]) + +#------------------------------------------------------------------------- +# create patient df with ID, label and dir +#------------------------------------------------------------------------- +if __name__ == '__main__': + + pro_data_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project/pro_data' + slice_range = range(50, 120) + reg_data_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/ahmed_data/rtog-0617_reg' + label_file = 'label_RTOG0617.csv' + fn_pat_df = 'rtog_pat_df.csv' + fn_img_df = 'rtog_img_df.csv' + + get_pat_df( + pro_data_dir=pro_data_dir, + reg_data_dir=reg_data_dir, + label_file=label_file, + fn_pat_df=fn_pat_df + ) + + get_img_df( + pro_data_dir=pro_data_dir, + fn_pat_df=fn_pat_df, + fn_img_df=fn_img_df, + slice_range=slice_range + ) + + + + diff --git a/utils/get_pred_img_df.py b/utils/get_pred_img_df.py new file mode 100644 index 0000000..7ed2e3f --- /dev/null +++ b/utils/get_pred_img_df.py @@ -0,0 +1,48 @@ +import pandas as pd +import numpy as np +import os +from utils.make_plots import make_plots + +def get_pred_img_df(ahmed_data_dir, pro_data_dir, slice_range): + + df = pd.read_csv(os.path.join(ahmed_data_dir, 'rtog-0617_img_pred.csv')) + df_label = pd.read_csv(os.path.join(pro_data_dir, 'label_RTOG0617.csv')) + df_label['Contrast'] = df_label['Contrast'].map({'Yes': 1, 'No': 0}) + pat_label = df_label['Contrast'].to_list() + img_label = [] + for label in pat_label: + list1 = [label] * slice_number + img_label.extend(list1) + df['label'] = img_label + df.to_csv(os.path.join(pro_data_dir, 'exval2_img_pred.csv')) + +if __name__ == '__main__': + + ahmed_data_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data_pro' + pro_data_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project/pro_data' + exval2_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/d_pro' + slice_range = range(50, 70) + saved_model = 'FineTuned_model_2021_07_27_16_44_40' + + get_pred_img_df( + ahmed_data_dir=ahmed_data_dir, + pro_data_dir=pro_data_dir, + slice_range=slice_range + ) + + make_plots( + run_type='exval2', + thr_img=0.321, + thr_prob=0.379, + thr_pos=0.575, + bootstrap=1000, + pro_data_dir=pro_data_dir, + save_dir=exval2_dir, + loss=None, + acc=None, + run_model='ResNet', + saved_model=saved_model, + epoch=100, + batch_size=32, + lr=1e-5 + ) diff --git a/utils/get_stats_plots.py b/utils/get_stats_plots.py new file mode 100644 index 0000000..b8d44d1 --- /dev/null +++ b/utils/get_stats_plots.py @@ -0,0 +1,147 @@ +import os +import numpy as np +import pandas as pd +import pickle +from time import gmtime, strftime +from datetime import datetime +import timeit +from utils.cm_all import cm_all +from utils.roc_all import roc_all +from utils.prc_all import prc_all +#from utils.acc_loss import acc_loss +from utils.write_txt import write_txt + + + +def get_stats_plots(out_dir, proj_dir, run_type, run_model, loss, acc, + saved_model, epoch, batch_size, lr, thr_img=0.5, + thr_prob=0.5, thr_pos=0.5, bootstrap=1000): + + """ + generate model val/test statistics and plot curves; + + Args: + loss {float} -- validation loss; + acc {float} -- validation accuracy; + run_model {str} -- cnn model name; + batch_size {int} -- batch size for data loading; + epoch {int} -- training epoch; + out_dir {path} -- path for output files; + opt {str or function} -- optimized function: 'adam'; + lr {float} -- learning rate; + + Keyword args: + bootstrap {int} -- number of bootstrap to calculate 95% CI for AUC; + thr_img {float} -- threshold to determine positive class on image level; + thr_prob {float} -- threshold to determine positive class on patient + level (mean prob score); + thr_pos {float} -- threshold to determine positive class on patient + level (positive class percentage); + Returns: + Model prediction statistics and plots: ROC, PRC, confusion matrix, etc. + + """ + + pro_data_dir = os.path.join(proj_dir, 'pro_data') + train_dir = os.path.join(out_dir, 'train') + val_dir = os.path.join(out_dir, 'val') + test_dir = os.path.join(out_dir, 'test') + exval1_dir = os.path.join(out_dir, 'exval1') + exval2_dir = os.path.join(out_dir, 'exval2') + + if not os.path.exists(train_dir): + os.mkdir(train_dir) + if not os.path.exists(val_dir): + os.mkdir(val_dir) + if not os.path.exists(test_dir): + os.mkdir(test_dir) + if not os.path.exists(exval1_dir): + os.mkdir(exval1_dir) + if not os.path.exists(exval2_dir): + os.mkdir(exval2_dir) + + ### determine if this is train or test + if run_type == 'val': + fn_df_pred = 'val_img_pred.csv' + save_dir = val_dir + elif run_type == 'test': + fn_df_pred = 'test_img_pred.csv' + save_dir = test_dir + elif run_type == 'exval1': + fn_df_pred = 'exval1_img_pred.csv' + save_dir = exval1_dir + elif run_type == 'exval2': + fn_df_pred = 'exval2_img_pred.csv' + save_dir = exval2_dir + + cms = [] + cm_norms = [] + reports = [] + roc_stats = [] + prc_aucs = [] + levels = ['img', 'patient_mean_prob', 'patient_mean_pos'] + + for level in levels: + + ## confusion matrix + cm, cm_norm, report = cm_all( + run_type=run_type, + level=level, + thr_img=thr_img, + thr_prob=thr_prob, + thr_pos=thr_pos, + pro_data_dir=pro_data_dir, + save_dir=save_dir, + fn_df_pred=fn_df_pred + ) + cms.append(cm) + cm_norms.append(cm_norm) + reports.append(report) + + ## ROC curves + roc_stat = roc_all( + run_type=run_type, + level=level, + thr_prob=thr_prob, + thr_pos=thr_pos, + bootstrap=bootstrap, + color='blue', + pro_data_dir=pro_data_dir, + save_dir=save_dir, + fn_df_pred=fn_df_pred + ) + roc_stats.append(roc_stat) + + ## PRC curves + prc_auc = prc_all( + run_type=run_type, + level=level, + thr_prob=thr_prob, + thr_pos=thr_pos, + color='red', + pro_data_dir=pro_data_dir, + save_dir=save_dir, + fn_df_pred=fn_df_pred + ) + prc_aucs.append(prc_auc) + + ### save validation results to txt + write_txt( + run_type=run_type, + out_dir=out_dir, + loss=loss, + acc=acc, + cms=cms, + cm_norms=cm_norms, + reports=reports, + prc_aucs=prc_aucs, + roc_stats=roc_stats, + run_model=run_model, + saved_model=saved_model, + epoch=epoch, + batch_size=batch_size, + lr=lr + ) + + print('saved model as:', saved_model) + diff --git a/utils/gradcam.py b/utils/gradcam.py new file mode 100644 index 0000000..676fd90 --- /dev/null +++ b/utils/gradcam.py @@ -0,0 +1,301 @@ +from tensorflow.keras.models import Model +import tensorflow as tf +from tensorflow import keras +import matplotlib.pyplot as plt +import matplotlib.cm as cm +import numpy as np +import cv2 +import numpy as np +import pandas as pd +from tensorflow.keras.models import load_model +import tensorflow as tf +import os + +#--------------------------------------------------------------------------------- +# get data +#--------------------------------------------------------------------------------- +def data(input_channel, i, val_save_dir, test_save_dir): + + ### load train data based on input channels + if run_type == 'val': + if input_channel == 1: + fn = 'val_arr_1ch.npy' + elif input_channel == 3: + fn = 'val_arr_3ch.npy' + data = np.load(os.path.join(pro_data_dir, fn)) + df = pd.read_csv(os.path.join(val_save_dir, 'val_pred_df.csv')) + elif run_type == 'test': + if input_channel == 1: + fn = 'test_arr_1ch.npy' + elif input_channel == 3: + fn = 'test_arr_3ch.npy' + data = np.load(os.path.join(pro_data_dir, fn)) + df = pd.read_csv(os.path.join(test_save_dir, 'test_pred_df.csv')) + elif run_type == 'exval': + if input_channel == 1: + fn = 'exval_arr_1ch.npy' + elif input_channel == 3: + fn = 'exval_arr_3ch.npy' + data = np.load(os.path.join(pro_data_dir, fn)) + df = pd.read_csv(os.path.join(exval_save_dir, 'exval_pred_df.csv')) + + ### load label + y_true = df['label'] + y_pred_class = df['y_pred_class'] + y_pred = df['y_pred'] + ID = df['fn'] + ### find the ith image to show grad-cam map + img = data[i, :, :, :] + img = img.reshape((1, 192, 192, 3)) + label = y_true[i] + pred_index = y_pred_class[i] + y_pred = y_pred[i] + ID = ID[i] + + return img, label, pred_index, y_pred, ID + +#------------------------------------------------------------------------------------ +# find last conv layer +#----------------------------------------------------------------------------------- +def find_target_layer(model, saved_model): + + # find the final conv layer by looping layers in reverse order + for layer in reversed(model.layers): + # check to see if the layer has a 4D output + if len(layer.output_shape) == 4: + return layer.name + raise ValueError("Could not find 4D layer. Cannot apply GradCAM.") + +#---------------------------------------------------------------------------------- +# calculate gradient class actiavtion map +#---------------------------------------------------------------------------------- +def compute_heatmap(model, saved_model, image, pred_index, last_conv_layer): + + """ + construct our gradient model by supplying (1) the inputs + to our pre-trained model, (2) the output of the (presumably) + final 4D layer in the network, and (3) the output of the + softmax activations from the model + """ + gradModel = Model( + inputs=[model.inputs], + outputs=[model.get_layer(last_conv_layer).output, model.output] + ) + + # record operations for automatic differentiation + with tf.GradientTape() as tape: + """ + cast the image tensor to a float-32 data type, pass the + image through the gradient model, and grab the loss + associated with the specific class index + """ + print(pred_index) + inputs = tf.cast(image, tf.float32) + print(image.shape) + last_conv_layer_output, preds = gradModel(inputs) + print(preds) + print(preds.shape) + # class_channel = preds[:, pred_index] + class_channel = preds + # use automatic differentiation to compute the gradients + grads = tape.gradient(class_channel, last_conv_layer_output) + """ + This is a vector where each entry is the mean intensity of the gradient + over a specific feature map channel + """ + pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) + """ + We multiply each channel in the feature map array + by "how important this channel is" with regard to the top predicted class + then sum all the channels to obtain the heatmap class activation + """ + last_conv_layer_output = last_conv_layer_output[0] + heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis] + heatmap = tf.squeeze(heatmap) + + # For visualization purpose, we will also normalize the heatmap between 0 & 1 + heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap) + heatmap = heatmap.numpy() + + return heatmap + +#------------------------------------------------------------------------------------ +# save gradcam heat map +#----------------------------------------------------------------------------------- +def save_gradcam(image, heatmap, val_gradcam_dir, test_gradcam_dir, alpha, i): + +# print('heatmap:', heatmap.shape) + # Rescale heatmap to a range 0-255 + heatmap = np.uint8(255 * heatmap) + # Use jet colormap to colorize heatmap + jet = cm.get_cmap("jet") + # Use RGB values of the colormap + jet_colors = jet(np.arange(256))[:, :3] + jet_heatmap = jet_colors[heatmap] + + # resize heatmap + jet_heatmap = keras.preprocessing.image.array_to_img(jet_heatmap) + jet_heatmap0 = jet_heatmap.resize(re_size) + jet_heatmap1 = keras.preprocessing.image.img_to_array(jet_heatmap0) +# print('jet_heatmap:', jet_heatmap1.shape) + + # resize background CT image + img = image.reshape((192, 192, 3)) + img = keras.preprocessing.image.array_to_img(img) + img0 = img.resize(re_size) + img1 = keras.preprocessing.image.img_to_array(img0) +# print('img shape:', img1.shape) + + # Superimpose the heatmap on original image + superimposed_img = jet_heatmap1 * alpha + img1 + superimposed_img = keras.preprocessing.image.array_to_img(superimposed_img) + + # Save the superimposed image + if run_type == 'val': + save_dir = val_gradcam_dir + elif run_type == 'test': + save_dir = test_gradcam_dir + elif run_type == 'exval': + save_dir = exval_gradcam_dir + + fn1 = str(conv_n) + '_' + str(i) + '_' + 'gradcam.png' + fn2 = str(conv_n) + '_' + str(i) + '_' + 'heatmap.png' + fn3 = str(conv_n) + '_' + str(i) + '_' + 'heatmap_raw.png' + fn4 = str(i) + '_' + 'CT.png' + superimposed_img.save(os.path.join(save_dir, fn1)) +# jet_heatmap0.save(os.path.join(save_dir, fn2)) +# jet_heatmap.save(os.path.join(save_dir, fn3)) +# img0.save(os.path.join(save_dir, fn4)) + + +if __name__ == '__main__': + + train_img_dir = '/media/bhkann/HN_RES1/HN_CONTRAST/train_img_dir' + val_save_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/val' + test_save_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/test' + exval_save_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/exval' + val_gradcam_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/val/gradcam' + test_gradcam_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/test/gradcam' + exval_gradcam_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/test/gradcam' + pro_data_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project/pro_data' + model_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/model' + input_channel = 3 + re_size = (192, 192) + i = 72 + crop = True + alpha = 0.9 + saved_model = 'ResNet_2021_07_18_06_28_40' + show_network = False + conv_n = 'conv5' + run_type = 'val' + + #--------------------------------------------------------- + # run main function + #-------------------------------------------------------- + if run_type == 'val': + save_dir = val_save_dir + elif run_type == 'test': + save_dir = test_save_dir + + ## load model and find conv layers + model = load_model(os.path.join(model_dir, saved_model)) +# model.summary() + + list_i = [100, 105, 110, 115, 120, 125] + for i in list_i: + image, label, pred_index, y_pred, ID = data( + input_channel=input_channel, + i=i, + val_save_dir=val_save_dir, + test_save_dir=test_save_dir + ) + + conv_list = ['conv2', 'conv3', 'conv4', 'conv5'] + conv_list = ['conv4'] + for conv_n in conv_list: + if conv_n == 'conv2': + last_conv_layer = 'conv2_block3_1_conv' + elif conv_n == 'conv3': + last_conv_layer = 'conv3_block4_1_conv' + elif conv_n == 'conv4': + last_conv_layer = 'conv4_block6_1_conv' + elif conv_n == 'conv5': + last_conv_layer = 'conv5_block3_out' + + heatmap = compute_heatmap( + model=model, + saved_model=saved_model, + image=image, + pred_index=pred_index, + last_conv_layer=last_conv_layer + ) + + save_gradcam( + image=image, + heatmap=heatmap, + val_gradcam_dir=val_gradcam_dir, + test_gradcam_dir=test_gradcam_dir, + alpha=alpha, + i=i + ) + + print('label:', label) + print('ID:', ID) + print('y_pred:', y_pred) + print('prediction:', pred_index) + print('conv layer:', conv_n) + + + +# if last_conv_layer is None: +# last_conv_layer = find_target_layer( +# model=model, +# saved_model=saved_model +# ) +# print(last_conv_layer) +# +# if show_network == True: +# for idx in range(len(model.layers)): +# print(model.get_layer(index = idx).name) + +# # compute the guided gradients +# castConvOutputs = tf.cast(convOutputs > 0, "float32") +# castGrads = tf.cast(grads > 0, "float32") +# guidedGrads = castConvOutputs * castGrads * grads +# # the convolution and guided gradients have a batch dimension +# # (which we don't need) so let's grab the volume itself and +# # discard the batch +# convOutputs = convOutputs[0] +# guidedGrads = guidedGrads[0] +# +# # compute the average of the gradient values, and using them +# # as weights, compute the ponderation of the filters with +# # respect to the weights +# weights = tf.reduce_mean(guidedGrads, axis=(0, 1)) +# cam = tf.reduce_sum(tf.multiply(weights, convOutputs), axis=-1) +# +# # grab the spatial dimensions of the input image and resize +# # the output class activation map to match the input image +# # dimensions +## (w, h) = (image.shape[2], image.shape[1]) +## heatmap = cv2.resize(cam.numpy(), (w, h)) +# heatmap = cv2.resize(heatmap.numpy(), (64, 64)) +# # normalize the heatmap such that all values lie in the range +## # [0, 1], scale the resulting values to the range [0, 255], +## # and then convert to an unsigned 8-bit integer +# numer = heatmap - np.min(heatmap) +# eps = 1e-8 +# denom = (heatmap.max() - heatmap.min()) + eps +# heatmap = numer / denom +# heatmap = (heatmap * 255).astype("uint8") +# colormap=cv2.COLORMAP_VIRIDIS +# heatmap = cv2.applyColorMap(heatmap, colormap) +# print('heatmap shape:', heatmap.shape) +## img = image[:, :, :, 0] +## print('img shape:', img.shape) +# img = image.reshape((64, 64, 3)) +# print(img.shape) +# output = cv2.addWeighted(img, 0.5, heatmap, 0.5, 0) +# +# +# return heatmap, output diff --git a/utils/gradcam2.py b/utils/gradcam2.py new file mode 100644 index 0000000..869e2bc --- /dev/null +++ b/utils/gradcam2.py @@ -0,0 +1,401 @@ +from tensorflow.keras.models import Model +import tensorflow as tf +from tensorflow import keras +import matplotlib.pyplot as plt +import matplotlib.cm as cm +import numpy as np +import cv2 +import numpy as np +import pandas as pd +from tensorflow.keras.models import load_model +import tensorflow as tf +import os +import SimpleITK as sitk + +#------------------------------------------------------------------------------------ +# find last conv layer +#----------------------------------------------------------------------------------- +def find_target_layer(cnn_model): + + # find the final conv layer by looping layers in reverse order + for layer in reversed(cnn_model.layers): + # check to see if the layer has a 4D output + if len(layer.output_shape) == 4: + return layer.name + raise ValueError("Could not find 4D layer. Cannot apply GradCAM.") + +#---------------------------------------------------------------------------------- +# calculate gradient class actiavtion map +#---------------------------------------------------------------------------------- +def compute_heatmap(cnn_model, image, pred_index, last_conv_layer): + + """ + construct our gradient model by supplying (1) the inputs + to our pre-trained model, (2) the output of the (presumably) + final 4D layer in the network, and (3) the output of the + softmax activations from the model + """ + gradModel = Model( + inputs=[cnn_model.inputs], + outputs=[cnn_model.get_layer(last_conv_layer).output, cnn_model.output] + ) + + # record operations for automatic differentiation + with tf.GradientTape() as tape: + """ + cast the image tensor to a float-32 data type, pass the + image through the gradient model, and grab the loss + associated with the specific class index + """ + #print(pred_index) + inputs = tf.cast(image, tf.float32) + #print(image.shape) + last_conv_layer_output, preds = gradModel(inputs) + #print(preds) + #print(preds.shape) + # class_channel = preds[:, pred_index] + class_channel = preds + + ## use automatic differentiation to compute the gradients + grads = tape.gradient(class_channel, last_conv_layer_output) + """ + This is a vector where each entry is the mean intensity of the gradient + over a specific feature map channel + """ + pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) + """ + We multiply each channel in the feature map array + by "how important this channel is" with regard to the top predicted class + then sum all the channels to obtain the heatmap class activation + """ + last_conv_layer_output = last_conv_layer_output[0] + heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis] + heatmap = tf.squeeze(heatmap) + + # For visualization purpose, we will also normalize the heatmap between 0 & 1 + heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap) + heatmap = heatmap.numpy() + + return heatmap + +#------------------------------------------------------------------------------------ +# save gradcam heat map +#----------------------------------------------------------------------------------- +def save_gradcam(run_type, img_back, heatmap, val_gradcam_dir, test_gradcam_dir, + exval2_gradcam_dir, alpha, img_id): + +# print('heatmap:', heatmap.shape) + # Rescale heatmap to a range 0-255 + heatmap = np.uint8(255 * heatmap) + # Use jet colormap to colorize heatmap + jet = cm.get_cmap('jet') + # Use RGB values of the colormap + jet_colors = jet(np.arange(256))[:, :3] + jet_heatmap = jet_colors[heatmap] + + # resize heatmap + jet_heatmap = keras.preprocessing.image.array_to_img(jet_heatmap) + jet_heatmap0 = jet_heatmap.resize(re_size) + jet_heatmap1 = keras.preprocessing.image.img_to_array(jet_heatmap0) +# print('jet_heatmap:', jet_heatmap1.shape) + + # resize background CT image + img = img_back.reshape((192, 192, 3)) + img = keras.preprocessing.image.array_to_img(img) + ## resize if resolution of raw image too low + #img0 = img.resize(re_size) + img1 = keras.preprocessing.image.img_to_array(img) +# print('img shape:', img1.shape) + + # Superimpose the heatmap on original image + superimposed_img = jet_heatmap1 * alpha + img1 + superimposed_img = keras.preprocessing.image.array_to_img(superimposed_img) + + # Save the superimposed image + if run_type == 'val': + save_dir = val_gradcam_dir + elif run_type == 'test': + save_dir = test_gradcam_dir + elif run_type == 'exval2': + save_dir = exval2_gradcam_dir + + fn1 = str(img_id) + '_' + str(conv_n) + '_' + 'gradcam.png' + fn2 = str(img_id) + '_' + str(conv_n) + '_' + 'heatmap.png' + fn3 = str(img_id) + '_' + str(conv_n) + '_' + 'heatmap_raw.png' + fn4 = str(img_id) + '_' + 'CT.png' + superimposed_img.save(os.path.join(save_dir, fn1)) +# jet_heatmap0.save(os.path.join(save_dir, fn2)) +# jet_heatmap.save(os.path.join(save_dir, fn3)) +# img0.save(os.path.join(save_dir, fn4)) + +#--------------------------------------------------------------------------------- +# get background image +#--------------------------------------------------------------------------------- +def get_background(img_id, slice_range, PMH_reg_dir, CHUM_reg_dir, CHUS_reg_dir, + MDACC_reg_dir): + + pat_id = img_id.split('_')[0] + if pat_id[:-3] == 'PMH': + reg_dir = PMH_reg_dir + elif pat_id[:-3] == 'CHUM': + reg_dir = CHUM_reg_dir + elif pat_id[:-3] == 'CHUS': + reg_dir = CHUS_reg_dir + elif pat_id[:-3] == 'MDACC': + reg_dir = MDACC_reg_dir + elif pat_id[:4] == 'rtog': + reg_dir = rtog_reg_dir + pat_id = img_id.split('_s')[0] + + nrrd_id = str(pat_id) + '.nrrd' + data_dir = os.path.join(reg_dir, nrrd_id) + ### get image slice and save them as numpy array + nrrd = sitk.ReadImage(data_dir, sitk.sitkFloat32) + img_arr = sitk.GetArrayFromImage(nrrd) + print(img_arr.shape) + data = img_arr[slice_range, :, :] + #slice_n = img_id.split('_')[1][6:] + slice_n = img_id.split('slice0')[1] + slice_n = int(slice_n) + print(slice_n) + arr = data[slice_n, :, :] + arr = np.clip(arr, a_min=-200, a_max=200) + MAX, MIN = arr.max(), arr.min() + arr = (arr - MIN) / (MAX - MIN) + #print(arr.shape) + arr = np.repeat(arr[..., np.newaxis], 3, axis=-1) + arr = arr.reshape((1, 192, 192, 3)) + #print(arr.shape) + img_back = arr + #np.save(os.path.join(pro_data_dir, fn_arr_3ch), img_arr) + + return img_back + +#--------------------------------------------------------------------------------- +# get data +#--------------------------------------------------------------------------------- +def gradcam(run_type, input_channel, img_IDs, conv_list, val_dir, test_dir, + exval_dir, model_dir, saved_model, data_pro_dir, pro_data_dir, + run_model): + + ## load model and find conv layers + cnn_model = load_model(os.path.join(model_dir, saved_model)) + # model.summary() + + ### load train data based on input channels + if run_type == 'val': + if input_channel == 1: + fn = 'val_arr_1ch.npy' + elif input_channel == 3: + fn = 'val_arr_3ch.npy' + data = np.load(os.path.join(data_pro_dir, fn)) + df = pd.read_csv(os.path.join(pro_data_dir, 'val_img_pred.csv')) + save_dir = val_dir + elif run_type == 'test': + if input_channel == 1: + fn = 'test_arr_1ch.npy' + elif input_channel == 3: + fn = 'test_arr_3ch.npy' + data = np.load(os.path.join(data_pro_dir, fn)) + df = pd.read_csv(os.path.join(test_dir, 'df_test_pred.csv')) + save_dir = test_dir + elif run_type == 'exval2': + if input_channel == 1: + fn = 'exval_arr_1ch.npy' + elif input_channel == 3: + fn = 'rtog_arr.npy' + data = np.load(os.path.join(pro_data_dir, fn)) + df = pd.read_csv(os.path.join(pro_data_dir, 'rtog_img_pred.csv')) + save_dir = exval2_dir + print("successfully load data!") + print(img_IDs) + print(df[0:10]) + ## load data for gradcam + img_inds = df[df['fn'].isin(img_IDs)].index.tolist() + print(img_inds) + if img_inds == []: + print("list is empty. Choose other slices.") + else: + for i, img_id in zip(img_inds, img_IDs): + print('image ID:', img_id) + print('index:', i) + image = data[i, :, :, :] + image = image.reshape((1, 192, 192, 3)) + label = df['label'][i] + pred_index = df['y_pred_class'][i] + y_pred = df['y_pred'][i] + ## get background CT image + img_back = get_background( + img_id=img_id, + slice_range=slice_range, + PMH_reg_dir=PMH_reg_dir, + CHUM_reg_dir=CHUM_reg_dir, + CHUS_reg_dir=CHUS_reg_dir, + MDACC_reg_dir=MDACC_reg_dir + ) + if run_model == 'ResNet101V2': + for conv_n in conv_list: + if conv_n == 'conv2': + last_conv_layer = 'conv2_block3_1_conv' + elif conv_n == 'conv3': + last_conv_layer = 'conv3_block4_1_conv' + elif conv_n == 'conv4': + last_conv_layer = 'conv4_block6_1_conv' + elif conv_n == 'conv5': + last_conv_layer = 'conv5_block3_out' + elif run_model == 'EfficientNetB4': + last_conv_layer = 'top_conv' + #last_conv_layer = 'top_activation' + ## compute heatnap + heatmap = compute_heatmap( + cnn_model=cnn_model, + image=image, + pred_index=pred_index, + last_conv_layer=last_conv_layer + ) + ## save heatmap + save_gradcam( + run_type=run_type, + img_back=img_back, + heatmap=heatmap, + val_gradcam_dir=val_gradcam_dir, + test_gradcam_dir=test_gradcam_dir, + exval2_gradcam_dir=exval2_gradcam_dir, + alpha=alpha, + img_id=img_id + ) + + print('label:', label) + print('ID:', img_id) + print('y_pred:', y_pred) + print('pred class:', pred_index) + #print('conv layer:', conv_n) +#--------------------------------------------------------------------------------- +# get data +#--------------------------------------------------------------------------------- +if __name__ == '__main__': + + train_img_dir = '/media/bhkann/HN_RES1/HN_CONTRAST/train_img_dir' + val_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/val' + test_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/test' + exval_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/exval' + exval2_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/exval2' + val_gradcam_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/val/gradcam' + test_gradcam_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/test/gradcam' + exval_gradcam_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/exval/gradcam' + exval2_gradcam_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/exval2/gradcam' + pro_data_dir = '/home/bhkann/zezhong/git_repo/IV-Contrast-CNN-Project/pro_data' + model_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/model' + data_pro_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data_pro' + CHUM_reg_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data/CHUM_data_reg' + CHUS_reg_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data/CHUS_data_reg' + PMH_reg_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data/PMH_data_reg' + MDACC_reg_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data/MDACC_data_reg' + rtog_reg_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/ahmed_data/rtog-0617_reg' + input_channel = 3 + re_size = (192, 192) + crop = True + alpha = 0.9 + + run_type = 'exval2' + run_model = 'EfficientNetB4' + if run_type in ['exval', 'exval2']: + slice_range = range(50, 120) + slice_ids = ["{0:03}".format(i) for i in range(70)] + #saved_model = 'FineTuned_model_2021_08_02_17_22_10' + saved_model = 'Tuned_EfficientNetB4_2021_08_27_20_26_55' + elif run_type in ['val', 'test']: + slice_range = range(17, 83) + slice_ids = ["{0:03}".format(i) for i in range(66)] + #saved_model = 'ResNet_2021_07_18_06_28_40' + saved_model = 'EffNet_2021_08_24_09_57_13' + print(run_type) + print(slice_range) + print(slice_ids) + show_network = False + conv_n = 'conv5' + conv_list = ['conv2', 'conv3', 'conv4', 'conv5'] + + ## image ID + pat_id = 'PMH423' + pat_ids = ['rtog_0617-438343'] + pat_ids = ['PMH574', 'PMH146', 'PMH135'] + pat_ids = ['PMH433', 'PMH312', 'PMH234', 'PMH281', 'PMH511', 'PMH405'] + pat_ids = ['PMH465', 'PMH287', 'PMH308', 'PMH276', 'PMH595', 'PMH467'] + pat_ids = ['rtog_0617-349454', 'rtog_0617-438343', 'rtog_0617-292370', 'rtog_0617-349454'] + img_IDs = [] + for pat_id in pat_ids: + for slice_id in slice_ids: + img_id = pat_id + '_' + 'slice' + str(slice_id) + img_IDs.append(img_id) + + gradcam( + run_type=run_type, + input_channel=input_channel, + img_IDs=img_IDs, + conv_list=conv_list, + val_dir=val_dir, + test_dir=test_dir, + exval_dir=exval_dir, + model_dir=model_dir, + saved_model=saved_model, + data_pro_dir=data_pro_dir, + pro_data_dir=pro_data_dir, + run_model=run_model + ) + + + + +# if last_conv_layer is None: +# last_conv_layer = find_target_layer( +# model=model, +# saved_model=saved_model +# ) +# print(last_conv_layer) +# +# if show_network == True: +# for idx in range(len(model.layers)): +# print(model.get_layer(index = idx).name) + +# # compute the guided gradients +# castConvOutputs = tf.cast(convOutputs > 0, "float32") +# castGrads = tf.cast(grads > 0, "float32") +# guidedGrads = castConvOutputs * castGrads * grads +# # the convolution and guided gradients have a batch dimension +# # (which we don't need) so let's grab the volume itself and +# # discard the batch +# convOutputs = convOutputs[0] +# guidedGrads = guidedGrads[0] +# +# # compute the average of the gradient values, and using them +# # as weights, compute the ponderation of the filters with +# # respect to the weights +# weights = tf.reduce_mean(guidedGrads, axis=(0, 1)) +# cam = tf.reduce_sum(tf.multiply(weights, convOutputs), axis=-1) +# +# # grab the spatial dimensions of the input image and resize +# # the output class activation map to match the input image +# # dimensions +## (w, h) = (image.shape[2], image.shape[1]) +## heatmap = cv2.resize(cam.numpy(), (w, h)) +# heatmap = cv2.resize(heatmap.numpy(), (64, 64)) +# # normalize the heatmap such that all values lie in the range +## # [0, 1], scale the resulting values to the range [0, 255], +## # and then convert to an unsigned 8-bit integer +# numer = heatmap - np.min(heatmap) +# eps = 1e-8 +# denom = (heatmap.max() - heatmap.min()) + eps +# heatmap = numer / denom +# heatmap = (heatmap * 255).astype("uint8") +# colormap=cv2.COLORMAP_VIRIDIS +# heatmap = cv2.applyColorMap(heatmap, colormap) +# print('heatmap shape:', heatmap.shape) +## img = image[:, :, :, 0] +## print('img shape:', img.shape) +# img = image.reshape((64, 64, 3)) +# print(img.shape) +# output = cv2.addWeighted(img, 0.5, heatmap, 0.5, 0) +# +# +# return heatmap, output diff --git a/utils/gradcam_new.py b/utils/gradcam_new.py new file mode 100644 index 0000000..d748268 --- /dev/null +++ b/utils/gradcam_new.py @@ -0,0 +1,312 @@ +from tensorflow.keras.models import Model +import tensorflow as tf +from tensorflow import keras +import matplotlib.pyplot as plt +import matplotlib.cm as cm +import numpy as np +import cv2 +import numpy as np +import pandas as pd +from tensorflow.keras.models import load_model +import tensorflow as tf +import os + +#--------------------------------------------------------------------------------- +# get data +#--------------------------------------------------------------------------------- +def pat_data(pat_id, run_type, input_channel, i, val_save_dir, test_save_dir): + + ### load data and labels + if run_type == 'val': + df = pd.read_csv(os.path.join(data_pro_dir, 'df_pat_val.csv')) + if pat_id[:-3] == 'PMH': + data_dir = PMH_reg_dir + elif pat_id[:-3] == 'CHUM': + data_dir = CHUM_reg_dir + elif pat_id[:-3] == 'CHUS': + data_dir = CHUS_reg_dir + elif run_type == 'test': + df = pd.read_csv(os.path.join(data_pro_dir, 'df_pat_test.csv')) + data_dir = MDACC_reg_dir + elif run_type == 'exval': + df = pd.read_csv(os.path.join(data_pro_dir, 'df_pat_exval.csv')) + data_dir = NSCLS_reg_dir + + ## create numpy array + scan_dir = os.path.join(data_dir, pat_id) + nrrd = sitk.ReadImage(scan_dir, sitk.sitkFloat32) + img_arr = sitk.GetArrayFromImage(nrrd) + data = img_arr[slice_range, :, :] + ### clear signals lower than -1024 + data[data <= -1024] = -1024 + ### strip skull, skull UHI = ~700 + data[data > 700] = 0 + ### normalize UHI to 0 - 1, all signlas outside of [0, 1] will be 0; + if norm_type == 'np_interp': + data = np.interp(data, [-200, 200], [0, 1]) + elif norm_type == 'np_clip': + data = np.clip(data, a_min=-200, a_max=200) + MAX, MIN = data.max(), data.min() + data = (data - MIN) / (MAX - MIN) + ## stack all image arrays to one array for CNN input + arr = np.concatenate([arr, data], 0) + + + ### load label + y_true = df['label'].loc[df['ID'] == pat_id] + + + y_pred_class = df['y_pred_class'] + y_pred = df['y_pred'] + ID = df['fn'] + ### find the ith image to show grad-cam map + img = data[i, :, :, :] + img = img.reshape((1, 192, 192, 3)) + label = y_true[i] + pred_index = y_pred_class[i] + y_pred = y_pred[i] + ID = ID[i] + + return img, label, pred_index, y_pred, ID + +#------------------------------------------------------------------------------------ +# find last conv layer +#----------------------------------------------------------------------------------- +def find_target_layer(model, saved_model): + + # find the final conv layer by looping layers in reverse order + for layer in reversed(model.layers): + # check to see if the layer has a 4D output + if len(layer.output_shape) == 4: + return layer.name + raise ValueError("Could not find 4D layer. Cannot apply GradCAM.") + +#---------------------------------------------------------------------------------- +# calculate gradient class actiavtion map +#---------------------------------------------------------------------------------- +def compute_heatmap(model, saved_model, image, pred_index, last_conv_layer): + + """ + construct our gradient model by supplying (1) the inputs + to our pre-trained model, (2) the output of the (presumably) + final 4D layer in the network, and (3) the output of the + softmax activations from the model + """ + gradModel = Model( + inputs=[model.inputs], + outputs=[model.get_layer(last_conv_layer).output, model.output] + ) + + # record operations for automatic differentiation + with tf.GradientTape() as tape: + """ + cast the image tensor to a float-32 data type, pass the + image through the gradient model, and grab the loss + associated with the specific class index + """ + print(pred_index) + inputs = tf.cast(image, tf.float32) + print(image.shape) + last_conv_layer_output, preds = gradModel(inputs) + print(preds) + print(preds.shape) + # class_channel = preds[:, pred_index] + class_channel = preds + # use automatic differentiation to compute the gradients + grads = tape.gradient(class_channel, last_conv_layer_output) + """ + This is a vector where each entry is the mean intensity of the gradient + over a specific feature map channel + """ + pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) + """ + We multiply each channel in the feature map array + by "how important this channel is" with regard to the top predicted class + then sum all the channels to obtain the heatmap class activation + """ + last_conv_layer_output = last_conv_layer_output[0] + heatmap = last_conv_layer_output @ pooled_grads[..., tf.newaxis] + heatmap = tf.squeeze(heatmap) + + # For visualization purpose, we will also normalize the heatmap between 0 & 1 + heatmap = tf.maximum(heatmap, 0) / tf.math.reduce_max(heatmap) + heatmap = heatmap.numpy() + + return heatmap + + +#------------------------------------------------------------------------------------ +# save gradcam heat map +#----------------------------------------------------------------------------------- +def save_gradcam(image, heatmap, val_gradcam_dir, test_gradcam_dir, alpha, i): + +# print('heatmap:', heatmap.shape) + # Rescale heatmap to a range 0-255 + heatmap = np.uint8(255 * heatmap) + # Use jet colormap to colorize heatmap + jet = cm.get_cmap("jet") + # Use RGB values of the colormap + jet_colors = jet(np.arange(256))[:, :3] + jet_heatmap = jet_colors[heatmap] + + # resize heatmap + jet_heatmap = keras.preprocessing.image.array_to_img(jet_heatmap) + jet_heatmap0 = jet_heatmap.resize(re_size) + jet_heatmap1 = keras.preprocessing.image.img_to_array(jet_heatmap0) +# print('jet_heatmap:', jet_heatmap1.shape) + + # resize background CT image + img = image.reshape((192, 192, 3)) + img = keras.preprocessing.image.array_to_img(img) + img0 = img.resize(re_size) + img1 = keras.preprocessing.image.img_to_array(img0) +# print('img shape:', img1.shape) + + # Superimpose the heatmap on original image + superimposed_img = jet_heatmap1 * alpha + img1 + superimposed_img = keras.preprocessing.image.array_to_img(superimposed_img) + + # Save the superimposed image + if run_type == 'val': + save_dir = val_gradcam_dir + elif run_type == 'test': + save_dir = test_gradcam_dir + fn1 = str(conv_n) + '_' + str(i) + '_' + 'gradcam.png' + fn2 = str(conv_n) + '_' + str(i) + '_' + 'heatmap.png' + fn3 = str(conv_n) + '_' + str(i) + '_' + 'heatmap_raw.png' + fn4 = str(i) + '_' + 'CT.png' + superimposed_img.save(os.path.join(save_dir, fn1)) +# jet_heatmap0.save(os.path.join(save_dir, fn2)) +# jet_heatmap.save(os.path.join(save_dir, fn3)) +# img0.save(os.path.join(save_dir, fn4)) + + +if __name__ == '__main__': + + train_img_dir = '/media/bhkann/HN_RES1/HN_CONTRAST/train_img_dir' + val_save_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/val' + test_save_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/test' + val_gradcam_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/val/gradcam' + test_gradcam_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/test/gradcam' + data_pro_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data_pro' + model_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/model' + input_channel = 3 + re_size = (192, 192) + i = 72 + crop = True + alpha = 0.9 + saved_model = 'ResNet_2021_07_18_06_28_40' + show_network = False + conv_n = 'conv5' + run_type = 'val' + + #--------------------------------------------------------- + # run main function + #-------------------------------------------------------- + if run_type == 'val': + save_dir = val_save_dir + elif run_type == 'test': + save_dir = test_save_dir + + ## load model and find conv layers + model = load_model(os.path.join(model_dir, saved_model)) +# model.summary() + + list_i = [100, 105, 110, 115, 120, 125] + for i in list_i: + image, label, pred_index, y_pred, ID = data( + input_channel=input_channel, + i=i, + val_save_dir=val_save_dir, + test_save_dir=test_save_dir + ) + + conv_list = ['conv2', 'conv3', 'conv4', 'conv5'] + conv_list = ['conv4'] + for conv_n in conv_list: + if conv_n == 'conv2': + last_conv_layer = 'conv2_block3_1_conv' + elif conv_n == 'conv3': + last_conv_layer = 'conv3_block4_1_conv' + elif conv_n == 'conv4': + last_conv_layer = 'conv4_block6_1_conv' + elif conv_n == 'conv5': + last_conv_layer = 'conv5_block3_out' + + heatmap = compute_heatmap( + model=model, + saved_model=saved_model, + image=image, + pred_index=pred_index, + last_conv_layer=last_conv_layer + ) + + save_gradcam( + image=image, + heatmap=heatmap, + val_gradcam_dir=val_gradcam_dir, + test_gradcam_dir=test_gradcam_dir, + alpha=alpha, + i=i + ) + + print('label:', label) + print('ID:', ID) + print('y_pred:', y_pred) + print('prediction:', pred_index) + print('conv layer:', conv_n) + + + +# if last_conv_layer is None: +# last_conv_layer = find_target_layer( +# model=model, +# saved_model=saved_model +# ) +# print(last_conv_layer) +# +# if show_network == True: +# for idx in range(len(model.layers)): +# print(model.get_layer(index = idx).name) + +# # compute the guided gradients +# castConvOutputs = tf.cast(convOutputs > 0, "float32") +# castGrads = tf.cast(grads > 0, "float32") +# guidedGrads = castConvOutputs * castGrads * grads +# # the convolution and guided gradients have a batch dimension +# # (which we don't need) so let's grab the volume itself and +# # discard the batch +# convOutputs = convOutputs[0] +# guidedGrads = guidedGrads[0] +# +# # compute the average of the gradient values, and using them +# # as weights, compute the ponderation of the filters with +# # respect to the weights +# weights = tf.reduce_mean(guidedGrads, axis=(0, 1)) +# cam = tf.reduce_sum(tf.multiply(weights, convOutputs), axis=-1) +# +# # grab the spatial dimensions of the input image and resize +# # the output class activation map to match the input image +# # dimensions +## (w, h) = (image.shape[2], image.shape[1]) +## heatmap = cv2.resize(cam.numpy(), (w, h)) +# heatmap = cv2.resize(heatmap.numpy(), (64, 64)) +# # normalize the heatmap such that all values lie in the range +## # [0, 1], scale the resulting values to the range [0, 255], +## # and then convert to an unsigned 8-bit integer +# numer = heatmap - np.min(heatmap) +# eps = 1e-8 +# denom = (heatmap.max() - heatmap.min()) + eps +# heatmap = numer / denom +# heatmap = (heatmap * 255).astype("uint8") +# colormap=cv2.COLORMAP_VIRIDIS +# heatmap = cv2.applyColorMap(heatmap, colormap) +# print('heatmap shape:', heatmap.shape) +## img = image[:, :, :, 0] +## print('img shape:', img.shape) +# img = image.reshape((64, 64, 3)) +# print(img.shape) +# output = cv2.addWeighted(img, 0.5, heatmap, 0.5, 0) +# +# +# return heatmap, output diff --git a/utils/import_to_tensorboard.py b/utils/import_to_tensorboard.py new file mode 100644 index 0000000..fab0626 --- /dev/null +++ b/utils/import_to_tensorboard.py @@ -0,0 +1,48 @@ +"""Imports a protobuf model as a graph in Tensorboard.""" + + +from tensorflow.python.client import session +from tensorflow.python.framework import importer +from tensorflow.python.framework import ops +from tensorflow.python.summary import summary +from tensorflow.python.tools import saved_model_utils + +# Try importing TensorRT ops if available +# TODO(aaroey): ideally we should import everything from contrib, but currently +# tensorrt module would cause build errors when being imported in +# tensorflow/contrib/__init__.py. Fix it. +# pylint: disable=unused-import,g-import-not-at-top,wildcard-import +try: + from tensorflow.contrib.tensorrt.ops.gen_trt_engine_op import * +except ImportError: + pass +# pylint: enable=unused-import,g-import-not-at-top,wildcard-import + + +def import_to_tensorboard(model_dir, log_dir, tag_set): + """View an SavedModel as a graph in Tensorboard. + Args: + model_dir: The directory containing the SavedModel to import. + log_dir: The location for the Tensorboard log to begin visualization from. + tag_set: Group of tag(s) of the MetaGraphDef to load, in string format, + separated by ','. For tag-set contains multiple tags, all tags must be + passed in. + Usage: Call this function with your SavedModel location and desired log + directory. Launch Tensorboard by pointing it to the log directory. View your + imported SavedModel as a graph. + """ + with session.Session(graph=ops.Graph()) as sess: + input_graph_def = saved_model_utils.get_meta_graph_def(model_dir, + tag_set).graph_def + importer.import_graph_def(input_graph_def) + + pb_visual_writer = summary.FileWriter(log_dir) + pb_visual_writer.add_graph(sess.graph) + print("Model Imported. Visualize by running: " + "tensorboard --logdir={}".format(log_dir)) + + + +import_to_tensorboard(model_dir, log_dir, tag_set) + + diff --git a/utils/mean_CI.py b/utils/mean_CI.py new file mode 100644 index 0000000..3402865 --- /dev/null +++ b/utils/mean_CI.py @@ -0,0 +1,42 @@ +import os +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import scipy.stats as ss +import pickle + + +# ---------------------------------------------------------------------------------- +# define function for calculating mean and 95% CI +# ---------------------------------------------------------------------------------- +def mean_CI(data): + + mean = np.mean(np.array(data)) + CI = ss.t.interval( + alpha=0.95, + df=len(data)-1, + loc=np.mean(data), + scale=ss.sem(data) + ) + lower = CI[0] + upper = CI[1] + + return mean, lower, upper + +##def mean_CI(metric): +## alpha = 0.95 +## mean = np.mean(np.array(metric)) +## p_up = (1.0 - alpha)/2.0*100 +## lower = max(0.0, np.percentile(metric, p_up)) +## p_down = ((alpha + (1.0 - alpha)/2.0)*100) +## upper = min(1.0, np.percentile(metric, p_down)) +## return mean, lower, upper + +##def mean_CI(stat, confidence=0.95): +## alpha = 0.95 +## mean = np.mean(np.array(stat)) +## p_up = (1.0 - alpha)/2.0*100 +## lower = max(0.0, np.percentile(stat, p_up)) +## p_down = ((alpha + (1.0 - alpha)/2.0)*100) +## upper = min(1.0, np.percentile(stat, p_down)) +## return mean, lower, upper diff --git a/utils/nrrd_reg.py b/utils/nrrd_reg.py new file mode 100644 index 0000000..0849b9c --- /dev/null +++ b/utils/nrrd_reg.py @@ -0,0 +1,56 @@ +import sys, os, glob +import SimpleITK as sitk +#import pydicom +import numpy as np + + +def nrrd_reg_rigid_ref(img_nrrd, fixed_img_dir, patient_id, save_dir): + + fixed_img = sitk.ReadImage(fixed_img_dir, sitk.sitkFloat32) + moving_img = img_nrrd +# moving_img = sitk.ReadImage(img_nrrd, sitk.sitkUInt32) + #moving_img = sitk.ReadImage(input_path, sitk.sitkFloat32) + + transform = sitk.CenteredTransformInitializer( + fixed_img, + moving_img, + sitk.Euler3DTransform(), + sitk.CenteredTransformInitializerFilter.GEOMETRY + ) + + # multi-resolution rigid registration using Mutual Information + registration_method = sitk.ImageRegistrationMethod() + registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50) + registration_method.SetMetricSamplingStrategy(registration_method.RANDOM) + registration_method.SetMetricSamplingPercentage(0.01) + registration_method.SetInterpolator(sitk.sitkLinear) + + registration_method.SetOptimizerAsGradientDescent( + learningRate=1.0, + numberOfIterations=100, + convergenceMinimumValue=1e-6, + convergenceWindowSize=10 + ) + + registration_method.SetOptimizerScalesFromPhysicalShift() + registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1]) + registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0]) + registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() + registration_method.SetInitialTransform(transform) + final_transform = registration_method.Execute(fixed_img, moving_img) + moving_img_resampled = sitk.Resample( + moving_img, + fixed_img, + final_transform, + sitk.sitkLinear, + 0.0, + moving_img.GetPixelID() + ) + img_reg = moving_img_resampled + + if save_dir != None: + nrrd_fn = str(patient_id) + '.nrrd' + sitk.WriteImage(img_red, os.path.join(save_dir, nrrd_fn)) + + return img_reg + #return fixed_img, moving_img, final_transform diff --git a/utils/plot_cm.py b/utils/plot_cm.py new file mode 100644 index 0000000..aaceb3c --- /dev/null +++ b/utils/plot_cm.py @@ -0,0 +1,42 @@ +import seaborn as sn +import numpy as np +import matplotlib.pyplot as plt +import os + +def plot_cm(cm0, cm_type, level, save_dir): + + if cm_type == 'norm': + fmt = '' + elif cm_type == 'raw': + fmt = 'd' + + ax = sn.heatmap( + cm0, + annot=True, + cbar=True, + cbar_kws={'ticks': [-0.1]}, + annot_kws={'size': 26, 'fontweight': 'bold'}, + cmap='Blues', + fmt=fmt, + linewidths=0.5 + ) + + ax.axhline(y=0, color='k', linewidth=4) + ax.axhline(y=2, color='k', linewidth=4) + ax.axvline(x=0, color='k', linewidth=4) + ax.axvline(x=2, color='k', linewidth=4) + + ax.tick_params(direction='out', length=4, width=2, colors='k') + ax.xaxis.set_ticks_position('top') + ax.set_xticklabels([]) + ax.set_yticklabels([]) + ax.set_aspect('equal') + plt.tight_layout() + + fn = 'cm' + '_' + str(cm_type) + '_' + str(level) + '.png' + plt.savefig( + os.path.join(save_dir, fn), + format='png', + dpi=600 + ) + plt.close() diff --git a/utils/plot_prc.py b/utils/plot_prc.py new file mode 100644 index 0000000..252c5b2 --- /dev/null +++ b/utils/plot_prc.py @@ -0,0 +1,71 @@ +#---------------------------------------------------------------------- +# Deep learning for classification for contrast CT; +# Transfer learning using Google Inception V3; +#----------------------------------------------------------------------------------------- +import os +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import scipy.stats as ss +import pickle +from sklearn.metrics import auc, roc_auc_score +from sklearn.metrics import precision_recall_curve + + +# ---------------------------------------------------------------------------------- +# precision recall curve +# ---------------------------------------------------------------------------------- +def plot_prc(save_dir, y_true, y_pred, level, color): + + precision = dict() + recall = dict() + threshold = dict() + prc_auc = [] + + precision, recall, threshold = precision_recall_curve(y_true, y_pred) + RP_2D = np.array([recall, precision]) + RP_2D = RP_2D[np.argsort(RP_2D[:, 0])] + #prc_auc.append(auc(RP_2D[1], RP_2D[0])) + prc_auc = auc(RP_2D[1], RP_2D[0]) + prc_auc = np.around(prc_auc, 3) + #print('PRC AUC:', prc_auc) + #prc_auc = auc(precision, recall) + print(prc_auc) + #prc_auc = 1 + + fn = 'prc' + '_' + str(level) + '.png' + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + ax.set_aspect('equal') + plt.plot( + recall, + precision, + color=color, + linewidth=3, + label='AUC %0.3f' % prc_auc + ) + plt.xlim([0, 1.03]) + plt.ylim([0, 1.03]) + ax.axhline(y=0, color='k', linewidth=4) + ax.axhline(y=1.03, color='k', linewidth=4) + ax.axvline(x=0, color='k', linewidth=4) + ax.axvline(x=1.03, color='k', linewidth=4) + plt.xticks([0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=16, fontweight='bold') + plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=16, fontweight='bold') + plt.xlabel('recall', fontweight='bold', fontsize=16) + plt.ylabel('precision', fontweight='bold', fontsize=16) + plt.legend(loc='lower left', prop={'size': 16, 'weight': 'bold'}) + plt.grid(True) +# plt.tight_layout(pad=0.2, h_pad=None, w_pad=None, rect=None) +# plt.subplots_adjust(left=0.1, right=0.9, top=0.9, bottom=0.1) + plt.savefig(os.path.join(save_dir, fn), format='png', dpi=600) + #plt.show() + plt.close() + + return prc_auc + + + + + + diff --git a/utils/plot_roc.py b/utils/plot_roc.py new file mode 100644 index 0000000..7c5d16e --- /dev/null +++ b/utils/plot_roc.py @@ -0,0 +1,49 @@ +import pandas as pd +import seaborn as sn +import matplotlib.pyplot as plt +import glob +import pickle +import os +import numpy as np +from sklearn.metrics import roc_auc_score +from sklearn.metrics import auc +from sklearn.metrics import roc_curve + + +def plot_roc(save_dir, y_true, y_pred, level, color): + + fpr = dict() + tpr = dict() + roc_auc = dict() + threshold = dict() + + ### calculate auc + fpr, tpr, threshold = roc_curve(y_true, y_pred) + roc_auc = auc(fpr, tpr) + roc_auc = np.around(roc_auc, 3) + #print('ROC AUC:', roc_auc) + + fn = 'roc'+ '_' + str(level) + '.png' + ### plot roc + fig = plt.figure() + ax = fig.add_subplot(1, 1, 1) + ax.set_aspect('equal') + plt.plot(fpr, tpr, color=color, linewidth=3, label='AUC %0.3f' % roc_auc) + plt.xlim([-0.03, 1]) + plt.ylim([0, 1.03]) + ax.axhline(y=0, color='k', linewidth=4) + ax.axhline(y=1.03, color='k', linewidth=4) + ax.axvline(x=-0.03, color='k', linewidth=4) + ax.axvline(x=1, color='k', linewidth=4) + plt.xticks([0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=16, fontweight='bold') + plt.yticks([0, 0.2, 0.4, 0.6, 0.8, 1.0], fontsize=16, fontweight='bold') + plt.xlabel('1 - Specificity', fontweight='bold', fontsize=16) + plt.ylabel('Sensitivity', fontweight='bold', fontsize=16) + plt.legend(loc='lower right', prop={'size': 16, 'weight': 'bold'}) + plt.grid(True) +# plt.tight_layout(pad=1.08, h_pad=None, w_pad=None, rect=None) + plt.savefig(os.path.join(save_dir, fn), format='png', dpi=600) + #plt.show() + plt.close() + + return roc_auc diff --git a/utils/plot_train_curve.py b/utils/plot_train_curve.py new file mode 100644 index 0000000..68330c3 --- /dev/null +++ b/utils/plot_train_curve.py @@ -0,0 +1,34 @@ +import os +import numpy as np +import pandas as pd +import seaborn as sn +import matplotlib.pyplot as plt + +def plot_train_curve(output_dir, epoch, fn, history): + + train_acc = history.history['accuracy'] + train_loss = history.history['loss'] + val_acc = history.history['val_accuracy'] + val_loss = history.history['val_loss'] + n_epoch = list(range(epoch)) + ## accuracy curves + plt.style.use('ggplot') + plt.figure(figsize=(15, 15)) + plt.subplot(2, 2, 1) + plt.plot(n_epoch, train_acc, label='Train Acc') + plt.plot(n_epoch, val_acc, label='Val Acc') + plt.legend(loc='lower right') + plt.title('Train and Tune Accuracy') + plt.xlabel('Epoch') + plt.ylabel('Accuracy') + ## loss curves + plt.subplot(2, 2, 2) + plt.plot(n_epoch, train_loss, label='Train Loss') + plt.plot(n_epoch, val_loss, label='Val Loss') + plt.legend(loc='upper right') + plt.title('Train and Val Loss') + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.show() + plt.savefig(os.path.join(output_dir, fn)) + plt.close() diff --git a/utils/prc_all.py b/utils/prc_all.py new file mode 100644 index 0000000..952ebf0 --- /dev/null +++ b/utils/prc_all.py @@ -0,0 +1,46 @@ +import os +import numpy as np +import pandas as pd +import pickle +from utils.mean_CI import mean_CI +from utils.plot_roc import plot_roc +from utils.roc_bootstrap import roc_bootstrap +from utils.plot_prc import plot_prc + + +# ---------------------------------------------------------------------------------- +# plot ROI +# ---------------------------------------------------------------------------------- + +def prc_all(run_type, level, thr_prob, thr_pos, color, pro_data_dir, save_dir, fn_df_pred): + + df_sum = pd.read_csv(os.path.join(pro_data_dir, fn_df_pred)) + + if level == 'img': + y_true = df_sum['label'].to_numpy() + y_pred = df_sum['y_pred'].to_numpy() + print_info = 'prc image:' + elif level == 'patient_mean_prob': + df_mean = df_sum.groupby(['ID']).mean() + y_true = df_mean['label'].to_numpy() + y_pred = df_mean['y_pred'].to_numpy() + print_info = 'prc patient prob:' + elif level == 'patient_mean_pos': + df_mean = df_sum.groupby(['ID']).mean() + y_true = df_mean['label'].to_numpy() + y_pred = df_mean['y_pred_class'].to_numpy() + print_info = 'prc patient pos:' + + prc_auc = plot_prc( + save_dir=save_dir, + y_true=y_true, + y_pred=y_pred, + level=level, + color=color, + ) + + print(print_info) + print(prc_auc) + + return prc_auc + diff --git a/utils/resize_3d.py b/utils/resize_3d.py new file mode 100644 index 0000000..9d3e0eb --- /dev/null +++ b/utils/resize_3d.py @@ -0,0 +1,57 @@ +#-------------------------------------------------------------------------- +# rescale to a common "more compact" size (either downsample or upsample) +#-------------------------------------------------------------------------- + +import SimpleITK as sitk +import sys +import os +import matplotlib.pyplot as plt + + +def resize_3d(img_nrrd, interp_type, output_size, patient_id, return_type, save_dir): + + ### calculate new spacing +# image = sitk.ReadImage(nrrd_image) + image = img_nrrd + input_size = image.GetSize() + input_spacing = image.GetSpacing() + output_spacing = ( + (input_size[0] * input_spacing[0]) / output_size[0], + (input_size[1] * input_spacing[1]) / output_size[1], + (input_size[2] * input_spacing[2]) / output_size[2] + ) + #print('{} {}'.format('input spacing: ', input_spacing)) + #print('{} {}'.format('output spacing: ', output_spacing)) + + ### choose interpolation algorithm + if interp_type == 'linear': + interp_type = sitk.sitkLinear + elif interp_type == 'bspline': + interp_type = sitk.sitkBSpline + elif interp_type == 'nearest_neighbor': + interp_type = sitk.sitkNearestNeighbor + + ### interpolate + resample = sitk.ResampleImageFilter() + resample.SetSize(output_size) + resample.SetOutputSpacing(output_spacing) + resample.SetOutputOrigin(image.GetOrigin()) + resample.SetOutputDirection(image.GetDirection()) + resample.SetInterpolator(interp_type) + img_nrrd = resample.Execute(image) + + ## save as numpy array + img_arr = sitk.GetArrayFromImage(img_nrrd) + + if return_type == 'nrrd': + writer = sitk.ImageFileWriter() + writer.SetFileName(os.path.join(save_dir, '{}.nrrd'.format(patient_id))) + writer.SetUseCompression(True) + writer.Execute(img_nrrd) + return img_nrrd + + elif return_type == 'npy': + return img_arr + + + diff --git a/utils/respacing.py b/utils/respacing.py new file mode 100644 index 0000000..3054b7c --- /dev/null +++ b/utils/respacing.py @@ -0,0 +1,95 @@ +#-------------------------------------------------------------------------- +# rescale to a common "more compact" size (either downsample or upsample) +#-------------------------------------------------------------------------- + +import SimpleITK as sitk +import sys +import os +import numpy as np + + +def respacing(nrrd_dir, interp_type, new_spacing, patient_id, return_type, save_dir): + + ### calculate new spacing + img = sitk.ReadImage(nrrd_dir) + old_size = img.GetSize() + old_spacing = img.GetSpacing() + #print('{} {}'.format('old size: ', old_size)) + #print('{} {}'.format('old spacing: ', old_spacing)) + + new_size = [ + int(round((old_size[0] * old_spacing[0]) / float(new_spacing[0]))), + int(round((old_size[1] * old_spacing[1]) / float(new_spacing[1]))), + int(round((old_size[2] * old_spacing[2]) / float(new_spacing[2]))) + ] + + #print('{} {}'.format('new size: ', new_size)) + + ### choose interpolation algorithm + if interp_type == 'linear': + interp_type = sitk.sitkLinear + elif interp_type == 'bspline': + interp_type = sitk.sitkBSpline + elif interp_type == 'nearest_neighbor': + interp_type = sitk.sitkNearestNeighbor + + ### interpolate + resample = sitk.ResampleImageFilter() + resample.SetOutputSpacing(new_spacing) + resample.SetSize(new_size) + resample.SetOutputOrigin(img.GetOrigin()) + resample.SetOutputDirection(img.GetDirection()) + resample.SetInterpolator(interp_type) + resample.SetDefaultPixelValue(img.GetPixelIDValue()) + resample.SetOutputPixelType(sitk.sitkFloat32) + img_nrrd = resample.Execute(img) + + ## save nrrd images + if save_dir != None: + writer = sitk.ImageFileWriter() + writer.SetFileName(os.path.join(save_dir, '{}.nrrd'.format(patient_id))) + writer.SetUseCompression(True) + writer.Execute(img_nrrd) + + ## save as numpy array + img_arr = sitk.GetArrayFromImage(img_nrrd) + + if return_type == 'nrrd': + return img_nrrd + + elif return_type == 'npy': + return img_arr + + +#----------------------------------------------------------------------- +# main function +#---------------------------------------------------------------------- +if __name__ == '__main__': + + + PMH_data_dir = '/media/bhkann/HN_RES1/HN_CONTRAST/0_image_raw_PMH' + PMH_reg_dir = '/media/bhkann/HN_RES1/HN_CONTRAST/PMH_data_rego' + exval_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/exval' + NSCLC_data_dir = '/mnt/aertslab/DATA/Lung/TOPCODER/nrrd_data' + fixed_img = 'nsclc_rt_TC001.nrrd' + nrrd_dir = os.path.join(NSCLC_data_dir, fixed_img) + interp_type = 'linear' + new_spacing = (1, 1, 3) + patient_id = 'NSCLC001' + return_type = 'nrrd' + save_dir = exval_dir + + os.mkdir(exval_dir) if not os.path.isdir(exval_dir) else None + + img_nrrd = respacing( + nrrd_dir=nrrd_dir, + interp_type=interp_type, + new_spacing=new_spacing, + patient_id=patient_id, + return_type=return_type, + save_dir=save_dir + ) + + + + diff --git a/utils/roc_all.py b/utils/roc_all.py new file mode 100644 index 0000000..48f5d4c --- /dev/null +++ b/utils/roc_all.py @@ -0,0 +1,53 @@ +import os +import numpy as np +import pandas as pd +import pickle +from utils.mean_CI import mean_CI +from utils.plot_roc import plot_roc +from utils.roc_bootstrap import roc_bootstrap + + + +# ---------------------------------------------------------------------------------- +# plot ROI +# ---------------------------------------------------------------------------------- + +def roc_all(run_type, level, thr_prob, thr_pos, bootstrap, color, pro_data_dir, save_dir, + fn_df_pred): + + df_sum = pd.read_csv(os.path.join(pro_data_dir, fn_df_pred)) + + if level == 'img': + y_true = df_sum['label'].to_numpy() + y_pred = df_sum['y_pred'].to_numpy() + print_info = 'roc image:' + elif level == 'patient_mean_prob': + df_mean = df_sum.groupby(['ID']).mean() + y_true = df_mean['label'].to_numpy() + y_pred = df_mean['y_pred'].to_numpy() + print_info = 'roc patient prob:' + elif level == 'patient_mean_pos': + df_mean = df_sum.groupby(['ID']).mean() + y_true = df_mean['label'].to_numpy() + y_pred = df_mean['y_pred_class'].to_numpy() + print_info = 'roc patient pos:' + + auc = plot_roc( + save_dir=save_dir, + y_true=y_true, + y_pred=y_pred, + level=level, + color='blue' + ) + ### calculate roc, tpr, tnr with 1000 bootstrap + roc_stat = roc_bootstrap( + bootstrap=bootstrap, + y_true=y_true, + y_pred=y_pred + ) + + print(print_info) + print(roc_stat) + + return roc_stat + diff --git a/utils/roc_bootstrap.py b/utils/roc_bootstrap.py new file mode 100644 index 0000000..7de5fc6 --- /dev/null +++ b/utils/roc_bootstrap.py @@ -0,0 +1,55 @@ +#-------------------------------------------- +# calculate auc, tpr, tnr with n bootstrap +#------------------------------------------- + +import os +import numpy as np +import pandas as pd +import glob +from sklearn.utils import resample +import scipy.stats as ss +from utils.mean_CI import mean_CI +from sklearn.metrics import roc_auc_score +from sklearn.metrics import auc +from sklearn.metrics import roc_curve + + +def roc_bootstrap(bootstrap, y_true, y_pred): + + AUC = [] + THRE = [] + TNR = [] + TPR = [] + for j in range(bootstrap): + #print("bootstrap iteration: " + str(j+1) + " out of " + str(n_bootstrap)) + index = range(len(y_pred)) + indices = resample(index, replace=True, n_samples=int(len(y_pred))) + fpr, tpr, thre = roc_curve(y_true[indices], y_pred[indices]) + q = np.arange(len(tpr)) + roc = pd.DataFrame( + {'fpr' : pd.Series(fpr, index=q), + 'tpr' : pd.Series(tpr, index=q), + 'tnr' : pd.Series(1 - fpr, index=q), + 'tf' : pd.Series(tpr - (1 - fpr), index=q), + 'thre': pd.Series(thre, index=q)} + ) + ### calculate optimal TPR, TNR under uden index + roc_opt = roc.loc[(roc['tpr'] - roc['fpr']).idxmax(),:] + AUC.append(roc_auc_score(y_true[indices], y_pred[indices])) + TPR.append(roc_opt['tpr']) + TNR.append(roc_opt['tnr']) + THRE.append(roc_opt['thre']) + ### calculate mean and 95% CI + AUCs = np.around(mean_CI(AUC), 3) + TPRs = np.around(mean_CI(TPR), 3) + TNRs = np.around(mean_CI(TNR), 3) + THREs = np.around(mean_CI(THRE), 3) + #print(AUCs) + ### save results into dataframe + stat_roc = pd.DataFrame( + [AUCs, TPRs, TNRs, THREs], + columns=['mean', '95% CI -', '95% CI +'], + index=['AUC', 'TPR', 'TNR', 'THRE'] + ) + + return stat_roc diff --git a/utils/save_npy_to_h5.py b/utils/save_npy_to_h5.py new file mode 100644 index 0000000..24934b7 --- /dev/null +++ b/utils/save_npy_to_h5.py @@ -0,0 +1,10 @@ +import numpy as np +import h5py + +file_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/data_pro/exval_arr_3ch1.npy' +data = np.load(file_dir) + +with h5py.File('exval_arr_3ch1.h5', 'w') as hf: + hf.create_dataset("name-of-dataset", data=data_to_write) + +print("H5 created.") diff --git a/utils/save_pred.py b/utils/save_pred.py new file mode 100644 index 0000000..9af46e1 --- /dev/null +++ b/utils/save_pred.py @@ -0,0 +1,17 @@ +import pandas as pd +import os + +proj_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection/ahmed_data/results' +df_pred = pd.read_csv(os.path.join(proj_dir, 'rtog-0617_pat_pred.csv')) +df_id = pd.read_csv(os.path.join(proj_dir, 'check_list.csv')) +list_id = df_id['patient_id'].to_list() + +IDs = [] +preds = [] +for ID, pred in zip(df_pred['ID'], df_pred['predictions']): + if ID.split('_')[1] in list_id: + IDs.append(ID) + preds.append(pred) + +df_check = pd.DataFrame({'ID': IDs, 'pred': preds}) +df_check.to_csv(os.path.join(proj_dir, 'rtog_check.csv'), index=False) diff --git a/utils/save_prepro_scan.py b/utils/save_prepro_scan.py new file mode 100644 index 0000000..a02c5ea --- /dev/null +++ b/utils/save_prepro_scan.py @@ -0,0 +1,119 @@ +import glob +import shutil +import os +import pandas as pd +import nrrd +import re +from sklearn.model_selection import train_test_split +import pickle +import numpy as np +from time import gmtime, strftime +from datetime import datetime +import timeit +from respacing import respacing +from nrrd_reg import nrrd_reg_rigid_ref +from crop_image import crop_image + + + +def save_prepro_scan(PMH_data_dir, CHUM_data_dir, CHUS_data_dir, MDACC_data_dir, + PMH_reg_dir, CHUM_reg_dir, CHUS_reg_dir, MDACC_reg_dir, fixed_img_dir, + interp_type, new_spacing, return_type, data_exclude, crop_shape): + + for fn, ID in zip(fns, IDs): + + print(ID) + + ## set up save dir + if ID[:-3] == 'PMH': + save_dir = PMH_reg_dir + file_dir = os.path.join(PMH_data_dir, fn) + elif ID[:-3] == 'CHUM': + save_dir = CHUM_reg_dir + file_dir = os.path.join(CHUM_data_dir, fn) + elif ID[:-3] == 'CHUS': + save_dir = CHUS_reg_dir + file_dir = os.path.join(CHUS_data_dir, fn) + elif ID[:-3] == 'MDACC': + save_dir = MDACC_reg_dir + file_dir = os.path.join(MDACC_data_dir, fn) + + ## respacing + img_nrrd = respacing( + nrrd_dir=file_dir, + interp_type=interp_type, + new_spacing=new_spacing, + patient_id=ID, + return_type=return_type, + save_dir=None + ) + + ## registration + img_reg = nrrd_reg_rigid_ref( + img_nrrd=img_nrrd, + fixed_img_dir=fixed_img_dir, + patient_id=ID, + save_dir=None + ) + + ## crop image from (500, 500, 116) to (180, 180, 60) + img_crop = crop_image( + nrrd_file=img_reg, + patient_id=ID, + crop_shape=crop_shape, + return_type='nrrd', + save_dir=save_dir + ) + +if __name__ == '__main__': + + + fns = [ + 'mdacc_HNSCC-01-0028_CT-SIM-09-06-1999-_raw_raw_raw_xx.nrrd', + 'mdacc_HNSCC-01-0168_CT-SIM-11-19-2001-_raw_raw_raw_xx.nrrd', + 'mdacc_HNSCC-01-0181_CT-SIM-07-16-2002-_raw_raw_raw_xx.nrrd' + ] + IDs = ['MDACC028', 'MDACC168', 'MDACC181'] + + val_save_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/val' + test_save_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/test' + MDACC_data_dir = '/media/bhkann/HN_RES1/HN_CONTRAST/0_image_raw_mdacc' + PMH_data_dir = '/media/bhkann/HN_RES1/HN_CONTRAST/0_image_raw_pmh' + CHUM_data_dir = '/media/bhkann/HN_RES1/HN_CONTRAST/0_image_raw_chum' + CHUS_data_dir = '/media/bhkann/HN_RES1/HN_CONTRAST/0_image_raw_chus' + CHUM_reg_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/data/CHUM_data_reg' + CHUS_reg_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/data/CHUS_data_reg' + PMH_reg_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/data/PMH_data_reg' + MDACC_reg_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/data/MDACC_data_reg' + data_pro_dir = '/mnt/aertslab/USERS/Zezhong/constrast_detection/data_pro' + fixed_img_dir = os.path.join(data_pro_dir, 'PMH050.nrrd') + input_channel = 3 + crop = True + n_img = 20 + save_img = False + data_exclude = None + thr_prob = 0.5 + run_type = 'test' + norm_type = 'np_clip' + new_spacing = [1, 1, 3] + crop_shape = [192, 192, 100] + return_type = 'nrrd' + interp_type = 'linear' + + save_prepro_scan( + PMH_data_dir=PMH_data_dir, + CHUM_data_dir=CHUM_data_dir, + CHUS_data_dir=CHUS_data_dir, + MDACC_data_dir=MDACC_data_dir, + PMH_reg_dir=PMH_reg_dir, + CHUM_reg_dir=CHUM_reg_dir, + CHUS_reg_dir=CHUS_reg_dir, + MDACC_reg_dir=MDACC_reg_dir, + fixed_img_dir=fixed_img_dir, + interp_type=interp_type, + new_spacing=new_spacing, + return_type=return_type, + data_exclude=data_exclude, + crop_shape=crop_shape + ) + print("successfully save prepro scan!") diff --git a/utils/save_sitk_from_arr.py b/utils/save_sitk_from_arr.py new file mode 100644 index 0000000..350d2f1 --- /dev/null +++ b/utils/save_sitk_from_arr.py @@ -0,0 +1,45 @@ +#-------------------------------------------------------------------------------------------- +# save image as itk +#------------------------------------------------------------------------------------------- +def save_sitk_from_arr(img_sitk, new_arr, resize, save_dir): + + """ + When resize == True: Used for saving predictions where padding needs to be added to increase the size + of the prediction and match that of input to model. This function matches the size of the array in + image_sitk_obj with the size of pred_arr, and saves it. This is done equally on all sides as the + input to model and model output have different dims to allow for shift data augmentation. + When resize == False: the image_sitk_obj is only used as a reference for spacing and origin. The numpy + array is not resized. + image_sitk_obj: sitk object of input to model + pred_arr: returned prediction from model - should be squeezed. + NOTE: image_arr.shape will always be equal or larger than pred_arr.shape, but never smaller given that + we are always cropping in data.py + """ + + if resize == True: + # get array from sitk object + img_arr = sitk.GetArrayFromImage(img_sitk) + # change pred_arr.shape to match image_arr.shape + # getting amount of padding needed on each side + z_diff = int((img_arr.shape[0] - new_arr.shape[0]) / 2) + y_diff = int((img_arr.shape[1] - new_arr.shape[1]) / 2) + x_diff = int((img_arr.shape[2] - new_arr.shape[2]) / 2) + # pad, defaults to 0 + new_arr = np.pad(new_arr, ((z_diff, z_diff), (y_diff, y_diff), (x_diff, x_diff)), 'constant') + assert img_arr.shape == new_arr.shape, "returned array shape does not match your requested shape." + + # save sitk obj + new_sitk = sitk.GetImageFromArray(new_arr) + new_sitk.SetSpacing(img_sitk.GetSpacing()) + new_sitk.SetOrigin(img_sitk.GetOrigin()) + + if output_dir != None: +# fn = "{}_{}_image_interpolated_roi_raw_gt.nrrd".format(dataset, patient_id) + fn = 'test_stik.nrrd' + img_dir = os.path.join(output_dir, fn) + writer = sitk.ImageFileWriter() + writer.SetFileName(img_dir) + writer.SetUseCompression(True) + writer.Execute(new_sitk) + + return new_sitk diff --git a/utils/scan_meta.py b/utils/scan_meta.py new file mode 100644 index 0000000..523d91c --- /dev/null +++ b/utils/scan_meta.py @@ -0,0 +1,56 @@ +import pandas as pd +import os +import numpy as np + + +data_dir = '/mnt/aertslab/USERS/Zezhong/contrast_detection' +meta_file = 'clinical_meta_data.csv' +df = pd.read_csv(os.path.join(data_dir, meta_file)) +IDs = [] +for manufacturer, model in zip(df['manufacturer'], df['manufacturermodelname']): + ID = str(manufacturer) + ' ' + str(model) + IDs.append(ID) +df['ID'] = IDs +#print(df['manufacturer'].value_counts()) +#print(df['manufacturermodelname'].value_counts()) +#print(df['ID'].value_counts()) +#print(df.shape[0]) + +## KVP +print('kvp mean:', df['kvp'].mean().round(3)) +print('kvp median:', df['kvp'].median()) +print('kvp mode:', df['kvp'].mode()) +print('kvp std:', df['kvp'].std()) +print('kvp min:', df['kvp'].min()) +print('kvp max:', df['kvp'].max()) + +## slice thickness +print('thk mean:', df['slicethickness'].mean().round(3)) +print('thk median:', df['slicethickness'].median()) +print('thk mode:', df['slicethickness'].mode()) +print('thk std:', df['slicethickness'].std().round(3)) +print('thk min:', df['slicethickness'].min()) +print('thk max:', df['slicethickness'].max()) +print(df['slicethickness'].value_counts()) +print(df['slicethickness'].shape[0]) + +## spatial resolution +print(df['rows'].value_counts()) + +## pixel spacing +pixels = [] +for pixel in df['pixelspacing']: + pixel = pixel.split("'")[1] + pixel = float(pixel) + pixels.append(pixel) +df['pixel'] = pixels +df['pixel'].round(3) +print('pixel mean:', df['pixel'].mean().round(3)) +print('pixel median:', df['pixel'].median().round(3)) +print('pixel mode:', df['pixel'].mode().round(3)) +print('pixel std:', df['pixel'].std()) +print('pixel min:', df['pixel'].min()) +print('pixel max:', df['pixel'].max()) + + + diff --git a/utils/tensorboard.py b/utils/tensorboard.py new file mode 100644 index 0000000..2936f6b --- /dev/null +++ b/utils/tensorboard.py @@ -0,0 +1,7 @@ + + + + +#log_path ='/media/bhkann/HN_RES1/HN_CONTRAST/log/train/events.out.tfevents.1624041416.bhkann-hpc1.862300.3087.v2' + +#tensorboard --logdir='/media/bhkann/HN_RES1/HN_CONTRAST/log/train/events.out.tfevents.1624041416.bhkann-hpc1.862300.3087.v2' diff --git a/utils/write_txt.py b/utils/write_txt.py new file mode 100644 index 0000000..31d1533 --- /dev/null +++ b/utils/write_txt.py @@ -0,0 +1,111 @@ +import os +import numpy as np +import pandas as pd +from datetime import datetime +from time import localtime, strftime + + + + + +def write_txt(run_type, out_dir, loss, acc, cms, cm_norms, reports, prc_aucs, + roc_stats, run_model, saved_model, epoch, batch_size, lr): + + """ + write model training, val, test results to txt files; + + Args: + loss {float} -- loss value; + acc {float} -- accuracy value; + cms {list} -- list contains confusion matrices; + cm_norms {list} -- list containing normalized confusion matrices; + reports {list} -- list containing classification reports; + prc_aucs {list} -- list containing prc auc; + roc_stats {list} -- list containing ROC statistics; + batch_size {int} -- batch size for data loading; + epoch {int} -- training epoch; + out_dir {path} -- path for output files; + + Returns: + save model results and parameters to txt file + + """ + + train_dir = os.path.join(out_dir, 'train') + val_dir = os.path.join(out_dir, 'val') + test_dir = os.path.join(out_dir, 'test') + exval1_dir = os.path.join(out_dir, 'exval1') + exval2_dir = os.path.join(out_dir, 'exval2') + + if not os.path.exists(train_dir): + os.mkdir(train_dir) + if not os.path.exists(val_dir): + os.mkdir(val_dir) + if not os.path.exists(test_dir): + os.mkdir(test_dir) + if not os.path.exists(exval1_dir): + os.mkdir(exval1_dir) + if not os.path.exists(exval2_dir): + os.mkdir(exval2_dir) + + if run_type == 'train': + log_fn = 'train_logs.text' + save_dir = train_dir + write_path = os.path.join(save_dir, log_fn) + with open(write_path, 'a') as f: + f.write('\n-------------------------------------------------------------------') + f.write('\ncreated time: %s' % strftime('%Y-%m-%d %H:%M:%S', localtime())) + f.write('\nval acc: %s' % acc) + f.write('\nval loss: %s' % loss) + f.write('\nrun model: %s' % run_model) + f.write('\nsaved model: %s' % saved_model) + f.write('\nepoch: %s' % epoch) + f.write('\nlearning rate: %s' % lr) + f.write('\nbatch size: %s' % batch_size) + f.write('\n') + f.close() + print('successfully save train logs.') + else: + if run_type == 'val': + log_fn = 'val_logs.text' + save_dir = val_dir + elif run_type == 'test': + log_fn = 'test_logs.text' + save_dir = test_dir + elif run_type == 'exval1': + log_fn = 'exval1_logs.text' + save_dir = exval1_dir + elif run_type == 'exval2': + log_fn = 'exval2_logs.text' + save_dir = exval2_dir + + write_path = os.path.join(save_dir, log_fn) + with open(write_path, 'a') as f: + f.write('\n------------------------------------------------------------------') + f.write('\ncreated time: %s' % strftime('%Y-%m-%d %H:%M:%S', localtime())) + f.write('\nval accuracy: %s' % acc) + f.write('\nval loss: %s' % loss) + f.write('\nprc image: %s' % prc_aucs[0]) + f.write('\nprc patient prob: %s' % prc_aucs[1]) + f.write('\nprc patient pos: %s' % prc_aucs[2]) + f.write('\nroc image:\n %s' % roc_stats[0]) + f.write('\nroc patient prob:\n %s' % roc_stats[1]) + f.write('\nroc patient pos:\n %s' % roc_stats[2]) + f.write('\ncm image:\n %s' % cms[0]) + f.write('\ncm image:\n %s' % cm_norms[0]) + f.write('\ncm patient prob:\n %s' % cms[1]) + f.write('\ncm patient prob:\n %s' % cm_norms[1]) + f.write('\ncm patient pos:\n %s' % cms[2]) + f.write('\ncm patient pos:\n %s' % cm_norms[2]) + f.write('\nreport image:\n %s' % reports[0]) + f.write('\nreport patient prob:\n %s' % reports[1]) + f.write('\nreport patient pos:\n %s' % reports[2]) + f.write('\nrun model: %s' % run_model) + f.write('\nsaved model: %s' % saved_model) + f.write('\nepoch: %s' % epoch) + f.write('\nlearning rate: %s' % lr) + f.write('\nbatch size: %s' % batch_size) + f.write('\n') + f.close() + print('successfully save logs.') +