Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add TranscribeProgressReceiver for update monitoring #2398

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,37 @@
from whisper.tokenizer import get_tokenizer


class TestingProgressReceiver(whisper.TranscribeProgressReceiver):
def start(self, total: int):
self.result = ""
self.total = total
self.progress = 0
return self
def update_line(self, start: float, end: float, text: str):
self.result += text
def update(self, n):
self.progress += n
def get_result(self):
return self.result
def verify_total(self):
return self.total == self.progress

@pytest.mark.parametrize("model_name", whisper.available_models())
def test_transcribe(model_name: str):
device = "cuda" if torch.cuda.is_available() else "cpu"
model = whisper.load_model(model_name).to(device)
audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac")
receiver = TestingProgressReceiver()

language = "en" if model_name.endswith(".en") else None
result = model.transcribe(
audio_path, language=language, temperature=0.0, word_timestamps=True
audio_path, language=language, temperature=0.0, word_timestamps=True,
progress_receiver=receiver
)
assert receiver.verify_total()
assert result["language"] == "en"
assert result["text"] == "".join([s["text"] for s in result["segments"]])
assert result["text"] == receiver.get_result()

transcription = result["text"].lower()
assert "my fellow americans" in transcription
Expand Down
2 changes: 1 addition & 1 deletion whisper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from .audio import load_audio, log_mel_spectrogram, pad_or_trim
from .decoding import DecodingOptions, DecodingResult, decode, detect_language
from .model import ModelDimensions, Whisper
from .transcribe import transcribe
from .transcribe import TranscribeProgressReceiver, transcribe
from .version import __version__

_MODELS = {
Expand Down
60 changes: 54 additions & 6 deletions whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import traceback
import warnings
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Self

import numpy as np
import torch
Expand Down Expand Up @@ -34,12 +34,57 @@
if TYPE_CHECKING:
from .model import Whisper

class TranscribeProgressReceiver:
"""
A class that allows external classes to inherit and handle transcription progress in customized
manners.
"""
def start(self, total: int) -> Self:
"""
The method is called when the transcription starts with integral `total` parameter in frames.
In most case this method should return `self`
"""
return self
def update(self, n: int):
"""
The `update` method is called with increment `n` in frames whenever a segment is transcribed.
"""
pass
def update_line(self, start: float, end: float, text: str):
"""
It is called whenever a segment is transcribed.

Parameters
----------
start: float
The floating point start time of the segment in seconds

end: float
The floating point end time of the segment in seconds

text: str
The transcribed text
"""
pass
def __enter__(self) -> Self:
"""
Inherit this method if resources allocation is needed at the start of the transcription.
In most cases this method should return `self`
"""
return self
def __exit__(self, exception_type, exception_value, exception_traceback):
"""
Inherit this method if resources need to be released when the transcription is finished or
terminated.
"""
pass

def transcribe(
model: "Whisper",
audio: Union[str, np.ndarray, torch.Tensor],
*,
verbose: Optional[bool] = None,
progress_receiver: TranscribeProgressReceiver = TranscribeProgressReceiver(),
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
compression_ratio_threshold: Optional[float] = 2.4,
logprob_threshold: Optional[float] = -1.0,
Expand Down Expand Up @@ -253,7 +298,8 @@ def new_segment(
# show the progress bar when verbose is False (if True, transcribed text will be printed)
with tqdm.tqdm(
total=content_frames, unit="frames", disable=verbose is not False
) as pbar:
) as pbar, \
progress_receiver.start(total=content_frames) as ext_progress:
last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
Expand Down Expand Up @@ -459,10 +505,11 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:
if last_word_end is not None:
last_speech_timestamp = last_word_end

if verbose:
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
for segment in current_segments:
start, end, text = segment["start"], segment["end"], segment["text"]
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
ext_progress.update_line(start, end, make_safe(text))
if verbose:
print(make_safe(line))

# if a segment is instantaneous or does not contain text, clear it
Expand Down Expand Up @@ -490,6 +537,7 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]:

# update progress bar
pbar.update(min(content_frames, seek) - previous_seek)
ext_progress.update(min(content_frames, seek) - previous_seek)

return dict(
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
Expand Down