Skip to content

Commit d84cef7

Browse files
authored
[Frontend] Add /v1/audio/transcriptions OpenAI API endpoint (vllm-project#12909)
1 parent 37dfa60 commit d84cef7

20 files changed

+910
-19
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ steps:
117117
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
118118
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
119119
- pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
120-
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py
120+
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/correctness/
121121
- pytest -v -s entrypoints/test_chat_utils.py
122122
- pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
123123

@@ -205,7 +205,7 @@ steps:
205205
- VLLM_USE_V1=1 pytest -v -s v1/e2e
206206
# Integration test for streaming correctness (requires special branch).
207207
- pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api
208-
- pytest -v -s entrypoints/openai/test_accuracy.py::test_lm_eval_accuracy_v1_engine
208+
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
209209

210210
- label: Examples Test # 25min
211211
working_dir: "/vllm-workspace/examples"
@@ -339,6 +339,14 @@ steps:
339339
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
340340
- bash ./run-tests.sh -c configs/models-small.txt -t 1
341341

342+
- label: OpenAI API correctness
343+
source_file_dependencies:
344+
- csrc/
345+
- vllm/entrypoints/openai/
346+
- vllm/model_executor/models/whisper.py
347+
commands: # LMEval+Transcription WER check
348+
- pytest -s entrypoints/openai/correctness/
349+
342350
- label: Encoder Decoder tests # 5min
343351
source_file_dependencies:
344352
- vllm/

docs/source/serving/openai_compatible_server.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ We currently support the following OpenAI APIs:
4141
- *Note: `parallel_tool_calls` and `user` parameters are ignored.*
4242
- [Embeddings API](#embeddings-api) (`/v1/embeddings`)
4343
- Only applicable to [embedding models](../models/pooling_models.md) (`--task embed`).
44+
- [Transcriptions API](#transcriptions-api) (`/v1/audio/transcriptions`)
45+
- Only applicable to Automatic Speech Recognition (ASR) models (OpenAI Whisper) (`--task generate`).
4446

4547
In addition, we have the following custom APIs:
4648

@@ -296,6 +298,17 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s
296298
:end-before: end-chat-embedding-extra-params
297299
:::
298300

301+
(transcriptions-api)=
302+
303+
### Transcriptions API
304+
305+
Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription);
306+
you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it.
307+
308+
<!-- TODO: api enforced limits + uploading audios -->
309+
310+
Code example: <gh-file:examples/online_serving/openai_transcription_client.py>
311+
299312
(tokenizer-api)=
300313

301314
### Tokenizer API
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from openai import OpenAI
3+
4+
from vllm.assets.audio import AudioAsset
5+
6+
mary_had_lamb = AudioAsset('mary_had_lamb').get_local_path()
7+
winning_call = AudioAsset('winning_call').get_local_path()
8+
9+
# Modify OpenAI's API key and API base to use vLLM's API server.
10+
openai_api_key = "EMPTY"
11+
openai_api_base = "http://localhost:8000/v1"
12+
client = OpenAI(
13+
api_key=openai_api_key,
14+
base_url=openai_api_base,
15+
)
16+
with open(str(mary_had_lamb), "rb") as f:
17+
transcription = client.audio.transcriptions.create(
18+
file=f,
19+
model="openai/whisper-large-v3",
20+
language="en",
21+
response_format="text",
22+
temperature=0.0)
23+
print("transcription result:", transcription)

requirements-common.txt

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,11 @@ py-cpuinfo
88
transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
99
tokenizers >= 0.19.1 # Required for Llama 3.
1010
protobuf # Required by LlamaTokenizer.
11-
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
12-
fastapi >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9'
11+
fastapi[standard] >= 0.107.0, < 0.113.0; python_version < '3.9'
12+
fastapi[standard] >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9'
1313
aiohttp
1414
openai >= 1.52.0 # Ensure modern openai package (ensure types module present and max_completion_tokens field support)
15-
uvicorn[standard]
16-
pydantic >= 2.9 # Required for fastapi >= 0.113.0
15+
pydantic >= 2.9
1716
prometheus_client >= 0.18.0
1817
pillow # Required for image processing
1918
prometheus-fastapi-instrumentator >= 7.0.0

requirements-test.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pqdm
1919
ray[adag]==2.40.0
2020
sentence-transformers # required for embedding tests
2121
soundfile # required for audio tests
22+
jiwer # required for audio tests
2223
timm # required for internvl test
2324
torch==2.5.1
2425
torchaudio==2.5.1

requirements-test.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ charset-normalizer==3.4.0
6666
click==8.1.7
6767
# via
6868
# black
69+
# jiwer
6970
# nltk
7071
# ray
7172
colorama==0.4.6
@@ -187,6 +188,8 @@ jinja2==3.1.4
187188
# via
188189
# datamodel-code-generator
189190
# torch
191+
jiwer==3.0.5
192+
# via -r requirements-test.in
190193
jmespath==1.0.1
191194
# via
192195
# boto3
@@ -470,6 +473,8 @@ pyyaml==6.0.2
470473
# timm
471474
# transformers
472475
# vocos
476+
rapidfuzz==3.12.1
477+
# via jiwer
473478
ray[adag]==2.40.0
474479
# via -r requirements-test.in
475480
redis==5.2.0

tests/entrypoints/openai/correctness/__init__.py

Whitespace-only changes.

tests/entrypoints/openai/test_accuracy.py renamed to tests/entrypoints/openai/correctness/test_lmeval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from vllm.platforms import current_platform
1515

16-
from ...utils import RemoteOpenAIServer
16+
from ....utils import RemoteOpenAIServer
1717

1818
MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct"
1919
NUM_CONCURRENT = 500
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""
3+
Evaluate Transcription API correctness by computing Word Error Rate (WER)
4+
on a given ASR dataset. When provided, it will also compare the WER against
5+
a baseline.
6+
This simulates real work usage of the API and makes sure that the frontend and
7+
AsyncLLMEngine are working correctly.
8+
"""
9+
import asyncio
10+
import io
11+
import time
12+
from statistics import mean, median
13+
from typing import List
14+
15+
import librosa
16+
import pytest
17+
import soundfile
18+
import torch
19+
from datasets import load_dataset
20+
from evaluate import load
21+
from transformers import AutoTokenizer
22+
23+
from ....utils import RemoteOpenAIServer
24+
25+
26+
def to_bytes(y, sr):
27+
buffer = io.BytesIO()
28+
soundfile.write(buffer, y, sr, format="WAV")
29+
buffer.seek(0)
30+
return buffer
31+
32+
33+
async def transcribe_audio(client, tokenizer, y, sr):
34+
# Send loaded audio directly instead of loading from disk,
35+
# dont account for that time though
36+
with to_bytes(y, sr) as f:
37+
start_time = time.perf_counter()
38+
transcription = await client.audio.transcriptions.create(
39+
file=f,
40+
model=tokenizer.name_or_path,
41+
language="en",
42+
temperature=0.0,
43+
)
44+
end_time = time.perf_counter()
45+
# NOTE there's no streaming in transcriptions, can't measure ttft
46+
latency = end_time - start_time
47+
num_output_tokens = len(
48+
tokenizer(transcription.text, add_special_tokens=False).input_ids)
49+
return latency, num_output_tokens, transcription.text
50+
51+
52+
async def bound_transcribe(model_name, sem, client, audio, reference):
53+
tokenizer = AutoTokenizer.from_pretrained(model_name)
54+
# Use semaphore to limit concurrent requests.
55+
async with sem:
56+
result = await transcribe_audio(client, tokenizer, *audio)
57+
# Normalize *english* output/reference for evaluation.
58+
out = tokenizer.normalize(result[2])
59+
ref = tokenizer.normalize(reference)
60+
return result[:2] + (out, ref)
61+
62+
63+
async def process_dataset(model, client, data, concurrent_request):
64+
sem = asyncio.Semaphore(concurrent_request)
65+
66+
# Warmup call as the first `librosa.load` server-side is quite slow.
67+
audio, sr = data[0]["audio"]["array"], data[0]["audio"]["sampling_rate"]
68+
_ = await bound_transcribe(model, sem, client, (audio, sr), "")
69+
70+
tasks: List[asyncio.Task] = []
71+
for sample in data:
72+
audio, sr = sample["audio"]["array"], sample["audio"]["sampling_rate"]
73+
task = asyncio.create_task(
74+
bound_transcribe(model, sem, client, (audio, sr), sample["text"]))
75+
tasks.append(task)
76+
return await asyncio.gather(*tasks)
77+
78+
79+
def print_performance_metrics(results, total_time):
80+
latencies = [res[0] for res in results]
81+
total_tokens = sum([res[1] for res in results])
82+
83+
total = len(results)
84+
print(f"Total Requests: {total}")
85+
print(f"Successful Requests: {len(latencies)}")
86+
print(f"Average Latency: {mean(latencies):.4f} seconds")
87+
print(f"Median Latency: {median(latencies):.4f} seconds")
88+
perc = sorted(latencies)[int(len(latencies) * 0.95) - 1]
89+
print(f"95th Percentile Latency: {perc:.4f} seconds")
90+
# Throughput
91+
req_throughput = len(latencies) / total_time
92+
print(f"Estimated req_Throughput: {req_throughput:.2f} requests/s")
93+
throughput = total_tokens / total_time
94+
print(f"Estimated Throughput: {throughput:.2f} tok/s")
95+
96+
97+
def add_duration(sample):
98+
y, sr = sample['audio']["array"], sample['audio']["sampling_rate"]
99+
sample['duration_ms'] = librosa.get_duration(y=y, sr=sr) * 1000
100+
return sample
101+
102+
103+
def load_hf_dataset(dataset_repo: str, split='validation', **hf_kwargs):
104+
## Load and filter the dataset
105+
dataset = load_dataset(dataset_repo, split=split, **hf_kwargs)
106+
if 'duration_ms' not in dataset[0]:
107+
# compute duration to filter
108+
dataset = dataset.map(add_duration)
109+
110+
# Whisper max supported duration
111+
dataset = dataset.filter(lambda example: example['duration_ms'] < 30000)
112+
return dataset
113+
114+
115+
def run_evaluation(model: str,
116+
client,
117+
dataset,
118+
max_concurrent_reqs: int,
119+
n_examples: int = -1,
120+
print_metrics: bool = True):
121+
if n_examples > 0:
122+
dataset = dataset.select(range(n_examples))
123+
start = time.perf_counter()
124+
results = asyncio.run(
125+
process_dataset(model, client, dataset, max_concurrent_reqs))
126+
end = time.perf_counter()
127+
total_time = end - start
128+
print(f"Total Test Time: {total_time:.4f} seconds")
129+
if print_metrics:
130+
print_performance_metrics(results, total_time)
131+
# Compute WER
132+
predictions = [res[2] for res in results]
133+
references = [res[3] for res in results]
134+
wer = load("wer")
135+
wer_score = 100 * wer.compute(references=references,
136+
predictions=predictions)
137+
print("WER:", wer_score)
138+
return wer_score
139+
140+
141+
# alternatives "openai/whisper-large-v2", "openai/whisper-large-v3-turbo"..
142+
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3"])
143+
# Original dataset is 20GB+ in size, hence we use a pre-filtered slice.
144+
@pytest.mark.parametrize(
145+
"dataset_repo", ["D4nt3/esb-datasets-earnings22-validation-tiny-filtered"])
146+
# NOTE: Expected WER measured with equivalent hf.transformers args:
147+
# whisper-large-v3 + esb-datasets-earnings22-validation-tiny-filtered.
148+
@pytest.mark.parametrize("expected_wer", [12.744980])
149+
def test_wer_correctness(model_name,
150+
dataset_repo,
151+
expected_wer,
152+
n_examples=-1,
153+
max_concurrent_request=None):
154+
with RemoteOpenAIServer(model_name, ['--enforce-eager']) as remote_server:
155+
dataset = load_hf_dataset(dataset_repo)
156+
157+
if not max_concurrent_request:
158+
# No max concurrency
159+
max_concurrent_request = n_examples if n_examples > 0\
160+
else len(dataset)
161+
162+
client = remote_server.get_async_client()
163+
wer = run_evaluation(model_name, client, dataset,
164+
max_concurrent_request, n_examples)
165+
if expected_wer:
166+
torch.testing.assert_close(wer, expected_wer, atol=1e-1, rtol=1e-2)

0 commit comments

Comments
 (0)