Skip to content

Commit

Permalink
Update codes to use recent APIs. Include sample data.
Browse files Browse the repository at this point in the history
  • Loading branch information
edwardchoi-google committed Oct 1, 2019
1 parent 2a16b66 commit df54b9a
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 15 deletions.
24 changes: 12 additions & 12 deletions mime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
import cPickle as pickle
import tensorflow as tf
import sonnet as snt
#import sonnet as snt
from sklearn.metrics import roc_auc_score
from sklearn.metrics import average_precision_score
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -56,34 +56,34 @@ def build_model(options):
rx_visit = emb_activation(rx_visit)
rx_visit = rx_visit * tf.reshape(rx_mask, (-1, options['max_rx_per_dx']))[:, :, None] ####Masking####
rx_visit = tf.reduce_sum(rx_visit, axis=1)
W_dr = snt.Sequential([snt.Linear(output_size=options['rx_emb_size'], name='W_dr'), order_activation])
W_dr = tf.keras.layers.Dense(options['rx_emb_size'], activation=order_activation, name='W_dr')
dr_visit = W_dr(dx_visit)
dr_visit = dr_visit * rx_visit

pr_visit = tf.nn.embedding_lookup(W_emb_pr, tf.reshape(pr_var, (-1, options['max_pr_per_dx'])))
pr_visit = emb_activation(pr_visit)
pr_visit = pr_visit * tf.reshape(pr_mask, (-1, options['max_pr_per_dx']))[:, :, None] ####Masking####
pr_visit = tf.reduce_sum(pr_visit, axis=1)
W_dp = snt.Sequential([snt.Linear(output_size=options['pr_emb_size'], name='W_dr'), order_activation])
W_dp = tf.keras.layers.Dense(options['pr_emb_size'], activation=order_activation, name='W_dp')
dp_visit = W_dp(dx_visit)
dp_visit = dp_visit * pr_visit

dx_obj = dx_visit + dr_visit + dp_visit
W_dx = snt.Sequential([snt.Linear(output_size=options['dxobj_emb_size'], name='W_dxobj'), order_activation])
W_dx = tf.keras.layers.Dense(options['dxobj_emb_size'], activation=order_activation, name='W_dxobj')
dx_obj = W_dx(dx_obj)
pre_visit = tf.reshape(dx_obj, (-1, options['max_dx_per_visit'], options['dxobj_emb_size']))
pre_visit = pre_visit * tf.reshape(dx_mask, (-1, options['max_dx_per_visit']))[:, :, None] ####Masking####
visit = tf.reduce_sum(pre_visit, axis=1)
seq_visit = tf.reshape(visit, (-1, options['batch_size'], options['visit_emb_size']))

seq_length = tf.placeholder(tf.int32, shape=(options['batch_size']), name='seq_length')
rnn = snt.GRU(options['rnn_size'], name='emb2rnn')
rnn2pred = snt.Sequential([snt.Linear(output_size=options['output_size'], name='rnn2pred'), tf.nn.sigmoid])
rnn2aux_dx = snt.Linear(output_size=options['num_dx'], name='rnn2aux_dx')
rnn2aux_rx = snt.Linear(output_size=options['num_rx'], name='rnn2aux_rx')
rnn2aux_pr = snt.Linear(output_size=options['num_pr'], name='rnn2aux_pr')
rnn_cell = tf.keras.layers.GRUCell(options['rnn_size'], name='emb2rnn')
rnn2pred = tf.keras.layers.Dense(options['output_size'], activation=tf.nn.sigmoid, name='rnn2pred')
rnn2aux_dx = tf.keras.layers.Dense(options['num_dx'], name='rnn2aux_dx')
rnn2aux_rx = tf.keras.layers.Dense(options['num_rx'], name='rnn2aux_rx')
rnn2aux_pr = tf.keras.layers.Dense(options['num_pr'], name='rnn2aux_pr')

_, final_states = tf.nn.dynamic_rnn(rnn, seq_visit, dtype=tf.float32, time_major=True, sequence_length=seq_length)
_, final_states = tf.nn.dynamic_rnn(rnn_cell, seq_visit, dtype=tf.float32, time_major=True, sequence_length=seq_length)
preds = tf.squeeze(rnn2pred(final_states))
labels = tf.placeholder(tf.float32, shape=(options['batch_size']), name='labels')
loss = -tf.reduce_mean(labels * tf.log(preds + 1e-10) + (1. - labels) * tf.log(1. - preds + 1e-10))
Expand Down Expand Up @@ -163,8 +163,8 @@ def train(
max_pr_per_dx=10,
regularize=1e-3,
aux_lambda=0.1,
min_threshold=5,
max_threshold=150,
min_threshold=1,
max_threshold=200,
train_ratio=1.0,
association_threshold=0.0,
):
Expand Down
6 changes: 3 additions & 3 deletions mime_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def load_data(input_path, min_threshold=5, max_threshold=150, seed=1234, train_r

seqs, labels = find_patients_with_many_associations(seqs, labels, association_threshold)

temp_seqs, test_seqs, temp_labels, test_labels = train_test_split(seqs, labels, test_size=0.2, random_state=seed)
train_seqs, valid_seqs, train_labels, valid_labels = train_test_split(temp_seqs, temp_labels, test_size=0.1, random_state=seed)
train_seqs, temp_seqs, train_labels, temp_labels = train_test_split(seqs, labels, test_size=0.8, random_state=seed)
valid_seqs, test_seqs, valid_labels, test_labels = train_test_split(temp_seqs, temp_labels, test_size=0.5, random_state=seed)

train_size = int(len(train_seqs) * train_ratio)
train_seqs = train_seqs[:train_size]
Expand Down Expand Up @@ -71,7 +71,7 @@ def find_patients_with_many_associations(seqs, labels, threshold):
code_count += 1
num_visits = float(len(patient))
if code_count / num_visits >= threshold:
new_seqs.append(patient)
new_seqs.append(patient)
new_labels.append(label)
return new_seqs, new_labels

Expand Down
Binary file added sample.labels
Binary file not shown.
Binary file added sample.seqs
Binary file not shown.

0 comments on commit df54b9a

Please sign in to comment.