-
Notifications
You must be signed in to change notification settings - Fork 11
/
audio_to_mel.py
50 lines (48 loc) · 1.68 KB
/
audio_to_mel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch.nn as nn
import torch.nn.functional as F
import torch
from librosa.filters import mel as librosa_mel_fn
from torch.nn.utils import weight_norm
import numpy as np
class Audio2Mel(nn.Module):
def __init__(
self,
n_fft=1024,
hop_length=256,
win_length=1024,
sampling_rate=22050,
n_mel_channels=80,
mel_fmin=0.0,
mel_fmax=None,
device='cuda'
):
super().__init__()
##############################################
# FFT Parameters #
##############################################
window = torch.hann_window(win_length, device=device).float()
mel_basis = librosa_mel_fn(
sampling_rate, n_fft, n_mel_channels, mel_fmin, mel_fmax
)
mel_basis = torch.from_numpy(mel_basis).cuda().float()
self.register_buffer("mel_basis", mel_basis)
self.register_buffer("window", window)
self.n_fft = n_fft
self.hop_length = hop_length
self.win_length = win_length
self.sampling_rate = sampling_rate
self.n_mel_channels = n_mel_channels
def forward(self, audioin):
p = (self.n_fft - self.hop_length) // 2
audio = F.pad(audioin, (p, p), "reflect").squeeze(1)
fft = torch.stft(
audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=False,
)
mel_output = torch.matmul(self.mel_basis, torch.sum(torch.pow(fft, 2), dim=[-1]))
log_mel_spec = torch.log10(torch.clamp(mel_output, min=1e-5))
return log_mel_spec