Skip to content

Commit

Permalink
optimize(chat): move gpt&decoders out of dict (#416)
Browse files Browse the repository at this point in the history
for a clear type definition
  • Loading branch information
fumiama authored Jun 24, 2024
1 parent 109376c commit 21c8ecc
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 58 deletions.
62 changes: 21 additions & 41 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
import json
import logging
import tempfile
from functools import partial
Expand All @@ -23,30 +22,28 @@

class Chat:
def __init__(self, logger=logging.getLogger(__name__)):
self.pretrain_models = {}
self.normalizer = {}
self.homophones_replacer = None
self.logger = logger
utils_logger.set_logger(logger)


self.pretrain_models = {}
self.normalizer = {}
self.homophones_replacer = self.homophones_replacer = HomophonesReplacer(os.path.join(os.path.dirname(__file__), 'res', 'homophones_map.json'))


def has_loaded(self, use_decoder = False):
not_finish = False
check_list = ['gpt', 'tokenizer']
check_list = ["vocos", "_vocos_decode", 'gpt', 'tokenizer']

if use_decoder:
check_list.append('decoder')
else:
check_list.append('dvae')

for module in check_list:
if module not in self.pretrain_models:
if not hasattr(self, module) and module not in self.pretrain_models:
self.logger.warn(f'{module} not initialized.')
not_finish = True

if not hasattr(self, "_vocos_decode") or not hasattr(self, "vocos"):
self.logger.warn('vocos not initialized.')
not_finish = True

if not not_finish:
self.logger.info('all models has been initialized.')

Expand Down Expand Up @@ -149,37 +146,37 @@ def _load(
dvae = DVAE(**cfg, coef=coef).to(device).eval()
coef = str(dvae)
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
dvae.load_state_dict(torch.load(dvae_ckpt_path))
self.pretrain_models['dvae'] = dvae
dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location=device))
self.dvae = dvae
self.logger.log(logging.INFO, 'dvae loaded.')

if gpt_config_path:
cfg = OmegaConf.load(gpt_config_path)
gpt = GPT(**cfg, device=device, logger=self.logger).eval()
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
gpt.load_state_dict(torch.load(gpt_ckpt_path))
gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device))
if compile and 'cuda' in str(device):
try:
gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
except RuntimeError as e:
self.logger.warning(f'Compile failed,{e}. fallback to normal mode.')
self.pretrain_models['gpt'] = gpt
self.gpt = gpt
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
self.pretrain_models['spk_stat'] = torch.load(spk_stat_path).to(device)
self.pretrain_models['spk_stat'] = torch.load(spk_stat_path, map_location=device).to(device)
self.logger.log(logging.INFO, 'gpt loaded.')

if decoder_config_path:
cfg = OmegaConf.load(decoder_config_path)
decoder = DVAE(**cfg, coef=coef).to(device).eval()
coef = str(decoder)
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
self.pretrain_models['decoder'] = decoder
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location=device))
self.decoder = decoder
self.logger.log(logging.INFO, 'decoder loaded.')

