Skip to content

Commit f9a912e

Browse files
author
Jonathan Uesato
committed
Add MuFuRU
1 parent 6a169b4 commit f9a912e

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

Mufuru.lua

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
------------------------------------------------------------------------
2+
--[[ MuFuRu - Multi-Function Recurrent Unit ]]--
3+
-- Author: Jonathan Uesato
4+
-- License: LICENSE.2nd.txt
5+
6+
-- Ref. A.: http://arxiv.org/pdf/1606.03002v1.pdf
7+
------------------------------------------------------------------------
8+
9+
local MuFuRu, parent = torch.class('nn.MuFuRu', 'nn.GRU')
10+
11+
local SqrtDiffLayer = nn.Sequential()
12+
:add(nn.CSubTable())
13+
:add(nn.Abs())
14+
:add(nn.Sqrt())
15+
:add(nn.MulConstant(0.25))
16+
17+
local MaxLayer = nn.Sequential()
18+
:add(nn.MapTable(nn.Unsqueeze(1)))
19+
:add(nn.JoinTable(1))
20+
:add(nn.Max(1))
21+
22+
local MinLayer = nn.Sequential()
23+
:add(nn.MapTable(nn.Unsqueeze(1)))
24+
:add(nn.JoinTable(1))
25+
:add(nn.Min(1))
26+
27+
-- all operations take a table {oldState, newState} and return newState
28+
_operations = {
29+
max = MaxLayer,
30+
keep = nn.SelectTable(1),
31+
replace = nn.SelectTable(2),
32+
mul = nn.CMulTable(),
33+
min = MinLayer,
34+
diff = nn.CSubTable(),
35+
forget = nn.Sequential():add(nn.SelectTable(1)):add(nn.MulConstant(0.0)),
36+
sqrt_diff = SqrtDiffLayer
37+
}
38+
39+
function MuFuRu:__init(inputSize, outputSize, ops, rho)
40+
-- Use all ops by default. To replicate GRU, use keep and replace only.
41+
self.ops = ops or {'keep', 'replace', 'mul', 'diff', 'forget', 'sqrt_diff', 'max', 'min'}
42+
self.num_ops = #self.ops
43+
self.operations = {}
44+
for i=1,self.num_ops do
45+
self.operations[i] = _operations[self.ops[i]]
46+
end
47+
self.inputSize = inputSize
48+
self.outputSize = outputSize
49+
parent.__init(self, inputSize, outputSize, rho or 9999)
50+
end
51+
52+
-------------------------- factory methods -----------------------------
53+
function MuFuRu:buildModel()
54+
-- input : {input, prevOutput}
55+
-- output : output
56+
57+
local nonBatchDim = 2
58+
-- resetGate takes {input, prevOutput} to resetGate
59+
local resetGate = nn.Sequential()
60+
:add(nn.ParallelTable()
61+
:add(nn.Linear(self.inputSize, self.outputSize), false)
62+
:add(nn.Linear(self.outputSize, self.outputSize))
63+
)
64+
:add(nn.CAddTable())
65+
:add(nn.Sigmoid())
66+
67+
-- Feature takes {input, prevOutput, reset} to feature
68+
local featureVec = nn.Sequential()
69+
:add(nn.ConcatTable()
70+
:add(nn.SelectTable(1))
71+
:add(nn.Sequential()
72+
:add(nn.NarrowTable(2,2))
73+
:add(nn.CMulTable())
74+
)
75+
)
76+
:add(nn.JoinTable(nonBatchDim)) -- [x_t, r dot s_t-1]
77+
:add(nn.Linear(self.inputSize + self.outputSize, self.outputSize))
78+
:add(nn.Sigmoid())
79+
80+
-- opWeights takes {input, prevOutput, reset} to opWeights.
81+
-- Note that reset is not used
82+
local opWeights = nn.Sequential()
83+
:add(nn.NarrowTable(1,2))
84+
:add(nn.JoinTable(nonBatchDim)) -- k_t
85+
:add(nn.Linear(self.inputSize + self.outputSize, self.num_ops * self.outputSize)) --p^_t
86+
:add(nn.View(self.num_ops, self.outputSize):setNumInputDims(1))
87+
:add(nn.Transpose({1,2}))
88+
:add(nn.SoftMax()) --p_t
89+
90+
-- all_ops takes {oldState, newState} to {newState1, newState2, ...newStateN}
91+
local all_ops = nn.ConcatTable()
92+
for i=1,self.num_ops do
93+
-- an operation is any layer taking {prevHidden, featureVec} to newState
94+
all_ops:add(self.operations[i])
95+
end
96+
97+
local all_op_activations = nn.Sequential()
98+
:add(nn.NarrowTable(1,2))
99+
:add(all_ops)
100+
:add(nn.MapTable(nn.Unsqueeze(1)))
101+
:add(nn.JoinTable(1,3))
102+
103+
-- combine_ops takes {prevHidden, featureVec, opWeights} to nextHidden
104+
local combine_ops = nn.Sequential()
105+
:add(nn.ConcatTable()
106+
:add(all_op_activations)
107+
:add(nn.SelectTable(3))
108+
)
109+
:add(nn.CMulTable())
110+
:add(nn.Sum(1,3))
111+
112+
local cell = nn.Sequential()
113+
:add(nn.ConcatTable()
114+
:add(nn.SelectTable(1))
115+
:add(nn.SelectTable(2))
116+
:add(resetGate)
117+
) -- {input,prevOutput,reset}
118+
:add(nn.ConcatTable()
119+
:add(nn.SelectTable(2))
120+
:add(featureVec)
121+
:add(opWeights)
122+
) -- {prevOutput, v_t, opWeights}
123+
:add(combine_ops)
124+
return cell
125+
end
126+
127+
-- Factory methods are inherited from GRU
128+
129+
function MuFuRu:__tostring__()
130+
local op_str = '{ '
131+
for i=1,self.num_ops do
132+
op_str = op_str .. self.ops[i] .. ' '
133+
end
134+
op_str = op_str .. '}'
135+
return (string.format('%s(%d -> %d) ', torch.type(self), self.inputSize, self.outputSize)) .. op_str
136+
end
137+
138+
function MuFuRu:migrate(params)
139+
error"Migrate not supported for MuFuRu"
140+
end

init.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ torch.include('rnn', 'Recurrent.lua')
3737
torch.include('rnn', 'LSTM.lua')
3838
torch.include('rnn', 'FastLSTM.lua')
3939
torch.include('rnn', 'GRU.lua')
40+
torch.include('rnn', 'Mufuru.lua')
4041
torch.include('rnn', 'Recursor.lua')
4142
torch.include('rnn', 'Recurrence.lua')
4243
torch.include('rnn', 'NormStabilizer.lua')

0 commit comments

Comments
 (0)