Skip to content

Commit

Permalink
New PR for Faster Whisper: Batching Support, Speed Boosts, and Qualit…
Browse files Browse the repository at this point in the history
…y Enhancements (#856)

Batching Support, Speed Boosts, and Quality Enhancements

---------

Co-authored-by: Hargun Mujral <[email protected]>
Co-authored-by: MahmoudAshraf97 <[email protected]>
  • Loading branch information
3 people authored Jul 18, 2024
1 parent fbcf58b commit eb83902
Show file tree
Hide file tree
Showing 13 changed files with 1,693 additions and 419 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
include faster_whisper/assets/silero_vad.onnx
include requirements.txt
include requirements.conversion.txt
include faster_whisper/assets/pyannote_vad_model.bin
30 changes: 29 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ segments, info = model.transcribe("audio.mp3", beam_size=5, language="en")

* Python 3.8 or greater

Unlike openai-whisper, FFmpeg does **not** need to be installed on the system. The audio is decoded with the Python library [PyAV](https://github.com/PyAV-Org/PyAV) which bundles the FFmpeg libraries in its package.

### GPU

Expand Down Expand Up @@ -166,6 +165,35 @@ for segment in segments:
segments, _ = model.transcribe("audio.mp3")
segments = list(segments) # The transcription will actually run here.
```

### multi-segment language detection

To directly use the model for improved language detection, the following code snippet can be used:

```python
from faster_whisper import WhisperModel
model = WhisperModel("medium", device="cuda", compute_type="float16")
language_info = model.detect_language_multi_segment("audio.mp3")
```

### Batched faster-whisper


The batched version of faster-whisper is inspired by [whisper-x](https://github.com/m-bain/whisperX) licensed under the BSD-2 Clause license and integrates its VAD model to this library. We modify this implementation and also replaced the feature extraction with a faster torch-based implementation. Batched version improves the speed upto 10-12x compared to openAI implementation and 3-4x compared to the sequential faster_whisper version. It works by transcribing semantically meaningful audio chunks as batches leading to faster inference.

The following code snippet illustrates how to run inference with batched version on an example audio file. Please also refer to the test scripts of batched faster whisper.

```python
from faster_whisper import WhisperModel, BatchedInferencePipeline

model = WhisperModel("medium", device="cuda", compute_type="float16")
batched_model = BatchedInferencePipeline(model=model)
segments, info = batched_model.transcribe("audio.mp3", batch_size=16)

for segment in segments:
print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text))
```

### Faster Distil-Whisper

The Distil-Whisper checkpoints are compatible with the Faster-Whisper package. In particular, the latest [distil-large-v3](https://huggingface.co/distil-whisper/distil-large-v3)
Expand Down
5 changes: 4 additions & 1 deletion benchmark/wer_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import json
import os

from datasets import load_dataset
from evaluate import load
Expand All @@ -26,7 +27,9 @@

# define the evaluation metric
wer_metric = load("wer")
normalizer = EnglishTextNormalizer(json.load(open("normalizer.json")))

with open(os.path.join(os.path.dirname(__file__), "normalizer.json"), "r") as f:
normalizer = EnglishTextNormalizer(json.load(f))


def inference(batch):
Expand Down
3 changes: 2 additions & 1 deletion faster_whisper/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from faster_whisper.audio import decode_audio
from faster_whisper.transcribe import WhisperModel
from faster_whisper.transcribe import BatchedInferencePipeline, WhisperModel
from faster_whisper.utils import available_models, download_model, format_timestamp
from faster_whisper.version import __version__

__all__ = [
"available_models",
"decode_audio",
"WhisperModel",
"BatchedInferencePipeline",
"download_model",
"format_timestamp",
"__version__",
Expand Down
Binary file added faster_whisper/assets/pyannote_vad_model.bin
Binary file not shown.
105 changes: 22 additions & 83 deletions faster_whisper/audio.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,7 @@
"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV
The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional
system dependencies. FFmpeg does not need to be installed on the system.
However, the API is quite low-level so we need to manipulate audio frames directly.
"""

import gc
import io
import itertools

from typing import BinaryIO, Union

import av
import numpy as np
import torch
import torchaudio


def decode_audio(
Expand All @@ -29,91 +17,42 @@ def decode_audio(
split_stereo: Return separate left and right channels.
Returns:
A float32 Numpy array.
A float32 Torch Tensor.
If `split_stereo` is enabled, the function returns a 2-tuple with the
separated left and right channels.
"""
resampler = av.audio.resampler.AudioResampler(
format="s16",
layout="mono" if not split_stereo else "stereo",
rate=sampling_rate,
)

raw_buffer = io.BytesIO()
dtype = None

with av.open(input_file, mode="r", metadata_errors="ignore") as container:
frames = container.decode(audio=0)
frames = _ignore_invalid_frames(frames)
frames = _group_frames(frames, 500000)
frames = _resample_frames(frames, resampler)

for frame in frames:
array = frame.to_ndarray()
dtype = array.dtype
raw_buffer.write(array)

# It appears that some objects related to the resampler are not freed
# unless the garbage collector is manually run.
del resampler
gc.collect()

audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)

# Convert s16 back to f32.
audio = audio.astype(np.float32) / 32768.0
waveform, audio_sf = torchaudio.load(input_file) # waveform: channels X T

if audio_sf != sampling_rate:
waveform = torchaudio.functional.resample(
waveform, orig_freq=audio_sf, new_freq=sampling_rate
)
if split_stereo:
left_channel = audio[0::2]
right_channel = audio[1::2]
return left_channel, right_channel

return audio


def _ignore_invalid_frames(frames):
iterator = iter(frames)

while True:
try:
yield next(iterator)
except StopIteration:
break
except av.error.InvalidDataError:
continue


def _group_frames(frames, num_samples=None):
fifo = av.audio.fifo.AudioFifo()

for frame in frames:
frame.pts = None # Ignore timestamp check.
fifo.write(frame)

if num_samples is not None and fifo.samples >= num_samples:
yield fifo.read()

if fifo.samples > 0:
yield fifo.read()

return waveform[0], waveform[1]

def _resample_frames(frames, resampler):
# Add None to flush the resampler.
for frame in itertools.chain(frames, [None]):
yield from resampler.resample(frame)
return waveform.mean(0)


def pad_or_trim(array, length: int, *, axis: int = -1):
"""
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
"""
axis = axis % array.ndim
if array.shape[axis] > length:
array = array.take(indices=range(length), axis=axis)
idx = [Ellipsis] * axis + [slice(length)] + [Ellipsis] * (array.ndim - axis - 1)
return array[idx]

if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
pad_widths = (
[
0,
]
* array.ndim
* 2
)
pad_widths[2 * axis] = length - array.shape[axis]
array = torch.nn.functional.pad(array, tuple(pad_widths[::-1]))

return array
Loading

2 comments on commit eb83902

@guicbrito
Copy link

@guicbrito guicbrito commented on eb83902 Jul 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Letting you know that this commit broke my app deps. I have ffmpeg 7.0.1 built from source at /usr/bin/ffmpeg which I use for other app services. Somehow faster-whisper gets stuck in a loop trying to load ffmpeg 6.

[+] Running 1/1
 ✔ Container faladoc-ai  Recreated                                                                         0.1s 
Attaching to container
2024-07-26 17:24:22.499 | DEBUG    | logging              | Loading FFmpeg6
2024-07-26 17:24:22.503 | DEBUG    | logging              | Loading FFmpeg6
2024-07-26 17:24:22.507 | DEBUG    | logging              | Loading FFmpeg6
2024-07-26 17:24:22.511 | DEBUG    | logging              | Loading FFmpeg6
2024-07-26 17:24:22.514 | DEBUG    | logging              | Loading FFmpeg6
2024-07-26 17:24:22.519 | DEBUG    | logging              | Loading FFmpeg6
2024-07-26 17:24:22.524 | DEBUG    | logging              | Loading FFmpeg6
2024-07-26 17:24:22.530 | DEBUG    | logging              | Loading FFmpeg6
2024-07-26 17:24:22.541 | DEBUG    | logging              | Loading FFmpeg6
...
<goes on forever>

I belive this has to do with removing av from deps and using torchaudio directly, as torchaudio is not compatible with ffmpeg 7 yet and probably tries to use the ffmpeg binary it finds at /usr/bin.
Reverting to commit c22db51 works fine with ffmpeg 7 around.

@aligokalppeker
Copy link

@aligokalppeker aligokalppeker commented on eb83902 Jul 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bullshit update, you split torch dependencies all over the place which does not bring any advantage but gigabytes of package dependency. Using Numpy is a more versatile and reusable way to put the library in a good place. This PR made faster whisper in a terrible state.

Please sign in to comment.