|
| 1 | +import itertools |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | +from torch.autograd import Variable |
| 6 | + |
| 7 | +from .predict_handlers import DiffScoreHandler, LogitScoreHandler, \ |
| 8 | + WritePredictionsHandler |
| 9 | +from ..sequences import Genome |
| 10 | +from ..sequences import sequence_to_encoding |
| 11 | + |
| 12 | + |
| 13 | +def predict(model, batch_sequences, use_cuda=False): |
| 14 | + inputs = torch.Tensor(batch_sequences) |
| 15 | + if use_cuda: |
| 16 | + inputs = inputs.cuda() |
| 17 | + inputs = Variable(inputs, volatile=True) |
| 18 | + outputs = model.forward(inputs.transpose(1, 2)) |
| 19 | + return outputs |
| 20 | + |
| 21 | + |
| 22 | +def predict_on_encoded_sequences(model, |
| 23 | + sequences, |
| 24 | + batch_size=64, |
| 25 | + use_cuda=False): |
| 26 | + predictions = [] |
| 27 | + n_examples, _, _ = sequences.shape |
| 28 | + |
| 29 | + for i in range(0, n_examples, batch_size): |
| 30 | + start = i |
| 31 | + end = i + batch_size |
| 32 | + batch_sequences = sequences[start:end, :, :] |
| 33 | + outputs = predict(model, batch_sequences, use_cuda=use_cuda) |
| 34 | + predictions.append(outputs.data.cpu().numpy()) |
| 35 | + return np.vstack(predictions) |
| 36 | + |
| 37 | + |
| 38 | +def in_silico_mutagenesis_sequences(input_sequence, |
| 39 | + mutate_n_bases=1): |
| 40 | + """Creates a list containing each mutation that occurs from in silico |
| 41 | + mutagenesis across the whole sequence. |
| 42 | +
|
| 43 | + Parameters |
| 44 | + ---------- |
| 45 | + input_sequence : str |
| 46 | + mutate_n_bases : int |
| 47 | +
|
| 48 | + Returns |
| 49 | + ------- |
| 50 | + list |
| 51 | + A list of all possible mutations. Each element in the list is |
| 52 | + itself a list of tuples, e.g. [(0, 'T')] if we are only mutating |
| 53 | + 1 base at a time. Each tuple is the position to mutate and the base |
| 54 | + with which we are replacing the reference base. |
| 55 | +
|
| 56 | + For a sequence of length 1000, mutating 1 base at a time means that |
| 57 | + we return a list of length 3000. |
| 58 | + """ |
| 59 | + sequence_alts = [] |
| 60 | + for index, ref in enumerate(input_sequence): |
| 61 | + alts = [] |
| 62 | + for base in Genome.BASES_ARR: |
| 63 | + if base == ref: |
| 64 | + continue |
| 65 | + alts.append(base) |
| 66 | + sequence_alts.append(alts) |
| 67 | + |
| 68 | + all_mutated_sequences = [] |
| 69 | + for indices in itertools.combinations( |
| 70 | + range(len(input_sequence)), mutate_n_bases): |
| 71 | + pos_mutations = [] |
| 72 | + for i in indices: |
| 73 | + pos_mutations.append(sequence_alts[i]) |
| 74 | + for mutations in itertools.product(*pos_mutations): |
| 75 | + all_mutated_sequences.append(list(zip(indices, mutations))) |
| 76 | + return all_mutated_sequences |
| 77 | + |
| 78 | + |
| 79 | +def _ism_sample_id(dna_sequence, mutation_information): |
| 80 | + positions = [] |
| 81 | + refs = [] |
| 82 | + alts = [] |
| 83 | + for (position, alt) in mutation_information: |
| 84 | + positions.append(str(position)) |
| 85 | + refs.append(dna_sequence[position]) |
| 86 | + alts.append(alt) |
| 87 | + return (';'.join(positions), ';'.join(refs), ';'.join(alts)) |
| 88 | + |
| 89 | + |
| 90 | +def in_silico_mutagenesis_predict(model, |
| 91 | + batch_size, |
| 92 | + sequence, |
| 93 | + mutations_list, |
| 94 | + use_cuda=False, |
| 95 | + reporters=[]): |
| 96 | + current_sequence_encoding = sequence_to_encoding( |
| 97 | + sequence, Genome.BASE_TO_INDEX) |
| 98 | + for i in range(0, len(mutations_list), batch_size): |
| 99 | + start = i |
| 100 | + end = i + batch_size |
| 101 | + |
| 102 | + mutated_sequences = np.zeros( |
| 103 | + (batch_size, *current_sequence_encoding.shape)) |
| 104 | + |
| 105 | + batch_ids = [] |
| 106 | + for ix, mutation_info in enumerate(mutations_list[start:end]): |
| 107 | + mutated_seq = mutate_sequence( |
| 108 | + current_sequence_encoding, mutation_info) |
| 109 | + mutated_sequences[ix, :, :] = mutated_seq |
| 110 | + batch_ids.append(_ism_sample_id(sequence, mutation_info)) |
| 111 | + outputs = predict( |
| 112 | + model, mutated_sequences, use_cuda=use_cuda).data.cpu().numpy() |
| 113 | + |
| 114 | + for r in reporters: |
| 115 | + r.handle_batch_predictions(outputs, batch_ids) |
| 116 | + |
| 117 | + for r in reporters: |
| 118 | + r.write_to_file() |
| 119 | + |
| 120 | + |
| 121 | +def _reverse_strand(dna_sequence): |
| 122 | + reverse_bases = [Genome.COMPLEMENTARY_BASE[b] for b in dna_sequence[::-1]] |
| 123 | + return ''.join(reverse_bases) |
| 124 | + |
| 125 | + |
| 126 | +def mutate_sequence(dna_encoded_sequence, mutation_information): |
| 127 | + mutated_seq = np.copy(dna_encoded_sequence) |
| 128 | + for (position, alt) in mutation_information: |
| 129 | + replace_base = Genome.BASE_TO_INDEX[alt] |
| 130 | + mutated_seq[position, :] = 0 |
| 131 | + mutated_seq[position, replace_base] = 1 |
| 132 | + return mutated_seq |
| 133 | + |
| 134 | + |
| 135 | +def in_silico_mutagenesis(model, |
| 136 | + batch_size, |
| 137 | + input_sequence, |
| 138 | + features_list, |
| 139 | + save_diffs, |
| 140 | + mutate_n_bases=1, |
| 141 | + use_cuda=False, |
| 142 | + save_logits=None, |
| 143 | + save_predictions=None): |
| 144 | + mutated_sequences = in_silico_mutagenesis_sequences( |
| 145 | + input_sequence, mutate_n_bases=1) |
| 146 | + |
| 147 | + current_sequence_encoding = sequence_to_encoding( |
| 148 | + input_sequence, Genome.BASE_TO_INDEX) |
| 149 | + |
| 150 | + base_encoding = current_sequence_encoding.reshape( |
| 151 | + (1, *current_sequence_encoding.shape)) |
| 152 | + base_preds = predict( |
| 153 | + model, base_encoding).data.cpu().numpy() |
| 154 | + |
| 155 | + reporters = [] |
| 156 | + nonfeature_cols = ["pos", "ref", "alt"] |
| 157 | + if save_diffs: |
| 158 | + diff_handler = DiffScoreHandler( |
| 159 | + base_preds, features_list, nonfeature_cols, save_diffs) |
| 160 | + reporters.append(diff_handler) |
| 161 | + if save_logits: |
| 162 | + logit_handler = LogitScoreHandler( |
| 163 | + base_preds, features_list, nonfeature_cols, save_logits) |
| 164 | + reporters.append(logit_handler) |
| 165 | + if save_predictions: |
| 166 | + preds_handler = WritePredictionsHandler( |
| 167 | + features_list, nonfeature_cols, save_predictions) |
| 168 | + reporters.append(preds_handler) |
| 169 | + |
| 170 | + in_silico_mutagenesis_predict( |
| 171 | + model, batch_size, input_sequence, mutated_sequences, |
| 172 | + use_cuda=use_cuda, reporters=reporters) |
0 commit comments