6
6
-- Decorator that zeroes the output rows of the encapsulated module
7
7
-- for commensurate input rows which are tensors of zeros
8
8
9
- -- The only difference from `MaskZero` is that it reduces computational costs
10
- -- by varying a batch size, if any, for the case that varying lengths
11
- -- are provided in the input. Notice that when the lengths are consistent,
12
- -- `MaskZero` will be faster, because `TrimZero` has an operational cost.
9
+ -- The only difference from `MaskZero` is that it reduces computational costs
10
+ -- by varying a batch size, if any, for the case that varying lengths
11
+ -- are provided in the input. Notice that when the lengths are consistent,
12
+ -- `MaskZero` will be faster, because `TrimZero` has an operational cost.
13
13
14
14
-- In short, the result is the same with `MaskZero`'s, however, `TrimZero` is
15
15
-- faster than `MaskZero` only when sentence lengths is costly vary.
@@ -38,7 +38,7 @@ function TrimZero:recursiveMask(output, input, mask)
38
38
else
39
39
assert (torch .isTensor (input ))
40
40
output = torch .isTensor (output ) and output or input .new ()
41
-
41
+
42
42
-- make sure mask has the same dimension as the input tensor
43
43
if torch .type (mask ) ~= ' torch.LongTensor' then
44
44
local inputSize = input :size ():fill (1 )
@@ -48,7 +48,7 @@ function TrimZero:recursiveMask(output, input, mask)
48
48
end
49
49
mask :resize (inputSize )
50
50
end
51
-
51
+
52
52
-- build mask
53
53
if self .batchmode then
54
54
assert (torch .find , ' install torchx package : luarocks install torchx' )
@@ -67,11 +67,11 @@ function TrimZero:recursiveMask(output, input, mask)
67
67
else
68
68
output :index (input , 1 , torch .LongTensor {1 }):zero ()
69
69
end
70
- else
71
- if mask :dim () == 0 or mask :view (- 1 )[1 ] == 1 then
72
- output :resize (input :size ()):zero ()
73
- else
74
- output :resize (input :size ()):copy (input )
70
+ else
71
+ if mask :dim () == 0 or mask :view (- 1 )[1 ] == 1 then
72
+ output :resize (input :size ()):zero ()
73
+ else
74
+ output :resize (input :size ()):copy (input )
75
75
end
76
76
end
77
77
end
@@ -87,14 +87,14 @@ function TrimZero:recursiveUnMask(output, input, mask)
87
87
else
88
88
assert (torch .isTensor (input ))
89
89
output = torch .isTensor (output ) and output or input .new ()
90
-
90
+
91
91
-- make sure output has the same dimension as the mask
92
92
local inputSize = input :size ()
93
93
if self .batchmode then
94
94
inputSize [1 ] = mask :size (1 )
95
95
end
96
96
output :resize (inputSize ):zero ()
97
-
97
+
98
98
-- build mask
99
99
if self .batchmode then
100
100
assert (self ._maskindices )
@@ -103,7 +103,7 @@ function TrimZero:recursiveUnMask(output, input, mask)
103
103
output :indexCopy (1 , mask , input )
104
104
end
105
105
else
106
- if mask :view (- 1 )[1 ] == 0 then
106
+ if mask :view (- 1 )[1 ] == 0 then
107
107
output :copy (input )
108
108
end
109
109
end
@@ -123,17 +123,17 @@ function TrimZero:updateOutput(input)
123
123
else
124
124
error (" nInputDim error: " .. rmi :dim ().. " , " .. self .nInputDim )
125
125
end
126
-
126
+
127
127
-- build mask
128
- local vectorDim = rmi :dim ()
128
+ local vectorDim = rmi :dim ()
129
129
self ._zeroMask = self ._zeroMask or rmi .new ()
130
130
self ._zeroMask :norm (rmi , 2 , vectorDim )
131
131
self .zeroMask = self .zeroMask or ((torch .type (rmi ) == ' torch.CudaTensor' ) and torch .CudaTensor () or torch .ByteTensor ())
132
132
self ._zeroMask .eq (self .zeroMask , self ._zeroMask , 0 )
133
-
133
+
134
134
-- forward through decorated module
135
135
self .temp = self :recursiveMask (self .temp , input , self .zeroMask )
136
- output = self .module :updateOutput (self .temp )
136
+ output = self .modules [ 1 ] :updateOutput (self .temp )
137
137
self .output = self :recursiveUnMask (self .output , output , self .zeroMask , true )
138
138
139
139
return self .output
@@ -143,7 +143,7 @@ function TrimZero:updateGradInput(input, gradOutput)
143
143
self .temp = self :recursiveMask (self .temp , input , self .zeroMask )
144
144
self .gradTemp = self :recursiveMask (self .gradTemp , gradOutput , self .zeroMask )
145
145
146
- local gradInput = self .module :updateGradInput (self .temp , self .gradTemp )
146
+ local gradInput = self .modules [ 1 ] :updateGradInput (self .temp , self .gradTemp )
147
147
148
148
self .gradInput = self :recursiveUnMask (self .gradInput , gradInput , self .zeroMask )
149
149
152
152
153
153
function TrimZero :accGradParameters (input , gradOutput , scale )
154
154
self .temp = self :recursiveMask (self .temp , input , self .zeroMask )
155
- self .module :accGradParameters (self .temp , gradOutput , scale )
155
+ self .modules [ 1 ] :accGradParameters (self .temp , gradOutput , scale )
156
156
end
0 commit comments