From dc5a3e08109c6ff8f099db7bf2c9164ae804dba3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Wed, 19 Jun 2024 21:42:00 +0900 Subject: [PATCH] fix: stream api (#363) --- ChatTTS/core.py | 39 +++++++++++++++++++++++++++++++++------ ChatTTS/infer/api.py | 5 +++-- ChatTTS/model/gpt.py | 5 ++--- 3 files changed, 38 insertions(+), 11 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 4b87e427b..9a00a4cc0 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -181,12 +181,17 @@ def _infer( self.logger.log(logging.INFO, f'Homophones replace: {t} -> {text[i]}') if not skip_refine_text: - text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids'] + text_tokens = refine_text( + self.pretrain_models, + text, + **params_refine_text, + )['ids'] text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens] text = self.pretrain_models['tokenizer'].batch_decode(text_tokens) if refine_text_only: - return text - + yield text + return + text = [params_infer_code.get('prompt', '') + i for i in text] params_infer_code.pop('prompt', '') result_gen = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder, stream=stream) @@ -220,9 +225,31 @@ def _infer( mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in next(result_gen)[field]] yield vocos_decode(mel_spec) - def infer(self, *args, **kwargs): - stream = kwargs.setdefault('stream', False) - res_gen = self._infer(*args, **kwargs) + def infer( + self, + text, + skip_refine_text=False, + refine_text_only=False, + params_refine_text={}, + params_infer_code={'prompt':'[speed_5]'}, + use_decoder=True, + do_text_normalization=True, + lang=None, + stream=False, + do_homophone_replacement=True, + ): + res_gen = self._infer( + text, + skip_refine_text, + refine_text_only, + params_refine_text, + params_infer_code, + use_decoder, + do_text_normalization, + lang, + stream, + do_homophone_replacement, + ) if stream: return res_gen else: diff --git a/ChatTTS/infer/api.py b/ChatTTS/infer/api.py index acc1b8c66..2058efb14 100644 --- a/ChatTTS/infer/api.py +++ b/ChatTTS/infer/api.py @@ -70,7 +70,7 @@ def infer_code( stream = stream, **kwargs ) - + return result @@ -122,6 +122,7 @@ def refine_text( eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None], max_new_token = max_new_token, infer_text = True, + stream = False, **kwargs ) - return result + return next(result) diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 2e84671ce..510371bf2 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -36,8 +36,7 @@ def __init__( num_audio_tokens, num_text_tokens, num_vq=4, - **kwargs, - ): + ): super().__init__() self.logger = logging.getLogger(__name__) @@ -291,7 +290,7 @@ def generate( hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())] if not finish.all(): - self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}') + self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}') del finish