Skip to content
This repository was archived by the owner on Dec 11, 2023. It is now read-only.

Commit 3bad10b

Browse files
Anonymousebrevdo
authored andcommitted
Support coverage penalty for beam search decoder.
PiperOrigin-RevId: 212902045
1 parent b278487 commit 3bad10b

File tree

4 files changed

+16
-6
lines changed

4 files changed

+16
-6
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,8 @@ decoder = tf.contrib.seq2seq.BeamSearchDecoder(
988988
initial_state=decoder_initial_state,
989989
beam_width=beam_width,
990990
output_layer=projection_layer,
991-
length_penalty_weight=0.0)
991+
length_penalty_weight=0.0,
992+
coverage_penalty_weight=0.0)
992993

993994
# Dynamic decoding
994995
outputs, _ = tf.contrib.seq2seq.dynamic_decode(decoder, ...)

nmt/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -534,12 +534,15 @@ def _build_decoder(self, encoder_outputs, encoder_state, hparams):
534534
start_tokens = tf.fill([self.batch_size], tgt_sos_id)
535535
end_token = tgt_eos_id
536536
utils.print_out(
537-
" decoder: infer_mode=%sbeam_width=%d, length_penalty=%f" % (
538-
infer_mode, hparams.beam_width, hparams.length_penalty_weight))
537+
" decoder: infer_mode=%sbeam_width=%d, "
538+
"length_penalty=%f, coverage_penalty=%f"
539+
% (infer_mode, hparams.beam_width, hparams.length_penalty_weight,
540+
hparams.coverage_penalty_weight))
539541

540542
if infer_mode == "beam_search":
541543
beam_width = hparams.beam_width
542544
length_penalty_weight = hparams.length_penalty_weight
545+
coverage_penalty_weight = hparams.coverage_penalty_weight
543546

544547
my_decoder = tf.contrib.seq2seq.BeamSearchDecoder(
545548
cell=cell,
@@ -549,7 +552,8 @@ def _build_decoder(self, encoder_outputs, encoder_state, hparams):
549552
initial_state=decoder_initial_state,
550553
beam_width=beam_width,
551554
output_layer=self.output_layer,
552-
length_penalty_weight=length_penalty_weight)
555+
length_penalty_weight=length_penalty_weight,
556+
coverage_penalty_weight=coverage_penalty_weight)
553557
elif infer_mode == "sample":
554558
# Helper
555559
sampling_temperature = hparams.sampling_temperature

nmt/nmt.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,9 @@
3737

3838
INFERENCE_KEYS = ["src_max_len_infer", "tgt_max_len_infer", "subword_option",
3939
"infer_batch_size", "beam_width",
40-
"length_penalty_weight", "sampling_temperature",
41-
"num_translations_per_input", "infer_mode"]
40+
"length_penalty_weight", "coverage_penalty_weight",
41+
"sampling_temperature", "num_translations_per_input",
42+
"infer_mode"]
4243

4344

4445
def add_arguments(parser):
@@ -288,6 +289,8 @@ def add_arguments(parser):
288289
"""))
289290
parser.add_argument("--length_penalty_weight", type=float, default=0.0,
290291
help="Length penalty for beam search.")
292+
parser.add_argument("--coverage_penalty_weight", type=float, default=0.0,
293+
help="Coverage penalty for beam search.")
291294
parser.add_argument("--sampling_temperature", type=float,
292295
default=0.0,
293296
help=("""\
@@ -370,6 +373,7 @@ def create_hparams(flags):
370373
infer_mode=flags.infer_mode,
371374
beam_width=flags.beam_width,
372375
length_penalty_weight=flags.length_penalty_weight,
376+
coverage_penalty_weight=flags.coverage_penalty_weight,
373377
sampling_temperature=flags.sampling_temperature,
374378
num_translations_per_input=flags.num_translations_per_input,
375379

nmt/utils/standard_hparams_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def create_standard_hparams():
9595
# only enable beam search during inference when beam_width > 0.
9696
beam_width=0,
9797
length_penalty_weight=0.0,
98+
coverage_penalty_weight=0.0,
9899
override_loaded_hparams=True,
99100
num_keep_ckpts=5,
100101
avg_ckpts=False,

0 commit comments

Comments
 (0)