Skip to content

Commit f871177

Browse files
author
Amartya Sanyal
committed
Measure functions
1 parent ae3d409 commit f871177

File tree

5 files changed

+74
-74
lines changed

5 files changed

+74
-74
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ SET(luasrc
8484
TotalDropout.lua
8585
VRClassReward.lua
8686
ReverseUnreverse.lua
87+
measure.lua
8788
deprecated/SeqLSTMP.lua
8889
deprecated/SeqReverseSequence.lua
8990
deprecated/BiSequencerLM.lua

init.lua

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ paths.require 'librnn'
2727
unpack = unpack or table.unpack
2828

2929
require('rnn.utils')
30-
3130
-- extensions to existing nn.Module
3231
require('rnn.Module')
3332
require('rnn.Container')
@@ -112,6 +111,9 @@ require('rnn.SeqLSTMP')
112111
require('rnn.SeqReverseSequence')
113112
require('rnn.BiSequencerLM')
114113

114+
115+
require('rnn.measure')
116+
115117
-- prevent likely name conflicts
116118
nn.rnn = rnn
117119

measure.lua

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
function nn.get_bleu(cand, ref, n)
2+
n = n or 4
3+
local smooth = 1
4+
if type(cand) ~= 'table' then
5+
cand = cand:totable()
6+
end
7+
if type(ref) ~= 'table' then
8+
ref = ref:totable()
9+
end
10+
local res = nn.utils.get_ngram_prec(cand, ref, n)
11+
local brevPen = math.exp(1-math.max(1, #ref/#cand))
12+
local correct = 0
13+
local total = 0
14+
local bleu = 1
15+
for i = 1, n do
16+
if res[i][1] > 0 then
17+
if res[i][2] == 0 then
18+
smooth = smooth*0.5
19+
res[i][2] = smooth
20+
end
21+
local prec = res[i][2]/res[i][1]
22+
bleu = bleu * prec
23+
end
24+
end
25+
bleu = bleu^(1/n)
26+
return bleu*brevPen
27+
end

scripts/evaluate-rnnlm.lua

Lines changed: 1 addition & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -42,77 +42,6 @@ local validerr = xplog.valnceloss or xplog.valppl
4242

4343
print(string.format("Error (epoch=%d): training=%f; validation=%f", xplog.epoch, trainerr[#trainerr], validerr[#validerr]))
4444

45-
46-
local function get_ngrams(sent, n, count)
47-
local ngrams = {}
48-
for beg = 1, #sent do
49-
for last= beg, math.min(beg+n-1, #sent) do
50-
local ngram = table.concat(sent, ' ', beg, last)
51-
local len = last-beg+1 -- keep track of ngram length
52-
if not count then
53-
table.insert(ngrams, ngram)
54-
else
55-
if ngrams[ngram] == nil then
56-
ngrams[ngram] = {1, len}
57-
else
58-
ngrams[ngram][1] = ngrams[ngram][1] + 1
59-
end
60-
end
61-
end
62-
end
63-
return ngrams
64-
end
65-
66-
local function get_ngram_prec(cand, ref, n)
67-
local results = {}
68-
for i = 1, n do
69-
results[i] = {0, 0}
70-
end
71-
local cand_ngrams = get_ngrams(cand, n, 1)
72-
local ref_ngrams = get_ngrams(ref, n, 1)
73-
for ngram, dist in pairs(cand_ngrams) do
74-
local freq = dist[1]
75-
local length = dist[2]
76-
results[length][1] = results[length][1] + freq
77-
local actual
78-
if ref_ngrams[ngram] == nil then
79-
actual = 0
80-
else
81-
actual = ref_ngrams[ngram][1]
82-
end
83-
results[length][2] = results[length][2] + math.min(actual, freq)
84-
end
85-
return results
86-
end
87-
88-
function get_bleu(cand, ref, n)
89-
n = n or 4
90-
local smooth = 1
91-
if type(cand) ~= 'table' then
92-
cand = cand:totable()
93-
end
94-
if type(ref) ~= 'table' then
95-
ref = ref:totable()
96-
end
97-
local res = get_ngram_prec(cand, ref, n)
98-
local brevPen = math.exp(1-math.max(1, #ref/#cand))
99-
local correct = 0
100-
local total = 0
101-
local bleu = 1
102-
for i = 1, n do
103-
if res[i][1] > 0 then
104-
if res[i][2] == 0 then
105-
smooth = smooth*0.5
106-
res[i][2] = smooth
107-
end
108-
local prec = res[i][2]/res[i][1]
109-
bleu = bleu * prec
110-
end
111-
end
112-
bleu = bleu^(1/n)
113-
return bleu*brevPen
114-
end
115-
11645
if opt.dumpcsv then
11746
local csvfile = opt.xplogpath:match('([^/]+)[.]t7$')..'.csv'
11847
paths.mkdir('learningcurves')
@@ -220,7 +149,7 @@ else
220149
if opt.bleu then
221150
max_ind = torch.multinomial(torch.exp(outputs:view(targets:nElement(), -1)), 1):view(targets:size(1),targets:size(2))
222151
for batchIdx=1, targets:size(2) do
223-
sum_bleu = sum_bleu + get_bleu(max_ind:select(2, batchIdx),
152+
sum_bleu = sum_bleu + nn.get_bleu(max_ind:select(2, batchIdx),
224153
targets:select(2, batchIdx),
225154
opt.blueN)
226155
num_sent = num_sent + 1

utils.lua

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,4 +290,45 @@ function nn.utils.setZeroMask(modules, zeroMask, cuda)
290290
for i,module in ipairs(torch.type(modules) == 'table' and modules or {modules}) do
291291
module:setZeroMask(zeroMask)
292292
end
293-
end
293+
end
294+
function nn.utils.get_ngrams(sent, n, count)
295+
local ngrams = {}
296+
for beg = 1, #sent do
297+
for last= beg, math.min(beg+n-1, #sent) do
298+
local ngram = table.concat(sent, ' ', beg, last)
299+
local len = last-beg+1 -- keep track of ngram length
300+
if not count then
301+
table.insert(ngrams, ngram)
302+
else
303+
if ngrams[ngram] == nil then
304+
ngrams[ngram] = {1, len}
305+
else
306+
ngrams[ngram][1] = ngrams[ngram][1] + 1
307+
end
308+
end
309+
end
310+
end
311+
return ngrams
312+
end
313+
314+
function nn.utils.get_ngram_prec(cand, ref, n)
315+
local results = {}
316+
for i = 1, n do
317+
results[i] = {0, 0}
318+
end
319+
local cand_ngrams = nn.utils.get_ngrams(cand, n, 1)
320+
local ref_ngrams = nn.utils.get_ngrams(ref, n, 1)
321+
for ngram, dist in pairs(cand_ngrams) do
322+
local freq = dist[1]
323+
local length = dist[2]
324+
results[length][1] = results[length][1] + freq
325+
local actual
326+
if ref_ngrams[ngram] == nil then
327+
actual = 0
328+
else
329+
actual = ref_ngrams[ngram][1]
330+
end
331+
results[length][2] = results[length][2] + math.min(actual, freq)
332+
end
333+
return results
334+
end

0 commit comments

Comments
 (0)