Skip to content

Commit

Permalink
Update feature_extractor.py
Browse files Browse the repository at this point in the history
  • Loading branch information
BBC-Esq authored Oct 5, 2024
1 parent d57c5b4 commit 9c5975c
Showing 1 changed file with 83 additions and 33 deletions.
116 changes: 83 additions & 33 deletions faster_whisper/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,59 @@
import torch


# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
class FeatureExtractor:
def __init__(
self,
device: str = "auto",
feature_size=80,
sampling_rate=16000,
hop_length=160,
chunk_length=30,
n_fft=400,
feature_size: int = 80,
sampling_rate: int = 16000,
hop_length: int = 160,
chunk_length: int = 30,
n_fft: int = 400,
):
"""
Initializes the FeatureExtractor with the given parameters.
Args:
device (str): Device to perform computations on ("cuda", "cpu", or "auto").
feature_size (int): Number of Mel filter banks.
sampling_rate (int): Sampling rate of the input audio.
hop_length (int): Number of samples between successive frames.
chunk_length (int): Length of audio chunks in seconds.
n_fft (int): Number of FFT components.
"""
if device == "auto":
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:
self.device = device
self.n_fft = n_fft
self.hop_length = hop_length
self.chunk_length = chunk_length
self.sampling_rate = sampling_rate
self.n_samples = chunk_length * sampling_rate
self.nb_max_frames = self.n_samples // hop_length
self.time_per_frame = hop_length / sampling_rate
self.sampling_rate = sampling_rate

# Precompute and cache the Hann window on the appropriate device
self.window = torch.hann_window(self.n_fft).to(self.device)

# Precompute mel filters and move them to the device
self.mel_filters = self.get_mel_filters(
sampling_rate, n_fft, n_mels=feature_size
)
).to(self.device)

@staticmethod
def get_mel_filters(sr, n_fft, n_mels=128):
def get_mel_filters(sr: int, n_fft: int, n_mels: int = 128) -> torch.Tensor:
"""
Implementation of librosa.filters.mel in Pytorch
Implementation of librosa.filters.mel in PyTorch.
Args:
sr (int): Sampling rate.
n_fft (int): Number of FFT components.
n_mels (int): Number of Mel filter banks.
Returns:
torch.Tensor: Mel filter matrix of shape (n_mels, 1 + n_fft // 2).
"""
# Initialize the weights
n_mels = int(n_mels)
Expand All @@ -40,7 +63,7 @@ def get_mel_filters(sr, n_fft, n_mels=128):

# 'Center freqs' of mel bands - uniformly spaced between limits
min_mel = 0.0
max_mel = 45.245640471924965
max_mel = 45.245640471924965 # Corresponds to the maximum mel value used

mels = torch.linspace(min_mel, max_mel, n_mels + 2)

Expand All @@ -54,61 +77,88 @@ def get_mel_filters(sr, n_fft, n_mels=128):
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
logstep = torch.log(torch.tensor(6.4)) / 27.0 # step size for log region

# If we have vector data, vectorize
# Apply nonlinear scaling for frequencies above min_log_hz
log_t = mels >= min_log_mel
freqs[log_t] = min_log_hz * torch.exp(logstep * (mels[log_t] - min_log_mel))

mel_f = freqs

# Compute the difference between adjacent mel frequencies
fdiff = torch.diff(mel_f)

# Create ramps for lower and upper edges
ramps = mel_f.view(-1, 1) - fftfreqs.view(1, -1)

lower = -ramps[:-2] / fdiff[:-1].unsqueeze(1)
upper = ramps[2:] / fdiff[1:].unsqueeze(1)

# Intersect them with each other and zero, vectorized across all i
# Intersect them with each other and zero
weights = torch.maximum(torch.zeros_like(lower), torch.minimum(lower, upper))

# Slaney-style mel is scaled to be approx constant energy per channel
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
weights *= enorm.unsqueeze(1)
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels]).unsqueeze(1)
weights *= enorm

return weights

def __call__(self, waveform, padding=True, chunk_length=None, to_cpu=False):
"""
Compute the log-Mel spectrogram of the provided audio.
def __call__(
self,
waveform: torch.Tensor,
padding: bool = True,
chunk_length: int = None,
to_cpu: bool = False,
) -> torch.Tensor:
"""
Compute the log-Mel spectrogram of the provided audio waveform.
Args:
waveform (torch.Tensor): Input audio waveform tensor of shape (..., n_samples).
padding (bool): Whether to pad the waveform to the required number of samples.
chunk_length (int, optional): Length of audio chunks in seconds. Overrides initialization if provided.
to_cpu (bool): Whether to move the output spectrogram to CPU.
Returns:
torch.Tensor: Log-Mel spectrogram tensor.
"""
# Use local variables instead of modifying instance attributes
if chunk_length is not None:
self.n_samples = chunk_length * self.sampling_rate
self.nb_max_frames = self.n_samples // self.hop_length
n_samples = chunk_length * self.sampling_rate
nb_max_frames = n_samples // self.hop_length
else:
n_samples = self.n_samples
nb_max_frames = self.nb_max_frames

# Ensure waveform is float32
if waveform.dtype is not torch.float32:
waveform = waveform.to(torch.float32)

waveform = (
waveform.to(self.device)
if self.device == "cuda" and not waveform.is_cuda
else waveform
)
# Move waveform to the target device if necessary
if self.device == "cuda" and not waveform.is_cuda:
waveform = waveform.to(self.device)

# Apply padding if required
if padding:
waveform = torch.nn.functional.pad(waveform, (0, self.n_samples))

window = torch.hann_window(self.n_fft).to(waveform.device)
waveform = torch.nn.functional.pad(waveform, (0, n_samples))

# Perform Short-Time Fourier Transform (STFT) using the cached window
stft = torch.stft(
waveform, self.n_fft, self.hop_length, window=window, return_complex=True
waveform,
n_fft=self.n_fft,
hop_length=self.hop_length,
window=self.window,
return_complex=True,
)
magnitudes = stft[..., :-1].abs() ** 2

mel_spec = self.mel_filters.to(waveform.device) @ magnitudes
# Compute the power spectrogram
magnitudes = stft.abs() ** 2

# Apply the precomputed mel filters
mel_spec = self.mel_filters @ magnitudes

# Compute the log-Mel spectrogram
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
log_spec = (log_spec + 4.0) / 4.0

# When the model is running on multiple GPUs, the output should be moved
# to the CPU since we don't know which GPU will handle the next job.
# Move the spectrogram to CPU if requested
return log_spec.cpu() if to_cpu else log_spec

0 comments on commit 9c5975c

Please sign in to comment.