Skip to content

Commit

Permalink
Block ngram repeats (#1675)
Browse files Browse the repository at this point in the history
* We avoid padding while mean pooling
* placed batch dimension first for bmm
* full rewrite of block_ngram_repeats for efficiency and accuracy
  • Loading branch information
pltrdy authored and vince62s committed Dec 13, 2019
1 parent 456ee18 commit acc5fbc
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 57 deletions.
101 changes: 67 additions & 34 deletions onmt/tests/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,25 +46,45 @@ def test_advance_with_all_repeats_gets_blocked(self):
word_probs = torch.full(
(batch_sz * beam_sz, n_words), -float('inf'))
word_probs[0::beam_sz, repeat_idx] = 0

attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns)
if i <= ngram_repeat:

if i < ngram_repeat:
# before repeat, scores are either 0 or -inf
expected_scores = torch.tensor(
[0] + [-float('inf')] * (beam_sz - 1))\
.repeat(batch_sz, 1)
[0] + [-float('inf')] * (beam_sz - 1))\
.repeat(batch_sz, 1)
self.assertTrue(beam.topk_log_probs.equal(expected_scores))
elif i % ngram_repeat == 0:
# on repeat, `repeat_idx` score is BLOCKED_SCORE
# (but it's still the best score, thus we have
# [BLOCKED_SCORE, -inf, -inf, -inf, -inf]
expected_scores = torch.tensor(
[0] + [-float('inf')] * (beam_sz - 1))\
.repeat(batch_sz, 1)
expected_scores[:, 0] = self.BLOCKED_SCORE
self.assertTrue(beam.topk_log_probs.equal(expected_scores))
else:
self.assertTrue(
beam.topk_log_probs.equal(
torch.tensor(self.BLOCKED_SCORE)
.repeat(batch_sz, beam_sz)))
# repetitions keeps maximizing score
# index 0 has been blocked, so repeating=>+0.0 score
# other indexes are -inf so repeating=>BLOCKED_SCORE
# which is higher
expected_scores = torch.tensor(
[0] + [-float('inf')] * (beam_sz - 1))\
.repeat(batch_sz, 1)
expected_scores[:, :] = self.BLOCKED_SCORE
expected_scores = torch.tensor(
self.BLOCKED_SCORE).repeat(batch_sz, beam_sz)

def test_advance_with_some_repeats_gets_blocked(self):
# beam 0 and beam >=2 will repeat (beam >= 2 repeat dummy scores)
beam_sz = 5
n_words = 100
repeat_idx = 47
ngram_repeat = 3
no_repeat_score = -2.3
repeat_score = -0.1
device_init = torch.zeros(1, 1)
for batch_sz in [1, 3]:
beam = BeamSearch(
Expand All @@ -81,31 +101,45 @@ def test_advance_with_some_repeats_gets_blocked(self):
# on initial round, only predicted scores for beam 0
# matter. Make two predictions. Top one will be repeated
# in beam zero, second one will live on in beam 1.
word_probs[0::beam_sz, repeat_idx] = -0.1
word_probs[0::beam_sz, repeat_idx + i + 1] = -2.3
word_probs[0::beam_sz, repeat_idx] = repeat_score
word_probs[0::beam_sz, repeat_idx +
i + 1] = no_repeat_score
else:
# predict the same thing in beam 0
word_probs[0::beam_sz, repeat_idx] = 0
# continue pushing around what beam 1 predicts
word_probs[1::beam_sz, repeat_idx + i + 1] = 0
attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns)
if i <= ngram_repeat:
if i < ngram_repeat:
self.assertFalse(
beam.topk_log_probs[0::beam_sz].eq(
self.BLOCKED_SCORE).any())
self.assertFalse(
beam.topk_log_probs[1::beam_sz].eq(
self.BLOCKED_SCORE).any())
elif i == ngram_repeat:
# now beam 0 dies (along with the others), beam 1 -> beam 0
self.assertFalse(
beam.topk_log_probs[:, 0].eq(
self.BLOCKED_SCORE).any())

expected = torch.full([batch_sz, beam_sz], float("-inf"))
expected[:, 0] = no_repeat_score
expected[:, 1] = self.BLOCKED_SCORE
self.assertTrue(
beam.topk_log_probs[:, :].equal(expected))
else:
# now beam 0 dies (along with the others), beam 1 -> beam 0
self.assertFalse(
beam.topk_log_probs[:, 0].eq(
self.BLOCKED_SCORE).any())

