|
5 | 5 | ------------------------------------------------------------------------
|
6 | 6 | local MaskZeroCriterion, parent = torch.class("nn.MaskZeroCriterion", "nn.Criterion")
|
7 | 7 |
|
8 |
| -function MaskZeroCriterion:__init(criterion, nInputDim) |
| 8 | +function MaskZeroCriterion:__init(criterion) |
9 | 9 | parent.__init(self)
|
10 | 10 | self.criterion = criterion
|
11 | 11 | assert(torch.isTypeOf(criterion, 'nn.Criterion'))
|
12 |
| - assert(torch.type(nInputDim) == 'number', 'Expecting nInputDim number at arg 2') |
13 |
| - self.nInputDim = nInputDim |
| 12 | + self.v2 = true |
14 | 13 | end
|
15 | 14 |
|
16 |
| -function MaskZeroCriterion:recursiveGetFirst(input) |
17 |
| - if torch.type(input) == 'table' then |
18 |
| - return self:recursiveGetFirst(input[1]) |
19 |
| - else |
20 |
| - assert(torch.isTensor(input)) |
21 |
| - return input |
22 |
| - end |
23 |
| -end |
24 |
| - |
25 |
| -function MaskZeroCriterion:recursiveMask(dst, src, mask) |
26 |
| - if torch.type(src) == 'table' then |
27 |
| - dst = torch.type(dst) == 'table' and dst or {} |
28 |
| - for k,v in ipairs(src) do |
29 |
| - dst[k] = self:recursiveMask(dst[k], v, mask) |
| 15 | +function MaskZeroCriterion:updateOutput(input, target) |
| 16 | + if self.v2 then |
| 17 | + assert(self.zeroMask ~= nil, "MaskZeroCriterion expecting zeroMask tensor or false") |
| 18 | + if self.zeroMask == false then |
| 19 | + self.output = self.criterion:updateOutput(input, target) |
| 20 | + return self.output |
30 | 21 | end
|
31 |
| - else |
32 |
| - assert(torch.isTensor(src)) |
33 |
| - dst = torch.isTensor(dst) and dst or src.new() |
34 |
| - |
35 |
| - dst:index(src, 1, mask) |
| 22 | + assert(self.zeroMask:dim() == 1, "MaskZeroCriterion expecting zeroMask of size batchsize") |
| 23 | + else -- backwards compat |
| 24 | + self.zeroMask = nn.utils.getZeroMaskBatch(input, self.zeroMask) |
36 | 25 | end
|
37 |
| - return dst |
38 |
| -end |
39 | 26 |
|
40 |
| -function MaskZeroCriterion:updateOutput(input, target) |
41 |
| - -- recurrent module input is always the first one |
42 |
| - local rmi = self:recursiveGetFirst(input):contiguous() |
43 |
| - if rmi:dim() == self.nInputDim then |
44 |
| - error("does not support online (i.e. non-batch) mode") |
45 |
| - elseif rmi:dim() - 1 == self.nInputDim then |
46 |
| - rmi = rmi:view(rmi:size(1), -1) -- collapse non-batch dims |
| 27 | + self.isEmptyBatch = (self.zeroMask:sum() == self.zeroMask:nElement()) |
| 28 | + if self.isEmptyBatch then |
| 29 | + self.output = 0 |
47 | 30 | else
|
48 |
| - error("nInputDim error: "..rmi:dim()..", "..self.nInputDim) |
49 |
| - end |
50 |
| - |
51 |
| - -- build mask |
52 |
| - local vectorDim = rmi:dim() |
53 |
| - self._zeroMask = self._zeroMask or rmi.new() |
54 |
| - self._zeroMask:norm(rmi, 2, vectorDim) |
55 |
| - local zeroMask = self._zeroMask |
56 |
| - if torch.isTypeOf(zeroMask, 'torch.CudaTensor') or |
57 |
| - torch.isTypeOf(zeroMask, 'torch.ClTensor') then |
58 |
| - self.__zeroMask = self.__zeroMask or torch.FloatTensor() |
59 |
| - self.__zeroMask:resize(self._zeroMask:size()):copy(self._zeroMask) |
60 |
| - zeroMask = self._zeroMask |
61 |
| - end |
62 |
| - |
63 |
| - self.zeroMask = self.zeroMask or torch.LongTensor() |
64 |
| - self.zeroMask:resize(self._zeroMask:size(1)):zero() |
65 |
| - |
66 |
| - local i, j = 0, 0 |
67 |
| - zeroMask:apply(function(norm) |
68 |
| - i = i + 1 |
69 |
| - if norm ~= 0 then |
70 |
| - j = j + 1 |
71 |
| - self.zeroMask[j] = i |
72 |
| - end |
73 |
| - end) |
74 |
| - self.zeroMask:resize(j) |
75 |
| - |
76 |
| - if j > 0 then |
77 |
| - self.input = self:recursiveMask(self.input, input, self.zeroMask) |
78 |
| - self.target = self:recursiveMask(self.target, target, self.zeroMask) |
79 |
| - |
| 31 | + -- e.g. 0,1,0 -> 1,0,1 |
| 32 | + self._oneMask = self._oneMask or self.zeroMask.new() |
| 33 | + self._oneMask:lt(self.zeroMask, 1) |
| 34 | + -- 1,0,1 -> 1,3 |
| 35 | + self._indices = self._indices or torch.LongTensor() |
| 36 | + self._range = self._range or torch.LongTensor() |
| 37 | + self._range:range(1,self._oneMask:nElement()) |
| 38 | + self._indices:maskedSelect(self._range, self._oneMask) |
| 39 | + -- indexSelect the input |
| 40 | + self.input = nn.utils.recursiveIndex(self.input, input, 1, self._indices) |
| 41 | + self.target = nn.utils.recursiveIndex(self.target, target, 1, self._indices) |
| 42 | + |
80 | 43 | -- forward through decorated criterion
|
81 | 44 | self.output = self.criterion:updateOutput(self.input, self.target)
|
82 |
| - else |
83 |
| - -- when all samples are masked, then loss is zero (issue 128) |
84 |
| - self.output = 0 |
85 | 45 | end
|
86 |
| - |
| 46 | + |
87 | 47 | return self.output
|
88 | 48 | end
|
89 | 49 |
|
90 |
| -function MaskZeroCriterion:recursiveMaskGradInput(dst, mask, src, input) |
91 |
| - if torch.type(input) == 'table' then |
92 |
| - dst = (torch.type(dst) == 'table') and dst or {dst} |
93 |
| - src = (torch.type(src) == 'table') and src or {src} |
94 |
| - for key,_ in pairs(input) do |
95 |
| - dst[key] = self:recursiveMaskGradInput(dst[key], mask, src[key], input[key]) |
96 |
| - end |
97 |
| - for i=#input+1,#dst do |
98 |
| - dst[i] = nil |
99 |
| - end |
100 |
| - elseif torch.isTensor(input) then |
101 |
| - dst = torch.isTensor(dst) and dst or input.new() |
102 |
| - dst:resizeAs(input):zero() |
103 |
| - if mask:nElement() > 0 then |
104 |
| - assert(src) |
105 |
| - dst:indexCopy(1, mask, src) |
106 |
| - end |
107 |
| - else |
108 |
| - error("expecting nested tensors or tables. Got ".. |
109 |
| - torch.type(dst).." and "..torch.type(input).." instead") |
| 50 | +function MaskZeroCriterion:updateGradInput(input, target) |
| 51 | + if self.zeroMask == false then |
| 52 | + self.gradInput = self.criterion:updateGradInput(input, target) |
| 53 | + return self.gradInput |
110 | 54 | end
|
111 |
| - return dst |
112 |
| -end |
113 | 55 |
|
114 |
| -function MaskZeroCriterion:updateGradInput(input, target) |
115 |
| - if self.zeroMask:nElement() > 0 then |
| 56 | + self._gradInput = nn.utils.recursiveResizeAs(self._gradInput, input) |
| 57 | + nn.utils.recursiveFill(self._gradInput, 0) |
| 58 | + |
| 59 | + if not self.isEmptyBatch then |
116 | 60 | assert(self.input and self.target)
|
117 |
| - self._gradInput = self.criterion:updateGradInput(self.input, self.target) |
| 61 | + local gradInput = self.criterion:updateGradInput(self.input, self.target) |
| 62 | + nn.utils.recursiveIndexCopy(self._gradInput, 1, self._indices, gradInput) |
118 | 63 | end
|
119 |
| - self.gradInput = self:recursiveMaskGradInput(self.gradInput, self.zeroMask, self._gradInput, input) |
| 64 | + |
| 65 | + self.gradInput = self._gradInput |
120 | 66 | return self.gradInput
|
121 | 67 | end
|
122 | 68 |
|
123 |
| -function MaskZeroCriterion:type(type, ...) |
| 69 | +function MaskZeroCriterion:clearState() |
124 | 70 | self.zeroMask = nil
|
125 |
| - self._zeroMask = nil |
126 |
| - self.__zeroMask = nil |
| 71 | + self._oneMask = nil |
| 72 | + self._range = nil |
| 73 | + self._indices = nil |
127 | 74 | self.input = nil
|
128 | 75 | self.target = nil
|
| 76 | + self.output = nil |
| 77 | + self.gradInput = nil |
129 | 78 | self._gradInput = nil
|
130 |
| - |
| 79 | +end |
| 80 | + |
| 81 | +function MaskZeroCriterion:type(type, ...) |
| 82 | + self:clearState() |
131 | 83 | return parent.type(self, type, ...)
|
132 | 84 | end
|
| 85 | + |
| 86 | +function MaskZeroCriterion:setZeroMask(zeroMask) |
| 87 | + self.zeroMask = zeroMask |
| 88 | +end |
0 commit comments