|
| 1 | +#!/usr/bin/env python3 |
| 2 | +"""Test PaccMann predictor.""" |
| 3 | +import argparse |
| 4 | +import json |
| 5 | +import logging |
| 6 | +import os |
| 7 | +import pickle |
| 8 | +import sys |
| 9 | +from copy import deepcopy |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +import pandas as pd |
| 13 | +import torch |
| 14 | +from tqdm import tqdm |
| 15 | +from paccmann_predictor.models import MODEL_FACTORY |
| 16 | +from paccmann_predictor.utils.hyperparams import OPTIMIZER_FACTORY |
| 17 | +from paccmann_predictor.utils.utils import get_device |
| 18 | +from pytoda.datasets import DrugSensitivityDataset |
| 19 | +from pytoda.smiles.smiles_language import SMILESTokenizer |
| 20 | +from scipy.stats import pearsonr |
| 21 | + |
| 22 | +# setup logging |
| 23 | +logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) |
| 24 | + |
| 25 | +# yapf: disable |
| 26 | +parser = argparse.ArgumentParser() |
| 27 | +parser.add_argument( |
| 28 | + 'test_sensitivity_filepath', type=str, |
| 29 | + help='Path to the drug sensitivity (IC50) data.' |
| 30 | +) |
| 31 | +parser.add_argument( |
| 32 | + 'gep_filepath', type=str, |
| 33 | + help='Path to the gene expression profile data.' |
| 34 | +) |
| 35 | +parser.add_argument( |
| 36 | + 'smi_filepath', type=str, |
| 37 | + help='Path to the SMILES data.' |
| 38 | +) |
| 39 | +parser.add_argument( |
| 40 | + 'gene_filepath', type=str, |
| 41 | + help='Path to a pickle object containing list of genes.' |
| 42 | +) |
| 43 | +parser.add_argument( |
| 44 | + 'smiles_language_filepath', type=str, |
| 45 | + help='Path to a folder with SMILES language .json files.' |
| 46 | +) |
| 47 | +parser.add_argument( |
| 48 | + 'model_filepath', type=str, |
| 49 | + help='Path to the stored model.' |
| 50 | +) |
| 51 | +parser.add_argument( |
| 52 | + 'predictions_filepath', type=str, |
| 53 | + help='Path to the predictions.' |
| 54 | +) |
| 55 | +parser.add_argument( |
| 56 | + 'params_filepath', type=str, |
| 57 | + help='Path to the parameter file.' |
| 58 | +) |
| 59 | +# yapf: enable |
| 60 | + |
| 61 | + |
| 62 | +def main( |
| 63 | + test_sensitivity_filepath, gep_filepath, |
| 64 | + smi_filepath, gene_filepath, smiles_language_filepath, model_filepath, predictions_filepath, |
| 65 | + params_filepath |
| 66 | +): |
| 67 | + |
| 68 | + logger = logging.getLogger('test') |
| 69 | + # Process parameter file: |
| 70 | + params = {} |
| 71 | + with open(params_filepath) as fp: |
| 72 | + params.update(json.load(fp)) |
| 73 | + |
| 74 | + |
| 75 | + # Prepare the dataset |
| 76 | + logger.info("Start data preprocessing...") |
| 77 | + |
| 78 | + # Load SMILES language |
| 79 | + smiles_language = SMILESTokenizer.from_pretrained(smiles_language_filepath) |
| 80 | + smiles_language.set_encoding_transforms( |
| 81 | + add_start_and_stop=params.get('add_start_and_stop', True), |
| 82 | + padding=params.get('padding', True), |
| 83 | + padding_length=params.get('smiles_padding_length', None) |
| 84 | + ) |
| 85 | + test_smiles_language = deepcopy(smiles_language) |
| 86 | + smiles_language.set_smiles_transforms( |
| 87 | + augment=params.get('augment_smiles', False), |
| 88 | + canonical=params.get('smiles_canonical', False), |
| 89 | + kekulize=params.get('smiles_kekulize', False), |
| 90 | + all_bonds_explicit=params.get('smiles_bonds_explicit', False), |
| 91 | + all_hs_explicit=params.get('smiles_all_hs_explicit', False), |
| 92 | + remove_bonddir=params.get('smiles_remove_bonddir', False), |
| 93 | + remove_chirality=params.get('smiles_remove_chirality', False), |
| 94 | + selfies=params.get('selfies', False), |
| 95 | + sanitize=params.get('selfies', False) |
| 96 | + ) |
| 97 | + test_smiles_language.set_smiles_transforms( |
| 98 | + augment=False, |
| 99 | + canonical=params.get('test_smiles_canonical', False), |
| 100 | + kekulize=params.get('smiles_kekulize', False), |
| 101 | + all_bonds_explicit=params.get('smiles_bonds_explicit', False), |
| 102 | + all_hs_explicit=params.get('smiles_all_hs_explicit', False), |
| 103 | + remove_bonddir=params.get('smiles_remove_bonddir', False), |
| 104 | + remove_chirality=params.get('smiles_remove_chirality', False), |
| 105 | + selfies=params.get('selfies', False), |
| 106 | + sanitize=params.get('selfies', False) |
| 107 | + ) |
| 108 | + |
| 109 | + # Load the gene list |
| 110 | + with open(gene_filepath, 'rb') as f: |
| 111 | + gene_list = pickle.load(f) |
| 112 | + |
| 113 | + # Assemble test dataset |
| 114 | + test_dataset = DrugSensitivityDataset( |
| 115 | + drug_sensitivity_filepath=test_sensitivity_filepath, |
| 116 | + smi_filepath=smi_filepath, |
| 117 | + gene_expression_filepath=gep_filepath, |
| 118 | + smiles_language=test_smiles_language, |
| 119 | + gene_list=gene_list, |
| 120 | + drug_sensitivity_min_max=params.get('drug_sensitivity_min_max', True), |
| 121 | + gene_expression_standardize=params.get( |
| 122 | + 'gene_expression_standardize', True |
| 123 | + ), |
| 124 | + gene_expression_min_max=params.get('gene_expression_min_max', False), |
| 125 | + gene_expression_processing_parameters=params.get( |
| 126 | + 'gene_expression_processing_parameters', {} |
| 127 | + ), |
| 128 | + device=torch.device(params.get('dataset_device', 'cpu')), |
| 129 | + iterate_dataset=False |
| 130 | + ) |
| 131 | + test_loader = torch.utils.data.DataLoader( |
| 132 | + dataset=test_dataset, |
| 133 | + batch_size=params['batch_size'], |
| 134 | + shuffle=False, |
| 135 | + drop_last=False, |
| 136 | + num_workers=params.get('num_workers', 0) |
| 137 | + ) |
| 138 | + logger.info( |
| 139 | + f'Test dataset has {len(test_dataset)} samples with {len(test_loader)} batches' |
| 140 | + ) |
| 141 | + |
| 142 | + device = get_device() |
| 143 | + logger.info( |
| 144 | + f'Device for data loader is {test_dataset.device} and for ' |
| 145 | + f'model is {device}' |
| 146 | + ) |
| 147 | + |
| 148 | + model_name = params.get('model_fn', 'paccmann') |
| 149 | + model = MODEL_FACTORY[model_name](params).to(device) |
| 150 | + model._associate_language(smiles_language) |
| 151 | + try: |
| 152 | + logger.info(f'Attempting to restore model from {model_filepath}...') |
| 153 | + model.load(model_filepath, map_location=device) |
| 154 | + except Exception: |
| 155 | + raise ValueError(f'Error in restoring model from {model_filepath}!') |
| 156 | + |
| 157 | + # Define optimizer |
| 158 | + optimizer = ( |
| 159 | + OPTIMIZER_FACTORY[params.get('optimizer', 'Adam')] |
| 160 | + (model.parameters(), lr=params.get('lr', 0.01)) |
| 161 | + ) |
| 162 | + |
| 163 | + num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| 164 | + params.update({'number_of_parameters': num_params}) |
| 165 | + logger.info(f'Number of parameters {num_params}') |
| 166 | + |
| 167 | + # Start testing |
| 168 | + logger.info('Testing about to start... \n') |
| 169 | + model.eval() |
| 170 | + |
| 171 | + with torch.no_grad(): |
| 172 | + test_loss = 0 |
| 173 | + predictions = [] |
| 174 | + # gene_attentions = [] |
| 175 | + # epistemic_confs = [] |
| 176 | + # aleatoric_confs = [] |
| 177 | + labels = [] |
| 178 | + for ind, (smiles, gep, y) in tqdm(enumerate(test_loader)): |
| 179 | + y_hat, pred_dict = model( |
| 180 | + torch.squeeze(smiles.to(device)), gep.to(device), confidence = False |
| 181 | + ) |
| 182 | + predictions.extend(list(y_hat.detach().cpu().squeeze().numpy())) |
| 183 | + # gene_attentions.append(pred_dict['gene_attention']) |
| 184 | + # epistemic_confs.append(pred_dict['epistemic_confidence']) |
| 185 | + # aleatoric_confs.append(pred_dict['aleatoric_confidence']) |
| 186 | + labels.extend(list(y.detach().cpu().squeeze().numpy())) |
| 187 | + loss = model.loss(y_hat, y.to(device)) |
| 188 | + test_loss += loss.item() |
| 189 | + |
| 190 | + #gene_attentions = np.array([a.cpu().numpy() for atts in gene_attentions for a in atts]) |
| 191 | + #epistemic_confs = np.array([c.cpu().numpy() for conf in epistemic_confs for c in conf]).ravel() |
| 192 | + #aleatoric_confs = np.array([c.cpu().numpy() for conf in aleatoric_confs for c in conf]).ravel() |
| 193 | + predictions = np.array(predictions) |
| 194 | + labels = np.array(labels) |
| 195 | + |
| 196 | + pearson = pearsonr(predictions, labels)[0] |
| 197 | + rmse = np.sqrt(np.mean((predictions - labels)**2)) |
| 198 | + loss = test_loss / len(test_loader) |
| 199 | + logger.info( |
| 200 | + f"\t**RESULT**\t loss:{loss:.5f}, Pearson: {pearson:.3f}, RMSE: {rmse:.3f}" |
| 201 | + ) |
| 202 | + |
| 203 | + df = test_dataset.drug_sensitivity_df |
| 204 | + df['prediction'] = predictions |
| 205 | + df.to_csv(predictions_filepath+'.csv') |
| 206 | + |
| 207 | + #np.save(predictions_filepath+'_gene_attention.npy', gene_attentions) |
| 208 | + #np.save(predictions_filepath+'_epistemic_confidence.npy', epistemic_confs) |
| 209 | + #np.save(predictions_filepath+'_aleatoric_confidence.npy', aleatoric_confs) |
| 210 | + |
| 211 | +if __name__ == '__main__': |
| 212 | + # parse arguments |
| 213 | + args = parser.parse_args() |
| 214 | + # run the testing |
| 215 | + main( |
| 216 | + args.test_sensitivity_filepath, |
| 217 | + args.gep_filepath, args.smi_filepath, args.gene_filepath, |
| 218 | + args.smiles_language_filepath, args.model_filepath, args.predictions_filepath, args.params_filepath |
| 219 | + ) |
0 commit comments