Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions chatterbox-tts/config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
model_name: Chatterbox TTS
base_image:
image: jojobaseten/truss-numpy-1.26.0-gpu:0.4
python_executable_path: /usr/bin/python3
python_version: py312
requirements:
- chatterbox-tts
python_version: py311
requirements: []
build_commands:
- "pip install --upgrade pip setuptools wheel"
- "pip install numpy"
- "pip install chatterbox-tts"
resources:
accelerator: H100
accelerator: A100
cpu: '1'
memory: 40Gi
use_gpu: true
Expand Down
19 changes: 19 additions & 0 deletions f5-tts/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
model_name: F5-TTS
python_version: py310
build_commands:
- "pip install f5-tts"
requirements:
- torch>=2.0.0
- torchaudio
- numpy
- scipy
- huggingface_hub
- soundfile
system_packages:
- ffmpeg
resources:
accelerator: A100
cpu: '4'
memory: 24Gi
use_gpu: true
secrets: {}
Empty file added f5-tts/model/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions f5-tts/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import base64
import io
import logging
import tempfile
from pathlib import Path
from contextlib import contextmanager
from typing import Dict, Optional

import torch
import torchaudio
import soundfile as sf

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


@contextmanager
def temp_audio_file(audio_b64: str, suffix: str = ".wav"):
"""Creates a temporary audio file from a base64 encoded string."""
temp_file_path = None
try:
audio_bytes = base64.b64decode(audio_b64)
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file:
temp_file.write(audio_bytes)
temp_file_path = Path(temp_file.name)
yield temp_file_path
finally:
if temp_file_path:
try:
temp_file_path.unlink(missing_ok=True)
except Exception as e:
logger.warning(f"Failed to delete temporary file {temp_file_path}: {e}")


class Model:
"""F5-TTS Text-to-Speech model with zero-shot voice cloning.

F5-TTS uses Flow Matching for high-quality voice cloning from a short
audio sample. It can generate natural-sounding speech that matches
the voice characteristics of the reference audio.
"""

def __init__(self, **kwargs):
self._model = None
self._device = "cuda" if torch.cuda.is_available() else "cpu"

def load(self) -> None:
"""Loads the F5-TTS model."""
from f5_tts.api import F5TTS

logger.info(f"Device: {self._device}")
logger.info("Loading F5-TTS model...")

self._model = F5TTS(device=self._device)

logger.info("F5-TTS model loaded successfully")

def predict(self, model_input: Dict[str, str]) -> Dict[str, str]:
"""Generates speech from text with voice cloning.

Args:
model_input: Dictionary containing:
- text: The text to convert to speech
- voice: Base64 encoded audio for voice cloning
- ref_text: Optional transcript of the reference audio (improves quality)

Returns:
Dict containing the generated audio as a base64 encoded string
"""
text = model_input["text"]
voice_b64 = model_input.get("voice")
ref_text = model_input.get("ref_text", "")

if not voice_b64:
raise ValueError("voice (base64 encoded audio) is required for voice cloning")

with temp_audio_file(voice_b64) as reference_audio_path:
logger.info("Generating speech with F5-TTS...")

result = self._model.infer(
ref_file=str(reference_audio_path),
ref_text=ref_text,
gen_text=text,
)
# F5-TTS returns (audio, sample_rate, additional_info)
audio = result[0]
sample_rate = 24000 # F5-TTS uses fixed 24kHz

buffer = io.BytesIO()
sf.write(buffer, audio, sample_rate, format="wav")
buffer.seek(0)
wav_base64 = base64.b64encode(buffer.read()).decode("utf-8")

return {"audio": wav_base64}
19 changes: 19 additions & 0 deletions fish-speech-tts/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
model_name: Fish Speech TTS
python_version: py312
build_commands: []
requirements:
- torch>=2.0.0
- torchaudio
- numpy
- fish-speech
system_packages:
- portaudio19-dev
- libsox-dev
- ffmpeg
resources:
accelerator: A100
cpu: '4'
memory: 24Gi
use_gpu: true
secrets:
hf_access_token: null
Empty file.
210 changes: 210 additions & 0 deletions fish-speech-tts/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import base64
import io
import logging
import tempfile
from pathlib import Path
from contextlib import contextmanager
from typing import Dict, Optional

import numpy as np
import torch
import torchaudio

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


@contextmanager
def temp_audio_file(audio_b64: str, suffix: str = ".wav"):
"""Creates a temporary audio file from a base64 encoded string."""
temp_file_path = None
try:
audio_bytes = base64.b64decode(audio_b64)
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file:
temp_file.write(audio_bytes)
temp_file_path = Path(temp_file.name)
yield temp_file_path
finally:
if temp_file_path:
try:
temp_file_path.unlink(missing_ok=True)
except Exception as e:
logger.warning(f"Failed to delete temporary file {temp_file_path}: {e}")


class Model:
"""Text-to-speech model wrapper for Fish Speech (OpenAudio).

This class provides an interface for generating speech from text,
optionally using voice cloning with a reference audio prompt.
"""

def __init__(self, **kwargs):
self._llm_model = None
self._decode_one_token = None
self._codec_model = None
self._device = "cuda" if torch.cuda.is_available() else "cpu"
self._precision = torch.bfloat16
self._checkpoint_path = Path("/app/checkpoints/openaudio-s1-mini")

