forked from espnet/espnet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmergejson.py
executable file
·127 lines (110 loc) · 5.1 KB
/
mergejson.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#!/usr/bin/env python
# encoding: utf-8
# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import codecs
import json
import logging
import os
import sys
from espnet.utils.cli_utils import get_commandline_args
is_python2 = sys.version_info[0] == 2
def get_parser():
parser = argparse.ArgumentParser(
description='merge json files',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--input-jsons', type=str, nargs='+', action='append',
default=[], help='Json files for the inputs')
parser.add_argument('--output-jsons', type=str, nargs='+', action='append',
default=[], help='Json files for the outputs')
parser.add_argument('--jsons', type=str, nargs='+', action='append',
default=[],
help='The json files except for the input and outputs')
parser.add_argument('--verbose', '-V', default=0, type=int,
help='Verbose option')
parser.add_argument('-O', dest='output', type=str, help='Output json file')
return parser
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
# logging info
logfmt = "%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s"
if args.verbose > 0:
logging.basicConfig(level=logging.INFO, format=logfmt)
else:
logging.basicConfig(level=logging.WARN, format=logfmt)
logging.info(get_commandline_args())
js_dict = {} # Dict[str, List[List[Dict[str, Dict[str, dict]]]]]
# make intersection set for utterance keys
intersec_ks = None # Set[str]
for jtype, jsons_list in [('input', args.input_jsons),
('output', args.output_jsons),
('other', args.jsons)]:
js_dict[jtype] = []
for jsons in jsons_list:
js = []
for x in jsons:
if os.path.isfile(x):
with codecs.open(x, encoding="utf-8") as f:
j = json.load(f)
ks = list(j['utts'].keys())
logging.info(x + ': has ' + str(len(ks)) + ' utterances')
if intersec_ks is not None:
intersec_ks = intersec_ks.intersection(set(ks))
if len(intersec_ks) == 0:
logging.warning("No intersection")
break
else:
intersec_ks = set(ks)
js.append(j)
js_dict[jtype].append(js)
logging.info('new json has ' + str(len(intersec_ks)) + ' utterances')
new_dic = {}
for k in intersec_ks:
new_dic[k] = {'input': [], 'output': []}
for jtype in ['input', 'output', 'other']:
for idx, js in enumerate(js_dict[jtype], 1):
# Merge dicts from jsons into a dict
dic = {k2: v for j in js for k2, v in j['utts'][k].items()}
if jtype == 'other':
new_dic[k].update(dic)
else:
_dic = {}
# FIXME(kamo): ad-hoc way to change str to List[int]
if jtype == 'input':
_dic['name'] = 'input{}'.format(idx)
if 'ilen' in dic and 'idim' in dic:
_dic['shape'] = (int(dic['ilen']),
int(dic['idim']))
elif 'ilen' in dic:
_dic['shape'] = (int(dic['ilen']),)
elif 'idim' in dic:
_dic['shape'] = (int(dic['idim']),)
elif jtype == 'output':
_dic['name'] = 'target{}'.format(idx)
if 'olen' in dic and 'odim' in dic:
_dic['shape'] = (int(dic['olen']),
int(dic['odim']))
elif 'ilen' in dic:
_dic['shape'] = (int(dic['olen']),)
elif 'idim' in dic:
_dic['shape'] = (int(dic['odim']),)
if 'shape' in dic:
# shape: "80,1000" -> [80, 1000]
_dic['shape'] = list(map(int, dic['shape'].split(',')))
for k2, v in dic.items():
if k2 not in ['ilen', 'idim', 'olen', 'odim', 'shape']:
_dic[k2] = v
new_dic[k][jtype].append(_dic)
# ensure "ensure_ascii=False", which is a bug
if args.output is not None:
sys.stdout = codecs.open(args.output, "w", encoding="utf-8")
else:
sys.stdout = codecs.getwriter("utf-8")(
sys.stdout if is_python2 else sys.stdout.buffer)
print(json.dumps({'utts': new_dic},
indent=4, ensure_ascii=False,
sort_keys=True, separators=(',', ': ')))