We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 88cbe99 commit 11e3183Copy full SHA for 11e3183
.gitignore
@@ -0,0 +1 @@
1
+build/
MaskZero.lua
@@ -68,7 +68,11 @@ function MaskZero:updateOutput(input)
68
local vectorDim = rmi:dim()
69
self._zeroMask = self._zeroMask or rmi.new()
70
self._zeroMask:norm(rmi, 2, vectorDim)
71
- self.zeroMask = self.zeroMask or ((torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor())
+ self.zeroMask = self.zeroMask or (
72
+ (torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor()
73
+ or (torch.type(rmi) == 'torch.ClTensor') and torch.ClTensor()
74
+ or torch.ByteTensor()
75
+ )
76
self._zeroMask.eq(self.zeroMask, self._zeroMask, 0)
77
78
-- forward through decorated module
0 commit comments