Skip to content

Commit

Permalink
Merge pull request #415 from 2noise/optimzie
Browse files Browse the repository at this point in the history
optimize: all
  • Loading branch information
fumiama authored Jun 24, 2024
2 parents 2eb97d2 + 51c2118 commit 46b007e
Show file tree
Hide file tree
Showing 16 changed files with 411 additions and 268 deletions.
183 changes: 100 additions & 83 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import logging
import tempfile
from functools import partial
from typing import Literal, Optional
from typing import Literal, Optional, List, Callable

import numpy as np
import torch
from omegaconf import OmegaConf
from vocos import Vocos
Expand All @@ -16,8 +17,8 @@
from .utils.infer import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map, HomophonesReplacer
from .utils.io import get_latest_modified_file, del_all
from .infer.api import refine_text, infer_code
from .utils.download import check_all_assets, download_all_assets
from .utils.log import set_utils_logger
from .utils.dl import check_all_assets, download_all_assets
from .utils.log import logger as utils_logger


class Chat:
Expand All @@ -26,45 +27,45 @@ def __init__(self, logger=logging.getLogger(__name__)):
self.normalizer = {}
self.homophones_replacer = None
self.logger = logger
set_utils_logger(logger)
utils_logger.set_logger(logger)

def check_model(self, level = logging.INFO, use_decoder = False):
def has_loaded(self, use_decoder = False):
not_finish = False
check_list = ['vocos', 'gpt', 'tokenizer']
check_list = ['gpt', 'tokenizer']

if use_decoder:
check_list.append('decoder')
else:
check_list.append('dvae')

for module in check_list:
if module not in self.pretrain_models:
self.logger.log(logging.WARNING, f'{module} not initialized.')
self.logger.warn(f'{module} not initialized.')
not_finish = True


if not hasattr(self, "_vocos_decode") or not hasattr(self, "vocos"):
self.logger.warn('vocos not initialized.')
not_finish = True

if not not_finish:
self.logger.log(level, f'All initialized.')
self.logger.info('all models has been initialized.')

return not not_finish

def load_models(
def download_models(
self,
source: Literal['huggingface', 'local', 'custom']='local',
force_redownload=False,
compile: bool = True,
custom_path: Optional[torch.serialization.FILE_LIKE]=None,
device: Optional[torch.device] = None,
coef: Optional[torch.Tensor] = None,
):
) -> Optional[str]:
if source == 'local':
torch.load
download_path = os.getcwd()
if not check_all_assets(update=True) or force_redownload:
with tempfile.TemporaryDirectory() as tmp:
download_all_assets(tmpdir=tmp)
if not check_all_assets(update=False):
self.logger.error("counld not satisfy all assets needed.")
return False
self.logger.error("download to local path %s failed.", download_path)
return None
elif source == 'huggingface':
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
try:
Expand All @@ -73,18 +74,38 @@ def load_models(
download_path = None
if download_path is None or force_redownload:
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS')
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
try:
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"])
except:
download_path = None
else:
self.logger.log(logging.INFO, f'Load from cache: {download_path}')
self.logger.log(logging.INFO, f'load latest snapshot from cache: {download_path}')
if download_path is None:
self.logger.error("download from huggingface failed.")
return None
elif source == 'custom':
self.logger.log(logging.INFO, f'Load from local: {custom_path}')
self.logger.log(logging.INFO, f'try to load from local: {custom_path}')
download_path = custom_path

return download_path

def load_models(
self,
source: Literal['huggingface', 'local', 'custom']='local',
force_redownload=False,
compile: bool = True,
custom_path: Optional[torch.serialization.FILE_LIKE]=None,
device: Optional[torch.device] = None,
coef: Optional[torch.Tensor] = None,
) -> bool:
download_path = self.download_models(source, force_redownload, custom_path)
if download_path is None:
return False
return self._load(
device=device, compile=compile, coef=coef,
**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()},
)

