Skip to content

Commit d41eba7

Browse files
author
Nicholas Leonard
committed
MaskZeroCriterion v2
1 parent 587be96 commit d41eba7

9 files changed

+335
-431
lines changed

AbstractSequencerCriterion.lua

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
------------------------------------------------------------------------
2+
--[[ AbstractSequencerCriterion ]]--
3+
-- Inherited by SequencerCriterion and RepeaterCriterion
4+
-- WARNING : assumes that the decorated criterion is stateless, i.e.
5+
-- the backward doesn't need to be preceded by a commensurate forward.
6+
------------------------------------------------------------------------
7+
local AbstractSequencerCriterion, parent = torch.class('nn.AbstractSequencerCriterion', 'nn.Criterion')
8+
9+
function AbstractSequencerCriterion:__init(criterion, sizeAverage)
10+
parent.__init(self)
11+
self.criterion = criterion
12+
if torch.isTypeOf(criterion, 'nn.ModuleCriterion') then
13+
error(torch.type(self).." shouldn't decorate a ModuleCriterion. "..
14+
"Instead, try the other way around : "..
15+
"ModuleCriterion decorates a ".. torch.type(self) .. ". "..
16+
"Its modules can also be similarly decorated with a Sequencer.")
17+
end
18+
if sizeAverage ~= nil then
19+
self.sizeAverage = sizeAverage
20+
else
21+
self.sizeAverage = false
22+
end
23+
self.clones = {}
24+
end
25+
26+
function AbstractSequencerCriterion:getStepCriterion(step)
27+
assert(step, "expecting step at arg 1")
28+
local criterion = self.clones[step]
29+
if not criterion then
30+
criterion = self.criterion:clone()
31+
self.clones[step] = criterion
32+
end
33+
return criterion
34+
end
35+
36+
function AbstractSequencerCriterion:setZeroMask(zeroMask)
37+
if zeroMask == false then
38+
for k,stepcriterion in pairs(self.clones) do
39+
stepcriterion:setZeroMask(zeroMask)
40+
end
41+
else
42+
assert(zeroMask:dim() >= 2, "Expecting dim >= 2 for zeroMask. For example, seqlen x batchsize")
43+
for step=1,zeroMask:size(1) do
44+
local stepcriterion = self:getStepCriterion(step)
45+
stepcriterion:setZeroMask(zeroMask[step])
46+
end
47+
end
48+
end
49+

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ SET(luasrc
1515
init.lua
1616
AbstractRecurrent.lua
1717
AbstractSequencer.lua
18+
AbstractSequencerCriterion.lua
1819
BiSequencer.lua
1920
BiSequencerLM.lua
2021
CopyGrad.lua
@@ -33,7 +34,6 @@ SET(luasrc
3334
Padding.lua
3435
Recurrence.lua
3536
RecurrentAttention.lua
36-
recursiveUtils.lua
3737
Recursor.lua
3838
Repeater.lua
3939
RepeaterCriterion.lua

Criterion.lua

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,15 @@ local Criterion = nn.Criterion
22

33
Criterion.toBatch = nn.Module.toBatch
44
Criterion.fromBatch = nn.Module.fromBatch
5+
6+
7+
function Criterion:setZeroMask(zeroMask)
8+
if self.criterions then
9+
for i, criterion in ipairs(self.criterions) do
10+
criterion:setZeroMask(zeroMask)
11+
end
12+
end
13+
if self.criterion then
14+
self.criterion:setZeroMask(zeroMask)
15+
end
16+
end

MaskZeroCriterion.lua

Lines changed: 53 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -5,128 +5,84 @@
55
------------------------------------------------------------------------
66
local MaskZeroCriterion, parent = torch.class("nn.MaskZeroCriterion", "nn.Criterion")
77

8-
function MaskZeroCriterion:__init(criterion, nInputDim)
8+
function MaskZeroCriterion:__init(criterion)
99
parent.__init(self)
1010
self.criterion = criterion
1111
assert(torch.isTypeOf(criterion, 'nn.Criterion'))
12-
assert(torch.type(nInputDim) == 'number', 'Expecting nInputDim number at arg 2')
13-
self.nInputDim = nInputDim
12+
self.v2 = true
1413
end
1514

16-
function MaskZeroCriterion:recursiveGetFirst(input)
17-
if torch.type(input) == 'table' then
18-
return self:recursiveGetFirst(input[1])
19-
else
20-
assert(torch.isTensor(input))
21-
return input
22-
end
23-
end
24-
25-
function MaskZeroCriterion:recursiveMask(dst, src, mask)
26-
if torch.type(src) == 'table' then
27-
dst = torch.type(dst) == 'table' and dst or {}
28-
for k,v in ipairs(src) do
29-
dst[k] = self:recursiveMask(dst[k], v, mask)
15+
function MaskZeroCriterion:updateOutput(input, target)
16+
if self.v2 then
17+
assert(self.zeroMask ~= nil, "MaskZeroCriterion expecting zeroMask tensor or false")
18+
if self.zeroMask == false then
19+
self.output = self.criterion:updateOutput(input, target)
20+
return self.output
3021
end
31-
else
32-
assert(torch.isTensor(src))
33-
dst = torch.isTensor(dst) and dst or src.new()
34-
35-
dst:index(src, 1, mask)
22+
assert(self.zeroMask:dim() == 1, "MaskZeroCriterion expecting zeroMask of size batchsize")
23+
else -- backwards compat
24+
self.zeroMask = nn.utils.getZeroMaskBatch(input, self.zeroMask)
3625
end
37-
return dst
38-
end
3926

40-
function MaskZeroCriterion:updateOutput(input, target)
41-
-- recurrent module input is always the first one
42-
local rmi = self:recursiveGetFirst(input):contiguous()
43-
if rmi:dim() == self.nInputDim then
44-
error("does not support online (i.e. non-batch) mode")
45-
elseif rmi:dim() - 1 == self.nInputDim then
46-
rmi = rmi:view(rmi:size(1), -1) -- collapse non-batch dims
27+
self.isEmptyBatch = (self.zeroMask:sum() == self.zeroMask:nElement())
28+
if self.isEmptyBatch then
29+
self.output = 0
4730
else
48-
error("nInputDim error: "..rmi:dim()..", "..self.nInputDim)
49-
end
50-
51-
-- build mask
52-
local vectorDim = rmi:dim()
53-
self._zeroMask = self._zeroMask or rmi.new()
54-
self._zeroMask:norm(rmi, 2, vectorDim)
55-
local zeroMask = self._zeroMask
56-
if torch.isTypeOf(zeroMask, 'torch.CudaTensor') or
57-
torch.isTypeOf(zeroMask, 'torch.ClTensor') then
58-
self.__zeroMask = self.__zeroMask or torch.FloatTensor()
59-
self.__zeroMask:resize(self._zeroMask:size()):copy(self._zeroMask)
60-
zeroMask = self._zeroMask
61-
end
62-
63-
self.zeroMask = self.zeroMask or torch.LongTensor()
64-
self.zeroMask:resize(self._zeroMask:size(1)):zero()
65-
66-
local i, j = 0, 0
67-
zeroMask:apply(function(norm)
68-
i = i + 1
69-
if norm ~= 0 then
70-
j = j + 1
71-
self.zeroMask[j] = i
72-
end
73-
end)
74-
self.zeroMask:resize(j)
75-
76-
if j > 0 then
77-
self.input = self:recursiveMask(self.input, input, self.zeroMask)
78-
self.target = self:recursiveMask(self.target, target, self.zeroMask)
79-
31+
-- e.g. 0,1,0 -> 1,0,1
32+
self._oneMask = self._oneMask or self.zeroMask.new()
33+
self._oneMask:lt(self.zeroMask, 1)
34+
-- 1,0,1 -> 1,3
35+
self._indices = self._indices or torch.LongTensor()
36+
self._range = self._range or torch.LongTensor()
37+
self._range:range(1,self._oneMask:nElement())
38+
self._indices:maskedSelect(self._range, self._oneMask)
39+
-- indexSelect the input
40+
self.input = nn.utils.recursiveIndex(self.input, input, 1, self._indices)
41+
self.target = nn.utils.recursiveIndex(self.target, target, 1, self._indices)
42+
8043
-- forward through decorated criterion
8144
self.output = self.criterion:updateOutput(self.input, self.target)
82-
else
83-
-- when all samples are masked, then loss is zero (issue 128)
84-
self.output = 0
8545
end
86-
46+
8747
return self.output
8848
end
8949

90-
function MaskZeroCriterion:recursiveMaskGradInput(dst, mask, src, input)
91-
if torch.type(input) == 'table' then
92-
dst = (torch.type(dst) == 'table') and dst or {dst}
93-
src = (torch.type(src) == 'table') and src or {src}
94-
for key,_ in pairs(input) do
95-
dst[key] = self:recursiveMaskGradInput(dst[key], mask, src[key], input[key])
96-
end
97-
for i=#input+1,#dst do
98-
dst[i] = nil
99-
end
100-
elseif torch.isTensor(input) then
101-
dst = torch.isTensor(dst) and dst or input.new()
102-
dst:resizeAs(input):zero()
103-
if mask:nElement() > 0 then
104-
assert(src)
105-
dst:indexCopy(1, mask, src)
106-
end
107-
else
108-
error("expecting nested tensors or tables. Got "..
109-
torch.type(dst).." and "..torch.type(input).." instead")
50+
function MaskZeroCriterion:updateGradInput(input, target)
51+
if self.zeroMask == false then
52+
self.gradInput = self.criterion:updateGradInput(input, target)
53+
return self.gradInput
11054
end
111-
return dst
112-
end
11355

114-
function MaskZeroCriterion:updateGradInput(input, target)
115-
if self.zeroMask:nElement() > 0 then
56+
self._gradInput = nn.utils.recursiveResizeAs(self._gradInput, input)
57+
nn.utils.recursiveFill(self._gradInput, 0)
58+
59+
if not self.isEmptyBatch then
11660
assert(self.input and self.target)
117-
self._gradInput = self.criterion:updateGradInput(self.input, self.target)
61+
local gradInput = self.criterion:updateGradInput(self.input, self.target)
62+
nn.utils.recursiveIndexCopy(self._gradInput, 1, self._indices, gradInput)
11863
end
119-
self.gradInput = self:recursiveMaskGradInput(self.gradInput, self.zeroMask, self._gradInput, input)
64+
65+
self.gradInput = self._gradInput
12066
return self.gradInput
12167
end
12268

123-
function MaskZeroCriterion:type(type, ...)
69+
function MaskZeroCriterion:clearState()
12470
self.zeroMask = nil
125-
self._zeroMask = nil
126-
self.__zeroMask = nil
71+
self._oneMask = nil
72+
self._range = nil
73+
self._indices = nil
12774
self.input = nil
12875
self.target = nil
76+
self.output = nil
77+
self.gradInput = nil
12978
self._gradInput = nil
130-
79+
end
80+
81+
function MaskZeroCriterion:type(type, ...)
82+
self:clearState()
13183
return parent.type(self, type, ...)
13284
end
85+
86+
function MaskZeroCriterion:setZeroMask(zeroMask)
87+
self.zeroMask = zeroMask
88+
end

README.md

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,20 +1192,25 @@ This lookup table makes it possible to pad sequences with different lengths in t
11921192

11931193
<a name='rnn.MaskZeroCriterion'></a>
11941194
## MaskZeroCriterion ##
1195-
This criterion zeroes the `err` and `gradInput` rows of the decorated criterion
1196-
for commensurate `input` rows which are tensors of zeros.
1195+
1196+
This criterion ignores samples (rows in the `input` and `target` tensors)
1197+
where the `zeroMask` ByteTensor passed to `MaskZeroCriterion:setZeroMask(zeroMask)` is 1.
1198+
This criterion only supports batch-mode.
11971199

11981200
```lua
1199-
mzc = nn.MaskZeroCriterion(criterion, nInputDim)
1201+
batchsize = 3
1202+
zeroMask = torch.ByteTensor(batchsize):zero()
1203+
zeroMask[2] = 1 -- the 2nd sample in batch is ignored
1204+
mzc = nn.MaskZeroCriterion(criterion)
1205+
mzc:setZeroMask(zeroMask)
1206+
loss = mzc:forward(input, target)
1207+
gradInput = mzc:backward(input, target)
1208+
assert(gradInput[2]:sum() == 0)
12001209
```
12011210

1202-
The `gradInput` Tensor (or table thereof) of the decorated `criterion`
1203-
will have each row (samples) zeroed when the commensurate row of the `input`
1204-
is a tensor of zeros. The `err` will also disregard such zero rows.
1205-
1206-
The `nInputDim` argument must specify the number of non-batch dims
1207-
in the first Tensor of the `input`. In the case of an `input` table,
1208-
the first Tensor is the first one encountered when doing a depth-first search.
1211+
In the above example, the second row of the `gradInput` Tensor is zero.
1212+
This is because the commensurate row in the `zeroMask` is a one.
1213+
The call to `forward` also disregards the second sample in measuring the `loss`.
12091214

12101215
This decorator makes it possible to pad sequences with different lengths in the same batch with zero vectors.
12111216

RepeaterCriterion.lua

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,59 +4,57 @@
44
-- same target (the target is repeated).
55
-- Useful for nn.Repeater and nn.Sequencer.
66
------------------------------------------------------------------------
7-
assert(not nn.RepeaterCriterion, "update nnx package : luarocks install nnx")
8-
local RepeaterCriterion, parent = torch.class('nn.RepeaterCriterion', 'nn.Criterion')
7+
local RepeaterCriterion, parent = torch.class('nn.RepeaterCriterion', 'nn.AbstractSequencerCriterion')
98

10-
function RepeaterCriterion:__init(criterion)
11-
parent.__init(self)
12-
self.criterion = criterion
13-
self.gradInput = {}
14-
self.clones = {}
15-
end
16-
17-
RepeaterCriterion.getStepCriterion = nn.SequencerCriterion.getStepCriterion
18-
19-
function RepeaterCriterion:forward(input, target)
9+
function RepeaterCriterion:updateOutput(input, target)
2010
self.output = 0
21-
local nStep
11+
local seqlen
2212
if torch.isTensor(input) then
23-
nStep = input:size(1)
13+
seqlen = input:size(1)
2414
else
25-
nStep = #input
15+
seqlen = #input
2616
end
2717

28-
29-
for i=1,nStep do
18+
for i=1,seqlen do
3019
local criterion = self:getStepCriterion(i)
3120
self.output = self.output + criterion:forward(input[i], target)
3221
end
33-
22+
23+
24+
if self.sizeAverage then
25+
self.output = self.output / seqlen
26+
end
27+
3428
return self.output
3529
end
3630

37-
function RepeaterCriterion:backward(input, target)
31+
function RepeaterCriterion:updateGradInput(input, target)
3832
self.gradInput = {}
3933
if torch.isTensor(input) then
40-
nStep = input:size(1)
34+
seqlen = input:size(1)
4135
else
42-
nStep = #input
36+
seqlen = #input
4337
end
44-
38+
4539
local tableGradInput = {}
46-
for i=1,nStep do
40+
for i=1,seqlen do
4741
local criterion = self:getStepCriterion(i)
4842
tableGradInput[i] = criterion:backward(input[i], target)
4943
end
50-
44+
45+
if self.sizeAverage then
46+
nn.utils.recursiveDiv(tableGradInput[i], seqlen)
47+
end
48+
5149
if torch.isTensor(input) then
5250
self.gradInput = tableGradInput[1].new()
53-
self.gradInput:resize(nStep, unpack(tableGradInput[1]:size():totable()))
54-
for step=1,nStep do
51+
self.gradInput:resize(seqlen, unpack(tableGradInput[1]:size():totable()))
52+
for step=1,seqlen do
5553
self.gradInput[step]:copy(tableGradInput[step])
5654
end
5755
else
5856
self.gradInput = tableGradInput
5957
end
60-
58+
6159
return self.gradInput
6260
end

0 commit comments

Comments
 (0)