forked from jingyonghou/RPN_KWS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstreaming_special_torch_dataset.py
executable file
·135 lines (111 loc) · 4.17 KB
/
streaming_special_torch_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
#!usr/bin/env python
#
# Copyright 2017 [email protected]
#
# MIT Lisence
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import torch
from torch.utils.data import Dataset
import kaldi_io
import htk_io
import raw_io
FunctionDict = {'htk_reader': htk_io.htk_read,
'htk_label_reader': htk_io.htk_read,
'kaldi_reader': kaldi_io.read_mat,
'kaldi_label_reader': kaldi_io.read_vec_int,
'raw_list_reader': raw_io.read_list
};
def get_fn_list(fn_name_list):
fn_list = [FunctionDict[x] for x in fn_name_list]
return fn_list
def splice_feats(feat, left, right):
"""Splice feature. We first pad each utterance lc frames left and rc frames right.
Then we use a sliding window to select lc+rc matrices, and concatenate them.
Args:
feat (numpy.ndarray): Input feature of an utterance
left: Left context for splicing.
right: Right context for splicng.
Returns:
spliced feat (numpy.ndarray)
"""
if left==0 and right==0:
return feat
sfeat = []
num_row = feat.shape[0]
f0 = feat[0, :]
fT = feat[num_row-1, :]
# Repeat the first frame
pad0 = np.tile(f0, (left, 1))
# Repeat the last frame
padT = np.tile(fT, (right, 1))
pad_feat = np.concatenate([pad0, feat, padT], 0)
for i in range(0, left+right+1):
# Splice feat
sfeat.append(pad_feat[i:i+num_row,:])
spliced_feat = np.concatenate(sfeat, 1)
return spliced_feat
class StreamingTorchDataset(Dataset):
def __init__(self, meta_file, fn_name_list, left_context, right_context, has_label=True):
with open(meta_file) as fid:
self.metadata = [ line.strip().split() for line in fid ]
self.load_fns = get_fn_list(fn_name_list)
self.left_context = left_context
self.right_context = right_context
self.has_label = has_label
def __len__(self):
return len(self.metadata)
def __getitem__(self, index):
items = self.metadata[index]
utt_id = items[0]
feat_path = items[1]
feat = self.load_fns[0](feat_path)
splice_feat = splice_feats(feat, self.left_context, self.right_context)
if self.has_label:
label_path = items[2]
label = self.load_fns[1](label_path)
# here label is usually squence level label
return utt_id, splice_feat, label
else:
return utt_id, splice_feat
WINDOW_SIZE=0
def collate_fn(batch):
"""Put each data field into a tensor with outer dimension batch size.
Args:
batch: A list of tuple (feat, label) for training or (utt_id, feat) for testing
"""
# batch is list and batch[i] is tuple
if len(batch[0]) == 2:
#testting
batch.sort(key=lambda x: x[:][1].shape[0], reverse=True)
keys = []
lengths = []
feats_padded = []
max_len = (batch[0][1].shape)[0]
for i in range(len(batch)):
keys.append(batch[i][0])
act_len = (batch[i][1].shape)[0]
pad_len = max_len - act_len
feats_padded.append(np.pad(batch[i][1], ((WINDOW_SIZE, pad_len), (0, 0)), "constant"))
lengths.append(act_len + WINDOW_SIZE)
return keys, torch.from_numpy(np.array(lengths)), torch.from_numpy(np.array(feats_padded))
elif len(batch[0]) == 3:
#training
batch.sort(key=lambda x: x[:][1].shape[0], reverse=True)
keys = []
lengths = []
feats_padded = []
label_padded = []
max_len = (batch[0][1].shape)[0]
for i in range(len(batch)):
keys.append(batch[i][0])
act_len = (batch[i][1].shape)[0]
pad_len = max_len - act_len
feats_padded.append(np.pad(batch[i][1], ((WINDOW_SIZE, pad_len), (0, 0)), "constant") )
label_padded.append(batch[i][2])
lengths.append(act_len + WINDOW_SIZE)
return keys, torch.from_numpy(np.array(lengths)), torch.from_numpy(np.array(feats_padded)), label_padded
else:
print("Error: we don't support this kind of datatype")