Skip to content

Commit d738481

Browse files
committed
add a predict module and bug fixes to performance metrics, non strand specific module
1 parent 946936d commit d738481

File tree

11 files changed

+362
-21
lines changed

11 files changed

+362
-21
lines changed

models/non_strand_specific_module.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import numpy as np
21
import torch
32
from torch.nn.modules import Module
43

@@ -23,7 +22,6 @@ class NonStrandSpecific(Module):
2322
def __init__(self, model, mode="mean"):
2423
super(NonStrandSpecific, self).__init__()
2524

26-
print(mode)
2725
self.model = model
2826

2927
if mode != "mean" and mode != "max":
@@ -42,19 +40,5 @@ def forward(self, input):
4240
if self.mode == "mean":
4341
return (output + output_from_rev) / 2
4442
else:
45-
max_output = torch.max(
46-
output.abs(), output_from_rev.abs())
47-
np_output = output.data.cpu().numpy()
48-
print(np_output)
49-
50-
it = np.nditer(np_output, flags=["multi_index"])
51-
while not it.finished:
52-
index = it.multi_index
53-
print(it[0])
54-
if max_output.data[index] != abs(it[0]):
55-
max_output.data[index] = output_from_rev.data[index]
56-
else:
57-
max_output.data[index] = it[0]
58-
it.iternext()
59-
return max_output
43+
return torch.max(output, output_from_rev)
6044

selene/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
__all__ = ["sequences", "targets", "samplers", "utils"]
1+
__all__ = ["predict", "sequences", "targets", "samplers", "utils"]
22
from .model_train import ModelController

selene/predict/__init__.py

Whitespace-only changes.

