Skip to content

Commit

Permalink
Merge branch 'dev' into feature/torchserve-example
Browse files Browse the repository at this point in the history
  • Loading branch information
JaysonAlbert authored Jan 8, 2025
2 parents 9bedaa1 + 25cf2bc commit e4e692b
Show file tree
Hide file tree
Showing 9 changed files with 162 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/checksum.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ jobs:

- name: Run RVC-Models-Downloader
run: |
wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.9/rvcmd_linux_amd64.deb
wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.10/rvcmd_linux_amd64.deb
sudo apt -y install ./rvcmd_linux_amd64.deb
rm -f ./rvcmd_linux_amd64.deb
rvcmd -notrs -w 1 -notui assets/chtts
Expand Down
166 changes: 124 additions & 42 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import logging
import tempfile
from dataclasses import dataclass, asdict
Expand Down Expand Up @@ -68,13 +69,13 @@ def download_models(
custom_path: Optional[torch.serialization.FILE_LIKE] = None,
) -> Optional[str]:
if source == "local":
download_path = os.getcwd()
download_path = custom_path if custom_path is not None else os.getcwd()
if (
not check_all_assets(Path(download_path), self.sha256_map, update=True)
or force_redownload
):
with tempfile.TemporaryDirectory() as tmp:
download_all_assets(tmpdir=tmp)
download_all_assets(tmpdir=tmp, homedir=download_path)
if not check_all_assets(
Path(download_path), self.sha256_map, update=False
):
Expand All @@ -83,10 +84,20 @@ def download_models(
)
return None
elif source == "huggingface":
hf_home = os.getenv("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
try:
download_path = get_latest_modified_file(
os.path.join(hf_home, "hub/models--2Noise--ChatTTS/snapshots")
download_path = (
get_latest_modified_file(
os.path.join(
os.getenv(
"HF_HOME", os.path.expanduser("~/.cache/huggingface")
),
"hub/models--2Noise--ChatTTS/snapshots",
)
)
if custom_path is None
else get_latest_modified_file(
os.path.join(custom_path, "models--2Noise--ChatTTS/snapshots")
)
)
except:
download_path = None
Expand All @@ -99,23 +110,27 @@ def download_models(
download_path = snapshot_download(
repo_id="2Noise/ChatTTS",
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
cache_dir=custom_path,
force_download=force_redownload,
)
except:
download_path = None
else:
self.logger.log(
logging.INFO, f"load latest snapshot from cache: {download_path}"
)
if download_path is None:
self.logger.error("download from huggingface failed.")
return None
else:
self.logger.log(
logging.INFO,
f"load latest snapshot from cache: {download_path}",
)
elif source == "custom":
self.logger.log(logging.INFO, f"try to load from local: {custom_path}")
if not check_all_assets(Path(custom_path), self.sha256_map, update=False):
self.logger.error("check models in custom path %s failed.", custom_path)
return None
download_path = custom_path

if download_path is None:
self.logger.error("Model download failed")
return None

return download_path

def load(
Expand Down Expand Up @@ -199,10 +214,29 @@ def infer(
use_decoder=True,
do_text_normalization=True,
do_homophone_replacement=True,
split_text=True,
max_split_batch=4,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
):
self.context.set(False)

if split_text and isinstance(text, str):
if "\n" in text:
text = text.split("\n")
else:
text = re.split(r"(?<=[。(.\s)])", text)
nt = []
for t in text:
if t:
nt.append(t)
text = nt
self.logger.info("split text into %d parts", len(text))
self.logger.debug("%s", str(text))

if len(text) == 0:
return []

res_gen = self._infer(
text,
stream,
Expand All @@ -212,11 +246,21 @@ def infer(
use_decoder,
do_text_normalization,
do_homophone_replacement,
split_text,
max_split_batch,
params_refine_text,
params_infer_code,
)
if stream:
return res_gen
elif not refine_text_only:
stripped_wavs = []
for wavs in res_gen:
for wav in wavs:
stripped_wavs.append(wav[np.abs(wav) > 1e-5])
if split_text:
return [np.concatenate(stripped_wavs)]
return stripped_wavs
else:
return next(res_gen)

Expand Down Expand Up @@ -336,14 +380,16 @@ def _load(

def _infer(
self,
text,
text: Union[List[str], str],
stream=False,
lang=None,
skip_refine_text=False,
refine_text_only=False,
use_decoder=True,
do_text_normalization=True,
do_homophone_replacement=True,
split_text=True,
max_split_batch=4,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
):
Expand Down Expand Up @@ -376,44 +422,80 @@ def _infer(
text = self.tokenizer.decode(text_tokens)
refined.destroy()
if refine_text_only:
if split_text and isinstance(text, list):
text = "\n".join(text)
yield text
return

if stream:
length = 0
pass_batch_count = 0
for result in self._infer_code(
text,
stream,
self.device,
use_decoder,
params_infer_code,
):
if split_text and len(text) > 1 and params_infer_code.spk_smp is None:
refer_text = text[0]
result = next(
self._infer_code(
refer_text,
False,
self.device,
use_decoder,
params_infer_code,
)
)
wavs = self._decode_to_wavs(
result.hiddens if use_decoder else result.ids,
use_decoder,
)
result.destroy()
if stream:
pass_batch_count += 1
if pass_batch_count <= params_infer_code.pass_first_n_batches:
continue
a = length
b = a + params_infer_code.stream_speed
if b > wavs.shape[1]:
b = wavs.shape[1]
new_wavs = wavs[:, a:b]
length = b
yield new_wavs
else:
yield wavs
assert len(wavs), 1
params_infer_code.spk_smp = self.sample_audio_speaker(wavs[0])
params_infer_code.txt_smp = refer_text

if stream:
new_wavs = wavs[:, length:]
# Identify rows with non-zero elements using np.any
# keep_rows = np.any(array != 0, axis=1)
keep_cols = np.sum(new_wavs != 0, axis=0) > 0
# Filter both rows and columns using slicing
yield new_wavs[:][:, keep_cols]
length = 0
pass_batch_count = 0
if split_text:
n = len(text) // max_split_batch
if len(text) % max_split_batch:
n += 1
else:
n = 1
max_split_batch = len(text)
for i in range(n):
text_remain = text[i * max_split_batch :]
if len(text_remain) > max_split_batch:
text_remain = text_remain[:max_split_batch]
if split_text:
self.logger.info(
"infer split %d~%d",
i * max_split_batch,
i * max_split_batch + len(text_remain),
)
for result in self._infer_code(
text_remain,
stream,
self.device,
use_decoder,
params_infer_code,
):
wavs = self._decode_to_wavs(
result.hiddens if use_decoder else result.ids,
use_decoder,
)
result.destroy()
if stream:
pass_batch_count += 1
if pass_batch_count <= params_infer_code.pass_first_n_batches:
continue
a = length
b = a + params_infer_code.stream_speed
if b > wavs.shape[1]:
b = wavs.shape[1]
new_wavs = wavs[:, a:b]
length = b
yield new_wavs
else:
yield wavs
if stream:
new_wavs = wavs[:, length:]
keep_cols = np.sum(np.abs(new_wavs) > 1e-5, axis=0) > 0
yield new_wavs[:][:, keep_cols]

@torch.inference_mode()
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
Expand Down
6 changes: 4 additions & 2 deletions ChatTTS/utils/dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def download_dns_yaml(url: str, folder: str, headers: Dict[str, str]):
logger.get_logger().info(f"downloaded into {folder}")


def download_all_assets(tmpdir: str, version="0.2.9"):
def download_all_assets(tmpdir: str, homedir: str, version="0.2.10"):
import subprocess
import platform

Expand Down Expand Up @@ -186,7 +186,7 @@ def download_all_assets(tmpdir: str, version="0.2.9"):
else:
download_and_extract_tar_gz(RVCMD_URL, tmpdir)
os.chmod(cmdfile, 0o755)
subprocess.run([cmdfile, "-notui", "-w", "0", "assets/chtts"])
subprocess.run([cmdfile, "-notui", "-w", "0", "-H", homedir, "assets/chtts"])
except Exception:
BASE_URL = (
"https://gitea.seku.su/fumiama/RVC-Models-Downloader/releases/download/"
Expand Down Expand Up @@ -215,6 +215,8 @@ def download_all_assets(tmpdir: str, version="0.2.9"):
"0",
"-dns",
os.path.join(tmpdir, "dns.yaml"),
"-H",
homedir,
"assets/chtts",
]
)
18 changes: 15 additions & 3 deletions examples/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@


from pydantic import BaseModel

from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from tools.normalizer.en import normalizer_en_nemo_text
from tools.normalizer.zh import normalizer_zh_tn

logger = get_logger("Command")

Expand All @@ -35,14 +38,23 @@ async def startup_event():
global chat

chat = ChatTTS.Chat(get_logger("ChatTTS"))
chat.normalizer.register("en", normalizer_en_nemo_text())
chat.normalizer.register("zh", normalizer_zh_tn())

logger.info("Initializing ChatTTS...")
if chat.load():
if chat.load(source="huggingface"):
logger.info("Models loaded successfully.")
else:
logger.error("Models load failed.")
sys.exit(1)


@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc: RequestValidationError):
logger.error(f"Validation error: {exc.errors()}")
return JSONResponse(status_code=422, content={"detail": exc.errors()})


class ChatTTSParams(BaseModel):
text: list[str]
stream: bool = False
Expand All @@ -52,7 +64,7 @@ class ChatTTSParams(BaseModel):
use_decoder: bool = True
do_text_normalization: bool = True
do_homophone_replacement: bool = False
params_refine_text: ChatTTS.Chat.RefineTextParams
params_refine_text: ChatTTS.Chat.RefineTextParams = None
params_infer_code: ChatTTS.Chat.InferCodeParams


Expand Down
5 changes: 5 additions & 0 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def refine_text(
temperature,
top_P,
top_K,
split_batch,
):
global chat

Expand All @@ -150,6 +151,7 @@ def refine_text(
top_K=top_K,
manual_seed=text_seed_input,
),
split_text=split_batch > 0,
)

return text[0] if isinstance(text, list) else text
Expand All @@ -165,6 +167,7 @@ def generate_audio(
audio_seed_input,
sample_text_input,
sample_audio_code_input,
split_batch,
):
global chat, has_interrupted

Expand All @@ -189,6 +192,8 @@ def generate_audio(
skip_refine_text=True,
params_infer_code=params_infer_code,
stream=stream,
split_text=split_batch > 0,
max_split_batch=split_batch,
)
if stream:
for gen in wav:
Expand Down
Loading

0 comments on commit e4e692b

Please sign in to comment.