diff --git a/.github/workflows/unitest.yml b/.github/workflows/unitest.yml index 407f6ec99..89565c9e1 100644 --- a/.github/workflows/unitest.yml +++ b/.github/workflows/unitest.yml @@ -21,6 +21,9 @@ jobs: - name: Test Install run: pip install . + - name: Install Dependencies + run: pip install -r requirements.txt + - name: Run Test run: | echo "TODO" diff --git a/ChatTTS/core.py b/ChatTTS/core.py index b74310fe9..d5d11244f 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -307,6 +307,11 @@ def _load( tokenizer = torch.load(tokenizer_path, map_location=device, mmap=True) tokenizer.padding_side = "left" self.pretrain_models["tokenizer"] = tokenizer + self.tokenizer_len = len(tokenizer) + self.tokenizer_spk_emb_ids: torch.Tensor = tokenizer.convert_tokens_to_ids("[spk_emb]") + self.tokenizer_eos_token: torch.Tensor = torch.tensor( + tokenizer.convert_tokens_to_ids("[Ebreak]"), device=gpt.device_gpt + ).unsqueeze_(0) self.logger.log(logging.INFO, "tokenizer loaded.") self.coef = coef @@ -342,38 +347,40 @@ def _infer( for t in text ] - if not skip_refine_text: - refined = self._refine_text( + with torch.no_grad(): + + if not skip_refine_text: + refined = self._refine_text( + text, + self.device, + params_refine_text, + ) + text_tokens = refined.ids + text_tokens = [ + i[ + i + < self.pretrain_models["tokenizer"].convert_tokens_to_ids( + "[break_0]" + ) + ] + for i in text_tokens + ] + text = self.pretrain_models["tokenizer"].batch_decode(text_tokens) + refined.destroy() + if refine_text_only: + yield text + return + + length = [0 for _ in range(len(text))] + for result in self._infer_code( text, + stream, self.device, - params_refine_text, - ) - text_tokens = refined.ids - text_tokens = [ - i[ - i - < self.pretrain_models["tokenizer"].convert_tokens_to_ids( - "[break_0]" - ) - ] - for i in text_tokens - ] - text = self.pretrain_models["tokenizer"].batch_decode(text_tokens) - refined.destroy() - if refine_text_only: - yield text - return - - length = [0 for _ in range(len(text))] - for result in self._infer_code( - text, - stream, - self.device, - use_decoder, - params_infer_code, - ): - wav = self._decode_to_wavs(result, length, use_decoder) - yield wav + use_decoder, + params_infer_code, + ): + wav = self._decode_to_wavs(result, length, use_decoder) + yield wav def _decode_to_wavs( self, result: GPT.GenerationOutputs, start_seeks: List[int], use_decoder: bool @@ -397,7 +404,7 @@ def _decode_to_wavs( del_all(x) return wavs - def _gen_gpt_inputs(self, text: str, device="cpu"): + def _text_to_token(self, text: str, device="cpu") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: gpt = self.gpt tokenizer = self.pretrain_models["tokenizer"] @@ -407,10 +414,14 @@ def _gen_gpt_inputs(self, text: str, device="cpu"): ) text_token = text_token_tmp.to(device) del text_token_tmp - input_ids = text_token["input_ids"][..., None].expand(-1, -1, gpt.num_vq) + + input_ids = text_token["input_ids"].unsqueeze(-1).expand(-1, -1, gpt.num_vq) text_mask = torch.ones(text_token["input_ids"].shape, dtype=bool, device=device) + attention_mask = text_token["attention_mask"] + + del_all(text_token) - return input_ids, text_token, text_mask + return input_ids, attention_mask, text_mask def _apply_spk_emb( self, @@ -419,14 +430,12 @@ def _apply_spk_emb( input_ids: torch.Tensor, text_len: int, ): - - tokenizer = self.pretrain_models["tokenizer"] - n = F.normalize( - spk_emb.to(emb.dtype)[None].expand(text_len, -1), p=2.0, dim=1, eps=1e-12 - ).to(self.gpt.device_gpt) - emb[input_ids[..., 0] == tokenizer.convert_tokens_to_ids("[spk_emb]")] = n - del n + spk_emb.unsqueeze(0).expand(text_len, -1), p=2.0, dim=1, eps=1e-12 + ).to(self.gpt.device_gpt).expand(emb.shape) + cond = input_ids.narrow(-1, 0, 1).eq(self.tokenizer_spk_emb_ids).expand(emb.shape) + torch.where(cond, n, emb, out=emb) + del cond, n def _infer_code( self, @@ -457,7 +466,7 @@ def _infer_code( else: text = [f"[Stts][empty_spk]{i}[Ptts]" for i in text] - input_ids, text_token, text_mask = self._gen_gpt_inputs(text, gpt.device_gpt) + input_ids, attention_mask, text_mask = self._text_to_token(text, gpt.device_gpt) emb = gpt(input_ids, text_mask) del text_mask @@ -479,7 +488,7 @@ def _infer_code( input_ids, temperature=torch.tensor(temperature, device=device), eos_token=num_code, - attention_mask=text_token["attention_mask"], + attention_mask=attention_mask, max_new_token=params.max_new_token, min_new_token=params.min_new_token, logits_warpers=logits_warpers, @@ -490,8 +499,7 @@ def _infer_code( context=self.context, ) - del_all(text_token) - del emb, text_token, input_ids + del emb, input_ids del_all(logits_warpers) del_all(logits_processors) @@ -505,17 +513,16 @@ def _refine_text( ): gpt = self.gpt - tokenizer = self.pretrain_models["tokenizer"] if not isinstance(text, list): text = [text] text = [f"[Sbreak]{i}[Pbreak]{params.prompt}" for i in text] - input_ids, text_token, text_mask = self._gen_gpt_inputs(text, gpt.device_gpt) + input_ids, attention_mask, text_mask = self._text_to_token(text, gpt.device_gpt) logits_warpers, logits_processors = gen_logits( - num_code=len(tokenizer), + num_code=self.tokenizer_len, top_P=params.top_P, top_K=params.top_K, repetition_penalty=params.repetition_penalty, @@ -528,10 +535,8 @@ def _refine_text( emb, input_ids, temperature=torch.tensor([params.temperature], device=device), - eos_token=torch.tensor( - tokenizer.convert_tokens_to_ids("[Ebreak]"), device=gpt.device_gpt - )[None], - attention_mask=text_token["attention_mask"], + eos_token=self.tokenizer_eos_token, + attention_mask=attention_mask, max_new_token=params.max_new_token, min_new_token=params.min_new_token, logits_warpers=logits_warpers, @@ -541,8 +546,7 @@ def _refine_text( context=self.context, ) - del_all(text_token) - del emb, text_token, input_ids + del emb, input_ids del_all(logits_warpers) del_all(logits_processors) diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index ee9467867..69d104326 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -240,7 +240,7 @@ def _prepare_generation_inputs( if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 - position_ids.masked_fill_(attention_mask == 0, 1) + position_ids.masked_fill_(attention_mask.eq(0), 1) if past_key_values: position_ids = position_ids.narrow(1, -input_ids.shape[1], input_ids.shape[1]) @@ -321,7 +321,7 @@ def generate( inputs_ids: torch.Tensor, temperature: torch.Tensor, eos_token: Union[int, torch.Tensor], - attention_mask=None, + attention_mask: Optional[torch.Tensor]=None, max_new_token=2048, min_new_token=0, logits_warpers: List[LogitsWarper] = [], @@ -469,14 +469,14 @@ def generate( if not infer_text: # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) idx_next = idx_next.view(-1, self.num_vq) - finish_or = (idx_next == eos_token).any(1) + finish_or = idx_next.eq(eos_token).any(1) finish.logical_or_(finish_or) del finish_or inputs_ids_tmp = torch.cat( [inputs_ids, idx_next.unsqueeze_(1)], 1 ) else: - finish_or = (idx_next == eos_token).any(1) + finish_or = idx_next.eq(eos_token).any(1) finish.logical_or_(finish_or) del finish_or inputs_ids_tmp = torch.cat( @@ -497,7 +497,7 @@ def generate( if stream: if ( end_idx.all() - and (end_idx % 24 == 0).any() + and end_idx.fmod(24).eq(0).any() and minus_prev_end_index.add_(end_idx).any() ): self.logger.debug("yield stream result, end: %d", end_idx) diff --git a/examples/ipynb/colab.ipynb b/examples/ipynb/colab.ipynb index 5a3900d55..ab11fec72 100644 --- a/examples/ipynb/colab.ipynb +++ b/examples/ipynb/colab.ipynb @@ -310,7 +310,6 @@ "\n", "wav = chat.infer(\n", " \"四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。\",\n", - " params_refine_text=params_refine_text,\n", " params_infer_code=params_infer_code,\n", ")" ] diff --git a/examples/ipynb/example.ipynb b/examples/ipynb/example.ipynb index 7c53ef87e..20e4e316a 100644 --- a/examples/ipynb/example.ipynb +++ b/examples/ipynb/example.ipynb @@ -253,7 +253,6 @@ "\n", "wav = chat.infer(\n", " \"四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。\",\n", - " params_refine_text=params_refine_text,\n", " params_infer_code=params_infer_code,\n", ")" ] diff --git a/examples/web/funcs.py b/examples/web/funcs.py index dcff4f5fe..4e2c2152d 100644 --- a/examples/web/funcs.py +++ b/examples/web/funcs.py @@ -95,23 +95,20 @@ def reload_chat(coef: Optional[str]) -> str: return chat.coef -def set_generate_buttons(generate_button, interrupt_button, is_reset=False): +def _set_generate_buttons(generate_button, interrupt_button, is_reset=False): return gr.update( value=generate_button, visible=is_reset, interactive=is_reset ), gr.update(value=interrupt_button, visible=not is_reset, interactive=not is_reset) def refine_text( - text, text_seed_input, refine_text_flag, generate_button, interrupt_button + text, text_seed_input, refine_text_flag, ): - global chat, has_interrupted - has_interrupted = False + global chat if not refine_text_flag: sleep(1) # to skip fast answer of loading mark - return text, *set_generate_buttons( - generate_button, interrupt_button, is_reset=True - ) + return text with TorchSeedContext(text_seed_input): text = chat.infer( @@ -119,19 +116,13 @@ def refine_text( skip_refine_text=False, refine_text_only=True, ) - return text[0] if isinstance(text, list) else text, *set_generate_buttons( - generate_button, interrupt_button, is_reset=True - ) - - -def text_output_listener(generate_button, interrupt_button): - return set_generate_buttons(generate_button, interrupt_button) + return text[0] if isinstance(text, list) else text def generate_audio(text, temperature, top_P, top_K, audio_seed_input, stream): global chat, has_interrupted - if not text or text == "𝕃𝕠𝕒𝕕𝕚𝕟𝕘..." or has_interrupted: + if not text or has_interrupted: return None with TorchSeedContext(audio_seed_input): @@ -157,9 +148,8 @@ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, stream): if audio is not None and len(audio) > 0: yield 24000, unsafe_float_to_int16(audio[0]) del audio - return - - yield 24000, unsafe_float_to_int16(np.array(wav[0]).flatten()) + else: + yield 24000, unsafe_float_to_int16(np.array(wav[0]).flatten()) def interrupt_generate(): @@ -168,11 +158,20 @@ def interrupt_generate(): has_interrupted = True chat.interrupt() +def set_buttons_before_generate(generate_button, interrupt_button): + global has_interrupted + + has_interrupted = False + + return _set_generate_buttons( + generate_button, + interrupt_button, + ) def set_buttons_after_generate(generate_button, interrupt_button, audio_output): global has_interrupted - return set_generate_buttons( + return _set_generate_buttons( generate_button, interrupt_button, audio_output is not None or has_interrupted, diff --git a/examples/web/webui.py b/examples/web/webui.py index af7f43f92..7d2023653 100644 --- a/examples/web/webui.py +++ b/examples/web/webui.py @@ -96,19 +96,6 @@ def main(): reload_chat, inputs=dvae_coef_text, outputs=dvae_coef_text ) - generate_button.click(fn=lambda: "𝕃𝕠𝕒𝕕𝕚𝕟𝕘...", outputs=text_output) - generate_button.click( - refine_text, - inputs=[ - text_input, - text_seed_input, - refine_text_checkbox, - generate_button, - interrupt_button, - ], - outputs=[text_output, generate_button, interrupt_button], - ) - interrupt_button.click(interrupt_generate) @gr.render(inputs=[auto_play_checkbox, stream_mode_checkbox]) @@ -122,12 +109,15 @@ def make_audio(autoplay, stream): show_label=True, format="mp3", ) - text_output.change( - text_output_listener, - inputs=[generate_button, interrupt_button], - outputs=[generate_button, interrupt_button], - ) - text_output.change( + generate_button.click(fn=set_buttons_before_generate, inputs=[generate_button, interrupt_button], outputs=[generate_button, interrupt_button]).then( + refine_text, + inputs=[ + text_input, + text_seed_input, + refine_text_checkbox, + ], + outputs=text_output, + ).then( generate_audio, inputs=[ text_output,