Skip to content

Commit 33d7961

Browse files
committed
v100
1 parent 760fba6 commit 33d7961

File tree

4 files changed

+150
-8
lines changed

4 files changed

+150
-8
lines changed
Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
#!/bin/bash
22

33
#SBATCH --job-name=wmt14_en_de
4-
#SBATCH --gres=gpu:4
4+
#SBATCH --gres=gpu:8
55
#SBATCH --cpus-per-task 1 # Number of CPUs per task
66
#SBATCH --nodes=1
7-
#SBATCH --ntasks-per-node=4
7+
#SBATCH --ntasks-per-node=8
88
#SBATCH --mem=30G # CPU memory per node
99

1010

@@ -17,24 +17,26 @@ DATA=data-bin/wmt16_en_de_bpe32k
1717
SAVE="checkpoints/$exp"
1818
mkdir -p $SAVE
1919

20-
python -m torch.distributed.launch --nproc_per_node 4 train.py \
21-
$DATA --fp16 --log-interval 100 --no-progress-bar \
20+
python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \
21+
$DATA --fp16 --log-interval 100 --no-progress-bar \
2222
--max-update 30000 --share-all-embeddings \
2323
--optimizer adam --adam-betas '(0.9, 0.98)' \
2424
--clip-norm 0.0 --weight-decay 0.0 \
2525
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
26-
--min-lr 1e-09 --update-freq 96 --keep-last-epochs 10 \
27-
--ddp-backend=no_c10d --max-tokens 1200 \
26+
--min-lr 1e-09 --update-freq 32 --keep-last-epochs 10 \
27+
--ddp-backend=no_c10d --max-tokens 1800 \
2828
--lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \
2929
--lr-shrink 1 --max-lr 0.0009 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \
3030
--t-mult 1 --lr-period-updates 20000 \
3131
--arch local_joint_attention_wmt_en_de_big --save-dir $SAVE \
32-
--dropout 0.3 --attention-dropout 0.3
32+
--dropout 0.3 --attention-dropout 0.3 \
33+
--user-dir models
3334

3435
# Checkpoint averaging
3536
python scripts/average_checkpoints.py --inputs $SAVE \
3637
--num-epoch-checkpoints 10 --output "${SAVE}/checkpoint_last10_avg.pt"
3738

3839
# Evaluation
39-
CUDA_VISIBLE_DEVICES=0 python generate.py $DATA --path "${SAVE}/checkpoint_last10_avg.pt" --batch-size 32 --beam 5 --remove-bpe --lenpen 0.35 --gen-subset test > wmt16_gen.txt
40+
CUDA_VISIBLE_DEVICES=0 fairseq-generate $DATA --path "${SAVE}/checkpoint_last10_avg.pt" --batch-size 32 --beam 5 \
41+
--user-dir models --remove-bpe --lenpen 0.35 --gen-subset test > wmt16_gen.txt
4042
bash scripts/compound_split_bleu.sh wmt16_gen.txt
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#!/bin/bash
2+
3+
#SBATCH --job-name=wmt14_en_de
4+
#SBATCH --gres=gpu:8
5+
#SBATCH --cpus-per-task 1 # Number of CPUs per task
6+
#SBATCH --nodes=1
7+
#SBATCH --ntasks-per-node=8
8+
#SBATCH --mem=30G # CPU memory per node
9+
10+
11+
stage=0
12+
exp=`basename $0 | sed -e 's/^run_//' -e 's/.sh$//'`
13+
exp=local_joint_attention_wmt_en_de_big
14+
echo $exp
15+
16+
DATA=data-bin/wmt16_en_de_bpe32k
17+
SAVE="checkpoints/$exp"
18+
mkdir -p $SAVE
19+
20+
python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \
21+
$DATA --fp16 --log-interval 100 --no-progress-bar \
22+
--max-update 30000 --share-all-embeddings \
23+
--optimizer adam --adam-betas '(0.9, 0.98)' \
24+
--clip-norm 0.0 --weight-decay 0.0 \
25+
--criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
26+
--min-lr 1e-09 --update-freq 32 --keep-last-epochs 10 \
27+
--ddp-backend=no_c10d --max-tokens 1800 \
28+
--lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \
29+
--lr-shrink 1 --max-lr 0.0009 --lr 1e-7 --min-lr 1e-9 --warmup-init-lr 1e-07 \
30+
--t-mult 1 --lr-period-updates 20000 \
31+
--arch local_joint_attention_wmt_en_de_big --save-dir $SAVE \
32+
--dropout 0.3 --attention-dropout 0.3 \
33+
--user-dir models
34+
35+
# Checkpoint averaging
36+
python scripts/average_checkpoints.py --inputs $SAVE \
37+
--num-epoch-checkpoints 10 --output "${SAVE}/checkpoint_last10_avg.pt"
38+
39+
# Evaluation
40+
CUDA_VISIBLE_DEVICES=0 fairseq-generate $DATA --path "${SAVE}/checkpoint_last10_avg.pt" --batch-size 32 --beam 5 \
41+
--user-dir models --remove-bpe --lenpen 0.35 --gen-subset test > wmt16_gen.txt
42+
bash scripts/compound_split_bleu.sh wmt16_gen.txt

