diff --git a/whisper/transcribe.py b/whisper/transcribe.py index 0a4cc3623..083a26026 100644 --- a/whisper/transcribe.py +++ b/whisper/transcribe.py @@ -2,7 +2,7 @@ import os import traceback import warnings -from typing import TYPE_CHECKING, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, Callable import numpy as np import torch @@ -52,6 +52,7 @@ def transcribe( append_punctuations: str = "\"'.。,,!!??::”)]}、", clip_timestamps: Union[str, List[float]] = "0", hallucination_silence_threshold: Optional[float] = None, + callback: Optional[Callable[[int, int, float], None]] = None, **decode_options, ): """ @@ -119,6 +120,10 @@ def transcribe( When word_timestamps is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected + callback: Optional[Callable[int, int, float]] = None, + After each step in the transcription process, call the callback function with + the arguments current posistion, total frames, estimated time to finish in seconds + Returns ------- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and @@ -504,8 +509,17 @@ def next_words_segment(segments: List[dict]) -> Optional[dict]: # do not feed the prompt tokens if a high temperature was used prompt_reset_since = len(all_tokens) + total_position = min(content_frames, seek) + increase = total_position - previous_seek + + if callback is not None: + rate = pbar.format_dict["rate"] + remaining = (pbar.total - pbar.n) / rate if rate and pbar.total else 0 + + callback(total_position, content_frames, remaining) + # update progress bar - pbar.update(min(content_frames, seek) - previous_seek) + pbar.update(increase) return dict( text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),