-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathTDNN_gpu.py
107 lines (89 loc) · 4.56 KB
/
TDNN_gpu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import math
class StatsPooling(nn.Module):
def __init__(self):
super(StatsPooling,self).__init__()
def forward(self,varient_length_tensor):
mean = varient_length_tensor.mean(dim=1)
std = varient_length_tensor.std(dim=1)
return torch.cat((mean,std),dim=1)
class FullyConnected(nn.Module):
def __init__(self):
super(FullyConnected, self).__init__()
self.hidden1 = nn.Linear(3000,512).double()
self.hidden2 = nn.Linear(512,512).double()
def forward(self, x):
x = F.relu( self.hidden1(x))
x = F.relu( self.hidden2(x))
return x
"""Time Delay Neural Network as mentioned in the 1989 paper by Waibel et al. (Hinton) and the 2015 paper by Peddinti et al. (Povey)"""
class TDNN(nn.Module):
def __init__(self, context, input_dim, output_dim, full_context = True,device = 'cpu'):
"""
Definition of context is the same as the way it's defined in the Peddinti paper. It's a list of integers, eg: [-2,2]
By deault, full context is chosen, which means: [-2,2] will be expanded to [-2,-1,0,1,2] i.e. range(-2,3)
"""
super(TDNN,self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.check_valid_context(context)
self.kernel_width, context = self.get_kernel_width(context,full_context)
self.register_buffer('context',torch.LongTensor(context))
self.full_context = full_context
stdv = 1./math.sqrt(input_dim)
self.kernel = nn.Parameter(torch.Tensor(output_dim, input_dim, self.kernel_width).normal_(0,stdv)).double().cuda()
self.bias = nn.Parameter(torch.Tensor(output_dim).normal_(0,stdv)).double().cuda()
# self.cuda_flag = False
def forward(self,x):
"""
x is one batch of data
x.shape: [batch_size, sequence_length, input_dim]
sequence length is the length of the input spectral data (number of frames) or if already passed through the convolutional network, it's the number of learned features
output size: [batch_size, output_dim, len(valid_steps)]
"""
# Check if parameters are cuda type and change context
# if type(self.bias.data) == torch.cuda.FloatTensor and self.cuda_flag == False:
# self.context = self.context.cuda()
# self.cuda_flag = True
conv_out = self.special_convolution(x, self.kernel, self.context, self.bias)
activation = F.relu(conv_out).transpose(1,2).contiguous()
print ('output shape: {}'.format(activation.shape))
return activation
def special_convolution(self, x, kernel, context, bias):
"""
This function performs the weight multiplication given an arbitrary context. Cannot directly use convolution because in case of only particular frames of context,
one needs to select only those frames and perform a convolution across all batch items and all output dimensions of the kernel.
"""
input_size = x.shape
assert len(input_size) == 3, 'Input tensor dimensionality is incorrect. Should be a 3D tensor'
[batch_size, input_sequence_length, input_dim] = input_size
print ('mel size: {}'.format(input_dim))
print ('sequence length: {}'.format(input_sequence_length))
x = x.transpose(1,2).contiguous()
# Allocate memory for output
valid_steps = self.get_valid_steps(self.context, input_sequence_length)
xs = Variable(self.bias.data.new(batch_size, kernel.shape[0], len(valid_steps)))
# Perform the convolution with relevant input frames
for c, i in enumerate(valid_steps):
features = torch.index_select(x, 2, context+i)
# print ('features taken:{}'.format(features))
xs[:,:,c] = F.conv1d(features, kernel, bias = bias)[:,:,0]
return xs
@staticmethod
def check_valid_context(context):
# here context is still a list
assert context[0] <= context[-1], 'Input tensor dimensionality is incorrect. Should be a 3D tensor'
@staticmethod
def get_kernel_width(context, full_context):
if full_context:
context = range(context[0],context[-1]+1)
return len(context), context
@staticmethod
def get_valid_steps(context, input_sequence_length):
start = 0 if context[0] >= 0 else -1*context[0]
end = input_sequence_length if context[-1] <= 0 else input_sequence_length - context[-1]
return range(start, end)