Skip to content

Commit

Permalink
feat: keep speaker in long-sentence infer
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jan 7, 2025
1 parent ff77e25 commit 8d7bcf0
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 30 deletions.
120 changes: 90 additions & 30 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import re
import logging
import tempfile
from dataclasses import dataclass, asdict
Expand Down Expand Up @@ -213,10 +214,28 @@ def infer(
use_decoder=True,
do_text_normalization=True,
do_homophone_replacement=True,
split_text=True,
max_split_batch=4,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
):
self.context.set(False)

if split_text and isinstance(text, str):
if "\n" in text:
text = text.split("\n")
else:
text = re.split(r'(?<=[。(.\s)])', text)
nt = []
for t in text:
if t:
nt.append(t)
text = nt
self.logger.info("split text into %d parts", len(text))
self.logger.debug("%s", str(text))

if len(text) == 0: return []

res_gen = self._infer(
text,
stream,
Expand All @@ -226,11 +245,21 @@ def infer(
use_decoder,
do_text_normalization,
do_homophone_replacement,
split_text,
max_split_batch,
params_refine_text,
params_infer_code,
)
if stream:
return res_gen
elif not refine_text_only:
stripped_wavs = []
for wavs in res_gen:
for wav in wavs:
stripped_wavs.append(wav[np.abs(wav)>1e-5])
if split_text:
return [np.concatenate(stripped_wavs)]
return stripped_wavs
else:
return next(res_gen)

Expand Down Expand Up @@ -350,14 +379,16 @@ def _load(

def _infer(
self,
text,
text: Union[List[str], str],
stream=False,
lang=None,
skip_refine_text=False,
refine_text_only=False,
use_decoder=True,
do_text_normalization=True,
do_homophone_replacement=True,
split_text=True,
max_split_batch=4,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
):
Expand Down Expand Up @@ -390,44 +421,73 @@ def _infer(
text = self.tokenizer.decode(text_tokens)
refined.destroy()
if refine_text_only:
if split_text and isinstance(text, list):
text = "\n".join(text)
yield text
return

if stream:
length = 0
pass_batch_count = 0
for result in self._infer_code(
text,
stream,
self.device,
use_decoder,
params_infer_code,
):
if split_text and len(text) > 1 and params_infer_code.spk_smp is None:
refer_text = text[0]
result = next(self._infer_code(
refer_text,
False,
self.device,
use_decoder,
params_infer_code,
))
wavs = self._decode_to_wavs(
result.hiddens if use_decoder else result.ids,
use_decoder,
)
result.destroy()
if stream:
pass_batch_count += 1
if pass_batch_count <= params_infer_code.pass_first_n_batches:
continue
a = length
b = a + params_infer_code.stream_speed
if b > wavs.shape[1]:
b = wavs.shape[1]
new_wavs = wavs[:, a:b]
length = b
yield new_wavs
else:
yield wavs
assert len(wavs), 1
params_infer_code.spk_smp = self.sample_audio_speaker(wavs[0])
params_infer_code.txt_smp = refer_text

if stream:
new_wavs = wavs[:, length:]
# Identify rows with non-zero elements using np.any
# keep_rows = np.any(array != 0, axis=1)
keep_cols = np.sum(new_wavs != 0, axis=0) > 0
# Filter both rows and columns using slicing
yield new_wavs[:][:, keep_cols]
length = 0
pass_batch_count = 0
if split_text:
n = len(text) // max_split_batch
if len(text) % max_split_batch: n += 1
else:
n = 1
max_split_batch = len(text)
for i in range(n):
text_remain = text[i*max_split_batch:]
if len(text_remain) > max_split_batch:
text_remain = text_remain[:max_split_batch]
if split_text:
self.logger.info("infer split %d~%d", i*max_split_batch, i*max_split_batch+len(text_remain))
for result in self._infer_code(
text_remain,
stream,
self.device,
use_decoder,
params_infer_code,
):
wavs = self._decode_to_wavs(
result.hiddens if use_decoder else result.ids,
use_decoder,
)
result.destroy()
if stream:
pass_batch_count += 1
if pass_batch_count <= params_infer_code.pass_first_n_batches:
continue
a = length
b = a + params_infer_code.stream_speed
if b > wavs.shape[1]:
b = wavs.shape[1]
new_wavs = wavs[:, a:b]
length = b
yield new_wavs
else:
yield wavs
if stream:
new_wavs = wavs[:, length:]
keep_cols = np.sum(np.abs(new_wavs)>1e-5, axis=0) > 0
yield new_wavs[:][:, keep_cols]

@torch.inference_mode()
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
Expand Down
5 changes: 5 additions & 0 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def refine_text(
temperature,
top_P,
top_K,
split_batch,
):
global chat

Expand All @@ -150,6 +151,7 @@ def refine_text(
top_K=top_K,
manual_seed=text_seed_input,
),
split_text=split_batch > 0,
)

return text[0] if isinstance(text, list) else text
Expand All @@ -165,6 +167,7 @@ def generate_audio(
audio_seed_input,
sample_text_input,
sample_audio_code_input,
split_batch,
):
global chat, has_interrupted

Expand All @@ -189,6 +192,8 @@ def generate_audio(
skip_refine_text=True,
params_infer_code=params_infer_code,
stream=stream,
split_text=split_batch > 0,
max_split_batch=split_batch,
)
if stream:
for gen in wav:
Expand Down
10 changes: 10 additions & 0 deletions examples/web/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,14 @@ def main():
scale=1,
interactive=True,
)
split_batch_slider = gr.Slider(
minimum=0,
maximum=100,
step=1,
value=4,
label="Split Batch",
interactive=True,
)
generate_button = gr.Button(
"Generate", scale=2, variant="primary", interactive=True
)
Expand Down Expand Up @@ -208,6 +216,7 @@ def make_audio(autoplay, stream):
temperature_slider,
top_p_slider,
top_k_slider,
split_batch_slider,
],
outputs=text_output,
).then(
Expand All @@ -222,6 +231,7 @@ def make_audio(autoplay, stream):
audio_seed_input,
sample_text_input,
sample_audio_code_input,
split_batch_slider,
],
outputs=audio_output,
).then(
Expand Down

0 comments on commit 8d7bcf0

Please sign in to comment.