-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathaudio_to_text.py
161 lines (134 loc) · 5.87 KB
/
audio_to_text.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import logging
import os
from dataclasses import dataclass
from enum import Enum
from typing import List
import torch
from fastapi import File, UploadFile
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from app.pipelines.base import Pipeline
from app.pipelines.utils import get_model_dir, get_torch_device
from app.pipelines.utils.audio import AudioConverter
from app.utils.errors import InferenceError
logger = logging.getLogger(__name__)
class ModelName(Enum):
"""Enumeration mapping model names to their corresponding IDs. Returns None if the
model ID is not found."""
WHISPER_LARGE_V3 = "openai/whisper-large-v3"
WHISPER_MEDIUM = "openai/whisper-medium"
WHISPER_DISTIL_LARGE_V3 = "distil-whisper/distil-large-v3"
@classmethod
def list(cls):
"""Return a list of all model IDs."""
return [model.value for model in cls]
@classmethod
def from_value(cls, value: str) -> Enum | None:
"""Return the enum member corresponding to the given value, or None if not
found."""
try:
return cls(value)
except ValueError:
return None
@dataclass
class ModelConfig:
"""Model configuration parameters."""
torch_dtype: torch.dtype = (
torch.float16 if torch.cuda.is_available() else torch.float32
)
chunk_length_s: int = 30
batch_size: int = 16
MODEL_CONFIGS = {
ModelName.WHISPER_LARGE_V3: ModelConfig(),
ModelName.WHISPER_MEDIUM: ModelConfig(torch_dtype=torch.float32),
ModelName.WHISPER_DISTIL_LARGE_V3: ModelConfig(chunk_length_s=25),
}
INCOMPATIBLE_EXTENSIONS = ["mp4", "m4a", "ac3"]
class AudioToTextPipeline(Pipeline):
def __init__(self, model_id: str):
self.model_id = model_id
kwargs = {}
torch_device = get_torch_device()
# Enable FlashAttention based on device compatibility.
attn_implementation = "eager"
if torch_device.type == "cuda":
device = torch.cuda.get_device_properties(0)
major, minor = device.major, device.minor
# FlashAttention requires CUDA Compute Capability >= 8.0.
if (major, minor) >= (8, 0):
attn_implementation = "flash_attention_2"
else:
attn_implementation = "sdpa"
logger.warning(
f"GPU {device.name} (Compute Capability {major}.{minor}) is not "
"compatible with FlashAttention, so scaled_dot_product_attention "
"is being used instead."
)
else:
logger.warning("FlashAttention disabled since it requires a CUDA device.")
# Get model specific configuration parameters.
model_enum = ModelName.from_value(model_id)
self._model_cfg: ModelConfig = MODEL_CONFIGS.get(model_enum, ModelConfig())
kwargs["torch_dtype"] = self._model_cfg.torch_dtype
logger.info(
"AudioToText loading '%s' on device '%s' with '%s' variant",
model_id,
torch_device,
kwargs["torch_dtype"],
)
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id,
low_cpu_mem_usage=True,
use_safetensors=True,
cache_dir=get_model_dir(),
attn_implementation=attn_implementation,
device_map="auto",
**kwargs,
)
processor = AutoProcessor.from_pretrained(model_id, cache_dir=get_model_dir())
self.tm = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
max_new_tokens=128,
device_map="auto",
**kwargs,
)
self._audio_converter = AudioConverter()
def __call__(self, audio: UploadFile, duration: float, **kwargs) -> List[File]:
audioBytes = audio.file.read()
#re-encode audio to match pre-processing done in transformers.
# pipeline accepts np.ndarray and does not convert it again. String file path and bytes are converted to np.ndarray in the pipeline.
#https://github.com/huggingface/transformers/blob/47c29ccfaf56947d845971a439cbe75a764b63d7/src/transformers/pipelines/automatic_speech_recognition.py#L353
#https://github.com/huggingface/transformers/blob/47c29ccfaf56947d845971a439cbe75a764b63d7/src/transformers/pipelines/audio_utils.py#L10
audio_array = self._audio_converter.to_ndarray(audioBytes)
# Adjust batch size and chunk length based on timestamps and duration.
# NOTE: Done to prevent CUDA OOM errors for large audio files.
kwargs["batch_size"] = self._model_cfg.batch_size
kwargs["chunk_length_s"] = self._model_cfg.chunk_length_s
if kwargs["return_timestamps"] == "word":
if duration > 3600:
raise InferenceError(
f"Word timestamps are only supported for audio files up to 60 minutes for model {self.model_id}"
)
if duration > 200:
kwargs["batch_size"] = 4
if duration <= kwargs["chunk_length_s"]:
kwargs.pop("batch_size", None)
kwargs.pop("chunk_length_s", None)
inference_mode = "sequential"
else:
inference_mode = f"chunked (batch_size={kwargs['batch_size']}, chunk_length_s={kwargs['chunk_length_s']})"
logger.info(
f"AudioToTextPipeline: Starting inference mode={inference_mode} with duration={duration}"
)
try:
outputs = self.tm(audio_array, **kwargs)
outputs.setdefault("chunks", [])
except torch.cuda.OutOfMemoryError as e:
raise e
except Exception as e:
raise InferenceError(original_exception=e)
return outputs
def __str__(self) -> str:
return f"AudioToTextPipeline model_id={self.model_id}"