Skip to content

Commit

Permalink
fix: stream api (#363)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jun 19, 2024
1 parent f0babd0 commit dc5a3e0
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 11 deletions.
39 changes: 33 additions & 6 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,17 @@ def _infer(
self.logger.log(logging.INFO, f'Homophones replace: {t} -> {text[i]}')

if not skip_refine_text:
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids']
text_tokens = refine_text(
self.pretrain_models,
text,
**params_refine_text,
)['ids']
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens]
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens)
if refine_text_only:
return text

yield text
return

text = [params_infer_code.get('prompt', '') + i for i in text]
params_infer_code.pop('prompt', '')
result_gen = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder, stream=stream)
Expand Down Expand Up @@ -220,9 +225,31 @@ def _infer(
mel_spec = [self.pretrain_models[docoder_name](i[None].permute(0,2,1)) for i in next(result_gen)[field]]
yield vocos_decode(mel_spec)

def infer(self, *args, **kwargs):
stream = kwargs.setdefault('stream', False)
res_gen = self._infer(*args, **kwargs)
def infer(
self,
text,
skip_refine_text=False,
refine_text_only=False,
params_refine_text={},
params_infer_code={'prompt':'[speed_5]'},
use_decoder=True,
do_text_normalization=True,
lang=None,
stream=False,
do_homophone_replacement=True,
):
res_gen = self._infer(
text,
skip_refine_text,
refine_text_only,
params_refine_text,
params_infer_code,
use_decoder,
do_text_normalization,
lang,
stream,
do_homophone_replacement,
)
if stream:
return res_gen
else:
Expand Down
5 changes: 3 additions & 2 deletions ChatTTS/infer/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def infer_code(
stream = stream,
**kwargs
)

return result


Expand Down Expand Up @@ -122,6 +122,7 @@ def refine_text(
eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
max_new_token = max_new_token,
infer_text = True,
stream = False,
**kwargs
)
return result
return next(result)
5 changes: 2 additions & 3 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def __init__(
num_audio_tokens,
num_text_tokens,
num_vq=4,
**kwargs,
):
):
super().__init__()

self.logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -291,7 +290,7 @@ def generate(
hiddens = [hiddens[idx, :i] for idx, i in enumerate(end_idx.int())]

if not finish.all():
self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')
self.logger.warn(f'Incomplete result. hit max_new_token: {max_new_token}')

del finish

Expand Down

1 comment on commit dc5a3e0

@bstr9
Copy link

@bstr9 bstr9 commented on dc5a3e0 Jun 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome

Please sign in to comment.