expected = torch.full([batch_sz, beam_sz], float("-inf"))
expected[:, 0] = no_repeat_score
expected[:, 1:] = self.BLOCKED_SCORE
self.assertTrue(
beam.topk_log_probs[:, 1:].equal(
torch.tensor(self.BLOCKED_SCORE)
.repeat(batch_sz, beam_sz-1)))
beam.topk_log_probs.equal(expected))

def test_repeating_excluded_index_does_not_die(self):
# beam 0 and beam >= 2 will repeat (beam 2 repeats excluded idx)
Expand Down Expand Up @@ -139,7 +173,7 @@ def test_repeating_excluded_index_does_not_die(self):
word_probs[2::beam_sz, repeat_idx_ignored] = 0
attns = torch.randn(1, batch_sz * beam_sz, 53)
beam.advance(word_probs, attns)
if i <= ngram_repeat:
if i < ngram_repeat:
self.assertFalse(beam.topk_log_probs[:, 0].eq(
self.BLOCKED_SCORE).any())
self.assertFalse(beam.topk_log_probs[:, 1].eq(
Expand All @@ -158,10 +192,9 @@ def test_repeating_excluded_index_does_not_die(self):
self.assertFalse(beam.topk_log_probs[:, 1].eq(
self.BLOCKED_SCORE).all())
self.assertTrue(beam.topk_log_probs[:, 1].eq(-5.0).all())
self.assertTrue(
beam.topk_log_probs[:, 2:].equal(
torch.tensor(self.BLOCKED_SCORE)
.repeat(batch_sz, beam_sz - 2)))

self.assertTrue(beam.topk_log_probs[:, 2].eq(
self.BLOCKED_SCORE).all())

def test_doesnt_predict_eos_if_shorter_than_min_len(self):
# beam 0 will always predict EOS. The other beams will predict
Expand Down Expand Up @@ -199,15 +232,15 @@ def test_doesnt_predict_eos_if_shorter_than_min_len(self):
# provide beam_sz other good predictions
for k, (j, score) in enumerate(
zip(_non_eos_idxs, valid_score_dist[1:])):
beam_idx = min(beam_sz-1, k)
beam_idx = min(beam_sz - 1, k)
word_probs[beam_idx::beam_sz, j] = score

attns = torch.randn(1, batch_sz * beam_sz, 53)
all_attns.append(attns)
beam.advance(word_probs, attns)
if i < min_length:
expected_score_dist = \
(i+1) * valid_score_dist[1:].unsqueeze(0)
(i + 1) * valid_score_dist[1:].unsqueeze(0)
self.assertTrue(
beam.topk_log_probs.allclose(
expected_score_dist))
Expand Down Expand Up @@ -255,15 +288,15 @@ def test_beam_is_done_when_n_best_beams_eos_using_min_length(self):
# provide beam_sz other good predictions in other beams
for k, (j, score) in enumerate(
zip(_non_eos_idxs, valid_score_dist[1:])):
beam_idx = min(beam_sz-1, k)
beam_idx = min(beam_sz - 1, k)
word_probs[beam_idx::beam_sz, j] = score
else:
word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
# provide beam_sz other good predictions in other beams
for k, (j, score) in enumerate(
zip(_non_eos_idxs, valid_score_dist[1:])):
beam_idx = min(beam_sz-1, k)
beam_idx = min(beam_sz - 1, k)
word_probs[beam_idx::beam_sz, j] = score

attns = torch.randn(1, batch_sz * beam_sz, 53)
Expand Down Expand Up @@ -316,15 +349,15 @@ def test_beam_returns_attn_with_correct_length(self):
# provide beam_sz other good predictions in other beams
for k, (j, score) in enumerate(
zip(_non_eos_idxs, valid_score_dist[1:])):
beam_idx = min(beam_sz-1, k)
beam_idx = min(beam_sz - 1, k)
word_probs[beam_idx::beam_sz, j] = score
else:
word_probs[0::beam_sz, eos_idx] = valid_score_dist[0]
word_probs[1::beam_sz, eos_idx] = valid_score_dist[0]
# provide beam_sz other good predictions in other beams
for k, (j, score) in enumerate(
zip(_non_eos_idxs, valid_score_dist[1:])):
beam_idx = min(beam_sz-1, k)
beam_idx = min(beam_sz - 1, k)
word_probs[beam_idx::beam_sz, j] = score

attns = torch.randn(1, batch_sz * beam_sz, 53)
Expand Down Expand Up @@ -357,7 +390,7 @@ def test_beam_returns_attn_with_correct_length(self):
inp_lens[b])
# first dim is equal to the time of death
# (beam 0 died at current step - adjust for SOS)
self.assertEqual(beam.attention[b][0].shape[0], i+1)
self.assertEqual(beam.attention[b][0].shape[0], i + 1)
# (beam 1 died at last step - adjust for SOS)
self.assertEqual(beam.attention[b][1].shape[0], i)
# behavior gets weird when beam is already done so just stop
Expand Down Expand Up @@ -399,9 +432,9 @@ def first_step(self, beam, expected_beam_scores, expected_len_pen):
# no EOS's yet
assert beam.is_finished.sum() == 0
scores_1 = torch.log_softmax(torch.tensor(
[[0, 0, 0, .3, 0, .51, .2, 0],
[0, 0, 1.5, 0, 0, 0, 0, 0],
[0, 0, 0, 0, .49, .48, 0, 0],
[[0, 0, 0, .3, 0, .51, .2, 0],
[0, 0, 1.5, 0, 0, 0, 0, 0],
[0, 0, 0, 0, .49, .48, 0, 0],
[0, 0, 0, .2, .2, .2, .2, .2],
[0, 0, 0, .2, .2, .2, .2, .2]]
), dim=1)
Expand Down Expand Up @@ -431,9 +464,9 @@ def first_step(self, beam, expected_beam_scores, expected_len_pen):
def second_step(self, beam, expected_beam_scores, expected_len_pen):
# assumes beam 2 finished on last step
scores_2 = torch.log_softmax(torch.tensor(
[[0, 0, 0, .3, 0, .51, .2, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 5000, .48, 0, 0], # beam 2 shouldn't continue
[[0, 0, 0, .3, 0, .51, .2, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 5000, .48, 0, 0], # beam 2 shouldn't continue
[0, 0, 50, .2, .2, .2, .2, .2], # beam 3 -> beam 0 should die
[0, 0, 0, .2, .2, .2, .2, .2]]
), dim=1)
Expand Down Expand Up @@ -470,9 +503,9 @@ def second_step(self, beam, expected_beam_scores, expected_len_pen):
def third_step(self, beam, expected_beam_scores, expected_len_pen):
# assumes beam 0 finished on last step
scores_3 = torch.log_softmax(torch.tensor(
[[0, 0, 5000, 0, 5000, .51, .2, 0], # beam 0 shouldn't cont
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 5000, 0, 0],
[[0, 0, 5000, 0, 5000, .51, .2, 0], # beam 0 shouldn't cont
[0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 5000, 0, 0],
[0, 0, 0, .2, .2, .2, .2, .2],
[0, 0, 50, 0, .2, .2, .2, .2]] # beam 4 -> beam 1 should die
), dim=1)
Expand Down
13 changes: 9 additions & 4 deletions onmt/translate/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def __init__(self, beam_size, batch_size, pad, bos, eos, n_best,
self._coverage = None

self._stepwise_cov_pen = (
stepwise_penalty and self.global_scorer.has_cov_pen)
stepwise_penalty and self.global_scorer.has_cov_pen)
self._vanilla_cov_pen = (
not stepwise_penalty and self.global_scorer.has_cov_pen)
self._cov_pen = self.global_scorer.has_cov_pen
Expand Down Expand Up @@ -166,15 +166,17 @@ def advance(self, log_probs, attn):
# Multiply probs by the beam probability.
log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)

self.block_ngram_repeats(log_probs)

# if the sequence ends now, then the penalty is the current
# length + 1, to include the EOS token
length_penalty = self.global_scorer.length_penalty(
step + 1, alpha=self.global_scorer.alpha)

# Flatten probs into a list of possibilities.
curr_scores = log_probs / length_penalty

# Avoid any direction that would repeat unwanted ngrams
self.block_ngram_repeats(curr_scores)

# Flatten probs into a list of possibilities.
curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size)
torch.topk(curr_scores, self.beam_size, dim=-1,
out=(self.topk_scores, self.topk_ids))
Expand All @@ -194,6 +196,9 @@ def advance(self, log_probs, attn):
self.alive_seq = torch.cat(
[self.alive_seq.index_select(0, self.select_indices),
self.topk_ids.view(_B * self.beam_size, 1)], -1)

self.maybe_update_forbidden_tokens()

if self.return_attention or self._cov_pen:
current_attn = attn.index_select(1, self.select_indices)
if step == 1:
Expand Down
92 changes: 73 additions & 19 deletions onmt/translate/decode_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,11 @@ def __init__(self, pad, bos, eos, batch_size, parallel_paths,

self.min_length = min_length
self.max_length = max_length

self.block_ngram_repeat = block_ngram_repeat
n_paths = batch_size * parallel_paths
self.forbidden_tokens = [dict() for _ in range(n_paths)]

self.exclusion_tokens = exclusion_tokens
self.return_attention = return_attention

Expand Down Expand Up @@ -109,25 +113,75 @@ def ensure_max_length(self):
self.is_finished.fill_(1)

def block_ngram_repeats(self, log_probs):
cur_len = len(self)
if self.block_ngram_repeat > 0 and cur_len > 1:
for path_idx in range(self.alive_seq.shape[0]):
# skip BOS
hyp = self.alive_seq[path_idx, 1:]
ngrams = set()
fail = False
gram = []
for i in range(cur_len - 1):
# Last n tokens, n = block_ngram_repeat
gram = (gram + [hyp[i].item()])[-self.block_ngram_repeat:]
# skip the blocking if any token in gram is excluded
if set(gram) & self.exclusion_tokens:
continue
if tuple(gram) in ngrams:
fail = True
ngrams.add(tuple(gram))
if fail:
log_probs[path_idx] = -10e20
"""
We prevent the beam from going in any direction that would repeat any
ngram of size <block_ngram_repeat> more thant once.
The way we do it: we maintain a list of all ngrams of size
<block_ngram_repeat> that is updated each time the beam advances, and
manually put any token that would lead to a repeated ngram to 0.
This improves on the previous version's complexity:
- previous version's complexity: batch_size * beam_size * len(self)
- current version's complexity: batch_size * beam_size
This improves on the previous version's accuracy;
- Previous version blocks the whole beam, whereas here we only
block specific tokens.
- Before the translation would fail when all beams contained
repeated ngrams. This is sure to never happen here.
"""

# we don't block nothing if the user doesn't want it
if self.block_ngram_repeat <= 0:
return

# we can't block nothing beam's too short
if len(self) < self.block_ngram_repeat:
return

n = self.block_ngram_repeat - 1
for path_idx in range(self.alive_seq.shape[0]):
# we check paths one by one

current_ngram = tuple(self.alive_seq[path_idx, -n:].tolist())
forbidden_tokens = self.forbidden_tokens[path_idx].get(
current_ngram, None)
if forbidden_tokens is not None:
log_probs[path_idx, list(forbidden_tokens)] = -10e20

def maybe_update_forbidden_tokens(self):
"""We complete and reorder the list of forbidden_tokens"""

# we don't forbid nothing if the user doesn't want it
if self.block_ngram_repeat <= 0:
return

# we can't forbid nothing if beam's too short
if len(self) < self.block_ngram_repeat:
return

n = self.block_ngram_repeat

forbidden_tokens = list()
for path_idx, seq in zip(self.select_indices, self.alive_seq):

# Reordering forbidden_tokens following beam selection
# We rebuild a dict to ensure we get the value and not the pointer
forbidden_tokens.append(
dict(self.forbidden_tokens[path_idx]))

# Grabing the newly selected tokens and associated ngram
current_ngram = tuple(seq[-n:].tolist())

# skip the blocking if any token in current_ngram is excluded
if set(current_ngram) & self.exclusion_tokens:
continue

forbidden_tokens[-1].setdefault(current_ngram[:-1], set())
forbidden_tokens[-1][current_ngram[:-1]].add(current_ngram[-1])

self.forbidden_tokens = forbidden_tokens

def advance(self, log_probs, attn):
"""DecodeStrategy subclasses should override :func:`advance()`.
Expand Down

0 comments on commit acc5fbc

Please sign in to comment.