forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathjson2sctm.py
86 lines (73 loc) · 3 KB
/
json2sctm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
#!/usr/bin/python
# -*- coding: utf-8 -*-
import argparse
import os
import subprocess
import sys
is_python2 = sys.version_info[0] == 2
def get_parser():
parser = argparse.ArgumentParser(description='convert json to sctm')
parser.add_argument('json', type=str, default=None, nargs='?',
help='input trn')
parser.add_argument('dict', type=str, help='dict')
parser.add_argument('--num-spkrs', type=int, default=1, nargs='?', help='number of speakers')
parser.add_argument('--refs', type=str, nargs='*', help='ref for all speakers')
parser.add_argument('--hyps', type=str, nargs='*', help='hyp for all outputs')
parser.add_argument('--orig-stm', type=str, nargs='?', help='orig stm')
parser.add_argument('--stm', type=str, default=None, nargs='+', help='output stm')
parser.add_argument('--ctm', type=str, default=None, nargs='+', help='output ctm')
parser.add_argument('--bpe', type=str, default=None, nargs='?', help='BPE model if applicable')
return parser
def main(args):
from utils import json2trn
from utils import trn2ctm
from utils import trn2stm
parser = get_parser()
args = parser.parse_args(args)
if args.refs is None:
refs = ["ref_tmp.trn"]
del_ref = True
else:
refs = args.refs
del_ref = False
if args.hyps is None:
hyps = ["hyp_tmp.trn"]
del_hyp = True
else:
hyps = args.hyps
del_hyp = False
json2trn.convert(args.json, args.dict, refs, hyps, args.num_spkrs)
for trn in refs + hyps:
# We don't remove non-lang-syms because kaldi already removes them when scoring
call_args = ["sed", "-i.bak2", "-r", "s/<blank> //g", trn]
subprocess.check_call(call_args)
if args.bpe is not None:
with open(wrd_name(trn), 'w') as out:
with open(trn, 'r') as spm_in:
sed_args = ["sed", "-e", "s/▁/ /g"]
sed = subprocess.Popen(sed_args, stdout=out, stdin=subprocess.PIPE)
spm_args = ["spm_decode", "--model=" + args.bpe, "--input_format=piece"]
subprocess.Popen(spm_args, stdin=spm_in)
sed.communicate()
else:
call_args = ["sed", "-e", "s/ //g", "-e", "s/(/ (/", "-e", "s/<space>/ /g", trn]
with open(wrd_name(trn), 'w') as out:
sed = subprocess.Popen(call_args, stdout=out)
sed.communicate()
for trn, stm in zip(refs, args.stm):
trn2stm.convert(wrd_name(trn), stm, args.orig_stm)
if del_ref:
os.remove(refs[0])
os.remove(refs[0] + ".bak2")
os.remove(wrd_name(refs[0]))
for trn, ctm in zip(hyps, args.ctm):
trn2ctm.convert(wrd_name(trn), ctm)
if del_hyp:
os.remove(hyps[0])
os.remove(hyps[0] + ".bak2")
os.remove(wrd_name(hyps[0]))
def wrd_name(trn):
split = trn.split(".")
return ".".join(split[:-1]) + ".wrd." + split[-1]
if __name__ == "__main__":
main(sys.argv[1:])