if tokenizer_path:
tokenizer = torch.load(tokenizer_path, map_location='cpu')
tokenizer = torch.load(tokenizer_path, map_location=device)
tokenizer.padding_side = 'left'
self.pretrain_models['tokenizer'] = tokenizer
self.logger.log(logging.INFO, 'tokenizer loaded.')
Expand Down Expand Up @@ -223,15 +220,15 @@ def _infer(
if len(invalid_characters):
self.logger.warn(f'Invalid characters found! : {invalid_characters}')
text[i] = apply_character_map(t)
if do_homophone_replacement and self._init_homophones_replacer():
if do_homophone_replacement:
text[i], replaced_words = self.homophones_replacer.replace(text[i])
if replaced_words:
repl_res = ', '.join([f'{_[0]}->{_[1]}' for _ in replaced_words])
self.logger.log(logging.INFO, f'Homophones replace: {repl_res}')

if not skip_refine_text:
refined = refine_text(
self.pretrain_models,
self.gpt, self.pretrain_models['tokenizer'],
text,
device=self.device,
**params_refine_text,
Expand All @@ -249,7 +246,7 @@ def _infer(

length = [0 for _ in range(len(text))]
for result in infer_code(
self.pretrain_models,
self.gpt, self.pretrain_models['tokenizer'],
text,
device=self.device,
**params_infer_code,
Expand Down Expand Up @@ -290,7 +287,7 @@ def infer(
return next(res_gen)

def sample_random_speaker(self):
dim = self.pretrain_models['gpt'].gpt.layers[0].mlp.gate_proj.in_features
dim = self.gpt.gpt.layers[0].mlp.gate_proj.in_features
std, mean = self.pretrain_models['spk_stat'].chunk(2)
return torch.randn(dim, device=std.device) * std + mean

Expand All @@ -305,10 +302,7 @@ def decode_to_wavs(self, result: GPT.GenerationOutputs, start_seeks: List[int],
continue
start_seeks[i] = length
chunk_data = chunk_data[start_seek:]
if use_decoder:
decoder = self.pretrain_models['decoder']
else:
decoder = self.pretrain_models['dvae']
decoder = self.decoder if use_decoder else self.dvae
mel_spec = decoder(chunk_data[None].permute(0,2,1).to(self.device))
del chunk_data
wavs.append(self._vocos_decode(mel_spec))
Expand Down Expand Up @@ -351,17 +345,3 @@ def _init_normalizer(self, lang) -> bool:
'Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing',
)
return False

def _init_homophones_replacer(self):
if self.homophones_replacer:
return True
else:
try:
self.homophones_replacer = HomophonesReplacer(os.path.join(os.path.dirname(__file__), 'res', 'homophones_map.json'))
self.logger.log(logging.INFO, 'successfully loaded HomophonesReplacer.')
return True
except (IOError, json.JSONDecodeError) as e:
self.logger.log(logging.WARNING, f'error loading homophones map: {e}')
except Exception as e:
self.logger.log(logging.WARNING, f'error loading HomophonesReplacer: {e}')
return False
18 changes: 8 additions & 10 deletions ChatTTS/infer/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from ..model.gpt import GPT

def infer_code(
models,
gpt: GPT,
tokenizer,
text,
spk_emb = None,
top_P = 0.7,
Expand All @@ -21,8 +22,6 @@ def infer_code(
**kwargs
):

gpt: GPT = models['gpt']

if not isinstance(text, list):
text = [text]

Expand All @@ -34,7 +33,7 @@ def infer_code(
else:
text = [f'[Stts][empty_spk]{i}[Ptts]' for i in text]

text_token_tmp = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True)
text_token_tmp = tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True)
text_token = text_token_tmp.to(device)
del text_token_tmp
input_ids = text_token['input_ids'][...,None].expand(-1, -1, gpt.num_vq).to(gpt.device_gpt)
Expand All @@ -45,7 +44,7 @@ def infer_code(

if spk_emb is not None:
n = F.normalize(spk_emb.to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12).to(gpt.device_gpt)
emb[input_ids[..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = n
emb[input_ids[..., 0] == tokenizer.convert_tokens_to_ids('[spk_emb]')] = n
del n

num_code = int(gpt.emb_code[0].num_embeddings - 1)
Expand Down Expand Up @@ -83,7 +82,8 @@ def infer_code(


def refine_text(
models,
gpt: GPT,
tokenizer,
text,
top_P = 0.7,
top_K = 20,
Expand All @@ -95,15 +95,13 @@ def refine_text(
**kwargs
):

gpt: GPT = models['gpt']

if not isinstance(text, list):
text = [text]

assert len(text), 'text should not be empty'

text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
text_token = tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)

input_ids = text_token['input_ids'][...,None].expand(-1, -1, gpt.num_vq)
Expand All @@ -127,7 +125,7 @@ def refine_text(
attention_mask = text_token['attention_mask'],
LogitsWarpers = LogitsWarpers,
LogitsProcessors = LogitsProcessors,
eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
eos_token = torch.tensor(tokenizer.convert_tokens_to_ids('[Ebreak]'), device=device)[None],
max_new_token = max_new_token,
infer_text = True,
stream = False,
Expand Down
1 change: 0 additions & 1 deletion ChatTTS/utils/dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def check_model(
logger.get_logger().warn(f"{target} sha256 hash mismatch.")
logger.get_logger().info(f"expected: {hash}")
logger.get_logger().info(f"real val: {digest}")
logger.get_logger().warn("please add parameter --update to download the latest assets.")
if remove_incorrect:
if not os.path.exists(bakfile):
os.rename(str(target), bakfile)
Expand Down
14 changes: 8 additions & 6 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, stream):

with TorchSeedContext(audio_seed_input):
rand_spk = chat.sample_random_speaker()
params_infer_code = {
'spk_emb': rand_spk,
'temperature': temperature,
'top_P': top_P,
'top_K': top_K,
}

params_infer_code = {
'spk_emb': rand_spk,
'temperature': temperature,
'top_P': top_P,
'top_K': top_K,
}

with TorchSeedContext(audio_seed_input):
wav = chat.infer(
text,
skip_refine_text=True,
Expand Down

0 comments on commit 21c8ecc

Please sign in to comment.