Skip to content

Commit

Permalink
optimize: spk_emb & generate & webui (#461)
Browse files Browse the repository at this point in the history
fix: ipynb param not defined
  • Loading branch information
fumiama authored Jun 26, 2024
1 parent 2cd5662 commit e5764c6
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 98 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/unitest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ jobs:
- name: Test Install
run: pip install .

- name: Install Dependencies
run: pip install -r requirements.txt

- name: Run Test
run: |
echo "TODO"
110 changes: 57 additions & 53 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ def _load(
tokenizer = torch.load(tokenizer_path, map_location=device, mmap=True)
tokenizer.padding_side = "left"
self.pretrain_models["tokenizer"] = tokenizer
self.tokenizer_len = len(tokenizer)
self.tokenizer_spk_emb_ids: torch.Tensor = tokenizer.convert_tokens_to_ids("[spk_emb]")
self.tokenizer_eos_token: torch.Tensor = torch.tensor(
tokenizer.convert_tokens_to_ids("[Ebreak]"), device=gpt.device_gpt
).unsqueeze_(0)
self.logger.log(logging.INFO, "tokenizer loaded.")

self.coef = coef
Expand Down Expand Up @@ -342,38 +347,40 @@ def _infer(
for t in text
]

if not skip_refine_text:
refined = self._refine_text(
with torch.no_grad():

if not skip_refine_text:
refined = self._refine_text(
text,
self.device,
params_refine_text,
)
text_tokens = refined.ids
text_tokens = [
i[
i
< self.pretrain_models["tokenizer"].convert_tokens_to_ids(
"[break_0]"
)
]
for i in text_tokens
]
text = self.pretrain_models["tokenizer"].batch_decode(text_tokens)
refined.destroy()
if refine_text_only:
yield text
return

length = [0 for _ in range(len(text))]
for result in self._infer_code(
text,
stream,
self.device,
params_refine_text,
)
text_tokens = refined.ids
text_tokens = [
i[
i
< self.pretrain_models["tokenizer"].convert_tokens_to_ids(
"[break_0]"
)
]
for i in text_tokens
]
text = self.pretrain_models["tokenizer"].batch_decode(text_tokens)
refined.destroy()
if refine_text_only:
yield text
return

length = [0 for _ in range(len(text))]
for result in self._infer_code(
text,
stream,
self.device,
use_decoder,
params_infer_code,
):
wav = self._decode_to_wavs(result, length, use_decoder)
yield wav
use_decoder,
params_infer_code,
):
wav = self._decode_to_wavs(result, length, use_decoder)
yield wav

def _decode_to_wavs(
self, result: GPT.GenerationOutputs, start_seeks: List[int], use_decoder: bool
Expand All @@ -397,7 +404,7 @@ def _decode_to_wavs(
del_all(x)
return wavs

def _gen_gpt_inputs(self, text: str, device="cpu"):
def _text_to_token(self, text: str, device="cpu") -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

gpt = self.gpt
tokenizer = self.pretrain_models["tokenizer"]
Expand All @@ -407,10 +414,14 @@ def _gen_gpt_inputs(self, text: str, device="cpu"):
)
text_token = text_token_tmp.to(device)
del text_token_tmp
input_ids = text_token["input_ids"][..., None].expand(-1, -1, gpt.num_vq)

input_ids = text_token["input_ids"].unsqueeze(-1).expand(-1, -1, gpt.num_vq)
text_mask = torch.ones(text_token["input_ids"].shape, dtype=bool, device=device)
attention_mask = text_token["attention_mask"]

del_all(text_token)

return input_ids, text_token, text_mask
return input_ids, attention_mask, text_mask

def _apply_spk_emb(
self,
Expand All @@ -419,14 +430,12 @@ def _apply_spk_emb(
input_ids: torch.Tensor,
text_len: int,
):

tokenizer = self.pretrain_models["tokenizer"]

n = F.normalize(
spk_emb.to(emb.dtype)[None].expand(text_len, -1), p=2.0, dim=1, eps=1e-12
).to(self.gpt.device_gpt)
emb[input_ids[..., 0] == tokenizer.convert_tokens_to_ids("[spk_emb]")] = n
del n
spk_emb.unsqueeze(0).expand(text_len, -1), p=2.0, dim=1, eps=1e-12
).to(self.gpt.device_gpt).expand(emb.shape)
cond = input_ids.narrow(-1, 0, 1).eq(self.tokenizer_spk_emb_ids).expand(emb.shape)
torch.where(cond, n, emb, out=emb)
del cond, n

def _infer_code(
self,
Expand Down Expand Up @@ -457,7 +466,7 @@ def _infer_code(
else:
text = [f"[Stts][empty_spk]{i}[Ptts]" for i in text]

input_ids, text_token, text_mask = self._gen_gpt_inputs(text, gpt.device_gpt)
input_ids, attention_mask, text_mask = self._text_to_token(text, gpt.device_gpt)

emb = gpt(input_ids, text_mask)
del text_mask
Expand All @@ -479,7 +488,7 @@ def _infer_code(
input_ids,
temperature=torch.tensor(temperature, device=device),
eos_token=num_code,
attention_mask=text_token["attention_mask"],
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_warpers=logits_warpers,
Expand All @@ -490,8 +499,7 @@ def _infer_code(
context=self.context,
)

del_all(text_token)
del emb, text_token, input_ids
del emb, input_ids
del_all(logits_warpers)
del_all(logits_processors)

Expand All @@ -505,17 +513,16 @@ def _refine_text(
):

gpt = self.gpt
tokenizer = self.pretrain_models["tokenizer"]

if not isinstance(text, list):
text = [text]

text = [f"[Sbreak]{i}[Pbreak]{params.prompt}" for i in text]

input_ids, text_token, text_mask = self._gen_gpt_inputs(text, gpt.device_gpt)
input_ids, attention_mask, text_mask = self._text_to_token(text, gpt.device_gpt)

logits_warpers, logits_processors = gen_logits(
num_code=len(tokenizer),
num_code=self.tokenizer_len,
top_P=params.top_P,
top_K=params.top_K,
repetition_penalty=params.repetition_penalty,
Expand All @@ -528,10 +535,8 @@ def _refine_text(
emb,
input_ids,
temperature=torch.tensor([params.temperature], device=device),
eos_token=torch.tensor(
tokenizer.convert_tokens_to_ids("[Ebreak]"), device=gpt.device_gpt
)[None],
attention_mask=text_token["attention_mask"],
eos_token=self.tokenizer_eos_token,
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_warpers=logits_warpers,
Expand All @@ -541,8 +546,7 @@ def _refine_text(
context=self.context,
)

del_all(text_token)
del emb, text_token, input_ids
del emb, input_ids
del_all(logits_warpers)
del_all(logits_processors)

Expand Down
10 changes: 5 additions & 5 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def _prepare_generation_inputs(
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
position_ids.masked_fill_(attention_mask.eq(0), 1)
if past_key_values:
position_ids = position_ids.narrow(1, -input_ids.shape[1], input_ids.shape[1])

Expand Down Expand Up @@ -321,7 +321,7 @@ def generate(
inputs_ids: torch.Tensor,
temperature: torch.Tensor,
eos_token: Union[int, torch.Tensor],
attention_mask=None,
attention_mask: Optional[torch.Tensor]=None,
max_new_token=2048,
min_new_token=0,
logits_warpers: List[LogitsWarper] = [],
Expand Down Expand Up @@ -469,14 +469,14 @@ def generate(
if not infer_text:
# idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
idx_next = idx_next.view(-1, self.num_vq)
finish_or = (idx_next == eos_token).any(1)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
inputs_ids_tmp = torch.cat(
[inputs_ids, idx_next.unsqueeze_(1)], 1
)
else:
finish_or = (idx_next == eos_token).any(1)
finish_or = idx_next.eq(eos_token).any(1)
finish.logical_or_(finish_or)
del finish_or
inputs_ids_tmp = torch.cat(
Expand All @@ -497,7 +497,7 @@ def generate(
if stream:
if (
end_idx.all()
and (end_idx % 24 == 0).any()
and end_idx.fmod(24).eq(0).any()
and minus_prev_end_index.add_(end_idx).any()
):
self.logger.debug("yield stream result, end: %d", end_idx)
Expand Down
1 change: 0 additions & 1 deletion examples/ipynb/colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,6 @@
"\n",
"wav = chat.infer(\n",
" \"四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。\",\n",
" params_refine_text=params_refine_text,\n",
" params_infer_code=params_infer_code,\n",
")"
]
Expand Down
1 change: 0 additions & 1 deletion examples/ipynb/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,6 @@
"\n",
"wav = chat.infer(\n",
" \"四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。\",\n",
" params_refine_text=params_refine_text,\n",
" params_infer_code=params_infer_code,\n",
")"
]
Expand Down
37 changes: 18 additions & 19 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,43 +95,34 @@ def reload_chat(coef: Optional[str]) -> str:
return chat.coef


def set_generate_buttons(generate_button, interrupt_button, is_reset=False):
def _set_generate_buttons(generate_button, interrupt_button, is_reset=False):
return gr.update(
value=generate_button, visible=is_reset, interactive=is_reset
), gr.update(value=interrupt_button, visible=not is_reset, interactive=not is_reset)


def refine_text(
text, text_seed_input, refine_text_flag, generate_button, interrupt_button
text, text_seed_input, refine_text_flag,
):
global chat, has_interrupted
has_interrupted = False
global chat

if not refine_text_flag:
sleep(1) # to skip fast answer of loading mark
return text, *set_generate_buttons(
generate_button, interrupt_button, is_reset=True
)
return text

with TorchSeedContext(text_seed_input):
text = chat.infer(
text,
skip_refine_text=False,
refine_text_only=True,
)
return text[0] if isinstance(text, list) else text, *set_generate_buttons(
generate_button, interrupt_button, is_reset=True
)


def text_output_listener(generate_button, interrupt_button):
return set_generate_buttons(generate_button, interrupt_button)

return text[0] if isinstance(text, list) else text

def generate_audio(text, temperature, top_P, top_K, audio_seed_input, stream):
global chat, has_interrupted

if not text or text == "𝕃𝕠𝕒𝕕𝕚𝕟𝕘..." or has_interrupted:
if not text or has_interrupted:
return None

with TorchSeedContext(audio_seed_input):
Expand All @@ -157,9 +148,8 @@ def generate_audio(text, temperature, top_P, top_K, audio_seed_input, stream):
if audio is not None and len(audio) > 0:
yield 24000, unsafe_float_to_int16(audio[0])
del audio
return

yield 24000, unsafe_float_to_int16(np.array(wav[0]).flatten())
else:
yield 24000, unsafe_float_to_int16(np.array(wav[0]).flatten())


def interrupt_generate():
Expand All @@ -168,11 +158,20 @@ def interrupt_generate():
has_interrupted = True
chat.interrupt()

def set_buttons_before_generate(generate_button, interrupt_button):
global has_interrupted

has_interrupted = False

return _set_generate_buttons(
generate_button,
interrupt_button,
)

def set_buttons_after_generate(generate_button, interrupt_button, audio_output):
global has_interrupted

return set_generate_buttons(
return _set_generate_buttons(
generate_button,
interrupt_button,
audio_output is not None or has_interrupted,
Expand Down
28 changes: 9 additions & 19 deletions examples/web/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,19 +96,6 @@ def main():
reload_chat, inputs=dvae_coef_text, outputs=dvae_coef_text
)

generate_button.click(fn=lambda: "𝕃𝕠𝕒𝕕𝕚𝕟𝕘...", outputs=text_output)
generate_button.click(
refine_text,
inputs=[
text_input,
text_seed_input,
refine_text_checkbox,
generate_button,
interrupt_button,
],
outputs=[text_output, generate_button, interrupt_button],
)

interrupt_button.click(interrupt_generate)

@gr.render(inputs=[auto_play_checkbox, stream_mode_checkbox])
Expand All @@ -122,12 +109,15 @@ def make_audio(autoplay, stream):
show_label=True,
format="mp3",
)
text_output.change(
text_output_listener,
inputs=[generate_button, interrupt_button],
outputs=[generate_button, interrupt_button],
)
text_output.change(
generate_button.click(fn=set_buttons_before_generate, inputs=[generate_button, interrupt_button], outputs=[generate_button, interrupt_button]).then(
refine_text,
inputs=[
text_input,
text_seed_input,
refine_text_checkbox,
],
outputs=text_output,
).then(
generate_audio,
inputs=[
text_output,
Expand Down

0 comments on commit e5764c6

Please sign in to comment.