def _load(
self,
vocos_config_path: str = None,
Expand Down Expand Up @@ -112,9 +133,17 @@ def _load(
).eval()
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None'
vocos.load_state_dict(torch.load(vocos_ckpt_path))
self.pretrain_models['vocos'] = vocos
self.vocos = vocos
if "mps" in str(self.device):
self._vocos_decode: Callable[[torch.Tensor], np.ndarray] = lambda spec: self.vocos.decode(
spec.cpu()
).cpu().numpy()
else:
self._vocos_decode: Callable[[torch.Tensor], np.ndarray] = lambda spec: self.vocos.decode(
spec
).cpu().numpy()
self.logger.log(logging.INFO, 'vocos loaded.')

if dvae_config_path:
cfg = OmegaConf.load(dvae_config_path)
dvae = DVAE(**cfg, coef=coef).to(device).eval()
Expand Down Expand Up @@ -157,8 +186,13 @@ def _load(

self.coef = coef

return self.check_model()
return self.has_loaded()

def unload(self):
logger = self.logger
del_all(self)
self.__init__(logger)

def _infer(
self,
text,
Expand All @@ -173,23 +207,23 @@ def _infer(
do_homophone_replacement=True
):

assert self.check_model(use_decoder=use_decoder)
assert self.has_loaded(use_decoder=use_decoder)

if not isinstance(text, list):
text = [text]
if do_text_normalization:
for i, t in enumerate(text):
_lang = detect_language(t) if lang is None else lang
if self.init_normalizer(_lang):
if self._init_normalizer(_lang):
text[i] = self.normalizer[_lang](t)
if _lang == 'zh':
text[i] = apply_half2full_map(text[i])
for i, t in enumerate(text):
invalid_characters = count_invalid_characters(t)
if len(invalid_characters):
self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}')
self.logger.warn(f'Invalid characters found! : {invalid_characters}')
text[i] = apply_character_map(t)
if do_homophone_replacement and self.init_homophones_replacer():
if do_homophone_replacement and self._init_homophones_replacer():
text[i], replaced_words = self.homophones_replacer.replace(text[i])
if replaced_words:
repl_res = ', '.join([f'{_[0]}->{_[1]}' for _ in replaced_words])
Expand All @@ -205,64 +239,25 @@ def _infer(
text_tokens = refined.ids
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
del_all(refined)
refined.destroy()
if refine_text_only:
yield text
return

text = [params_infer_code.get('prompt', '') + i for i in text]
params_infer_code.pop('prompt', '')
result_gen = infer_code(

length = [0 for _ in range(len(text))]
for result in infer_code(
self.pretrain_models,
text,
device=self.device,
**params_infer_code,
return_hidden=use_decoder,
stream=stream,
)
if use_decoder:
docoder_name = 'decoder'
else:
docoder_name = 'dvae'
if "mps" in str(self.device):
vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
i.cpu()
).cpu().numpy() for i in spec]
else:
vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
i
).cpu().numpy() for i in spec]
if stream:

length = 0
for result in result_gen:
x = result.hiddens if use_decoder else result.ids
assert len(x) == 1
chunk_data = x[0]
start_seek = length
length = len(chunk_data)
self.logger.debug(f'{start_seek=} total len: {length}, new len: {length - start_seek = }')
chunk_data = chunk_data[start_seek:]
if not len(chunk_data):
continue
self.logger.debug(f'new hidden {len(chunk_data)=}')
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1).to(self.device)) for i in [chunk_data]]
del_all(result)
del chunk_data
del_all(x)
wav = vocos_decode(mel_spec)
del_all(mel_spec)
self.logger.debug(f'yield wav chunk {len(wav[0])=} {len(wav[0][0])=}')
yield wav
return
result = next(result_gen)
x = result.hiddens if use_decoder else result.ids
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1).to(self.device)) for i in x]
del_all(result)
del_all(x)
wav = vocos_decode(mel_spec)
del_all(mel_spec)
yield wav
):
wav = self.decode_to_wavs(result, length, use_decoder)
yield wav