def load(self) -> None:
"""Loads the Fish Speech model."""
import traceback
import os

try:
logger.info(f"Python version: {os.sys.version}")
logger.info(f"PyTorch version: {torch.__version__}")

from huggingface_hub import snapshot_download
logger.info("Successfully imported huggingface_hub")

logger.info(f"Device: {self._device}")
logger.info(f"Checkpoint path: {self._checkpoint_path}")
logger.info(f"Checkpoint path exists: {self._checkpoint_path.exists()}")

logger.info("Downloading Fish Speech model weights...")
if not self._checkpoint_path.exists():
logger.info("Creating checkpoint directory")
os.makedirs(self._checkpoint_path, exist_ok=True)

try:
snapshot_download(
repo_id="fishaudio/openaudio-s1-mini",
local_dir=str(self._checkpoint_path),
)
logger.info("Download complete")
except Exception as e:
logger.error(f"Error downloading model weights: {e}")
logger.error(traceback.format_exc())

try:
logger.info("Importing fish_speech modules...")
import fish_speech
logger.info(f"Fish Speech version: {getattr(fish_speech, '__version__', 'unknown')}")

from fish_speech.models.text2semantic.inference import init_model
from fish_speech.models.dac.inference import load_model as load_codec
logger.info("Successfully imported fish_speech modules")

logger.info("Loading Fish Speech LLM model...")
self._llm_model, self._decode_one_token = init_model(
checkpoint_path=str(self._checkpoint_path),
device=self._device,
precision=self._precision,
compile=False,
)
logger.info("LLM model loaded")

with torch.device(self._device):
logger.info(f"Setting up caches with max_seq_len={self._llm_model.config.max_seq_len}")
self._llm_model.setup_caches(
max_batch_size=1,
max_seq_len=self._llm_model.config.max_seq_len,
dtype=self._precision,
)
logger.info("Caches set up")

logger.info("Loading Fish Speech codec model...")
self._codec_model = load_codec(
config_name="modded_dac_vq",
checkpoint_path=str(self._checkpoint_path / "codec.pth"),
device=self._device,
)
logger.info("Codec model loaded")

logger.info("Fish Speech models loaded successfully")
except Exception as e:
logger.error(f"Error loading models: {e}")
logger.error(traceback.format_exc())
raise
except Exception as e:
logger.error(f"Fatal error in load(): {e}")
logger.error(traceback.format_exc())
raise

def _encode_reference_audio(self, audio_path: Path) -> torch.Tensor:
"""Encode reference audio to VQ tokens."""
audio, sr = torchaudio.load(str(audio_path))
if audio.shape[0] > 1:
audio = audio.mean(0, keepdim=True)
audio = torchaudio.functional.resample(audio, sr, self._codec_model.sample_rate)

audios = audio[None].to(self._device)
audio_lengths = torch.tensor([audios.shape[2]], device=self._device, dtype=torch.long)

indices, _ = self._codec_model.encode(audios, audio_lengths)
if indices.ndim == 3:
indices = indices[0]

return indices

def _decode_codes_to_audio(self, codes: torch.Tensor) -> torch.Tensor:
"""Decode semantic codes to audio."""
indices_lens = torch.tensor([codes.shape[1]], device=self._device, dtype=torch.long)
fake_audios, _ = self._codec_model.decode(codes, indices_lens)
return fake_audios[0, 0]

def predict(self, model_input: Dict[str, str]) -> Dict[str, str]:
"""Generates speech from text with optional voice cloning.

Args:
model_input: Dictionary containing:
- text: The text to convert to speech
- voice: Optional base64 encoded audio for voice cloning
- voice_text: Optional transcript of the voice audio (improves cloning)

Returns:
Dict containing the generated audio as a base64 encoded string
"""
from fish_speech.models.text2semantic.inference import generate_long

text = model_input["text"]
voice_b64 = model_input.get("voice")
voice_text = model_input.get("voice_text")

prompt_tokens = None
prompt_text = None

if voice_b64:
logger.info("Using voice audio for cloning...")
with temp_audio_file(voice_b64) as voice_path:
prompt_tokens = [self._encode_reference_audio(voice_path)]
prompt_text = [voice_text] if voice_text else [""]

all_codes = []
generator = generate_long(
model=self._llm_model,
device=self._device,
decode_one_token=self._decode_one_token,
text=text,
num_samples=1,
max_new_tokens=0,
top_p=0.8,
repetition_penalty=1.1,
temperature=0.8,
compile=False,
iterative_prompt=True,
chunk_length=300,
prompt_text=prompt_text,
prompt_tokens=prompt_tokens,
)

for response in generator:
if response.action == "sample":
all_codes.append(response.codes)
elif response.action == "next":
break

if not all_codes:
raise ValueError("No audio generated")

codes = torch.cat(all_codes, dim=1)
audio = self._decode_codes_to_audio(codes)

buffer = io.BytesIO()
torchaudio.save(buffer, audio.cpu().unsqueeze(0), self._codec_model.sample_rate, format="wav")
buffer.seek(0)
wav_base64 = base64.b64encode(buffer.read()).decode("utf-8")

return {"audio": wav_base64}
Loading