Skip to content

Commit e5508b7

Browse files
author
Nicholas Leonard
committed
GRU and LSTM optimizations
1 parent 71d80f8 commit e5508b7

File tree

8 files changed

+80
-16
lines changed

8 files changed

+80
-16
lines changed

SeqGRU.lua

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,12 @@ function SeqGRU:updateOutput(input)
9797

9898
local h = self.output
9999
h:resize(seqlen, batchsize, outputsize):zero()
100-
self.gates:resize(seqlen, batchsize, 3 * outputsize):zero()
100+
101+
local nElement = self.gates:nElement()
102+
self.gates:resize(seqlen, batchsize, 3 * outputsize)
103+
if nElement ~= seqlen * batchsize * 3 * outputsize then
104+
self.gates:zero()
105+
end
101106

102107
local prev_h = h0
103108
if input.nn and input.nn.StepGRU_updateOutput and not self.forceLua then
@@ -184,15 +189,16 @@ function SeqGRU:backward(input, gradOutput, scale)
184189
local u = self.gates[{t, {}, {outputsize + 1, 2 * outputsize}}]
185190
local hc = self.gates[{t, {}, {2 * outputsize + 1, 3 * outputsize}}]
186191

187-
local grad_a = self.grad_a_buffer:resize(batchsize, 3 * outputsize):zero()
192+
local grad_a = self.grad_a_buffer:resize(batchsize, 3 * outputsize)
193+
188194
local grad_ar = grad_a[{{}, {1, outputsize}}]
189195
local grad_au = grad_a[{{}, {outputsize + 1, 2 * outputsize}}]
190196
local grad_ahc = grad_a[{{}, {2 * outputsize + 1, 3 * outputsize}}]
191197

192198
-- use grad_au as temporary buffer to compute grad_ahc.
193199

194200
local grad_hc = grad_au:fill(0):addcmul(grad_next_h, -1, u, grad_next_h)
195-
grad_ahc:fill(1):addcmul(-1, hc,hc):cmul(grad_hc)
201+
grad_ahc:fill(1):addcmul(-1, hc, hc):cmul(grad_hc)
196202
local grad_r = grad_au:fill(0):addmm(grad_ahc, Wh[{{}, {2 * outputsize + 1, 3 * outputsize}}]:t() ):cmul(prev_h)
197203
grad_ar:fill(1):add(-1, r):cmul(r):cmul(grad_r)
198204

SeqLSTM.lua

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,12 @@ function SeqLSTM:updateOutput(input)
145145
local h, c = self.output, self.cell
146146
h:resize(seqlen, batchsize, outputsize)
147147
c:resize(seqlen, batchsize, hiddensize)
148+
149+
local nElement = self.gates:nElement()
148150
self.gates:resize(seqlen, batchsize, 4 * hiddensize)
151+
if nElement ~= seqlen * batchsize * 4 * hiddensize then
152+
self.gates:zero()
153+
end
149154

150155
local prev_h, prev_c = h0, c0
151156
if input.nn and input.nn.StepLSTM_updateOutput and not self.forceLua then

StepGRU.lua

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,12 @@ function StepGRU:updateOutput(input)
5252
local Wh = self.weight:narrow(1, inputsize + 1, self.outputsize)
5353

5454
next_h:resize(batchsize, outputsize)
55-
self.gates:resize(batchsize, 3 * outputsize):zero()
5655
local gates = self.gates
56+
local nElement = gates:nElement()
57+
gates:resize(batchsize, 3 * outputsize)
58+
if gates:nElement() ~= batchsize * 3 * outputsize then
59+
gates:zero()
60+
end
5761

5862
gates:addmm(bias_expand, cur_x, Wx)
5963
local sub_gates = gates:narrow(2, 1, 2 * outputsize)
@@ -92,7 +96,6 @@ function StepGRU:backward(input, gradOutput, scale)
9296
scale = scale or 1.0
9397
assert(scale == 1.0, 'must have scale=1')
9498

95-
--
9699
local grad_gates = torch.getBuffer('StepGRU', 'grad_gates', self.gates) -- batchsize x 3*outputsize
97100
local buffer = torch.getBuffer('StepGRU', 'buffer', self.gates) -- 1 x 3*outputsize
98101

@@ -125,7 +128,8 @@ function StepGRU:backward(input, gradOutput, scale)
125128
local update_gate = gates:narrow(2, outputsize + 1, outputsize)
126129
local hidden_candidate = gates:narrow(2, 2 * outputsize + 1, outputsize)
127130