score.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) 2017-present, Facebook, Inc.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the license found in the LICENSE file in
6+
# the root directory of this source tree. An additional grant of patent rights
7+
# can be found in the PATENTS file in the same directory.
8+
"""
9+
BLEU scoring of generated translations against reference translations.
10+
"""
11+
12+
import argparse
13+
import os
14+
import sys
15+
16+
from fairseq import bleu, tokenizer
17+
from fairseq.data import dictionary
18+
19+
20+
def get_parser():
21+
parser = argparse.ArgumentParser(description='Command-line script for BLEU scoring.')
22+
# fmt: off
23+
parser.add_argument('-s', '--sys', default='-', help='system output')
24+
parser.add_argument('-r', '--ref', required=True, help='references')
25+
parser.add_argument('-o', '--order', default=4, metavar='N',
26+
type=int, help='consider ngrams up to this order')
27+
parser.add_argument('--ignore-case', action='store_true',
28+
help='case-insensitive scoring')
29+
parser.add_argument('--sacrebleu', action='store_true',
30+
help='score with sacrebleu')
31+
# fmt: on
32+
return parser
33+
34+
35+
def main():
36+
parser = get_parser()
37+
args = parser.parse_args()
38+
print(args)
39+
40+
assert args.sys == '-' or os.path.exists(args.sys), \
41+
"System output file {} does not exist".format(args.sys)
42+
assert os.path.exists(args.ref), \
43+
"Reference file {} does not exist".format(args.ref)
44+
45+
dict = dictionary.Dictionary()
46+
47+
def readlines(fd):
48+
for line in fd.readlines():
49+
if args.ignore_case:
50+
yield line.lower()
51+
else:
52+
yield line
53+
54+
if args.sacrebleu:
55+
import sacrebleu
56+
57+
def score(fdsys):
58+
with open(args.ref) as fdref:
59+
print(sacrebleu.corpus_bleu(fdsys, [fdref]))
60+
else:
61+
def score(fdsys):
62+
with open(args.ref) as fdref:
63+
scorer = bleu.Scorer(dict.pad(), dict.eos(), dict.unk())
64+
for sys_tok, ref_tok in zip(readlines(fdsys), readlines(fdref)):
65+
sys_tok = tokenizer.Tokenizer.tokenize(sys_tok, dict)
66+
ref_tok = tokenizer.Tokenizer.tokenize(ref_tok, dict)
67+
scorer.add(ref_tok, sys_tok)
68+
print(scorer.result_string(args.order))
69+
70+
if args.sys == '-':
71+
score(sys.stdin)
72+
else:
73+
with open(args.sys, 'r') as f:
74+
score(f)
75+
76+
77+
if __name__ == '__main__':
78+
main()

scripts/compound_split_bleu.sh

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#!/bin/bash
2+
3+
if [ $# -ne 1 ]; then
4+
echo "usage: $0 GENERATE_PY_OUTPUT"
5+
exit 1
6+
fi
7+
8+
GEN=$1
9+
10+
SYS=$GEN.sys
11+
REF=$GEN.ref
12+
13+
if [ $(tail -n 1 $GEN | grep BLEU | wc -l) -ne 1 ]; then
14+
echo "not done generating"
15+
exit
16+
fi
17+
18+
grep ^H $GEN | cut -f3- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $SYS
19+
grep ^T $GEN | cut -f2- | perl -ple 's{(\S)-(\S)}{$1 ##AT##-##AT## $2}g' > $REF
20+
python score.py --sys $SYS --ref $REF

0 commit comments

Comments
 (0)