From e628c86d61a8af5a3b3c1040200acc089ae9e40a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Mon, 1 Jul 2024 17:35:12 +0900 Subject: [PATCH] chore(format): run black on main (#504) Co-authored-by: github-actions[bot] --- ChatTTS/model/gpt.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 7c1063333..e6007e6c8 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -369,7 +369,7 @@ def generate( attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_( attention_mask ) - + pbar: Optional[tqdm] = None if show_tqdm: @@ -473,9 +473,7 @@ def generate( del logits - idx_next = torch.multinomial(scores, num_samples=1).to( - finish.device - ) + idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) if not infer_text: # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) @@ -483,9 +481,7 @@ def generate( finish_or = idx_next.eq(eos_token).any(1) finish.logical_or_(finish_or) del finish_or - inputs_ids_tmp = torch.cat( - [inputs_ids, idx_next.unsqueeze_(1)], 1 - ) + inputs_ids_tmp = torch.cat([inputs_ids, idx_next.unsqueeze_(1)], 1) else: finish_or = idx_next.eq(eos_token).any(1) finish.logical_or_(finish_or) @@ -525,9 +521,11 @@ def generate( if finish.all() or context.get(): break - if pbar is not None: pbar.update(1) - - if pbar is not None: pbar.close() + if pbar is not None: + pbar.update(1) + + if pbar is not None: + pbar.close() if not finish.all(): if context.get():