Skip to content

Commit

Permalink
fix: redundant [spk_emb]s in refine_text (fix #459) (#464)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jun 26, 2024
1 parent 6aac04b commit 608bd19
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def _load(
self.pretrain_models["tokenizer"] = tokenizer
self.tokenizer_len = len(tokenizer)
self.tokenizer_spk_emb_ids: torch.Tensor = tokenizer.convert_tokens_to_ids("[spk_emb]")
self.tokenizer_break_0_ids: torch.Tensor = tokenizer.convert_tokens_to_ids("[break_0]")
self.tokenizer_eos_token: torch.Tensor = torch.tensor(
tokenizer.convert_tokens_to_ids("[Ebreak]"), device=gpt.device_gpt
).unsqueeze_(0)
Expand Down Expand Up @@ -377,12 +378,7 @@ def _infer(
)
text_tokens = refined.ids
text_tokens = [
i[
i
< self.pretrain_models["tokenizer"].convert_tokens_to_ids(
"[break_0]"
)
]
i[i.less(self.tokenizer_break_0_ids)]
for i in text_tokens
]
text = self.pretrain_models["tokenizer"].batch_decode(text_tokens)
Expand Down Expand Up @@ -458,7 +454,7 @@ def _apply_spk_emb(
filters=[{"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME}],
), dtype=np.float16).copy(),
).unsqueeze(0).expand(text_len, -1), p=2.0, dim=1, eps=1e-12
).to(self.gpt.device_gpt).expand(emb.shape)
).to(self.gpt.device_gpt).unsqueeze_(1).expand(emb.shape)
cond = input_ids.narrow(-1, 0, 1).eq(self.tokenizer_spk_emb_ids).expand(emb.shape)
torch.where(cond, n, emb, out=emb)
del cond, n
Expand All @@ -483,6 +479,12 @@ def _infer_code(
temperature = [params.temperature] * gpt.num_vq
else:
temperature = params.temperature

for i, t in enumerate(text):
text[i] = t.replace('[Stts]', '').replace('[spk_emb]', '').replace('[empty_spk]', '').strip()
"""
see https://github.com/2noise/ChatTTS/issues/459
"""

if params.prompt:
text = [params.prompt + i for i in text]
Expand Down Expand Up @@ -557,7 +559,7 @@ def _refine_text(
emb = gpt(input_ids, text_mask)
del text_mask

result = gpt.generate(
result = next(gpt.generate(
emb,
input_ids,
temperature=torch.tensor([params.temperature], device=device),
Expand All @@ -570,10 +572,10 @@ def _refine_text(
infer_text=True,
stream=False,
context=self.context,
)
))

del emb, input_ids
del_all(logits_warpers)
del_all(logits_processors)

return next(result)
return result

0 comments on commit 608bd19

Please sign in to comment.