def infer(
self,
Expand Down Expand Up @@ -294,13 +289,35 @@ def infer(
else:
return next(res_gen)

def sample_random_speaker(self, ):

def sample_random_speaker(self):
dim = self.pretrain_models['gpt'].gpt.layers[0].mlp.gate_proj.in_features
std, mean = self.pretrain_models['spk_stat'].chunk(2)
return torch.randn(dim, device=std.device) * std + mean

def init_normalizer(self, lang) -> bool:

def decode_to_wavs(self, result: GPT.GenerationOutputs, start_seeks: List[int], use_decoder: bool):
x = result.hiddens if use_decoder else result.ids
wavs: List[np.ndarray] = []
for i, chunk_data in enumerate(x):
start_seek = start_seeks[i]
length = len(chunk_data)
if length <= start_seek:
wavs.append(None)
continue
start_seeks[i] = length
chunk_data = chunk_data[start_seek:]
if use_decoder:
decoder = self.pretrain_models['decoder']
else:
decoder = self.pretrain_models['dvae']
mel_spec = decoder(chunk_data[None].permute(0,2,1).to(self.device))
del chunk_data
wavs.append(self._vocos_decode(mel_spec))
del_all(mel_spec)
result.destroy()
del_all(x)
return wavs

def _init_normalizer(self, lang) -> bool:

if lang in self.normalizer:
return True
Expand Down Expand Up @@ -335,16 +352,16 @@ def init_normalizer(self, lang) -> bool:
)
return False

def init_homophones_replacer(self):
def _init_homophones_replacer(self):
if self.homophones_replacer:
return True
else:
try:
self.homophones_replacer = HomophonesReplacer(os.path.join(os.path.dirname(__file__), 'res', 'homophones_map.json'))
self.logger.log(logging.INFO, 'homophones_replacer loaded.')
self.logger.log(logging.INFO, 'successfully loaded HomophonesReplacer.')
return True
except (IOError, json.JSONDecodeError) as e:
self.logger.log(logging.WARNING, f'Error loading homophones map: {e}')
self.logger.log(logging.WARNING, f'error loading homophones map: {e}')
except Exception as e:
self.logger.log(logging.WARNING, f'Error loading homophones_replacer: {e}')
self.logger.log(logging.WARNING, f'error loading HomophonesReplacer: {e}')
return False
27 changes: 14 additions & 13 deletions ChatTTS/model/dvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,20 +164,21 @@ def __repr__(self) -> str:
return b14.encode_to_string(self.coef.cpu().numpy().astype(np.float32).tobytes())

def forward(self, inp: torch.Tensor) -> torch.Tensor:
with torch.no_grad():

if self.vq_layer is not None:
vq_feats = self.vq_layer._embed(inp)
else:
vq_feats = inp.detach().clone()
if self.vq_layer is not None:
vq_feats = self.vq_layer._embed(inp)
else:
vq_feats = inp.detach().clone()

vq_feats = vq_feats.view(
(vq_feats.size(0), 2, vq_feats.size(1)//2, vq_feats.size(2)),
).permute(0, 2, 3, 1).flatten(2)
vq_feats = vq_feats.view(
(vq_feats.size(0), 2, vq_feats.size(1)//2, vq_feats.size(2)),
).permute(0, 2, 3, 1).flatten(2)

dec_out = self.out_conv(
self.decoder(
input=vq_feats.transpose_(1, 2),
).transpose_(1, 2),
)
dec_out = self.out_conv(
self.decoder(
input=vq_feats.transpose_(1, 2),
).transpose_(1, 2),
)

return torch.mul(dec_out, self.coef, out=dec_out)
return torch.mul(dec_out, self.coef, out=dec_out)
Loading

0 comments on commit 46b007e

Please sign in to comment.