-
Notifications
You must be signed in to change notification settings - Fork 0
/
melspec_strech.py
102 lines (75 loc) · 3.48 KB
/
melspec_strech.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import numpy as np
import torch
import torch.nn as nn
from torchaudio.transforms import Spectrogram, MelSpectrogram , ComplexNorm
from torchaudio.transforms import TimeStretch, AmplitudeToDB
from torch.distributions import Uniform
def _num_stft_bins(lengths, fft_length, hop_length, pad):
return (lengths + 2 * pad - fft_length + hop_length) // hop_length
class MelspectrogramStretch(MelSpectrogram):
def __init__(self, hop_length=None,
sample_rate=44100,
num_mels=128,
fft_length=2048,
norm='whiten',
stretch_param=[0.4, 0.4]):
super(MelspectrogramStretch, self).__init__(sample_rate=sample_rate,
n_fft=fft_length,
hop_length=hop_length,
n_mels=num_mels)
self.stft = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
hop_length=self.hop_length, pad=self.pad,
power=None, normalized=False)
# Augmentation
self.prob = stretch_param[0]
self.random_stretch = RandomTimeStretch(stretch_param[1],
self.hop_length,
self.n_fft//2+1,
fixed_rate=None)
# Normalization (pot spec processing)
self.complex_norm = ComplexNorm(power=2.)
self.norm = SpecNormalization(norm)
def forward(self, x, lengths=None):
x = self.stft(x)
if lengths is not None:
lengths = _num_stft_bins(lengths, self.n_fft, self.hop_length, self.n_fft//2)
lengths = lengths.long()
if torch.rand(1)[0] <= self.prob and self.training:
# Stretch spectrogram in time using Phase Vocoder
x, rate = self.random_stretch(x)
# Modify the rate accordingly
lengths = (lengths.float()/rate).long()+1
x = self.complex_norm(x)
x = self.mel_scale(x)
# Normalize melspectrogram
x = self.norm(x)
if lengths is not None:
return x, lengths
return x
def __repr__(self):
return self.__class__.__name__ + '()'
class RandomTimeStretch(TimeStretch):
def __init__(self, max_perc, hop_length=None, n_freq=201, fixed_rate=None):
super(RandomTimeStretch, self).__init__(hop_length, n_freq, fixed_rate)
self._dist = Uniform(1.-max_perc, 1+max_perc)
def forward(self, x):
rate = self._dist.sample().item()
return super(RandomTimeStretch, self).forward(x, rate), rate
class SpecNormalization(nn.Module):
def __init__(self, norm_type, top_db=80.0):
super(SpecNormalization, self).__init__()
if 'db' == norm_type:
self._norm = AmplitudeToDB(stype='power', top_db=top_db)
elif 'whiten' == norm_type:
self._norm = lambda x: self.z_transform(x)
else:
self._norm = lambda x: x
def z_transform(self, x):
# Independent mean, std per batch
non_batch_inds = [1, 2, 3]
mean = x.mean(non_batch_inds, keepdim=True)
std = x.std(non_batch_inds, keepdim=True)
x = (x - mean)/std
return x
def forward(self, x):
return self._norm(x)