Skip to content

Commit d8b53d5

Browse files
committed
Add stream transcriber to preview the transcribe progress
1 parent c9b76f6 commit d8b53d5

File tree

4 files changed

+117
-5
lines changed

4 files changed

+117
-5
lines changed

cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from utils import to_srt
77

8-
from transcriber.Transcriber import Transcriber
8+
from transcriber.StreamTranscriber import StreamTranscriber
99

1010
# Configure logging first, before any imports
1111
logging.basicConfig(
@@ -78,7 +78,7 @@ def main():
7878
check_models()
7979

8080
logger.info("Transcribing %s", args.audio_file)
81-
transcriber = Transcriber(
81+
transcriber = StreamTranscriber(
8282
corrector="opencc", use_denoiser=args.denoise, with_punct=args.punct)
8383
transcribe_results = transcriber.transcribe(args.audio_file)
8484

transcriber/StreamTranscriber.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import logging
2+
import os
3+
import re
4+
from typing import List, Literal, Union, Generator, Iterator
5+
import inspect
6+
7+
import librosa
8+
import numpy as np
9+
import onnxruntime
10+
import torch
11+
from funasr_onnx import Fsmn_vad_online, SenseVoiceSmall
12+
from funasr_onnx.utils.sentencepiece_tokenizer import SentencepiecesTokenizer
13+
from resampy.core import resample
14+
from torchaudio.pipelines import MMS_FA as bundle
15+
from tqdm.auto import tqdm
16+
17+
from corrector.Corrector import Corrector
18+
from denoiser import denoiser
19+
from transcriber.TranscribeResult import TranscribeResult
20+
from transcriber.Transcriber import Transcriber
21+
22+
logger = logging.getLogger(__name__)
23+
24+
25+
class StreamTranscriber(Transcriber):
26+
"""
27+
StreamTranscriber class
28+
29+
"""
30+
31+
def transcribe(
32+
self,
33+
audio_file: str,
34+
) -> Generator[TranscribeResult, None, None]:
35+
"""
36+
Transcribe audio file to text with timestamps.
37+
38+
Args:
39+
audio_file (str): Path to audio file
40+
41+
Returns:
42+
Generator[TranscribeResult]: Generator of transcription results
43+
"""
44+
speech, sr = librosa.load(audio_file, sr=self.sr)
45+
46+
if self.use_denoiser:
47+
logger.info("Denoising speech...")
48+
speech, _ = denoiser(speech, sr)
49+
50+
if sr != 16_000:
51+
speech = resample(speech, sr, 16_000,
52+
filter="kaiser_best", parallel=True)
53+
54+
logger.info("Segmenting speech...")
55+
vad_segments = self._segment_speech(speech)
56+
57+
if not vad_segments:
58+
return []
59+
60+
61+
pgb_vad_segments = tqdm(
62+
enumerate(vad_segments),
63+
total=len(vad_segments),
64+
desc="Transcribing"
65+
)
66+
67+
result_generator = self._process_segments(speech, pgb_vad_segments)
68+
for result in self._convert_to_traditional_chinese(result_generator):
69+
pgb_vad_segments.set_description(result.text)
70+
yield result
71+
72+
def _process_segments(
73+
self,
74+
speech: np.ndarray,
75+
pgb_vad_segments: Iterator
76+
) -> Generator[TranscribeResult, None, None]:
77+
"""Process each speech segment"""
78+
speech_lengths = len(speech)
79+
80+
for _, segment in pgb_vad_segments:
81+
speech_j, _ = self._slice_padding_audio_samples(
82+
speech,
83+
speech_lengths,
84+
[[segment]]
85+
)
86+
87+
stt_results = self._asr(speech_j[0])
88+
timestamp_offset = ((segment[0] * 16) / 16_000) - 0.1
89+
90+
if not stt_results:
91+
continue
92+
93+
for result in stt_results:
94+
result.start_time += timestamp_offset
95+
result.end_time += timestamp_offset
96+
97+
yield result
98+
99+
def _convert_to_traditional_chinese(
100+
self,
101+
results: Iterator[TranscribeResult]
102+
) -> Generator[TranscribeResult, None, None]:
103+
"""Convert simplified Chinese to traditional Chinese"""
104+
if not results:
105+
return results
106+
107+
corrector = Corrector(self.corrector)
108+
109+
for result in results:
110+
result.text = corrector.correct(result.text)
111+
yield result

transcriber/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .Transcriber import Transcriber
2+
from .StreamTranscriber import StreamTranscriber
23
from .TranscribeResult import TranscribeResult
34

4-
__all__ = ["Transcriber", "TranscribeResult"]
5+
__all__ = ["Transcriber", "StreamTranscriber", "TranscribeResult"]

utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import os
33
import tempfile
4-
from typing import List
4+
from typing import Iterator
55

66
from pysrt import SubRipFile, SubRipItem, SubRipTime
77
from pytubefix import YouTube
@@ -46,7 +46,7 @@ def download_youtube_audio(video_id: str) -> str:
4646
return None
4747

4848

49-
def to_srt(results: List["TranscribeResult"]) -> str:
49+
def to_srt(results: Iterator["TranscribeResult"]) -> str:
5050
"""
5151
Convert the list of TranscribeResult objects into a SRT file
5252
"""

0 commit comments

Comments
 (0)