diff --git a/tests/test_transcribe.py b/tests/test_transcribe.py index 599221af5..e68a8ba49 100644 --- a/tests/test_transcribe.py +++ b/tests/test_transcribe.py @@ -7,18 +7,37 @@ from whisper.tokenizer import get_tokenizer +class TestingProgressReceiver(whisper.TranscribeProgressReceiver): + def start(self, total: int): + self.result = "" + self.total = total + self.progress = 0 + return self + def update_line(self, start: float, end: float, text: str): + self.result += text + def update(self, n): + self.progress += n + def get_result(self): + return self.result + def verify_total(self): + return self.total == self.progress + @pytest.mark.parametrize("model_name", whisper.available_models()) def test_transcribe(model_name: str): device = "cuda" if torch.cuda.is_available() else "cpu" model = whisper.load_model(model_name).to(device) audio_path = os.path.join(os.path.dirname(__file__), "jfk.flac") + receiver = TestingProgressReceiver() language = "en" if model_name.endswith(".en") else None result = model.transcribe( - audio_path, language=language, temperature=0.0, word_timestamps=True + audio_path, language=language, temperature=0.0, word_timestamps=True, + progress_receiver=receiver ) + assert receiver.verify_total() assert result["language"] == "en" assert result["text"] == "".join([s["text"] for s in result["segments"]]) + assert result["text"] == receiver.get_result() transcription = result["text"].lower() assert "my fellow americans" in transcription diff --git a/whisper/__init__.py b/whisper/__init__.py index e210718f3..5546ab5bd 100644 --- a/whisper/__init__.py +++ b/whisper/__init__.py @@ -11,7 +11,7 @@ from .audio import load_audio, log_mel_spectrogram, pad_or_trim from .decoding import DecodingOptions, DecodingResult, decode, detect_language from .model import ModelDimensions, Whisper -from .transcribe import transcribe +from .transcribe import TranscribeProgressReceiver, transcribe from .version import __version__ _MODELS = { diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 8e1240bd6..1ed5bc256 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -2,7 +2,7 @@ import os import traceback import warnings -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Self import numpy as np import torch @@ -34,12 +34,57 @@ if TYPE_CHECKING: from .model import Whisper +class TranscribeProgressReceiver: + """ + A class that allows external classes to inherit and handle transcription progress in customized + manners. + """ + def start(self, total: int) -> Self: + """ + The method is called when the transcription starts with integral `total` parameter in frames. + In most case this method should return `self` + """ + return self + def update(self, n: int): + """ + The `update` method is called with increment `n` in frames whenever a segment is transcribed. + """ + pass + def update_line(self, start: float, end: float, text: str): + """ + It is called whenever a segment is transcribed. + + Parameters + ---------- + start: float + The floating point start time of the segment in seconds + + end: float + The floating point end time of the segment in seconds + + text: str + The transcribed text + """ + pass + def __enter__(self) -> Self: + """ + Inherit this method if resources allocation is needed at the start of the transcription. + In most cases this method should return `self` + """ + return self + def __exit__(self, exception_type, exception_value, exception_traceback): + """ + Inherit this method if resources need to be released when the transcription is finished or + terminated. + """ + pass def transcribe( model: "Whisper", audio: Union[str, np.ndarray, torch.Tensor], *, verbose: Optional[bool] = None, + progress_receiver: TranscribeProgressReceiver = TranscribeProgressReceiver(), temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold: Optional[float] = 2.4, logprob_threshold: Optional[float] = -1.0, @@ -253,7 +298,8 @@ def new_segment( # show the progress bar when verbose is False (if True, transcribed text will be printed) with tqdm.tqdm( total=content_frames, unit="frames", disable=verbose is not False - ) as pbar: + ) as pbar, \ + progress_receiver.start(total=content_frames) as ext_progress: last_speech_timestamp = 0.0 # NOTE: This loop is obscurely flattened to make the diff readable. # A later commit should turn this into a simpler nested loop. @@ -459,10 +505,11 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: if last_word_end is not None: last_speech_timestamp = last_word_end - if verbose: - for segment in current_segments: - start, end, text = segment["start"], segment["end"], segment["text"] - line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" + for segment in current_segments: + start, end, text = segment["start"], segment["end"], segment["text"] + line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}" + ext_progress.update_line(start, end, make_safe(text)) + if verbose: print(make_safe(line)) # if a segment is instantaneous or does not contain text, clear it @@ -490,6 +537,7 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: # update progress bar pbar.update(min(content_frames, seek) - previous_seek) + ext_progress.update(min(content_frames, seek) - previous_seek) return dict( text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),