-
Notifications
You must be signed in to change notification settings - Fork 9
/
ExplainBrain.py
374 lines (298 loc) · 16.1 KB
/
ExplainBrain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
"""Pipeline for predicting brain activations.
"""
from evaluation.metrics import mean_explain_variance
from util.misc import get_folds
from voxel_preprocessing.preprocess_voxels import detrend
from voxel_preprocessing.preprocess_voxels import minus_average_resting_states
from voxel_preprocessing.preprocess_voxels import reduce_mean
from voxel_preprocessing.select_voxels import VarianceFeatureSelection
from voxel_preprocessing.select_voxels import TopkFeatureSelection
from language_preprocessing.tokenize import SpacyTokenizer
import tensorflow as tf
import numpy as np
import itertools
from tqdm import tqdm
import os
import pickle
from pathlib import Path
class ExplainBrain(object):
def __init__(self, hparams, brain_data_reader, stimuli_encoder, mapper_tuple):
"""
:param brain_data_reader:
:param stimuli_encoder:
:param mapper_tuple: (mapper constructor, args)
:param embedding_type:
:param subject_id:
"""
self.hparams = hparams
self.brain_data_reader = brain_data_reader
self.stimuli_encoder = stimuli_encoder
print("class name:", self.stimuli_encoder.__class__)
self.mapper_tuple = mapper_tuple
self.blocks = None
self.folds = None
self.subject_id = self.hparams.subject_id
self.voxel_selectors = [VarianceFeatureSelection()]
self.post_train_voxel_selectors = [TopkFeatureSelection()]
self.embedding_type = self.hparams.embedding_type
self.past_window = self.hparams.past_window
self.future_window = self.hparams.future_window
self.data_dir = self.hparams.root+"/bridge_data/processed/harrypotter/" + str(self.subject_id) + "_" + self.hparams.context_mode + "_window"+str(self.past_window)+"-"+str(self.future_window) + "_"
self.model_dir = os.path.join("../bridge_models/", str(self.subject_id) + "_" + str(self.embedding_type)+"_"+self.hparams.context_mode+"_window"+str(self.past_window)+"-"+str(self.future_window))
def load_brain_experiment(self):
"""Load stimili and brain measurements.
:return:
time_steps: dictonary that contains the steps for each block. dic key is block number.
brain_activations: {block_number: {time_step: vector of brain activation}
stimuli: {block_number: {time_step: stimuli representation}
"""
print("Processed brain data Exist:", Path(self.data_dir+"brain_activations.npy").exists())
if self.hparams.load_data and Path(self.data_dir+"brain_activations.npy").exists():
brain_activations = np.load(self.data_dir+"brain_activations.npy").item()
stimuli_in_context = np.load(self.data_dir + "stimuli_in_context.npy").item()
time_steps = np.load(self.data_dir + "time_steps.npy").item()
blocks = brain_activations.keys()
else:
all_events = self.brain_data_reader.read_all_events(subject_ids=[self.subject_id])
blocks, time_steps, brain_activations, stimuli_in_context = self.decompose_scan_events(all_events[self.subject_id])
self.blocks = blocks
start_steps = {}
end_steps = {}
for block in self.blocks:
start_steps[block] = 0
while stimuli_in_context[block][start_steps[block]][1] == None:
start_steps[block] += 1
end_steps[block] = start_steps[block]
while stimuli_in_context[block][end_steps[block]][1] != None and end_steps[block] < len(stimuli_in_context[block]):
end_steps[block] += 1
print(start_steps)
print(end_steps)
if self.hparams.save_data:
print("Saving the data ...")
np.save(self.data_dir + "brain_activations", brain_activations)
np.save(self.data_dir + "stimuli_in_context", stimuli_in_context)
np.save(self.data_dir + "time_steps", time_steps)
return time_steps, brain_activations, stimuli_in_context, start_steps, end_steps
def decompose_scan_events(self, scan_events):
"""
:param scan_events:
:return:
time_steps: dictonary that contains the steps for each block. dic key is block number.
brain_activations: {block_number: {time_step: vector of brain activation}
stimuli: {block_number: {time_step: stimuli representation}
"""
tokenizer = SpacyTokenizer()
timesteps = {} # dict(block_id: list)
brain_activations = {} # dict(block_id: list)
stimuli_in_context = {} # dict(block_id: list)
for block in scan_events:
stimuli_in_context[block.block_id] = []
brain_activations[block.block_id] = []
timesteps[block.block_id] = []
for event in tqdm(block.scan_events):
# Get the stimuli in context (what are we going to feed to the computational model!)
context, stimuli_index = block.get_stimuli_in_context(
scan_event=event,
tokenizer=tokenizer,
context_mode=self.hparams.context_mode,
past_window=self.hparams.past_window,
future_window=self.hparams.future_window,
only_past=self.hparams.only_past)
stimuli_in_context[block.block_id].append((context, stimuli_index))
brain_activations[block.block_id].append(event.scan)
timesteps[block.block_id].append(event.timestamp)
return list(stimuli_in_context.keys()), timesteps, brain_activations, stimuli_in_context
def encode_stimuli(self, stimuli_in_context, integration_fn=np.mean):
"""Applies the text encoder on the stimuli.
:param stimuli_in_context:
:return:
"""
if self.hparams.load_encoded_stimuli and Path(self.model_dir+"encoded_stimuli_in_context.npy").exists():
encoded_stimuli_of_each_block = np.load(self.model_dir + "encoded_stimuli_in_context.npy").item()
else:
encoded_stimuli_of_each_block = {}
for block in self.blocks:
encoded_stimuli_of_each_block[block] = []
print("Encoding stimuli of block:", block)
contexts, indexes = zip(*stimuli_in_context[block])
print("class name:", self.stimuli_encoder.__class__)
encoded_stimuli = self.stimuli_encoder.get_embeddings_values(contexts, [len(c) for c in contexts], key=self.embedding_type)
if len(encoded_stimuli.shape) > 2:
for encoded, index in zip(encoded_stimuli,indexes):
# TODO(samira): maybe there is a better fix for this?
if index is None:
index = [0]
encoded_stimuli = integration_fn(encoded[index], axis=0)
encoded_stimuli_of_each_block[block].append(encoded_stimuli)
else:
encoded_stimuli_of_each_block[block] = np.asarray(encoded_stimuli)
if self.hparams.save_encoded_stimuli:
np.save(self.model_dir + "encoded_stimuli_in_context", encoded_stimuli_of_each_block)
return encoded_stimuli_of_each_block
def metrics(self):
"""
:return:
Dictionary of {metric_name: metric_function}
"""
return {'mean_EV': mean_explain_variance}
def voxel_preprocess(self, **kwargs):
"""Returns the list of preprocessing functions with their input parameters.
:param kwargs:
:return:
"""
return [(detrend,{'t_r':2.0}), (reduce_mean,{})]
def eval(self, mapper, brain_activations, encoded_stimuli,
test_blocks, time_steps, start_steps, end_steps, test_delay):
test_encoded_stimuli, test_brain_activations = mapper.prepare_inputs(blocks=test_blocks,
timed_targets=brain_activations,
sorted_inputs=encoded_stimuli,
sorted_timesteps=time_steps,
delay=test_delay,
start_steps=start_steps,
end_steps=end_steps)
test_brain_activations, _ = self.voxel_selection(test_brain_activations, fit=False)
test_brain_activations, selected_voxels = self.post_train_voxel_selection(test_brain_activations)
self.eval_mapper(mapper, test_encoded_stimuli, test_brain_activations)
def eval_mapper(self, mapper, encoded_stimuli, brain_activations):
"""Evaluate the mapper based on the defined metrics.
:param encoded_stimuli:
:param brain_activations:
:return:
"""
mapper_output = mapper.map(inputs=encoded_stimuli,targets=brain_activations)
predictions = mapper_output['predictions']
#predictions, selected_voxels = self.post_train_voxel_selection(predictions)
# brain_activations, selected_voxels = self.post_train_voxel_selection(brain_activations)
print("number of voxels under evaluation:", len(brain_activations[0]), predictions.shape)
for metric_name, metric_fn in self.metrics().items():
metric_eval = metric_fn(predictions=predictions, targets=brain_activations)
print(metric_name,":",metric_eval)
def get_blocks(self):
"""
:return: blocks of the loaded brain data.
"""
return self.blocks
def get_folds(self, fold_index):
"""
:param fold_index:
:return:
train block ids and test block ids for the given fold index.
"""
if self.folds is None:
self.folds = get_folds(self.blocks)
return self.folds[fold_index]
def voxel_selection(self, brain_activations, fit=False):
selected_voxels = np.arange(len(brain_activations[0]))
reduced_brain_activations = brain_activations
for selector in self.voxel_selectors:
reduced_brain_activations = selector.select_featurs(reduced_brain_activations, fit)
selected_voxels = np.asarray(selected_voxels)[selector.get_selected_indexes()]
return reduced_brain_activations, selected_voxels
def post_train_voxel_selection(self, brain_activations, predictions=None, labels=None, fit=False):
"""Voxel selection that needs to be applied after training and is based on training results.
:param brain_activations:
:param predictions:
:param labels:
:param fit:
:return:
"""
if fit and labels is not None and predictions is not None:
print("Fitting post training voxel selectors")
for selector in self.post_train_voxel_selectors:
selector.fit(predictions, labels)
selected_voxels = np.arange(len(brain_activations[0]))
reduced_brain_activations = brain_activations
for selector in self.post_train_voxel_selectors:
reduced_brain_activations = selector.select_featurs(reduced_brain_activations)
selected_voxels = np.asarray(selected_voxels)[selector.get_selected_indexes()]
return reduced_brain_activations, selected_voxels
def preprocess_brain_activations(self, brain_activations, voxel_preprocessings, start_steps, end_steps, resting_norm=False):
for block in self.blocks:
for voxel_preprocessing_fn, args in voxel_preprocessings:
brain_activations[block] = voxel_preprocessing_fn(brain_activations[block], **args)
if resting_norm:
brain_activations[block] = minus_average_resting_states(brain_activations[block], brain_activations[block][:start_steps[block]])
return brain_activations
def train_mapper(self, brain_activations, encoded_stimuli, train_blocks,
test_blocks,time_steps, start_steps, end_steps, train_delay, test_delay, fold_id, save=True):
mapper = self.mapper_tuple[0](**self.mapper_tuple[1])
tf.logging.info('Prepare train pairs ...')
train_encoded_stimuli, train_brain_activations = mapper.prepare_inputs(blocks=train_blocks,
timed_targets=brain_activations,
sorted_inputs=encoded_stimuli,
sorted_timesteps=time_steps,
delay=train_delay,
start_steps=start_steps,
end_steps=end_steps)
train_brain_activations, selected_voxels = self.voxel_selection(train_brain_activations, fit=True)
print("number of selected voxels:", len(selected_voxels))
print(train_brain_activations.shape)
tf.logging.info('Prepare test pairs ...')
# Train the mapper
tf.logging.info('Start training ...')
mapper.train(inputs=train_encoded_stimuli,targets=train_brain_activations)
tf.logging.info('Training done!')
# Select voxels based on performance on training set ...
mapper_output = mapper.map(inputs=train_encoded_stimuli, targets=train_brain_activations)
predictions = mapper_output['predictions']
_, post_selected_voxels = self.post_train_voxel_selection(train_brain_activations, predictions=predictions, labels=train_brain_activations, fit=True)
mapper.train(inputs=train_encoded_stimuli,targets=np.asarray(train_brain_activations)[:,post_selected_voxels])
# Evaluate the mapper
if eval:
tf.logging.info('Evaluating ...')
self.eval_mapper(mapper, train_encoded_stimuli, np.asarray(train_brain_activations)[:,post_selected_voxels])
if len(test_blocks) > 0:
self.eval(mapper, brain_activations, encoded_stimuli,
test_blocks, time_steps, start_steps, end_steps, test_delay)
if save:
print("Saving the model:")
save_dir = self.model_dir + "_fold_"+str(fold_id)+"_delay_"+str(train_delay)
# Save the mode
pickle.dump(mapper, open(save_dir, 'wb'))
# Save the params
pickle.dump(self.hparams, open(save_dir+'_params', 'wb'))
# Save the selected voxels:
np.save(open(save_dir+ 'selected_voxels', 'wb'), selected_voxels)
np.save(open(save_dir + '_post_selected_voxels', 'wb'), post_selected_voxels)
return mapper
def train_mappers(self, delays=[0], cross_delay=True,eval=True, fold_index=-1):
# Add different options to encode the stimuli (sentence based, word based, whole block based, whole story based)
# Load the brain data
tf.logging.info('Loading brain data ...')
time_steps, brain_activations, stimuli, start_steps, end_steps = self.load_brain_experiment()
tf.logging.info('Blocks: %s' %str(self.blocks))
print('Example Stimuli %s' % str(stimuli[1][0]))
# Preprocess brain activations
brain_activations = self.preprocess_brain_activations(brain_activations,
voxel_preprocessings=self.voxel_preprocess(),
start_steps=start_steps, end_steps=end_steps,
resting_norm=False)
# Encode the stimuli and get the representations from the computational model.
tf.logging.info('Encoding the stimuli ...')
def integration_fn(inputs, axis, max_size=512):
if len(inputs.shape) > 1:
inputs = np.mean(inputs, axis=axis)
size = inputs.shape[-1]
return inputs[:np.min([max_size,size])]
encoded_stimuli = self.encode_stimuli(stimuli,integration_fn=integration_fn)
# Get the test and training sets
train_blocks, test_blocks = self.get_folds(fold_index)
# Pepare the data for the mapping model (testc and train sets)
print("train blocks:", train_blocks)
print("test blocks:", test_blocks)
if cross_delay:
delay_pairs = itertools.product(delays, repeat=2)
else:
delay_pairs = zip(delays,delays)
trained_mapper_dic = {}
for train_delay, test_delay in delay_pairs:
train_delay, test_delay = int(train_delay), int(test_delay)
print("Training with train time delay of %d and test time delay of %d" % (train_delay, test_delay))
if train_delay not in trained_mapper_dic:
trained_mapper_dic[train_delay] = self.train_mapper(brain_activations, encoded_stimuli,
train_blocks,test_blocks,
time_steps, start_steps, end_steps,
train_delay, test_delay, fold_index, save=self.hparams.save_models)
else:
self.eval(trained_mapper_dic[train_delay], brain_activations, encoded_stimuli,
test_blocks, time_steps, start_steps, end_steps, test_delay)