Skip to content

Commit 8a0e837

Browse files
committed
init upload
0 parents  commit 8a0e837

File tree

12 files changed

+1327
-0
lines changed

12 files changed

+1327
-0
lines changed

.gitignore

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
*.ckpt
6+
# C extensions
7+
*.so
8+
*.pt
9+
10+
# Distribution / packaging
11+
.Python
12+
outputs/
13+
build/
14+
develop-eggs/
15+
dist/
16+
downloads/
17+
eggs/
18+
.eggs/
19+
lib/
20+
lib64/
21+
parts/
22+
sdist/
23+
var/
24+
wheels/
25+
share/python-wheels/
26+
*.egg-info/
27+
asset/*
28+
.installed.cfg
29+
*.egg
30+
MANIFEST
31+
32+
# PyInstaller
33+
# Usually these files are written by a python script from a template
34+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
35+
*.manifest
36+
*.spec
37+
38+
# Installer logs
39+
pip-log.txt
40+
pip-delete-this-directory.txt
41+
42+
# Unit test / coverage reports
43+
htmlcov/
44+
.tox/
45+
.nox/
46+
.coverage
47+
.coverage.*
48+
.cache
49+
nosetests.xml
50+
coverage.xml
51+
*.cover
52+
*.py,cover
53+
.hypothesis/
54+
.pytest_cache/
55+
cover/
56+
57+
# Translations
58+
*.mo
59+
*.pot
60+
61+
# Django stuff:
62+
*.log
63+
local_settings.py
64+
db.sqlite3
65+
db.sqlite3-journal
66+
67+
# Flask stuff:
68+
instance/
69+
.webassets-cache
70+
71+
# Scrapy stuff:
72+
.scrapy
73+
74+
# Sphinx documentation
75+
docs/_build/
76+
77+
# PyBuilder
78+
.pybuilder/
79+
target/
80+
81+
# Jupyter Notebook
82+
.ipynb_checkpoints
83+
84+
# IPython
85+
profile_default/
86+
ipython_config.py
87+
88+
# pyenv
89+
# For a library or package, you might want to ignore these files since the code is
90+
# intended to run in multiple environments; otherwise, check them in:
91+
# .python-version
92+
93+
# pipenv
94+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
96+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
97+
# install all needed dependencies.
98+
#Pipfile.lock
99+
100+
# poetry
101+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102+
# This is especially recommended for binary packages to ensure reproducibility, and is more
103+
# commonly ignored for libraries.
104+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105+
#poetry.lock
106+
107+
# pdm
108+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109+
#pdm.lock
110+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111+
# in version control.
112+
# https://pdm.fming.dev/#use-with-ide
113+
.pdm.toml
114+
115+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116+
__pypackages__/
117+
118+
# Celery stuff
119+
celerybeat-schedule
120+
celerybeat.pid
121+
122+
# SageMath parsed files
123+
*.sage.py
124+
125+
# Environments
126+
.env
127+
.venv
128+
env/
129+
venv/
130+
ENV/
131+
env.bak/
132+
venv.bak/
133+
134+
# Spyder project settings
135+
.spyderproject
136+
.spyproject
137+
138+
# Rope project settings
139+
.ropeproject
140+
141+
# mkdocs documentation
142+
/site
143+
144+
# mypy
145+
.mypy_cache/
146+
.dmypy.json
147+
dmypy.json
148+
149+
# Pyre type checker
150+
.pyre/
151+
152+
# pytype static type analyzer
153+
.pytype/
154+
155+
# Cython debug symbols
156+
cython_debug/
157+
158+
# PyCharm
159+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161+
# and can be added to the global gitignore or merged into this file. For a more nuclear
162+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
163+
#.idea/

ChatTTS/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .core import Chat

ChatTTS/core.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
2+
import os
3+
import logging
4+
from omegaconf import OmegaConf
5+
6+
import torch
7+
from vocos import Vocos
8+
from .model.dvae import DVAE
9+
from .model.gpt import GPT_warpper
10+
from .utils.gpu_utils import select_device
11+
from .infer.api import refine_text, infer_code
12+
13+
from huggingface_hub import snapshot_download
14+
15+
logging.basicConfig(level = logging.INFO)
16+
17+
18+
class Chat:
19+
def __init__(self, ):
20+
self.pretrain_models = {}
21+
self.logger = logging.getLogger(__name__)
22+
23+
def check_model(self, level = logging.INFO, use_decoder = False):
24+
not_finish = False
25+
check_list = ['vocos', 'gpt', 'tokenizer']
26+
27+
if use_decoder:
28+
check_list.append('decoder')
29+
else:
30+
check_list.append('dvae')
31+
32+
for module in check_list:
33+
if module not in self.pretrain_models:
34+
self.logger.log(logging.WARNING, f'{module} not initialized.')
35+
not_finish = True
36+
37+
if not not_finish:
38+
self.logger.log(level, f'All initialized.')
39+
40+
return not not_finish
41+
42+
def load_models(self, source='huggingface'):
43+
if source == 'huggingface':
44+
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
45+
self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()})
46+
47+
def _load(
48+
self,
49+
vocos_config_path: str = None,
50+
vocos_ckpt_path: str = None,
51+
dvae_config_path: str = None,
52+
dvae_ckpt_path: str = None,
53+
gpt_config_path: str = None,
54+
gpt_ckpt_path: str = None,
55+
decoder_config_path: str = None,
56+
decoder_ckpt_path: str = None,
57+
tokenizer_path: str = None,
58+
device: str = None
59+
):
60+
if not device:
61+
device = select_device(4096)
62+
self.logger.log(logging.INFO, f'use {device}')
63+
64+
if vocos_config_path:
65+
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
66+
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
67+
vocos.load_state_dict(torch.load(vocos_ckpt_path))
68+
self.pretrain_models['vocos'] = vocos
69+
self.logger.log(logging.INFO, 'vocos loaded.')
70+
71+
if dvae_config_path:
72+
cfg = OmegaConf.load(dvae_config_path)
73+
dvae = DVAE(**cfg).to(device).eval()
74+
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
75+
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu'))
76+
self.pretrain_models['dvae'] = dvae
77+
self.logger.log(logging.INFO, 'dvae loaded.')
78+
79+
if gpt_config_path:
80+
cfg = OmegaConf.load(gpt_config_path)
81+
gpt = GPT_warpper(**cfg).to(device).eval()
82+
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
83+
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu'))
84+
self.pretrain_models['gpt'] = gpt
85+
self.logger.log(logging.INFO, 'gpt loaded.')
86+
87+
if decoder_config_path:
88+
cfg = OmegaConf.load(decoder_config_path)
89+
decoder = DVAE(**cfg).to(device).eval()
90+
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
91+
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
92+
self.pretrain_models['decoder'] = decoder
93+
self.logger.log(logging.INFO, 'decoder loaded.')
94+
95+
if tokenizer_path:
96+
tokenizer = torch.load(tokenizer_path, map_location='cpu')
97+
tokenizer.padding_side = 'left'
98+
self.pretrain_models['tokenizer'] = tokenizer
99+
self.logger.log(logging.INFO, 'tokenizer loaded.')
100+
101+
self.check_model()
102+
103+
def infer(self, text, skip_refine_text=False, params_refine_text={}, params_infer_code={}, use_decoder=False):
104+
assert self.check_model(use_decoder=use_decoder)
105+
if not skip_refine_text:
106+
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
107+
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
108+
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
109+
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder)
110+
if use_decoder:
111+
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']]
112+
else:
113+
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']]
114+
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec]
115+
return wav
116+
117+
118+

ChatTTS/experimental/llm.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
2+
from openai import OpenAI
3+
4+
prompt_dict = {
5+
'kimi': [ {"role": "system", "content": "你是 Kimi,由 Moonshot AI 提供的人工智能助手,你更擅长中文和英文的对话。"},
6+
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
7+
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
8+
'deepseek': [
9+
{"role": "system", "content": "You are a helpful assistant"},
10+
{"role": "user", "content": "你好,请注意你现在生成的文字要按照人日常生活的口吻,你的回复将会后续用TTS模型转为语音,并且请把回答控制在100字以内。并且标点符号仅包含逗号和句号,将数字等转为文字回答。"},
11+
{"role": "assistant", "content": "好的,我现在生成的文字将按照人日常生活的口吻, 并且我会把回答控制在一百字以内, 标点符号仅包含逗号和句号,将阿拉伯数字等转为中文文字回答。下面请开始对话。"},],
12+
'deepseek_TN': [
13+
{"role": "system", "content": "You are a helpful assistant"},
14+
{"role": "user", "content": "你好,现在我们在处理TTS的文本输入,下面将会给你输入一段文本,请你将其中的阿拉伯数字等等转为文字表达,并且输出的文本里仅包含逗号和句号这两个标点符号"},
15+
{"role": "assistant", "content": "好的,我现在对TTS的文本输入进行处理。这一般叫做text normalization。下面请输入"},
16+
{"role": "user", "content": "We paid $123 for this desk."},
17+
{"role": "assistant", "content": "We paid one hundred and twenty three dollars for this desk."},
18+
{"role": "user", "content": "详询请拨打010-724654"},
19+
{"role": "assistant", "content": "详询请拨打零幺零,七二四六五四"},
20+
{"role": "user", "content": "罗森宣布将于7月24日退市,在华门店超6000家!"},
21+
{"role": "assistant", "content": "罗森宣布将于七月二十四日退市,在华门店超过六千家。"},
22+
],
23+
}
24+
25+
class llm_api:
26+
def __init__(self, api_key, base_url, model):
27+
self.client = OpenAI(
28+
api_key = api_key,
29+
base_url = base_url,
30+
)
31+
self.model = model
32+
def call(self, user_question, temperature = 0.3, prompt_version='kimi', **kwargs):
33+
34+
completion = self.client.chat.completions.create(
35+
model = self.model,
36+
messages = prompt_dict[prompt_version]+[{"role": "user", "content": user_question},],
37+
temperature = temperature,
38+
**kwargs
39+
)
40+
return completion.choices[0].message.content

0 commit comments

Comments
 (0)