forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtext2vocabulary.py
executable file
·72 lines (62 loc) · 2.4 KB
/
text2vocabulary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#!/usr/bin/env python
# Copyright 2018 Mitsubishi Electric Research Laboratories (Takaaki Hori)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import argparse
import codecs
import logging
import six
import sys
is_python2 = sys.version_info[0] == 2
def get_parser():
parser = argparse.ArgumentParser(
description='create a vocabulary file from text files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--output', '-o', default='', type=str,
help='output a vocabulary file')
parser.add_argument('--cutoff', '-c', default=0, type=int,
help='cut-off frequency')
parser.add_argument('--vocabsize', '-s', default=20000, type=int,
help='vocabulary size')
parser.add_argument('text_files', nargs='*',
help='input text files')
return parser
if __name__ == "__main__":
parser = get_parser()
args = parser.parse_args()
# count the word occurrences
counts = {}
exclude = ['<sos>', '<eos>', '<unk>']
if len(args.text_files) == 0:
args.text_files.append('-')
for fn in args.text_files:
fd = codecs.open(fn, 'r', encoding="utf-8") if fn != '-' else codecs.getreader("utf-8")(
sys.stdin if is_python2 else sys.stdin.buffer)
for ln in fd.readlines():
for tok in ln.split():
if tok not in exclude:
if tok not in counts:
counts[tok] = 1
else:
counts[tok] += 1
if fn != '-':
fd.close()
# limit the vocabulary size
total_count = sum(counts.values())
invocab_count = 0
vocabulary = []
for w, c in sorted(counts.items(), key=lambda x: -x[1]):
if c <= args.cutoff:
break
if len(vocabulary) >= args.vocabsize:
break
vocabulary.append(w)
invocab_count += c
logging.warning('OOV rate = %.2f %%' % (float(total_count - invocab_count) / total_count * 100))
# write the vocabulary
fd = codecs.open(args.output, 'w', encoding="utf-8") if args.output else codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer)
six.print_('<unk> 1', file=fd)
for n, w in enumerate(sorted(vocabulary)):
six.print_('%s %d' % (w, n + 2), file=fd)
if args.output:
fd.close()