Skip to content

Commit e6a5160

Browse files
NCE unit test + torch.Timer
1 parent da18b23 commit e6a5160

File tree

3 files changed

+97
-43
lines changed

3 files changed

+97
-43
lines changed

NCEModule.lua

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,4 +255,57 @@ function NCEModule:clearState()
255255
self.gradInput[2]:set()
256256
end
257257

258-
-- TODO : speedup unigram sampling using frequency bins...
258+
-- NOT IN USE : the following is experimental and not currently in use.
259+
-- ref.: https://hips.seas.harvard.edu/blog/2013/03/03/the-alias-method-efficient-sampling-with-many-discrete-outcomes/
260+
function NCEModule:aliasDraw(J, q)
261+
local K = J:nElement()
262+
263+
-- Draw from the overall uniform mixture.
264+
local kk = math.random(1,K)
265+
266+
-- Draw from the binary mixture, either keeping the
267+
-- small one, or choosing the associated larger one.
268+
if math.random() < q[kk] then
269+
return kk
270+
else
271+
return J[kk]
272+
end
273+
end
274+
275+
function NCEModule:aliasSetup(probs)
276+
assert(probs:dim() == 1)
277+
local K = probs:nElement()
278+
local q = probs.new(K)
279+
local J = torch.LongTensor(K):zero()
280+
281+
-- Sort the data into the outcomes with probabilities
282+
-- that are larger and smaller than 1/K.
283+
local smaller, larger = {}, {}
284+
for kk = 1,K do
285+
local prob = probs[kk]
286+
q[kk] = K*prob
287+
if q[kk] < 1 then
288+
table.insert(smaller, kk)
289+
else
290+
table.insert(larger, kk)
291+
end
292+
end
293+
294+
-- Loop though and create little binary mixtures that
295+
-- appropriately allocate the larger outcomes over the
296+
-- overall uniform mixture.
297+
while #smaller > 0 and #larger > 0 do
298+
local small = table.remove(smaller)
299+
local large = table.remove(larger)
300+
301+
J[small] = large
302+
q[large] = q[large] - (1.0 - q[small])
303+
304+
if q[large] < 1.0 then
305+
table.insert(smaller,large)
306+
else
307+
table.insert(larger,large)
308+
end
309+
end
310+
return J, q
311+
end

