-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathadd.py
129 lines (101 loc) · 4.58 KB
/
add.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/python
# Author: Clara Vania
import tensorflow as tf
class AdditiveModel(object):
"""
RNNLM using subword to word (S2W) model
Code based on tensorflow tutorial on building a PTB LSTM model.
https://www.tensorflow.org/versions/r0.7/tutorials/recurrent/index.html
"""
def __init__(self, args, is_training, is_testing=False):
self.batch_size = batch_size = args.batch_size
self.num_steps = num_steps = args.num_steps
self.model = model = args.model
self.subword_vocab_size = subword_vocab_size = args.subword_vocab_size
self.optimizer = args.optimization
self.unit = args.unit
rnn_size = args.rnn_size
out_vocab_size = args.out_vocab_size
tf_device = "/gpu:" + str(args.gpu)
if is_testing:
self.batch_size = batch_size = 1
self.num_steps = num_steps = 1
if model == 'rnn':
cell_fn = tf.nn.rnn_cell.BasicRNNCell
elif model == 'gru':
cell_fn = tf.nn.rnn_cell.GRUCell
elif model == 'lstm':
cell_fn = tf.nn.rnn_cell.BasicLSTMCell
else:
raise Exception("model type not supported: {}".format(args.model))
with tf.device(tf_device):
# placeholders for data
self._input_data = tf.placeholder(tf.float32, shape=[batch_size, num_steps, subword_vocab_size])
self._targets = tf.placeholder(tf.int32, shape=[batch_size, num_steps])
# ********************************************************************************
# RNNLM
# ********************************************************************************
lm_cell = cell_fn(rnn_size, forget_bias=0.0)
if is_training and args.keep_prob < 1:
lm_cell = tf.nn.rnn_cell.DropoutWrapper(lm_cell, output_keep_prob=args.keep_prob)
lm_cell = tf.nn.rnn_cell.MultiRNNCell([lm_cell] * args.num_layers)
self._initial_lm_state = lm_cell.zero_state(batch_size, tf.float32)
inputs = self._input_data
if is_training and args.keep_prob < 1:
inputs = tf.nn.dropout(inputs, args.keep_prob)
softmax_win = tf.get_variable("softmax_win", [subword_vocab_size, rnn_size])
softmax_bin = tf.get_variable("softmax_bin", [rnn_size])
# split input into a list
inputs = tf.split(1, num_steps, inputs)
lm_inputs = []
for input_ in inputs:
input_ = tf.squeeze(input_, [1])
input_ = tf.matmul(input_, softmax_win) + softmax_bin
lm_inputs.append(input_)
lm_outputs, lm_state = tf.nn.rnn(lm_cell, lm_inputs, initial_state=self._initial_lm_state)
lm_outputs = tf.concat(1, lm_outputs)
lm_outputs = tf.reshape(lm_outputs, [-1, rnn_size])
softmax_w = tf.get_variable("softmax_w", [out_vocab_size, rnn_size])
softmax_b = tf.get_variable("softmax_b", [out_vocab_size])
# compute cross entropy loss
logits = tf.matmul(lm_outputs, softmax_w, transpose_b=True) + softmax_b
loss = tf.nn.seq2seq.sequence_loss_by_example(
[logits],
[tf.reshape(self._targets, [-1])],
[tf.ones([batch_size * num_steps])])
# compute cost
self._cost = cost = tf.reduce_sum(loss) / batch_size
self._final_state = lm_state
if not is_training:
return
self._lr = tf.Variable(0.0, trainable=False)
tvars = tf.trainable_variables()
grads, _ = tf.clip_by_global_norm(tf.gradients(cost, tvars),
args.grad_clip)
optimizer = tf.train.GradientDescentOptimizer(self._lr)
self._train_op = optimizer.apply_gradients(zip(grads, tvars))
self._new_lr = tf.placeholder(tf.float32, shape=[], name="new_lr")
self._lr_update = tf.assign(self._lr, self._new_lr)
def assign_lr(self, session, lr_value):
session.run(self._lr_update, feed_dict={self._new_lr: lr_value})
@property
def input_data(self):
return self._input_data
@property
def targets(self):
return self._targets
@property
def initial_lm_state(self):
return self._initial_lm_state
@property
def cost(self):
return self._cost
@property
def final_state(self):
return self._final_state
@property
def lr(self):
return self._lr
@property
def train_op(self):
return self._train_op