Skip to content

Commit

Permalink
feat(core): add load_models source local (#361)
Browse files Browse the repository at this point in the history
rename original `local` to `custom`
  • Loading branch information
fumiama authored Jun 19, 2024
1 parent ce1c962 commit a63e9c2
Show file tree
Hide file tree
Showing 9 changed files with 353 additions and 13 deletions.
45 changes: 45 additions & 0 deletions .github/workflows/checksum.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
name: Calculate and Sync SHA256
on:
push:
branches:
- main
- dev
jobs:
checksum:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@master

- name: Setup Go Environment
uses: actions/setup-go@master

- name: Run RVC-Models-Downloader
run: |
wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.5/rvcmd_linux_amd64.deb
sudo apt -y install ./rvcmd_linux_amd64.deb
rm -f ./rvcmd_linux_amd64.deb
rvcmd -notrs -w 1 -notui assets/chtts
- name: Calculate all Checksums
run: go run tools/checksum/*.go

- name: Commit back
if: ${{ !github.head_ref }}
id: commitback
continue-on-error: true
run: |
git config --local user.name 'github-actions[bot]'
git config --local user.email 'github-actions[bot]@users.noreply.github.com'
git add --all
git commit -m "chore(env): sync checksum on ${{github.ref_name}}"
- name: Create Pull Request
if: steps.commitback.outcome == 'success'
continue-on-error: true
uses: peter-evans/create-pull-request@v5
with:
delete-branch: true
body: "Automatically sync checksum in .env"
title: "chore(env): sync checksum on ${{github.ref_name}}"
commit-message: "chore(env): sync checksum on ${{github.ref_name}}"
branch: checksum-${{github.ref_name}}
35 changes: 26 additions & 9 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,21 @@
import json
import logging
from functools import partial
from omegaconf import OmegaConf
from typing import Literal
import tempfile

import torch
from omegaconf import OmegaConf
from vocos import Vocos
from huggingface_hub import snapshot_download

from .model.dvae import DVAE
from .model.gpt import GPT_warpper
from .utils.gpu_utils import select_device
from .utils.infer_utils import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map, HomophonesReplacer
from .utils.io_utils import get_latest_modified_file
from .infer.api import refine_text, infer_code

from huggingface_hub import snapshot_download
from .utils.download import check_all_assets, download_all_assets

logging.basicConfig(level = logging.INFO)

Expand Down Expand Up @@ -44,9 +47,23 @@ def check_model(self, level = logging.INFO, use_decoder = False):
self.logger.log(level, f'All initialized.')

return not not_finish

def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>', **kwargs):
if source == 'huggingface':

def load_models(
self,
source: Literal['huggingface', 'local', 'custom']='local',
force_redownload=False,
custom_path='<LOCAL_PATH>',
**kwargs,
):
if source == 'local':
download_path = os.getcwd()
if not check_all_assets(update=True):
with tempfile.TemporaryDirectory() as tmp:
download_all_assets(tmpdir=tmp)
if not check_all_assets(update=False):
logging.error("counld not satisfy all assets needed.")
exit(1)
elif source == 'huggingface':
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
try:
download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots'))
Expand All @@ -57,9 +74,9 @@ def load_models(self, source='huggingface', force_redownload=False, local_path='
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
else:
self.logger.log(logging.INFO, f'Load from cache: {download_path}')
elif source == 'local':
self.logger.log(logging.INFO, f'Load from local: {local_path}')
download_path = local_path
elif source == 'custom':
self.logger.log(logging.INFO, f'Load from local: {custom_path}')
download_path = custom_path

self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)

Expand Down
191 changes: 191 additions & 0 deletions ChatTTS/utils/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import os
from pathlib import Path
import hashlib
import requests
from io import BytesIO
import logging

logger = logging.getLogger(__name__)


def sha256(f) -> str:
sha256_hash = hashlib.sha256()
# Read and update hash in chunks of 4M
for byte_block in iter(lambda: f.read(4 * 1024 * 1024), b""):
sha256_hash.update(byte_block)
return sha256_hash.hexdigest()


def check_model(
dir_name: Path, model_name: str, hash: str, remove_incorrect=False
) -> bool:
target = dir_name / model_name
relname = target.as_posix()
logger.debug(f"checking {relname}...")
if not os.path.exists(target):
logger.info(f"{target} not exist.")
return False
with open(target, "rb") as f:
digest = sha256(f)
bakfile = f"{target}.bak"
if digest != hash:
logger.warn(f"{target} sha256 hash mismatch.")
logger.info(f"expected: {hash}")
logger.info(f"real val: {digest}")
logger.warn("please add parameter --update to download the latest assets.")
if remove_incorrect:
if not os.path.exists(bakfile):
os.rename(str(target), bakfile)
else:
os.remove(str(target))
return False
if remove_incorrect and os.path.exists(bakfile):
os.remove(bakfile)
return True


def check_all_assets(update=False) -> bool:
BASE_DIR = Path(__file__).resolve().parent.parent.parent

logger.info("checking assets...")
current_dir = BASE_DIR / "asset"
names = [
"Decoder.pt",
"DVAE.pt",
"GPT.pt",
"spk_stat.pt",
"tokenizer.pt",
"Vocos.pt",
]
for model in names:
menv = model.replace(".", "_")
if not check_model(
current_dir, model, os.environ[f"sha256_asset_{menv}"], update
):
return False

logger.info("checking configs...")
current_dir = BASE_DIR / "config"
names = [
"decoder.yaml",
"dvae.yaml",
"gpt.yaml",
"path.yaml",
"vocos.yaml",
]
for model in names:
menv = model.replace(".", "_")
if not check_model(
current_dir, model, os.environ[f"sha256_config_{menv}"], update
):
return False

logger.info("all assets are already latest.")
return True


def download_and_extract_tar_gz(url: str, folder: str):
import tarfile

logger.info(f"downloading {url}")
response = requests.get(url, stream=True, timeout=(5, 10))
with BytesIO() as out_file:
out_file.write(response.content)
out_file.seek(0)
logger.info(f"downloaded.")
with tarfile.open(fileobj=out_file, mode="r:gz") as tar:
tar.extractall(folder)
logger.info(f"extracted into {folder}")


def download_and_extract_zip(url: str, folder: str):
import zipfile

logger.info(f"downloading {url}")
response = requests.get(url, stream=True, timeout=(5, 10))
with BytesIO() as out_file:
out_file.write(response.content)
out_file.seek(0)
logger.info(f"downloaded.")
with zipfile.ZipFile(out_file) as zip_ref:
zip_ref.extractall(folder)
logger.info(f"extracted into {folder}")


def download_dns_yaml(url: str, folder: str):
logger.info(f"downloading {url}")
response = requests.get(url, stream=True, timeout=(5, 10))
with open(os.path.join(folder, "dns.yaml"), "wb") as out_file:
out_file.write(response.content)
logger.info(f"downloaded into {folder}")


def download_all_assets(tmpdir: str, version="0.2.5"):
import subprocess
import platform

archs = {
"aarch64": "arm64",
"armv8l": "arm64",
"arm64": "arm64",
"x86": "386",
"i386": "386",
"i686": "386",
"386": "386",
"x86_64": "amd64",
"x64": "amd64",
"amd64": "amd64",
}
system_type = platform.system().lower()
architecture = platform.machine().lower()
is_win = system_type == "windows"

architecture = archs.get(architecture, None)
if not architecture:
logger.error(f"architecture {architecture} is not supported")
exit(1)
try:
BASE_URL = "https://github.com/fumiama/RVC-Models-Downloader/releases/download/"
suffix = "zip" if is_win else "tar.gz"
RVCMD_URL = BASE_URL + f"v{version}/rvcmd_{system_type}_{architecture}.{suffix}"
cmdfile = os.path.join(tmpdir, "rvcmd")
if is_win:
download_and_extract_zip(RVCMD_URL, tmpdir)
cmdfile += ".exe"
else:
download_and_extract_tar_gz(RVCMD_URL, tmpdir)
os.chmod(cmdfile, 0o755)
subprocess.run([cmdfile, "-notui", "-w", "0", "assets/chtts"])
except Exception:
BASE_URL = "https://raw.gitcode.com/u011570312/RVC-Models-Downloader/assets/"
suffix = {
"darwin_amd64": "555",
"darwin_arm64": "556",
"linux_386": "557",
"linux_amd64": "558",
"linux_arm64": "559",
"windows_386": "562",
"windows_amd64": "563",
}[f"{system_type}_{architecture}"]
RVCMD_URL = BASE_URL + suffix
download_dns_yaml(
"https://raw.gitcode.com/u011570312/RVC-Models-Downloader/raw/main/dns.yaml",
tmpdir,
)
if is_win:
download_and_extract_zip(RVCMD_URL, tmpdir)
cmdfile += ".exe"
else:
download_and_extract_tar_gz(RVCMD_URL, tmpdir)
os.chmod(cmdfile, 0o755)
subprocess.run(
[
cmdfile,
"-notui",
"-w",
"0",
"-dns",
os.path.join(tmpdir, "dns.yaml"),
"assets/chtts",
]
)
3 changes: 3 additions & 0 deletions examples/cmd/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
now_dir = os.getcwd()
sys.path.append(now_dir)

from dotenv import load_dotenv
load_dotenv("sha256.env")

import wave
import ChatTTS
from IPython.display import Audio
Expand Down
11 changes: 7 additions & 4 deletions examples/web/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import gradio as gr
import numpy as np

from dotenv import load_dotenv
load_dotenv("sha256.env")

import ChatTTS

# 音色选项:用于预置合适的音色
Expand Down Expand Up @@ -132,18 +135,18 @@ def main():
parser.add_argument('--server_name', type=str, default='0.0.0.0', help='Server name')
parser.add_argument('--server_port', type=int, default=8080, help='Server port')
parser.add_argument('--root_path', type=str, default=None, help='Root Path')
parser.add_argument('--local_path', type=str, default=None, help='the local_path if need')
parser.add_argument('--custom_path', type=str, default=None, help='the custom model path')
args = parser.parse_args()

print("loading ChatTTS model...")
global chat
chat = ChatTTS.Chat()

if args.local_path == None:
if args.custom_path == None:
chat.load_models()
else:
print('local model path:', args.local_path)
chat.load_models('local', local_path=args.local_path)
print('local model path:', args.custom_path)
chat.load_models('custom', custom_path=args.custom_path)

demo.launch(server_name=args.server_name, server_port=args.server_port, root_path=args.root_path, inbrowser=True)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ transformers~=4.41.1
vocos
IPython
gradio
python-dotenv
pynini==2.1.5
WeTextProcessing
nemo_text_processing
12 changes: 12 additions & 0 deletions sha256.env
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
sha256_asset_Decoder_pt = 9964e36e840f0e3a748c5f716fe6de6490d2135a5f5155f4a642d51860e2ec38
sha256_asset_DVAE_pt = 613cb128adf89188c93ea5880ea0b798e66b1fe6186d0c535d99bcd87bfd6976
sha256_asset_GPT_pt = d7d4ee6461ea097a2be23eb40d73fb94ad3b3d39cb64fbb50cb3357fd466cadb
sha256_asset_spk_stat_pt = 3228d8a4cbbf349d107a1b76d2f47820865bd3c9928c4bdfe1cefd5c7071105f
sha256_asset_tokenizer_pt = e911ae7c6a7c27953433f35c44227a67838fe229a1f428503bdb6cd3d1bcc69c
sha256_asset_Vocos_pt = 09a670eda1c08b740013679c7a90ebb7f1a97646ea7673069a6838e6b51d6c58

sha256_config_decoder_yaml = 0890ab719716b0ad8abcb9eba0a9bf52c59c2e45ddedbbbb5ed514ff87bff369
sha256_config_dvae_yaml = 1b3a5aa0c6a314f766d4432ab36f84e882e29561648d837f71c04c7bea494fc6
sha256_config_gpt_yaml = 0c3c7277b674094bdd00b63b18b18aa3156502101dbd03c7f802e0fcf26cff51
sha256_config_path_yaml = 79829705c2d2a29b3f55e3b3f228bb81875e4e265211595fb50a73eb6434684b
sha256_config_vocos_yaml = 1ca837ce790dd8b55bdd5a16c6af8f813926b9c9b48f2a4da305e7e9ff0c9b0c
Loading

0 comments on commit a63e9c2

Please sign in to comment.