128-
grad_gates:resize(batchsize, 3 * outputsize):zero()
131+
grad_gates:resize(batchsize, 3 * outputsize)
132+
129133
local grad_reset_gate = grad_gates:narrow(2, 1, outputsize)
130134
local grad_update_gate = grad_gates:narrow(2, outputsize + 1, outputsize)
131135
local grad_hidden_candidate = grad_gates:narrow(2, 2 * outputsize + 1, outputsize)

StepLSTM.lua

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,12 @@ function StepLSTM:updateOutput(input)
8282
next_h:resize(batchsize, hiddensize)
8383
next_c:resize(batchsize, hiddensize)
8484

85-
self.gates:resize(batchsize, 4 * hiddensize):zero()
8685
local gates = self.gates
86+
local nElement = gates:nElement()
87+
gates:resize(batchsize, 4 * hiddensize)
88+
if gates:nElement() ~= batchsize * 4 * hiddensize then
89+
gates:zero()
90+
end
8791

8892
-- forward
8993
gates:addmm(bias_expand, cur_x, Wx)
@@ -182,7 +186,8 @@ function StepLSTM:backward(input, gradOutput, scale)
182186
local output_gate = gates[{{}, {2 * hiddensize + 1, 3 * hiddensize}}]
183187
local input_transform = gates[{{}, {3 * hiddensize + 1, 4 * hiddensize}}]
184188

185-
grad_gates:resize(batchsize, 4 * hiddensize):zero()
189+
grad_gates:resize(batchsize, 4 * hiddensize)
190+
186191
local grad_input_gate = grad_gates[{{}, {1, hiddensize}}]
187192
local grad_forget_gate = grad_gates[{{}, {hiddensize + 1, 2 * hiddensize}}]
188193
local grad_output_gate = grad_gates[{{}, {2 * hiddensize + 1, 3 * hiddensize}}]

benchmark/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# Benchmark
22

3-
On CPU, using Ubuntu 16.04, using float32, Torch LSTM boasts 886 samples/sec compared to TF’s 809 samples/sec for LSTM with 512 hiddensize and 64 batchsize.
4-
On the other hand, for 128 hiddensize and 32 batchsize, Torch has 3950 compared to TF’s 4130 samples/sec.
3+
On CPU, using Ubuntu 16.04, using float32, Torch LSTM boasts 900 samples/sec compared to TF’s 809 samples/sec for LSTM with 512 hiddensize and 64 batchsize.
4+
On the other hand, for 128 hiddensize and 32 batchsize, Torch has 3990 compared to TF’s 4130 samples/sec.

