Skip to content

Commit

Permalink
Merge branch 'dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jan 7, 2025
2 parents b367dda + ff77e25 commit 18e5ec5
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 47 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/checksum.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:

- name: Run RVC-Models-Downloader
run: |
wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.9/rvcmd_linux_amd64.deb
wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.10/rvcmd_linux_amd64.deb
sudo apt -y install ./rvcmd_linux_amd64.deb
rm -f ./rvcmd_linux_amd64.deb
rvcmd -notrs -w 1 -notui assets/chtts
Expand Down
38 changes: 26 additions & 12 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,13 @@ def download_models(
custom_path: Optional[torch.serialization.FILE_LIKE] = None,
) -> Optional[str]:
if source == "local":
download_path = os.getcwd()
download_path = custom_path if custom_path is not None else os.getcwd()
if (
not check_all_assets(Path(download_path), self.sha256_map, update=True)
or force_redownload
):
with tempfile.TemporaryDirectory() as tmp:
download_all_assets(tmpdir=tmp)
download_all_assets(tmpdir=tmp, homedir=download_path)
if not check_all_assets(
Path(download_path), self.sha256_map, update=False
):
Expand All @@ -83,10 +83,20 @@ def download_models(
)
return None
elif source == "huggingface":
hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
try:
download_path = get_latest_modified_file(
os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots")
download_path = (
get_latest_modified_file(
os.path.join(
os.getenv(
"HF_HOME", os.path.expanduser("~/.cache/huggingface")
),
"hub/models--2Noise--ChatTTS/snapshots",
)
)
if custom_path is None
else get_latest_modified_file(
os.path.join(custom_path, "models--2Noise--ChatTTS/snapshots")
)
)
except:
download_path = None
Expand All @@ -99,23 +109,27 @@ def download_models(
download_path = snapshot_download(
repo_id="2Noise/ChatTTS",
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
cache_dir=custom_path,
force_download=force_redownload,
)
except:
download_path = None
else:
self.logger.log(
logging.INFO, f"load latest snapshot from cache: {download_path}"
)
if download_path is None:
self.logger.error("download from huggingface failed.")
return None
else:
self.logger.log(
logging.INFO,
f"load latest snapshot from cache: {download_path}",
)
elif source == "custom":
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
self.logger.error("check models in custom path %s failed.", custom_path)
return None
download_path = custom_path

if download_path is None:
self.logger.error("Model download failed")
return None

return download_path

def load(
Expand Down
6 changes: 4 additions & 2 deletions ChatTTS/utils/dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def download_dns_yaml(url: str, folder: str, headers: Dict[str, str]):
logger.get_logger().info(f"downloaded into {folder}")


def download_all_assets(tmpdir: str, version="0.2.9"):
def download_all_assets(tmpdir: str, homedir: str, version="0.2.10"):
import subprocess
import platform

Expand Down Expand Up @@ -186,7 +186,7 @@ def download_all_assets(tmpdir: str, version="0.2.9"):
else:
download_and_extract_tar_gz(RVCMD_URL, tmpdir)
os.chmod(cmdfile, 0o755)
subprocess.run([cmdfile, "-notui", "-w", "0", "assets/chtts"])
subprocess.run([cmdfile, "-notui", "-w", "0", "-H", homedir, "assets/chtts"])
except Exception:
BASE_URL = (
"https://gitea.seku.su/fumiama/RVC-Models-Downloader/releases/download/"
Expand Down Expand Up @@ -215,6 +215,8 @@ def download_all_assets(tmpdir: str, version="0.2.9"):
"0",
"-dns",
os.path.join(tmpdir, "dns.yaml"),
"-H",
homedir,
"assets/chtts",
]
)
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ pip install --upgrade -r requirements.txt

#### 2. Install from conda
```bash
conda create -n chattts
conda create -n chattts python=3.11
conda activate chattts
pip install -r requirements.txt
```
Expand Down
18 changes: 15 additions & 3 deletions examples/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@


from pydantic import BaseModel

from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from tools.normalizer.en import normalizer_en_nemo_text
from tools.normalizer.zh import normalizer_zh_tn

logger = get_logger("Command")

Expand All @@ -35,14 +38,23 @@ async def startup_event():
global chat

chat = ChatTTS.Chat(get_logger("ChatTTS"))
chat.normalizer.register("en", normalizer_en_nemo_text())
chat.normalizer.register("zh", normalizer_zh_tn())

logger.info("Initializing ChatTTS...")
if chat.load():
if chat.load(source="huggingface"):
logger.info("Models loaded successfully.")
else:
logger.error("Models load failed.")
sys.exit(1)


@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc: RequestValidationError):
logger.error(f"Validation error: {exc.errors()}")
return JSONResponse(status_code=422, content={"detail": exc.errors()})


class ChatTTSParams(BaseModel):
text: list[str]
stream: bool = False
Expand All @@ -52,7 +64,7 @@ class ChatTTSParams(BaseModel):
use_decoder: bool = True
do_text_normalization: bool = True
do_homophone_replacement: bool = False
params_refine_text: ChatTTS.Chat.RefineTextParams
params_refine_text: ChatTTS.Chat.RefineTextParams = None
params_infer_code: ChatTTS.Chat.InferCodeParams


Expand Down
104 changes: 76 additions & 28 deletions tools/audio/av.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from io import BufferedWriter, BytesIO
from pathlib import Path
from typing import Dict
from typing import Dict, Tuple, Optional, Union, List

import av
from av.audio.frame import AudioFrame
from av.audio.resampler import AudioResampler
import numpy as np

Expand Down Expand Up @@ -39,41 +40,88 @@ def wav2(i: BytesIO, o: BufferedWriter, format: str):
inp.close()


def load_audio(file: str, sr: int) -> np.ndarray:
def load_audio(
file: Union[str, BytesIO, Path],
sr: Optional[int] = None,
format: Optional[str] = None,
mono=True,
) -> Union[np.ndarray, Tuple[np.ndarray, int]]:
"""
https://github.com/fumiama/Retrieval-based-Voice-Conversion-WebUI/blob/412a9950a1e371a018c381d1bfb8579c4b0de329/infer/lib/audio.py#L39
"""

if not Path(file).exists():
if (isinstance(file, str) and not Path(file).exists()) or (
isinstance(file, Path) and not file.exists()
):
raise FileNotFoundError(f"File not found: {file}")
rate = 0

container = av.open(file, format=format)
audio_stream = next(s for s in container.streams if s.type == "audio")
channels = 1 if audio_stream.layout == "mono" else 2
container.seek(0)
resampler = (
AudioResampler(format="fltp", layout=audio_stream.layout, rate=sr)
if sr is not None
else None
)

# Estimated maximum total number of samples to pre-allocate the array
# AV stores length in microseconds by default
estimated_total_samples = (
int(container.duration * sr // 1_000_000) if sr is not None else 48000
)
decoded_audio = np.zeros(
(
estimated_total_samples + 1
if channels == 1
else (channels, estimated_total_samples + 1)
),
dtype=np.float32,
)

offset = 0

def process_packet(packet: List[AudioFrame]):
frames_data = []
rate = 0
for frame in packet:
# frame.pts = None # 清除时间戳,避免重新采样问题
resampled_frames = (
resampler.resample(frame) if resampler is not None else [frame]
)
for resampled_frame in resampled_frames:
frame_data = resampled_frame.to_ndarray()
rate = resampled_frame.rate
frames_data.append(frame_data)
return (rate, frames_data)

try:
container = av.open(file)
resampler = AudioResampler(format="fltp", layout="mono", rate=sr)
def frame_iter(container):
for p in container.demux(container.streams.audio[0]):
yield p.decode()

# Estimated maximum total number of samples to pre-allocate the array
# AV stores length in microseconds by default
estimated_total_samples = int(container.duration * sr // 1_000_000)
decoded_audio = np.zeros(estimated_total_samples + 1, dtype=np.float32)
for r, frames_data in map(process_packet, frame_iter(container)):
if not rate:
rate = r
for frame_data in frames_data:
end_index = offset + len(frame_data[0])

offset = 0
for frame in container.decode(audio=0):
frame.pts = None # Clear presentation timestamp to avoid resampling issues
resampled_frames = resampler.resample(frame)
for resampled_frame in resampled_frames:
frame_data = resampled_frame.to_ndarray()[0]
end_index = offset + len(frame_data)
# 检查 decoded_audio 是否有足够的空间,并在必要时调整大小
if end_index > decoded_audio.shape[1]:
decoded_audio = np.resize(
decoded_audio, (decoded_audio.shape[0], end_index * 4)
)

np.copyto(decoded_audio[..., offset:end_index], frame_data)
offset += len(frame_data[0])

# Check if decoded_audio has enough space, and resize if necessary
if end_index > decoded_audio.shape[0]:
decoded_audio = np.resize(decoded_audio, end_index + 1)
container.close()

decoded_audio[offset:end_index] = frame_data
offset += len(frame_data)
# Truncate the array to the actual size
decoded_audio = decoded_audio[..., :offset]

# Truncate the array to the actual size
decoded_audio = decoded_audio[:offset]
except Exception as e:
raise RuntimeError(f"Failed to load audio: {e}")
if mono and decoded_audio.shape[0] > 1:
decoded_audio = decoded_audio.mean(0)

return decoded_audio
if sr is not None:
return decoded_audio
return decoded_audio, rate

0 comments on commit 18e5ec5

Please sign in to comment.