Skip to content

Commit 72091ff

Browse files
author
Nicholas Leonard
committed
ReverseTable+SeqReverseSequence -> ReverseSequnece
1 parent 6ca4c57 commit 72091ff

File tree

6 files changed

+106
-49
lines changed

6 files changed

+106
-49
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ SET(luasrc
4242
SeqGRU.lua
4343
SeqLSTM.lua
4444
deprecated/SeqLSTMP.lua
45-
SeqReverseSequence.lua
45+
deprecated/SeqReverseSequence.lua
4646
Sequencer.lua
4747
SequencerCriterion.lua
4848
ZeroGrad.lua
@@ -85,7 +85,7 @@ SET(luasrc
8585
ReinforceCategorical.lua
8686
ReinforceGamma.lua
8787
ReinforceNormal.lua
88-
ReverseTable.lua
88+
ReverseSequence.lua
8989
Sequential.lua
9090
Serial.lua
9191
SimpleColorTransform.lua

ReverseSequence.lua

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
local ReverseSequence, parent = torch.class("nn.ReverseSequence", "nn.Module")
2+
3+
function ReverseSequence:updateOutput(input)
4+
local seqlen
5+
if torch.isTensor(input) then
6+
seqlen = input:size(1)
7+
self.output = torch.isTensor(self.output) and self.output or input.new()
8+
self.output:resizeAs(input)
9+
10+
self._range = self._range or torch.isCudaTensor(input) and torch.CudaLongTensor() or torch.LongTensor()
11+
if self._range:nElement() ~= seqlen then
12+
self._range:range(seqlen,1,-1)
13+
end
14+
self.output:index(input, 1, self._range)
15+
else
16+
seqlen = #input
17+
self.output = torch.type(self.output) == 'table' and self.output or {}
18+
assert(torch.type(input) == 'table', "Expecting table or tensor at arg 1")
19+
20+
-- empty output table
21+
for k,v in ipairs(self.output) do
22+
self.output[k] = nil
23+
end
24+
25+
-- reverse input
26+
local k = 1
27+
for i=seqlen,1,-1 do
28+
self.output[k] = input[i]
29+
k = k + 1
30+
end
31+
end
32+
33+
return self.output
34+
end
35+
36+
function ReverseSequence:updateGradInput(input, gradOutput)
37+
local seqlen
38+
if torch.isTensor(input) then
39+
seqlen = input:size(1)
40+
self.gradInput = torch.isTensor(self.gradInput) and self.gradInput or input.new()
41+
self.gradInput:resizeAs(input)
42+
43+
self.gradInput:index(gradOutput, 1, self._range)
44+
else
45+
seqlen = #input
46+
self.gradInput = torch.type(self.gradInput) == 'table' and self.gradInput or {}
47+
assert(torch.type(gradOutput) == 'table', "Expecting table or tensor at arg 2")
48+
49+
-- empty gradInput table
50+
for k,v in ipairs(self.gradInput) do
51+
self.gradInput[k] = nil
52+
end
53+
54+
-- reverse gradOutput
55+
local k = 1
56+
for i=seqlen,1,-1 do
57+
self.gradInput[k] = gradOutput[i]
58+
k = k + 1
59+
end
60+
end
61+
62+
return self.gradInput
63+
end
64+
65+
function ReverseSequence:clearState()
66+
self.gradInput = torch.Tensor()
67+
self.output = torch.Tensor()
68+
self._range = nil
69+
end
70+
71+
function ReverseSequence:type(...)
72+
self:clearState()
73+
return parent.type(self, ...)
74+
end
75+

ReverseTable.lua

Lines changed: 0 additions & 39 deletions
This file was deleted.
File renamed without changes.

init.lua

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ require('rnn.Collapse')
5858
require('rnn.ZipTable')
5959
require('rnn.ZipTableOneToMany')
6060
require('rnn.CAddTensorTable')
61-
require('rnn.ReverseTable')
61+
require('rnn.ReverseSequence')
6262
require('rnn.Dictionary')
6363
require('rnn.Inception')
6464
require('rnn.Clip')
@@ -131,7 +131,6 @@ require('rnn.RecurrentAttention')
131131
-- sequencer + recurrent modules
132132
require('rnn.SeqLSTM')
133133
require('rnn.SeqGRU')
134-
require('rnn.SeqReverseSequence')
135134
require('rnn.SeqBRNN')
136135

137136
-- recurrent criterions:
@@ -144,6 +143,7 @@ require('rnn.MaskZeroCriterion')
144143
require('rnn.LSTM')
145144
require('rnn.FastLSTM')
146145
require('rnn.SeqLSTMP')
146+
require('rnn.SeqReverseSequence')
147147

148148
-- prevent likely name conflicts
149149
nn.rnn = rnn

test/test.lua

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4991,24 +4991,45 @@ function rnntest.CAddTensorTable()
49914991
mytester:assertTensorEq(output[1]+output[2]+output[3], gradInput[1], 0.000001, "CAddTensorTable gradInput1")
49924992
end
49934993

4994-
function rnntest.ReverseTable()
4994+
function rnntest.ReverseSequence()
4995+
-- test table
4996+
49954997
-- input : { a, b, c, d }
49964998
-- output : { c, b, a, d }
4997-
local r = nn.ReverseTable()
4999+
local r = nn.ReverseSequence()
49985000
local input = {torch.randn(3,4), torch.randn(3,4), torch.randn(3,4), torch.randn(3,4)}
49995001
local output = r:forward(input)
50005002

5001-
mytester:assert(#output == 4, "ReverseTable #output")
5003+
mytester:assert(#output == 4, "ReverseSequence #output")
50025004
local k = 1
50035005
for i=#input,1,-1 do
5004-
mytester:assertTensorEq(input[i], output[k], 0.00001, "ReverseTable output err "..k)
5006+
mytester:assertTensorEq(input[i], output[k], 0.00001, "ReverseSequence output err "..k)
50055007
k = k + 1
50065008
end
50075009

50085010
local gradInput = r:backward(input, output)
5009-
mytester:assert(#gradInput == 4, "ReverseTable #gradInput")
5011+
mytester:assert(#gradInput == 4, "ReverseSequence #gradInput")
50105012
for i=1,#input do
5011-
mytester:assertTensorEq(gradInput[i], input[i], 0.00001, "ReverseTable gradInput err "..i)
5013+
mytester:assertTensorEq(gradInput[i], input[i], 0.00001, "ReverseSequence gradInput err "..i)
5014+
end
5015+
5016+
-- test tensor
5017+
5018+
local r = nn.ReverseSequence()
5019+
local input = torch.randn(5,4,3)
5020+
local output = r:forward(input)
5021+
5022+
mytester:assert(output:isSameSizeAs(input), "ReverseSequence #output")
5023+
local k = 1
5024+
for i=5,1,-1 do
5025+
mytester:assertTensorEq(input[i], output[k], 0.00001, "ReverseSequence output err "..k)
5026+
k = k + 1
5027+
end
5028+
5029+
local gradInput = r:backward(input, output)
5030+
mytester:assert(gradInput:isSameSizeAs(input), "ReverseSequence #gradInput")
5031+
for i=1,5 do
5032+
mytester:assertTensorEq(gradInput[i], input[i], 0.00001, "ReverseSequence gradInput err "..i)
50125033
end
50135034
end
50145035

0 commit comments

Comments
 (0)