generic/StepGRU.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@ static int nn_(StepGRU_updateOutput)(lua_State *L) {
2222
buffer->size[0] = batchsize;
2323

2424
THTensor_(resize2d)(next_h, batchsize, outputsize);
25+
long nElement = THTensor_(nElement)(gates);
2526
THTensor_(resize2d)(gates, batchsize, 3 * outputsize);
27+
if (nElement != batchsize * 3 * outputsize)
28+
THTensor_(fill)(gates, 0);
2629

2730
THTensor *Wx = THTensor_(newNarrow)(weight, 0, 0, inputsize);
2831
THTensor *Wh = THTensor_(newNarrow)(weight, 0, inputsize, outputsize);
@@ -32,8 +35,6 @@ static int nn_(StepGRU_updateOutput)(lua_State *L) {
3235
THTensor *update_gate = THTensor_(newNarrow)(gates, 1, outputsize, outputsize); // u = sig(Wx * x + Wh * prev_h + b)
3336
THTensor *hidden_candidate = THTensor_(newNarrow)(gates, 1, 2*outputsize, outputsize); // hc = tanh(Wx * x + Wh * r . prev_h + b)
3437

35-
//THTensor_(fill)(gates, 0);
36-
3738
// forward
3839
THTensor_(addmm)(gates, 1, buffer, 1, cur_x, Wx);
3940
THTensor_(addmm)(sub_gates, 1, sub_gates, 1, prev_h, sub_Wh);
@@ -84,7 +85,6 @@ static int nn_(StepGRU_backward)(lua_State *L) {
8485
THTensor_(resize2d)(grad_cur_x, batchsize, inputsize);
8586
THTensor_(resize2d)(grad_prev_h, batchsize, outputsize);
8687
THTensor_(resize2d)(grad_gates, batchsize, 3 * outputsize);
87-
THTensor_(fill)(grad_gates, 0);
8888

8989
THTensor *Wx = THTensor_(newNarrow)(weight, 0, 0, inputsize);
9090
THTensor *Wh = THTensor_(newNarrow)(weight, 0, inputsize, outputsize);

generic/StepLSTM.c

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ static int nn_(StepLSTM_updateOutput)(lua_State *L) {
2929

3030
THTensor_(resize2d)(next_h, batchsize, hiddensize);
3131
THTensor_(resize2d)(next_c, batchsize, hiddensize);
32-
32+
long nElement = THTensor_(nElement)(gates);
3333
THTensor_(resize2d)(gates, batchsize, 4 * hiddensize);
34-
//THTensor_(fill)(gates, 0);
34+
if (nElement != batchsize * 4 * hiddensize)
35+
THTensor_(fill)(gates, 0);
3536

3637
// forward
3738
THTensor_(addmm)(gates, 1, buffer, 1, cur_x, Wx);
@@ -147,7 +148,6 @@ static int nn_(StepLSTM_backward)(lua_State *L) {
147148
THTensor *grad_Wh = THTensor_(newNarrow)(gradWeight, 0, inputsize, outputsize);
148149

149150
THTensor_(resize2d)(grad_gates, batchsize, 4 * hiddensize);
150-
THTensor_(fill)(grad_gates, 0);
151151

152152
THTensor *grad_input_gate = THTensor_(newNarrow)(grad_gates, 1, 0, hiddensize);
153153
THTensor *grad_forget_gate = THTensor_(newNarrow)(grad_gates, 1, hiddensize, hiddensize);

test/README.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,48 @@ fast LSTM memory: 98.27904510498:2.2750110626221 MB
3636
step LSTM memory: 17.168065071106:2.1289348602295 MB
3737
rec LSTM memory: 13.374607086182:2.0407600402832 MB
3838
seq LSTM memory: 8.8895826339722:3.0098876953125 MB
39+
```
40+
41+
More optimizations
42+
43+
```
44+
th -lrnn -e 'rnn.bigtest({"LSTM","GRU"})'
45+
Running 3 tests
46+
1/3 LSTM_char_rnn ....................................................... [PASS]
47+
2/3 GRU ................................................................. [WAIT]CPU test
48+
old GRU time: 0.039725697040558 seconds
49+
step GRU time: 0.014464259147644 seconds
50+
luarec GRU time: 0.017707204818726 seconds
51+
rec GRU time: 0.013900947570801 seconds
52+
luaseq GRU time: 0.016570293903351 seconds
53+
seq GRU time: 0.012663447856903 seconds
54+
RecGRU-C 1.2738127907136 faster than RecGRU-Lua
55+
RecGRU 2.8577690001509 faster than old GRU
56+
SeqGRU 1.0977221786579 faster than RecGRU
57+
SeqGRU-C 1.3085136126113 faster than SeqGRU-Lua
58+
Memory test
59+
old GRU memory: 82.804834365845:1.833381652832 MB
60+
step GRU memory: 10.018351554871:1.5651426315308 MB
61+
rec GRU memory: 10.018255233765:1.5337238311768 MB
62+
seq GRU memory: 6.3827362060547:1.5385322570801 MB
63+
2/3 GRU ................................................................. [PASS]
64+
3/3 LSTM ................................................................ [WAIT]CPU test
65+
fast LSTM time: 0.044381546974182 seconds
66+
step LSTM time: 0.021313452720642 seconds
67+
luarec LSTM time: 0.021889448165894 seconds
68+
rec LSTM time: 0.017923295497894 seconds
69+
luaseq LSTM time: 0.018705642223358 seconds
70+
seq LSTM time: 0.016467046737671 seconds
71+
RecLSTM-C 1.2212847893104 faster than RecLSTM-Lua
72+
RecLSTM 2.476193453341 faster than FastLSTM
73+
SeqLSTM 1.0884341183591 faster than RecLSTM
74+
SeqLSTM-C 1.1359439565181 faster than SeqLSTM-Lua
75+
Memory test
76+
fast LSTM memory: 98.2790184021:2.2749843597412 MB
77+
step LSTM memory: 17.168484687805:2.1293544769287 MB
78+
rec LSTM memory: 13.375099182129:2.0412521362305 MB
79+
seq LSTM memory: 8.8264684677124:2.0093183517456 MB
80+
3/3 LSTM ................................................................ [PASS]
81+
Completed 0 asserts in 3 tests with 0 failures and 0 errors and 1 warning
82+
--------------------------------------------------------------------------------
3983
```

0 commit comments

Comments
 (0)