Skip to content

Commit

Permalink
unify the apis of batched and sequential transcribe functions
Browse files Browse the repository at this point in the history
  • Loading branch information
MahmoudAshraf97 committed Jul 30, 2024
1 parent ca31ce3 commit bc302ad
Showing 1 changed file with 43 additions and 47 deletions.
90 changes: 43 additions & 47 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ class BatchedInferencePipeline:
def __init__(
self,
model,
use_vad_model: bool = True,
options: Optional[NamedTuple] = None,
tokenizer=None,
language: Optional[str] = None,
Expand All @@ -119,7 +118,6 @@ def __init__(
self.tokenizer = tokenizer
self.options = options
self.preset_language = language
self.use_vad_model = use_vad_model
self.last_speech_timestamp = 0.0

def forward(self, features, segments_metadata, **forward_params):
Expand Down Expand Up @@ -227,10 +225,7 @@ def audio_split(audio, segments, sampling_rate):

def transcribe(
self,
audio: Union[str, torch.Tensor, np.ndarray],
vad_segments: Optional[List[dict]] = None,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
batch_size: int = 16,
audio: Union[str, BinaryIO, torch.Tensor, np.ndarray],
language: Optional[str] = None,
task: str = None,
log_progress: bool = False,
Expand All @@ -256,29 +251,26 @@ def transcribe(
prefix: Optional[str] = None,
suppress_blank: bool = True,
suppress_tokens: Optional[List[int]] = [-1],
without_timestamps: bool = True,
word_timestamps: bool = False,
prepend_punctuations: str = "\"'“¿([{-",
append_punctuations: str = "\"'.。,,!!??::”)]}、",
vad_filter: bool = True,
vad_parameters: Optional[Union[dict, VadOptions]] = None,
max_new_tokens: Optional[int] = None,
chunk_length: Optional[int] = None,
clip_timestamps: Optional[List[dict]] = None,
batch_size: int = 16,
hotwords: Optional[str] = None,
word_timestamps: bool = False,
without_timestamps: bool = True,
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
"""transcribe audio in chunks in batched fashion and return with language info.
Arguments:
audio: audio file as numpy array/path for batched transcription.
vad_segments: Optionally provide list of dictionaries each containing "start", "end",
and "segments" keys.
"start" and "end" keys specify the start and end of the voiced region within
30 sec boundary. An additional key "segments" contains all the start
and end of voiced regions within that 30sec boundary as a list of tuples.
If no vad_segments specified, it uses internal vad model automatically segment them.
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
parameters and default values in the class `VadOptions`).
batch_size: the maximum number of parallel requests to model for decoding.
language: The language spoken in the audio.
task: either "transcribe" or "translate".
audio: Path to the input file (or a file-like object), or the audio waveform.
language: The language spoken in the audio. It should be a language code such
as "en" or "fr". If not set, the language will be detected in the first 30 seconds
of audio.
task: Task to execute (transcribe or translate).
log_progress: whether to show progress bar or not.
beam_size: Beam size to use for decoding.
best_of: Number of candidates when sampling with non-zero temperature.
Expand All @@ -295,8 +287,8 @@ def transcribe(
log_prob_threshold: If the average log probability over sampled tokens is
below this value, treat as failed.
log_prob_low_threshold: This parameter alone is sufficient to skip an output text,
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
This value should be less than log_prob_threshold.
whereas log_prob_threshold also looks for appropriate no_speech_threshold value.
This value should be less than log_prob_threshold.
no_speech_threshold: If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `log_prob_threshold`,
consider the segment as silent.
Expand All @@ -306,20 +298,29 @@ def transcribe(
suppress_blank: Suppress blank outputs at the beginning of the sampling.
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
of symbols as defined in `tokenizer.non_speech_tokens()`.
without_timestamps: Only sample text tokens.
word_timestamps: Extract word-level timestamps using the cross-attention pattern
and dynamic time warping, and include the timestamps for each word in each segment.
Set as False.
prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
with the next word
append_punctuations: If word_timestamps is True, merge these punctuation symbols
with the previous word
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
parameters and default values in the class `VadOptions`).
max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
the maximum will be set by the default max_length.
chunk_length: The length of audio segments. If it is not None, it will overwrite the
default chunk_length of the FeatureExtractor.
clip_timestamps: Optionally provide list of dictionaries each containing "start" and
"end" keys that specify the start and end of the voiced region within
`chunk_length` boundary. vad_filter will be ignored if clip_timestamps is used.
batch_size: the maximum number of parallel requests to model for decoding.
hotwords:
Hotwords/hint phrases to the model. Has no effect if prefix is not None.
word_timestamps: Extract word-level timestamps using the cross-attention pattern
and dynamic time warping, and include the timestamps for each word in each segment.
Set as False.
without_timestamps: Only sample text tokens.
Static params: (Fixed for batched version)
max_initial_timestamp: The initial timestamp cannot be later than this, set at 0.0.
Expand All @@ -337,24 +338,18 @@ def transcribe(
hallucination_silence_threshold: Optional[float]
When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected. set as None.
clip_timestamps:
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
process. The last end timestamp defaults to the end of the file. Set as "0".
unused:
language_detection_threshold: If the maximum probability of the language tokens is
higher than this value, the language is detected.
language_detection_segments: Number of segments to consider for the language detection.
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
Returns:
A tuple with:
- a generator over transcribed batched segments.
- an instance of TranscriptionInfo.
- a generator over transcribed segments
- an instance of TranscriptionInfo
"""

sampling_rate = self.model.feature_extractor.sampling_rate
Expand All @@ -367,9 +362,8 @@ def transcribe(

chunk_length = chunk_length or self.model.feature_extractor.chunk_length
# if no segment split is provided, use vad_model and generate segments
if not vad_segments:
# run the audio if it is less than 30 sec even without vad_segments
if self.use_vad_model:
if not clip_timestamps:
if vad_filter:
if vad_parameters is None:
vad_parameters = VadOptions(
onset=0.500,
Expand All @@ -380,30 +374,32 @@ def transcribe(
)
elif isinstance(vad_parameters, dict):
if "max_speech_duration_s" in vad_parameters.keys():
vad_parameters["max_speech_duration_s"] = chunk_length
vad_parameters.pop("max_speech_duration_s")

vad_parameters = VadOptions(**vad_parameters)
vad_parameters = VadOptions(
**vad_parameters, max_speech_duration_s=chunk_length
)

scores, timestamps = get_vad_scores(audio)
active_segments = get_active_regions(
scores.squeeze(), timestamps, vad_parameters
)
vad_segments = merge_segments(active_segments, vad_parameters)
clip_timestamps = merge_segments(active_segments, vad_parameters)
# run the audio if it is less than 30 sec even without clip_timestamps
elif duration < chunk_length:
vad_segments = [
{"start": 0.0, "end": duration, "segments": [(0.0, duration)]}
]
clip_timestamps = [{"start": 0.0, "end": duration}]
else:
raise RuntimeError(
"No vad segments found. Set 'use_vad_model' to True while loading the model"
"No clip timestamps found. "
"Set 'vad_filter' to True or provide 'clip_timestamps'."
)
if self.model.model.is_multilingual:
language = language or self.preset_language
elif language != "en":
if language is not None:
self.model.logger.warning(
f"English-only model is used, but {language} language is"
"chosen, setting language to 'en'."
" chosen, setting language to 'en'."
)
language = "en"

Expand All @@ -415,7 +411,7 @@ def transcribe(
) = self.get_language_and_tokenizer(audio, task, language)

duration_after_vad = sum(
segment["end"] - segment["start"] for segment in vad_segments
segment["end"] - segment["start"] for segment in clip_timestamps
)

# batched options: see the difference with default options in WhisperModel
Expand Down Expand Up @@ -463,7 +459,7 @@ def transcribe(
)

audio_segments, segments_metadata = self.audio_split(
audio, vad_segments, sampling_rate
audio, clip_timestamps, sampling_rate
)
to_cpu = (
self.model.model.device == "cuda" and len(self.model.model.device_index) > 1
Expand Down

0 comments on commit bc302ad

Please sign in to comment.