Skip to content

Commit 6523a17

Browse files
authored
POST API for infer
1 parent bebbfbb commit 6523a17

File tree

1 file changed

+212
-0
lines changed

1 file changed

+212
-0
lines changed

src/f5_tts/infer_api.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import argparse
2+
import gc
3+
import torch
4+
import torchaudio
5+
import traceback
6+
from importlib.resources import files
7+
from fastapi import FastAPI, HTTPException, Response, Query
8+
from pydantic import BaseModel
9+
import base64
10+
import io
11+
from cached_path import cached_path
12+
from fastapi.responses import StreamingResponse
13+
14+
from infer.utils_infer import infer_batch_process, preprocess_ref_audio_text, load_vocoder, load_model
15+
from model.backbones.dit import DiT
16+
17+
# Initialize FastAPI App
18+
app = FastAPI()
19+
20+
class TTSStreamingProcessor:
21+
def __init__(self, ckpt_file, vocab_file, ref_audio, ref_text, device=None, dtype=torch.float32):
22+
self.device = device or (
23+
"cuda"
24+
if torch.cuda.is_available()
25+
else "xpu"
26+
if torch.xpu.is_available()
27+
else "mps"
28+
if torch.backends.mps.is_available()
29+
else "cpu"
30+
)
31+
32+
# Load the model using the provided checkpoint and vocab files
33+
self.model = load_model(
34+
model_cls=DiT,
35+
model_cfg=dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4),
36+
ckpt_path=ckpt_file,
37+
mel_spec_type="vocos", # or "bigvgan" depending on vocoder
38+
vocab_file=vocab_file,
39+
ode_method="euler",
40+
use_ema=True,
41+
device=self.device,
42+
).to(self.device, dtype=dtype)
43+
44+
# Load the vocoder
45+
self.vocoder = load_vocoder(is_local=False)
46+
47+
# Set sampling rate for streaming
48+
self.sampling_rate = 24000 # Consistency with client
49+
50+
# Set reference audio and text
51+
self.ref_audio = ref_audio
52+
self.ref_text = ref_text
53+
54+
# Warm up the model
55+
self._warm_up()
56+
57+
def _warm_up(self):
58+
"""Warm up the model with a dummy input to ensure it's ready for real-time processing."""
59+
print("Warming up the model...")
60+
ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
61+
audio, sr = torchaudio.load(ref_audio)
62+
gen_text = "Warm-up text for the model."
63+
64+
# Pass the vocoder as an argument here
65+
infer_batch_process((audio, sr), ref_text, [gen_text], self.model, self.vocoder, device=self.device)
66+
print("Warm-up completed.")
67+
68+
def generate_audio(self, text):
69+
"""Generate audio for the given text and return it as a WAV file."""
70+
# Preprocess the reference audio and text
71+
ref_audio, ref_text = preprocess_ref_audio_text(self.ref_audio, self.ref_text)
72+
73+
# Load reference audio
74+
audio, sr = torchaudio.load(ref_audio)
75+
76+
# Run inference for the input text
77+
audio_chunk, final_sample_rate, _ = infer_batch_process(
78+
(audio, sr),
79+
ref_text,
80+
[text],
81+
self.model,
82+
self.vocoder,
83+
device=self.device, # Pass vocoder here
84+
)
85+
86+
# Convert audio array to bytes (WAV format)
87+
audio_buffer = io.BytesIO()
88+
torchaudio.save(audio_buffer, torch.tensor(audio_chunk).unsqueeze(0), final_sample_rate, format="wav")
89+
audio_buffer.seek(0)
90+
91+
92+
return audio_buffer
93+
94+
95+
# Define input data model for API requests
96+
class TTSRequest(BaseModel):
97+
text: str
98+
response_type: str = Query("json", description="Response format: json, file, stream")
99+
100+
101+
# Initialize processor globally
102+
processor = None
103+
104+
@app.on_event("startup")
105+
def load_model_on_startup():
106+
"""Load the model when the server starts"""
107+
global processor
108+
args = parser.parse_args()
109+
110+
try:
111+
processor = TTSStreamingProcessor(
112+
ckpt_file=args.ckpt_file,
113+
vocab_file=args.vocab_file,
114+
ref_audio=args.ref_audio,
115+
ref_text=args.ref_text,
116+
device=args.device,
117+
dtype=args.dtype,
118+
)
119+
except Exception as e:
120+
print(f"Error loading model: {e}")
121+
traceback.print_exc()
122+
processor = None
123+
124+
125+
@app.post("/tts/")
126+
async def text_to_speech(request: TTSRequest):
127+
"""
128+
Converts text to speech and returns the audio in different formats.
129+
"""
130+
try:
131+
if processor is None:
132+
raise HTTPException(status_code=500, detail="TTS Processor not initialized")
133+
134+
# Generate audio buffer
135+
audio_buffer = processor.generate_audio(request.text)
136+
chunk_size = 1024 # Stream in 1024-byte chunks
137+
138+
# (A) JSON-encoded Base64 (default)
139+
if request.response_type == "json":
140+
audio_base64 = base64.b64encode(audio_buffer.read()).decode("utf-8")
141+
return {"audio_base64": audio_base64, "message": "TTS generated successfully"}
142+
143+
# (B) Return WAV File (File Download Mode)
144+
elif request.response_type == "file":
145+
audio_buffer.seek(0)
146+
return Response(content=audio_buffer.read(), media_type="audio/wav",
147+
headers={"Content-Disposition": "attachment; filename=output.wav"})
148+
149+
# (C) Stream Audio in Small Chunks (Real-Time Playback)
150+
elif request.response_type == "stream":
151+
def audio_stream():
152+
audio_buffer.seek(0)
153+
154+
# **Ensure WAV header is sent first**
155+
wav_header = audio_buffer.read(44) # First 44 bytes = WAV header
156+
yield wav_header
157+
158+
# **Stream the rest of the audio in chunks**
159+
while True:
160+
chunk = audio_buffer.read(chunk_size)
161+
if not chunk:
162+
print("End of audio stream")
163+
break # Stop when all audio is sent
164+
print(f"Streaming chunk of size {len(chunk)} bytes")
165+
yield chunk
166+
167+
return StreamingResponse(audio_stream(), media_type="audio/wav")
168+
169+
else:
170+
raise HTTPException(status_code=400, detail="Invalid response_type. Choose 'json', 'file', or 'stream'.")
171+
172+
except Exception as e:
173+
print(f"Server Error: {e}") # Log error on server side
174+
raise HTTPException(status_code=500, detail=f"Error generating speech: {str(e)}")
175+
176+
177+
if __name__ == "__main__":
178+
import uvicorn
179+
180+
parser = argparse.ArgumentParser()
181+
parser.add_argument("--host", default="0.0.0.0")
182+
parser.add_argument("--port", type=int, default=8000)
183+
184+
parser.add_argument(
185+
"--ckpt_file",
186+
default=str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")),
187+
help="Path to the model checkpoint file",
188+
)
189+
parser.add_argument(
190+
"--vocab_file",
191+
default="",
192+
help="Path to the vocab file if customized",
193+
)
194+
195+
parser.add_argument(
196+
"--ref_audio",
197+
default=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")),
198+
help="Reference audio to provide model with speaker characteristics",
199+
)
200+
parser.add_argument(
201+
"--ref_text",
202+
default="",
203+
help="Reference audio subtitle, leave empty to auto-transcribe",
204+
)
205+
206+
parser.add_argument("--device", default=None, help="Device to run the model on")
207+
parser.add_argument("--dtype", default=torch.float32, help="Data type to use for model inference")
208+
209+
args = parser.parse_args()
210+
211+
# Start FastAPI server
212+
uvicorn.run(app, host=args.host, port=args.port)

0 commit comments

Comments
 (0)