Skip to content

Commit

Permalink
optimize(gpt): apply type definitions (#407)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jun 22, 2024
1 parent 2ba7c6c commit b5fd820
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 112 deletions.
22 changes: 13 additions & 9 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from huggingface_hub import snapshot_download

from .model.dvae import DVAE
from .model.gpt import GPT_warpper
from .model.gpt import GPT
from .utils.gpu import select_device
from .utils.infer import count_invalid_characters, detect_language, apply_character_map, apply_half2full_map, HomophonesReplacer
from .utils.io import get_latest_modified_file, del_all
Expand Down Expand Up @@ -126,7 +126,7 @@ def _load(

if gpt_config_path:
cfg = OmegaConf.load(gpt_config_path)
gpt = GPT_warpper(**cfg, device=device, logger=self.logger).eval()
gpt = GPT(**cfg, device=device, logger=self.logger).eval()
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
gpt.load_state_dict(torch.load(gpt_ckpt_path))
if compile and 'cuda' in str(device):
Expand Down Expand Up @@ -196,14 +196,16 @@ def _infer(
self.logger.log(logging.INFO, f'Homophones replace: {repl_res}')

if not skip_refine_text:
text_tokens = refine_text(
refined = refine_text(
self.pretrain_models,
text,
device=self.device,
**params_refine_text,
)['ids']
)
text_tokens = refined.ids
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
del_all(refined)
if refine_text_only:
yield text
return
Expand All @@ -219,10 +221,8 @@ def _infer(
stream=stream,
)
if use_decoder:
field = 'hiddens'
docoder_name = 'decoder'
else:
field = 'ids'
docoder_name = 'dvae'
if "mps" in str(self.device):
vocos_decode = lambda spec: [self.pretrain_models['vocos'].decode(
Expand All @@ -236,8 +236,9 @@ def _infer(

length = 0
for result in result_gen:
chunk_data = result[field][0]
assert len(result[field]) == 1
x = result.hiddens if use_decoder else result.ids
assert len(x) == 1
chunk_data = x[0]
start_seek = length
length = len(chunk_data)
self.logger.debug(f'{start_seek=} total len: {length}, new len: {length - start_seek = }')
Expand All @@ -248,14 +249,17 @@ def _infer(
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1).to(self.device)) for i in [chunk_data]]
del_all(result)
del chunk_data
del_all(x)
wav = vocos_decode(mel_spec)
del_all(mel_spec)
self.logger.debug(f'yield wav chunk {len(wav[0])=} {len(wav[0][0])=}')
yield wav
return
result = next(result_gen)
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1).to(self.device)) for i in result[field]]
x = result.hiddens if use_decoder else result.ids
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1).to(self.device)) for i in x]
del_all(result)
del_all(x)
wav = vocos_decode(mel_spec)
del_all(mel_spec)
yield wav
Expand Down
16 changes: 5 additions & 11 deletions ChatTTS/infer/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from ..utils.infer import CustomRepetitionPenaltyLogitsProcessorRepeat
from ..utils.io import del_all
from ..model.gpt import GPT_warpper
from ..model.gpt import GPT

def infer_code(
models,
Expand All @@ -21,7 +21,7 @@ def infer_code(
**kwargs
):

gpt: GPT_warpper = models['gpt']
gpt: GPT = models['gpt']

if not isinstance(text, list):
text = [text]
Expand All @@ -40,10 +40,7 @@ def infer_code(
input_ids = text_token['input_ids'][...,None].expand(-1, -1, gpt.num_vq).to(gpt.device_gpt)
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=gpt.device_gpt)

emb = gpt.get_emb(
input_ids=input_ids,
text_mask=text_mask,
)
emb = gpt(input_ids, text_mask)
del text_mask

if spk_emb is not None:
Expand Down Expand Up @@ -98,7 +95,7 @@ def refine_text(
**kwargs
):

gpt: GPT_warpper = models['gpt']
gpt: GPT = models['gpt']

if not isinstance(text, list):
text = [text]
Expand All @@ -121,10 +118,7 @@ def refine_text(
if repetition_penalty is not None and repetition_penalty != 1:
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))

emb = gpt.get_emb(
input_ids=input_ids,
text_mask=text_mask,
)
emb = gpt(input_ids,text_mask)
del text_mask

result = gpt.generate(
Expand Down
2 changes: 1 addition & 1 deletion ChatTTS/model/dvae.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import math
from typing import List, Optional

import pybase16384 as b14
import numpy as np
import pybase16384 as b14
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down
Loading

0 comments on commit b5fd820

Please sign in to comment.