Skip to content

Commit 11e3183

Browse files
committed
add opencl to MaskZero.lua
1 parent 88cbe99 commit 11e3183

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
build/

MaskZero.lua

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,11 @@ function MaskZero:updateOutput(input)
6868
local vectorDim = rmi:dim()
6969
self._zeroMask = self._zeroMask or rmi.new()
7070
self._zeroMask:norm(rmi, 2, vectorDim)
71-
self.zeroMask = self.zeroMask or ((torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor() or torch.ByteTensor())
71+
self.zeroMask = self.zeroMask or (
72+
(torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor()
73+
or (torch.type(rmi) == 'torch.ClTensor') and torch.ClTensor()
74+
or torch.ByteTensor()
75+
)
7276
self._zeroMask.eq(self.zeroMask, self._zeroMask, 0)
7377

7478
-- forward through decorated module

0 commit comments

Comments
 (0)