From 3836db8a911d810b0b601884ead07a270bc084a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Mon, 1 Jul 2024 17:20:06 +0900 Subject: [PATCH] feat: add infer param `show_tqdm` --- ChatTTS/core.py | 8 +- ChatTTS/model/gpt.py | 282 ++++++++++++++++++++++--------------------- 2 files changed, 148 insertions(+), 142 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 0d96905d9..78eaf8999 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -197,17 +197,15 @@ class RefineTextParams: repetition_penalty: float = 1.0 max_new_token: int = 384 min_new_token: int = 0 + show_tqdm: bool = True @dataclass(repr=False, eq=False) - class InferCodeParams: + class InferCodeParams(RefineTextParams): prompt: str = "[speed_5]" spk_emb: Optional[str] = None - top_P: float = 0.7 - top_K: int = 20 temperature: float = 0.3 repetition_penalty: float = 1.05 max_new_token: int = 2048 - min_new_token: int = 0 def infer( self, @@ -596,6 +594,7 @@ def _infer_code( infer_text=False, return_hidden=return_hidden, stream=stream, + show_tqdm=params.show_tqdm, context=self.context, ) @@ -644,6 +643,7 @@ def _refine_text( logits_processors=logits_processors, infer_text=True, stream=False, + show_tqdm=params.show_tqdm, context=self.context, ) ) diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 1bbfdbcd1..7c1063333 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -335,6 +335,7 @@ def generate( return_attn=False, return_hidden=False, stream=False, + show_tqdm=True, context=Context(), ): @@ -368,160 +369,165 @@ def generate( attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_( attention_mask ) + + pbar: Optional[tqdm] = None + + if show_tqdm: + pbar = tqdm( + total=max_new_token, + desc="text" if infer_text else "code", + bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]", + ) - with tqdm( - total=max_new_token, - desc="text" if infer_text else "code", - bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]", - ) as pbar: - - past_key_values = None + past_key_values = None - for i in range(max_new_token): - model_input = self._prepare_generation_inputs( - inputs_ids, - past_key_values, - attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]), - use_cache=True, - ) + for i in range(max_new_token): + model_input = self._prepare_generation_inputs( + inputs_ids, + past_key_values, + attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]), + use_cache=True, + ) - if i > 0: - del emb - inputs_ids_emb = model_input.input_ids.to(self.device_gpt) - if infer_text: - emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0]) - else: - code_emb = [ - self.emb_code[i](inputs_ids_emb[:, :, i]) - for i in range(self.num_vq) - ] - emb = torch.stack(code_emb, 3).sum(3) - del inputs_ids_emb, model_input.input_ids - model_input.inputs_embeds = emb - - model_input.to(self.device_gpt) - - outputs: BaseModelOutputWithPast = self.gpt( - attention_mask=model_input.attention_mask, - position_ids=model_input.position_ids, - past_key_values=model_input.past_key_values, - inputs_embeds=model_input.inputs_embeds, - use_cache=model_input.use_cache, - output_attentions=return_attn, - cache_position=model_input.cache_position, - ) - del_all(model_input) - attentions.append(outputs.attentions) - hidden_states = outputs.last_hidden_state.to(self.device) # 🐻 - past_key_values = outputs.past_key_values - del_all(outputs) - if return_hidden: - hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1)) - - with P.cached(): - if infer_text: - logits: torch.Tensor = self.head_text(hidden_states) - else: - # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3) - logits = torch.empty( - hidden_states.size(0), - hidden_states.size(1), - self.num_audio_tokens, - self.num_vq, - dtype=torch.float, - device=self.device, - ) - for i in range(self.num_vq): - x: torch.Tensor = self.head_code[i](hidden_states) - logits[..., i] = x - del x - - # logits = logits[:, -1].float() - logits = logits.narrow(1, -1, 1).squeeze_(1).float() - - if not infer_text: - # logits = rearrange(logits, "b c n -> (b n) c") - logits = logits.permute(0, 2, 1) - logits = logits.reshape(-1, logits.size(2)) - # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c") - inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1) - logits_token = inputs_ids_sliced.reshape( - inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1), - -1, - ).to(self.device) + if i > 0: + del emb + inputs_ids_emb = model_input.input_ids.to(self.device_gpt) + if infer_text: + emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0]) + else: + code_emb = [ + self.emb_code[i](inputs_ids_emb[:, :, i]) + for i in range(self.num_vq) + ] + emb = torch.stack(code_emb, 3).sum(3) + del inputs_ids_emb, model_input.input_ids + model_input.inputs_embeds = emb + + model_input.to(self.device_gpt) + + outputs: BaseModelOutputWithPast = self.gpt( + attention_mask=model_input.attention_mask, + position_ids=model_input.position_ids, + past_key_values=model_input.past_key_values, + inputs_embeds=model_input.inputs_embeds, + use_cache=model_input.use_cache, + output_attentions=return_attn, + cache_position=model_input.cache_position, + ) + del_all(model_input) + attentions.append(outputs.attentions) + hidden_states = outputs.last_hidden_state.to(self.device) # 🐻 + past_key_values = outputs.past_key_values + del_all(outputs) + if return_hidden: + hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1)) + + with P.cached(): + if infer_text: + logits: torch.Tensor = self.head_text(hidden_states) else: - logits_token = inputs_ids[:, start_idx:, 0].to(self.device) + # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3) + logits = torch.empty( + hidden_states.size(0), + hidden_states.size(1), + self.num_audio_tokens, + self.num_vq, + dtype=torch.float, + device=self.device, + ) + for i in range(self.num_vq): + x: torch.Tensor = self.head_code[i](hidden_states) + logits[..., i] = x + del x + + # logits = logits[:, -1].float() + logits = logits.narrow(1, -1, 1).squeeze_(1).float() + + if not infer_text: + # logits = rearrange(logits, "b c n -> (b n) c") + logits = logits.permute(0, 2, 1) + logits = logits.reshape(-1, logits.size(2)) + # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c") + inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1) + logits_token = inputs_ids_sliced.reshape( + inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1), + -1, + ).to(self.device) + else: + logits_token = inputs_ids[:, start_idx:, 0].to(self.device) + + logits /= temperature - logits /= temperature + for logitsProcessors in logits_processors: + logits = logitsProcessors(logits_token, logits) - for logitsProcessors in logits_processors: - logits = logitsProcessors(logits_token, logits) + for logitsWarpers in logits_warpers: + logits = logitsWarpers(logits_token, logits) - for logitsWarpers in logits_warpers: - logits = logitsWarpers(logits_token, logits) + del logits_token - del logits_token + if i < min_new_token: + logits[:, eos_token] = -torch.inf - if i < min_new_token: - logits[:, eos_token] = -torch.inf + scores = F.softmax(logits, dim=-1) - scores = F.softmax(logits, dim=-1) + del logits - del logits + idx_next = torch.multinomial(scores, num_samples=1).to( + finish.device + ) - idx_next = torch.multinomial(scores, num_samples=1).to( - finish.device + if not infer_text: + # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) + idx_next = idx_next.view(-1, self.num_vq) + finish_or = idx_next.eq(eos_token).any(1) + finish.logical_or_(finish_or) + del finish_or + inputs_ids_tmp = torch.cat( + [inputs_ids, idx_next.unsqueeze_(1)], 1 + ) + else: + finish_or = idx_next.eq(eos_token).any(1) + finish.logical_or_(finish_or) + del finish_or + inputs_ids_tmp = torch.cat( + [ + inputs_ids, + idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq), + ], + 1, ) - if not infer_text: - # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) - idx_next = idx_next.view(-1, self.num_vq) - finish_or = idx_next.eq(eos_token).any(1) - finish.logical_or_(finish_or) - del finish_or - inputs_ids_tmp = torch.cat( - [inputs_ids, idx_next.unsqueeze_(1)], 1 - ) - else: - finish_or = idx_next.eq(eos_token).any(1) - finish.logical_or_(finish_or) - del finish_or - inputs_ids_tmp = torch.cat( - [ - inputs_ids, - idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq), - ], - 1, + del inputs_ids + inputs_ids = inputs_ids_tmp + del inputs_ids_tmp, idx_next + + if stream: + minus_prev_end_index = end_idx.neg() + end_idx.add_((finish.logical_not().to(end_idx.device)).int()) + if stream: + if ( + end_idx.all() + and end_idx.fmod(24).eq(0).any() + and minus_prev_end_index.add_(end_idx).any() + ): + self.logger.debug("yield stream result, end: %d", end_idx) + yield self._prepare_generation_outputs( + inputs_ids, + start_idx, + end_idx, + attentions, + hiddens, + infer_text, ) + del minus_prev_end_index + + if finish.all() or context.get(): + break - del inputs_ids - inputs_ids = inputs_ids_tmp - del inputs_ids_tmp, idx_next - - if stream: - minus_prev_end_index = end_idx.neg() - end_idx.add_((finish.logical_not().to(end_idx.device)).int()) - if stream: - if ( - end_idx.all() - and end_idx.fmod(24).eq(0).any() - and minus_prev_end_index.add_(end_idx).any() - ): - self.logger.debug("yield stream result, end: %d", end_idx) - yield self._prepare_generation_outputs( - inputs_ids, - start_idx, - end_idx, - attentions, - hiddens, - infer_text, - ) - del minus_prev_end_index - - if finish.all() or context.get(): - break - - pbar.update(1) + if pbar is not None: pbar.update(1) + + if pbar is not None: pbar.close() if not finish.all(): if context.get():