Skip to content

Remove Silence in Batched transcription #1297

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions faster_whisper/transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
VadOptions,
collect_chunks,
get_speech_timestamps,
merge_segments,
)


Expand Down Expand Up @@ -125,7 +124,7 @@ def forward(self, features, tokenizer, chunks_metadata, options):
segmented_outputs = []
segment_sizes = []
for chunk_metadata, output in zip(chunks_metadata, outputs):
duration = chunk_metadata["end_time"] - chunk_metadata["start_time"]
duration = chunk_metadata["duration"]
segment_size = int(ceil(duration) * self.model.frames_per_second)
segment_sizes.append(segment_size)
(
Expand All @@ -135,7 +134,7 @@ def forward(self, features, tokenizer, chunks_metadata, options):
) = self.model._split_segments_by_timestamps(
tokenizer=tokenizer,
tokens=output["tokens"],
time_offset=chunk_metadata["start_time"],
time_offset=chunk_metadata["offset"],
segment_size=segment_size,
segment_duration=duration,
seek=0,
Expand All @@ -153,7 +152,7 @@ def forward(self, features, tokenizer, chunks_metadata, options):
tokenizer.decode(subsegment["tokens"])
),
seek=int(
chunk_metadata["start_time"] * self.model.frames_per_second
chunk_metadata["offset"] * self.model.frames_per_second
),
)
for subsegment in subsegments
Expand Down Expand Up @@ -409,8 +408,7 @@ def transcribe(
**vad_parameters, max_speech_duration_s=chunk_length
)

active_segments = get_speech_timestamps(audio, vad_parameters)
clip_timestamps = merge_segments(active_segments, vad_parameters)
clip_timestamps = get_speech_timestamps(audio, vad_parameters)
# run the audio if it is less than 30 sec even without clip_timestamps
elif duration < chunk_length:
clip_timestamps = [{"start": 0, "end": audio.shape[0]}]
Expand All @@ -419,6 +417,15 @@ def transcribe(
"No clip timestamps found. "
"Set 'vad_filter' to True or provide 'clip_timestamps'."
)
else:
clip_timestamps = [
{k: int(v * sampling_rate) for k, v in segment.items()}
for segment in clip_timestamps
]

audio_chunks, chunks_metadata = collect_chunks(
audio, clip_timestamps, max_duration=chunk_length
)

duration_after_vad = (
sum((segment["end"] - segment["start"]) for segment in clip_timestamps)
Expand All @@ -430,7 +437,6 @@ def transcribe(
format_timestamp(duration - duration_after_vad),
)

audio_chunks, chunks_metadata = collect_chunks(audio, clip_timestamps)
features = (
[self.model.feature_extractor(chunk)[..., :-1] for chunk in audio_chunks]
if duration_after_vad
Expand Down Expand Up @@ -541,6 +547,7 @@ def transcribe(
options,
log_progress,
)
segments = restore_speech_timestamps(segments, clip_timestamps, sampling_rate)

return segments, info

Expand Down
104 changes: 48 additions & 56 deletions faster_whisper/vad.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,25 +184,62 @@ def get_speech_timestamps(


def collect_chunks(
audio: np.ndarray, chunks: List[dict], sampling_rate: int = 16000
) -> Tuple[List[np.ndarray], List[Dict[str, int]]]:
"""Collects audio chunks."""
audio: np.ndarray,
chunks: List[dict],
sampling_rate: int = 16000,
max_duration: float = float("inf"),
) -> Tuple[List[np.ndarray], List[Dict[str, float]]]:
"""This function merges the chunks of audio into chunks of max_duration (s) length."""
if not chunks:
chunk_metadata = {
"start_time": 0,
"end_time": 0,
"offset": 0,
"duration": 0,
"segments": [],
}
return [np.array([], dtype=np.float32)], [chunk_metadata]

audio_chunks = []
chunks_metadata = []

current_segments = []
current_duration = 0
total_duration = 0
current_audio = np.array([], dtype=np.float32)

for chunk in chunks:
chunk_metadata = {
"start_time": chunk["start"] / sampling_rate,
"end_time": chunk["end"] / sampling_rate,
}
audio_chunks.append(audio[chunk["start"] : chunk["end"]])
chunks_metadata.append(chunk_metadata)
if (
current_duration + chunk["end"] - chunk["start"]
> max_duration * sampling_rate
):
audio_chunks.append(current_audio)
chunk_metadata = {
"offset": total_duration / sampling_rate,
"duration": current_duration / sampling_rate,
"segments": current_segments,
}
total_duration += current_duration
chunks_metadata.append(chunk_metadata)

current_segments = []

current_audio = audio[chunk["start"] : chunk["end"]]
current_duration = chunk["end"] - chunk["start"]
else:
current_segments.append(chunk)
current_audio = np.concatenate(
(current_audio, audio[chunk["start"] : chunk["end"]])
)

current_duration += chunk["end"] - chunk["start"]

audio_chunks.append(current_audio)

chunk_metadata = {
"offset": total_duration / sampling_rate,
"duration": current_duration / sampling_rate,
"segments": current_segments,
}
chunks_metadata.append(chunk_metadata)
return audio_chunks, chunks_metadata


Expand Down Expand Up @@ -325,48 +362,3 @@ def __call__(

out = np.stack(decoder_outputs, axis=1).squeeze(-1)
return out


def merge_segments(segments_list, vad_options: VadOptions, sampling_rate: int = 16000):
if not segments_list:
return []

curr_end = 0
seg_idxs = []
merged_segments = []
edge_padding = vad_options.speech_pad_ms * sampling_rate // 1000
chunk_length = vad_options.max_speech_duration_s * sampling_rate

curr_start = segments_list[0]["start"]

for idx, seg in enumerate(segments_list):
# if any segment start timing is less than previous segment end timing,
# reset the edge padding. Similarly for end timing.
if idx > 0:
if seg["start"] < segments_list[idx - 1]["end"]:
seg["start"] += edge_padding
if idx < len(segments_list) - 1:
if seg["end"] > segments_list[idx + 1]["start"]:
seg["end"] -= edge_padding

if seg["end"] - curr_start > chunk_length and curr_end - curr_start > 0:
merged_segments.append(
{
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
}
)
curr_start = seg["start"]
seg_idxs = []
curr_end = seg["end"]
seg_idxs.append((seg["start"], seg["end"]))
# add final
merged_segments.append(
{
"start": curr_start,
"end": curr_end,
"segments": seg_idxs,
}
)
return merged_segments
2 changes: 1 addition & 1 deletion tests/test_transcribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def test_batched_transcribe(physcisworks_path):
{"start": segment.start, "end": segment.end, "text": segment.text}
)
# number of near 30 sec segments
assert len(segments) == 7
assert len(segments) == 6

result, info = batched_model.transcribe(
physcisworks_path,
Expand Down