selene/predict/model_predict.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import itertools
2+
3+
import numpy as np
4+
import torch
5+
from torch.autograd import Variable
6+
7+
from .predict_handlers import DiffScoreHandler, LogitScoreHandler, \
8+
WritePredictionsHandler
9+
from ..sequences import Genome
10+
from ..sequences import sequence_to_encoding
11+
12+
13+
def predict(model, batch_sequences, use_cuda=False):
14+
inputs = torch.Tensor(batch_sequences)
15+
if use_cuda:
16+
inputs = inputs.cuda()
17+
inputs = Variable(inputs, volatile=True)
18+
outputs = model.forward(inputs.transpose(1, 2))
19+
return outputs
20+
21+
22+
def predict_on_encoded_sequences(model,
23+
sequences,
24+
batch_size=64,
25+
use_cuda=False):
26+
predictions = []
27+
n_examples, _, _ = sequences.shape
28+
29+
for i in range(0, n_examples, batch_size):
30+
start = i
31+
end = i + batch_size
32+
batch_sequences = sequences[start:end, :, :]
33+
outputs = predict(model, batch_sequences, use_cuda=use_cuda)
34+
predictions.append(outputs.data.cpu().numpy())
35+
return np.vstack(predictions)
36+
37+
38+
def in_silico_mutagenesis_sequences(input_sequence,
39+
mutate_n_bases=1):
40+
"""Creates a list containing each mutation that occurs from in silico
41+
mutagenesis across the whole sequence.
42+
43+
Parameters
44+
----------
45+
input_sequence : str
46+
mutate_n_bases : int
47+
48+
Returns
49+
-------
50+
list
51+
A list of all possible mutations. Each element in the list is
52+
itself a list of tuples, e.g. [(0, 'T')] if we are only mutating
53+
1 base at a time. Each tuple is the position to mutate and the base
54+
with which we are replacing the reference base.
55+
56+
For a sequence of length 1000, mutating 1 base at a time means that
57+
we return a list of length 3000.
58+
"""
59+
sequence_alts = []
60+
for index, ref in enumerate(input_sequence):
61+
alts = []
62+
for base in Genome.BASES_ARR:
63+
if base == ref:
64+
continue
65+
alts.append(base)
66+
sequence_alts.append(alts)
67+
68+
all_mutated_sequences = []
69+
for indices in itertools.combinations(
70+
range(len(input_sequence)), mutate_n_bases):
71+
pos_mutations = []
72+
for i in indices:
73+
pos_mutations.append(sequence_alts[i])
74+
for mutations in itertools.product(*pos_mutations):
75+
all_mutated_sequences.append(list(zip(indices, mutations)))
76+
return all_mutated_sequences
77+
78+
79+
def _ism_sample_id(dna_sequence, mutation_information):
80+
positions = []
81+
refs = []
82+
alts = []
83+
for (position, alt) in mutation_information:
84+
positions.append(str(position))
85+
refs.append(dna_sequence[position])
86+
alts.append(alt)
87+
return (';'.join(positions), ';'.join(refs), ';'.join(alts))
88+
89+
90+
def in_silico_mutagenesis_predict(model,
91+
batch_size,
92+
sequence,
93+
mutations_list,
94+
use_cuda=False,
95+
reporters=[]):
96+
current_sequence_encoding = sequence_to_encoding(
97+
sequence, Genome.BASE_TO_INDEX)
98+
for i in range(0, len(mutations_list), batch_size):
99+
start = i
100+
end = i + batch_size
101+
102+
mutated_sequences = np.zeros(
103+
(batch_size, *current_sequence_encoding.shape))
104+
105+
batch_ids = []
106+
for ix, mutation_info in enumerate(mutations_list[start:end]):
107+
mutated_seq = mutate_sequence(
108+
current_sequence_encoding, mutation_info)
109+
mutated_sequences[ix, :, :] = mutated_seq
110+
batch_ids.append(_ism_sample_id(sequence, mutation_info))
111+
outputs = predict(
112+
model, mutated_sequences, use_cuda=use_cuda).data.cpu().numpy()
113+
114+
for r in reporters:
115+
r.handle_batch_predictions(outputs, batch_ids)
116+
117+
for r in reporters:
118+
r.write_to_file()
119+
120+
121+
def _reverse_strand(dna_sequence):
122+
reverse_bases = [Genome.COMPLEMENTARY_BASE[b] for b in dna_sequence[::-1]]
123+
return ''.join(reverse_bases)
124+
125+
126+
def mutate_sequence(dna_encoded_sequence, mutation_information):
127+
mutated_seq = np.copy(dna_encoded_sequence)
128+
for (position, alt) in mutation_information:
129+
replace_base = Genome.BASE_TO_INDEX[alt]
130+
mutated_seq[position, :] = 0
131+
mutated_seq[position, replace_base] = 1
132+
return mutated_seq
133+
134+
135+
def in_silico_mutagenesis(model,
136+
batch_size,
137+
input_sequence,
138+
features_list,
139+
save_diffs,
140+
mutate_n_bases=1,
141+
use_cuda=False,
142+
save_logits=None,
143+
save_predictions=None):
144+
mutated_sequences = in_silico_mutagenesis_sequences(
145+
input_sequence, mutate_n_bases=1)
146+
147+
current_sequence_encoding = sequence_to_encoding(
148+
input_sequence, Genome.BASE_TO_INDEX)
149+
150+
base_encoding = current_sequence_encoding.reshape(
151+
(1, *current_sequence_encoding.shape))
152+
base_preds = predict(
153+
model, base_encoding).data.cpu().numpy()
154+
155+
reporters = []
156+
nonfeature_cols = ["pos", "ref", "alt"]
157+
if save_diffs:
158+
diff_handler = DiffScoreHandler(
159+
base_preds, features_list, nonfeature_cols, save_diffs)
160+
reporters.append(diff_handler)
161+
if save_logits:
162+
logit_handler = LogitScoreHandler(
163+
base_preds, features_list, nonfeature_cols, save_logits)
164+
reporters.append(logit_handler)
165+
if save_predictions:
166+
preds_handler = WritePredictionsHandler(
167+
features_list, nonfeature_cols, save_predictions)
168+
reporters.append(preds_handler)
169+
170+
in_silico_mutagenesis_predict(
171+
model, batch_size, input_sequence, mutated_sequences,
172+
use_cuda=use_cuda, reporters=reporters)
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .handler import PredictionsHandler
2+
from .diff_score_handler import DiffScoreHandler
3+
from .logit_score_handler import LogitScoreHandler
4+
from .write_predictions_handler import WritePredictionsHandler
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import numpy as np
2+
3+
from .handler import _write_to_file, PredictionsHandler
4+
5+
6+
class DiffScoreHandler(PredictionsHandler):
7+
8+
def __init__(self,
9+
baseline_prediction,
10+
features_list,
11+
nonfeature_columns,
12+
out_filename):
13+
self.baseline_prediction = baseline_prediction
14+
self.column_names = nonfeature_columns + features_list
15+
self.results = []
16+
self.samples = []
17+
self.out_filename = out_filename
18+
19+
def handle_batch_predictions(self,
20+
batch_predictions,
21+
batch_ids):
22+
absolute_diffs = np.abs(self.baseline_prediction - batch_predictions)
23+
self.results.append(absolute_diffs)
24+
self.samples.append(batch_ids)
25+
return absolute_diffs
26+
27+
def write_to_file(self):
28+
self.results = np.vstack(self.results)
29+
self.samples = np.vstack(self.samples)
30+
_write_to_file(self.results,
31+
self.samples,
32+
self.column_names,
33+
self.out_filename)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
"""
2+
This class is the abstract base class for handling model predicions
3+
"""
4+
from abc import ABCMeta
5+
from abc import abstractmethod
6+
7+
8+
def _write_to_file(feature_predictions, info_cols, column_names, filename):
9+
with open(filename, 'w+') as file_handle:
10+
file_handle.write("{columns}\n".format(
11+
columns='\t'.join(column_names)))
12+
for info, preds in zip(info_cols, feature_predictions):
13+
feature_cols = '\t'.join(
14+
probabilities_to_string(preds))
15+
info_cols = '\t'.join(info)
16+
file_handle.write(f"{info_cols}\t{feature_cols}\n")
17+
18+
19+
def probabilities_to_string(probabilities):
20+
return ["{:.2e}".format(p) for p in probabilities]
21+
22+
23+
class PredictionsHandler(metaclass=ABCMeta):
24+
"""
25+
The base class for handling model predictions.
26+
"""
27+
@abstractmethod
28+
def handle_batch_predictions(self, *args, **kwargs):
29+
"""
30+
Must be able to handle a batch of model predictions.
31+
"""
32+
raise NotImplementedError
33+
34+
@abstractmethod
35+
def write_to_file(self, *args, **kwargs):
36+
"""
37+
Writes accumulated handler results to file.
38+
"""
39+
raise NotImplementedError
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import numpy as np
2+
from scipy.special import logit
3+
4+
from .handler import _write_to_file, PredictionsHandler
5+
6+
7+
class LogitScoreHandler(PredictionsHandler):
8+
9+
def __init__(self,
10+
baseline_prediction,
11+
features_list,
12+
nonfeature_columns,
13+
out_filename):
14+
self.logit_baseline = logit(baseline_prediction)
15+
self.column_names = nonfeature_columns + features_list
16+
self.results = []
17+
self.samples = []
18+
self.out_filename = out_filename
19+
20+
def handle_batch_predictions(self,
21+
batch_predictions,
22+
batch_ids):
23+
absolute_logits = np.abs(self.logit_baseline - logit(batch_predictions))
24+
self.results.append(absolute_logits)
25+
self.samples.append(batch_ids)
26+
return absolute_logits
27+
28+
def write_to_file(self):
29+
self.results = np.vstack(self.results)
30+
self.samples = np.vstack(self.samples)
31+
_write_to_file(self.results,
32+
self.samples,
33+
self.column_names,
34+
self.out_filename)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import numpy as np
2+
3+
from .handler import _write_to_file, PredictionsHandler
4+
5+
class WritePredictionsHandler(PredictionsHandler):
6+
7+
def __init__(self, features_list, nonfeature_columns, out_filename):
8+
self.column_names = nonfeature_columns + features_list
9+
self.results = []
10+
self.samples = []
11+
self.out_filename = out_filename
12+
13+
def handle_batch_predictions(self,
14+
batch_predictions,
15+
batch_ids):
16+
self.results.append(batch_predictions)
17+
self.samples.append(batch_ids)
18+
return batch_predictions
19+
20+
def write_to_file(self):
21+
self.results = np.vstack(self.results)
22+
self.samples = np.vstack(self.samples)
23+
_write_to_file(self.results,
24+
self.samples,
25+
self.column_names,
26+
self.out_filename)
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import unittest
2+
3+
import numpy as np
4+
5+
from selene.predict.model_predict import in_silico_mutagenesis_sequences, \
6+
in_silico_mutagenesis_predict
7+
8+
9+
class TestModelPredict(unittest.TestCase):
10+
11+
def setUp(self):
12+
self.bases_arr = ['A', 'C', 'G', 'T']
13+
self.bases_encoding = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
14+
self.input_sequence = "ATCCG"
15+
16+
def test_in_silico_muta_sequences_single(self):
17+
observed = in_silico_mutagenesis_sequences("ATCCG")
18+
expected = [
19+
(0, 'C'), (0, 'G'), (0, 'T'),
20+
(1, 'A'), (1, 'C'), (1, 'G'),
21+
(2, 'A'), (2, 'G'), (2, 'T'),
22+
(3, 'A'), (3, 'G'), (3, 'T'),
23+
(4, 'A'), (4, 'C'), (4, 'T')]
24+
25+
expected_lists = [[e] for e in expected]
26+
self.assertListEqual(observed, expected_lists)
27+
28+
def test_in_silico_muta_sequences_double(self):
29+
observed = in_silico_mutagenesis_sequences(
30+
"ATC", mutate_n_bases=2)
31+
expected = [
32+
[(0, 'C'), (1, 'A')], [(0, 'G'), (1, 'A')], [(0, 'T'), (1, 'A')],
33+
[(0, 'C'), (1, 'C')], [(0, 'G'), (1, 'C')], [(0, 'T'), (1, 'C')],
34+
[(0, 'C'), (1, 'G')], [(0, 'G'), (1, 'G')], [(0, 'T'), (1, 'G')],
35+
36+
[(0, 'C'), (2, 'A')], [(0, 'G'), (2, 'A')], [(0, 'T'), (2, 'A')],
37+
[(0, 'C'), (2, 'G')], [(0, 'G'), (2, 'G')], [(0, 'T'), (2, 'G')],
38+
[(0, 'C'), (2, 'T')], [(0, 'G'), (2, 'T')], [(0, 'T'), (2, 'T')],
39+
40+
[(1, 'A'), (2, 'A')], [(1, 'C'), (2, 'A')], [(1, 'G'), (2, 'A')],
41+
[(1, 'A'), (2, 'G')], [(1, 'C'), (2, 'G')], [(1, 'G'), (2, 'G')],
42+
[(1, 'A'), (2, 'T')], [(1, 'C'), (2, 'T')], [(1, 'G'), (2, 'T')],
43+
]
44+
self.assertCountEqual(observed, expected)
45+
46+
if __name__ == "__main__":
47+
unittest.main()

0 commit comments

Comments
 (0)