Sequential.lua

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ function Sequential:profile()
55
function Sequential:updateOutput(input)
66
local currentOutput = input
77
for i=1,#self.modules do
8-
local start = sys.clock()
8+
local start = torch.Timer()
99
currentOutput = self.modules[i]:updateOutput(currentOutput)
1010
if cutorch then cutorch.synchronize() end
11-
print(torch.type(self.modules[i])..' updateOutput: '..sys.clock() - start.." s")
11+
print(torch.type(self.modules[i])..' updateOutput: '..start:time().real.." s")
1212
end
1313
self.output = currentOutput
1414
return currentOutput
@@ -19,16 +19,16 @@ function Sequential:profile()
1919
local currentModule = self.modules[#self.modules]
2020
for i=#self.modules-1,1,-1 do
2121
local previousModule = self.modules[i]
22-
local start = sys.clock()
22+
local start = torch.Timer()
2323
currentGradOutput = currentModule:updateGradInput(previousModule.output, currentGradOutput)
2424
if cutorch then cutorch.synchronize() end
25-
print(torch.type(currentModule)..' updateGradInput: '..sys.clock() - start.." s")
25+
print(torch.type(currentModule)..' updateGradInput: '..start:time().real.." s")
2626
currentModule = previousModule
2727
end
28-
local start = sys.clock()
28+
local start = torch.Timer()
2929
currentGradOutput = currentModule:updateGradInput(input, currentGradOutput)
3030
if cutorch then cutorch.synchronize() end
31-
print(torch.type(currentModule)..' updateGradInput: '..sys.clock() - start.." s")
31+
print(torch.type(currentModule)..' updateGradInput: '..start:time().real.." s")
3232
self.gradInput = currentGradOutput
3333
return currentGradOutput
3434
end
@@ -40,18 +40,18 @@ function Sequential:profile()
4040
local currentModule = self.modules[#self.modules]
4141
for i=#self.modules-1,1,-1 do
4242
local previousModule = self.modules[i]
43-
local start = sys.clock()
43+
local start = torch.Timer()
4444
currentModule:accGradParameters(previousModule.output, currentGradOutput, scale)
4545
if cutorch then cutorch.synchronize() end
46-
print(torch.type(currentModule)..' accGradParameters: '..sys.clock() - start.." s")
46+
print(torch.type(currentModule)..' accGradParameters: '..start:time().real.." s")
4747
currentGradOutput = currentModule.gradInput
4848
currentModule = previousModule
4949
end
5050

51-
local start = sys.clock()
51+
local start = torch.Timer()
5252
currentModule:accGradParameters(input, currentGradOutput, scale)
5353
if cutorch then cutorch.synchronize() end
54-
print(torch.type(currentModule)..' accGradParameters: '..sys.clock() - start.." s")
54+
print(torch.type(currentModule)..' accGradParameters: '..start:time().real.." s")
5555
end
5656

5757
function Sequential:backward(input, gradOutput, scale)
@@ -60,17 +60,17 @@ function Sequential:profile()
6060
local currentModule = self.modules[#self.modules]
6161
for i=#self.modules-1,1,-1 do
6262
local previousModule = self.modules[i]
63-
local start = sys.clock()
63+
local start = torch.Timer()
6464
currentGradOutput = currentModule:backward(previousModule.output, currentGradOutput, scale)
6565
if cutorch then cutorch.synchronize() end
66-
print(torch.type(currentModule)..' backward: '..sys.clock() - start.." s")
66+
print(torch.type(currentModule)..' backward: '..start:time().real.." s")
6767
currentModule.gradInput = currentGradOutput
6868
currentModule = previousModule
6969
end
70-
local start = sys.clock()
70+
local start = torch.Timer()
7171
currentGradOutput = currentModule:backward(input, currentGradOutput, scale)
7272
if cutorch then cutorch.synchronize() end
73-
print(torch.type(currentModule)..' backward: '..sys.clock() - start.." s")
73+
print(torch.type(currentModule)..' backward: '..start:time().real.." s")
7474
self.gradInput = currentGradOutput
7575
return currentGradOutput
7676
end
@@ -80,18 +80,18 @@ function Sequential:profile()
8080
local currentModule = self.modules[#self.modules]
8181
for i=#self.modules-1,1,-1 do
8282
local previousModule = self.modules[i]
83-
local start = sys.clock()
83+
local start = torch.Timer()
8484
currentModule:accUpdateGradParameters(previousModule.output, currentGradOutput, lr)
8585
if cutorch then cutorch.synchronize() end
86-
print(torch.type(currentModule)..' accUpdateGradParameters: '..sys.clock() - start.." s")
86+
print(torch.type(currentModule)..' accUpdateGradParameters: '..start:time().real.." s")
8787
currentGradOutput = currentModule.gradInput
8888
currentModule = previousModule
8989
end
9090

91-
local start = sys.clock()
91+
local start = torch.Timer()
9292
currentModule:accUpdateGradParameters(input, currentGradOutput, lr)
9393
if cutorch then cutorch.synchronize() end
94-
print(torch.type(currentModule)..' accUpdateGradParameters: '..sys.clock() - start.." s")
94+
print(torch.type(currentModule)..' accUpdateGradParameters: '..start:time().real.." s")
9595
end
9696

9797
parent.profile(self)

test/test.lua

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -282,36 +282,37 @@ function dpnntest.Module_type()
282282
mytester:assertTensorEq(gradParams[i], gradParams2[i], 0.00001, " gradParams err "..i)
283283
end
284284

285+
local input = torch.randn(3,32,32)
286+
local cnn = nn.Sequential()
287+
cnn:add(nn.SpatialConvolution(3,8,5,5))
288+
cnn:add(nn.ReLU())
289+
cnn:add(nn.SpatialAveragePooling(2,2,2,2))
290+
cnn:add(nn.SpatialConvolution(8,12,5,5))
291+
cnn:add(nn.ReLU())
292+
cnn:add(nn.SpatialAveragePooling(2,2,2,2))
293+
local outsize = cnn:outside{1,3,32,32}
294+
cnn:add(nn.Collapse(3))
295+
cnn:add(nn.Linear(outsize[2]*outsize[3]*outsize[4],20))
296+
cnn:add(nn.ReLU())
297+
cnn:add(nn.Linear(20,10))
298+
local output = cnn:forward(input):clone()
299+
local gradOutput = output:clone()
300+
local gradInput = cnn:backward(input, gradOutput):clone()
301+
cnn:float()
302+
local input3 = input:float()
303+
local output3 = cnn:forward(input3):clone()
304+
local gradOutput3 = output3:clone()
305+
local gradInput3 = cnn:backward(input3, gradOutput3):clone()
306+
local o1, o2 = output3:float(), output:float()
307+
mytester:assertTensorEq(o1, o2, 0.000001)
308+
mytester:assertTensorEq(gradInput3:float(), gradInput:float(), 0.00001, "type float bwd err")
285309
if pcall(function() require 'cunn' end) then
286-
local input = torch.randn(3,32,32)
287-
local cnn = nn.Sequential()
288-
cnn:add(nn.SpatialConvolutionMM(3,8,5,5))
289-
cnn:add(nn.ReLU())
290-
cnn:add(nn.SpatialAveragePooling(2,2,2,2))
291-
cnn:add(nn.SpatialConvolutionMM(8,12,5,5))
292-
cnn:add(nn.ReLU())
293-
cnn:add(nn.SpatialAveragePooling(2,2,2,2))
294-
local outsize = cnn:outside{1,3,32,32}
295-
cnn:add(nn.Collapse(3))
296-
cnn:add(nn.Linear(outsize[2]*outsize[3]*outsize[4],20))
297-
cnn:add(nn.ReLU())
298-
cnn:add(nn.Linear(20,10))
299-
local output = cnn:forward(input):clone()
300-
local gradOutput = output:clone()
301-
local gradInput = cnn:backward(input, gradOutput):clone()
302-
cnn:float()
303-
local input3 = input:float()
304-
local output3 = cnn:forward(input3):clone()
305-
local gradOutput3 = output3:clone()
306-
local gradInput3 = cnn:backward(input3, gradOutput3):clone()
307-
mytester:assertTensorEq(output3:float(), output:float(), 0.000001, "type float fwd err")
308-
mytester:assertTensorEq(gradInput3:float(), gradInput:float(), 0.00001, "type float bwd err")
309310
cnn:cuda()
310311
local input2 = input3:cuda()
311312
local gradOutput2 = gradOutput3:cuda()
312313
local output2 = cnn:forward(input2)
313314
local gradInput2 = cnn:backward(input2, gradOutput2)
314-
mytester:assertTensorEq(output2:float(), output3, 0.000001, "type cuda fwd err")
315+
mytester:assertTensorEq(output2:float(), output3, 0.000001)
315316
mytester:assertTensorEq(gradInput2:float(), gradInput3, 0.00001, "type cuda bwd err")
316317
end
317318
end

0 commit comments

Comments
 (0)