Skip to content

Commit

Permalink
Use correct features padding for encoder input (#1101)
Browse files Browse the repository at this point in the history
* pad to 3000 instead of `feature_extractor.nb_max_frames`

* correct trimming for batched features
  • Loading branch information
MahmoudAshraf97 authored Oct 29, 2024
1 parent c2a1da1 commit 2386843
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
4 changes: 2 additions & 2 deletions faster_whisper/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,9 @@ def _resample_frames(frames, resampler):
yield from resampler.resample(frame)


def pad_or_trim(array, length: int, *, axis: int = -1):
def pad_or_trim(array, length: int = 3000, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
Pad or trim the Mel features array to 3000, as expected by the encoder.
"""
axis = axis % array.ndim
if array.shape[axis] > length:
Expand Down
17 changes: 10 additions & 7 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,9 +441,12 @@ def transcribe(
features = (
torch.stack(
[
self.model.feature_extractor(chunk, to_cpu=to_cpu)[
..., : self.model.feature_extractor.nb_max_frames
]
pad_or_trim(
self.model.feature_extractor(chunk, to_cpu=to_cpu)[
...,
: chunk.shape[0] // self.model.feature_extractor.hop_length,
]
)
for chunk in audio_chunks
]
)
Expand Down Expand Up @@ -847,7 +850,7 @@ def transcribe(
segment = features[
:, seek : seek + self.feature_extractor.nb_max_frames
]
encoder_output = self.encode(segment)
encoder_output = self.encode(pad_or_trim(segment))
# results is a list of tuple[str, float] with language names and
# probabilities.
results = self.model.detect_language(encoder_output)[0]
Expand Down Expand Up @@ -1105,7 +1108,7 @@ def generate_segments(
)
segment = features[:, seek : seek + segment_size]
segment_duration = segment_size * self.feature_extractor.time_per_frame
segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames)
segment = pad_or_trim(segment)

if self.logger.isEnabledFor(logging.DEBUG):
self.logger.debug(
Expand Down Expand Up @@ -1766,7 +1769,7 @@ def detect_language(self, audio: torch.Tensor):
segment = self.feature_extractor(audio, padding=True, to_cpu=to_cpu)[
:, : self.feature_extractor.nb_max_frames
]
encoder_output = self.encode(segment)
encoder_output = self.encode(pad_or_trim(segment))
results = self.model.detect_language(encoder_output)
language_token, language_probability = results[0][0]
language = language_token[2:-2]
Expand Down Expand Up @@ -1895,7 +1898,7 @@ def detect_language_multi_segment(
for i in indices:
segment_features = features[:, i * nb_max_frames : (i + 1) * nb_max_frames]
try:
encoder_output = self.encode(segment_features)
encoder_output = self.encode(pad_or_trim(segment_features))
results = self.model.detect_language(encoder_output)[0]

except ValueError as e: # or RuntimeError
Expand Down

0 comments on commit 2386843

Please sign in to comment.