Skip to content

Commit

Permalink
chore(format): run black on dev (#868)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and github-actions[bot] authored Jan 7, 2025
1 parent 8d7bcf0 commit e6ab5ca
Showing 1 changed file with 22 additions and 14 deletions.
36 changes: 22 additions & 14 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def infer(
if "\n" in text:
text = text.split("\n")
else:
text = re.split(r'(?<=[。(.\s)])', text)
text = re.split(r"(?<=[。(.\s)])", text)
nt = []
for t in text:
if t:
Expand All @@ -234,7 +234,8 @@ def infer(
self.logger.info("split text into %d parts", len(text))
self.logger.debug("%s", str(text))

if len(text) == 0: return []
if len(text) == 0:
return []

res_gen = self._infer(
text,
Expand All @@ -256,7 +257,7 @@ def infer(
stripped_wavs = []
for wavs in res_gen:
for wav in wavs:
stripped_wavs.append(wav[np.abs(wav)>1e-5])
stripped_wavs.append(wav[np.abs(wav) > 1e-5])
if split_text:
return [np.concatenate(stripped_wavs)]
return stripped_wavs
Expand Down Expand Up @@ -428,13 +429,15 @@ def _infer(

if split_text and len(text) > 1 and params_infer_code.spk_smp is None:
refer_text = text[0]
result = next(self._infer_code(
refer_text,
False,
self.device,
use_decoder,
params_infer_code,
))
result = next(
self._infer_code(
refer_text,
False,
self.device,
use_decoder,
params_infer_code,
)
)
wavs = self._decode_to_wavs(
result.hiddens if use_decoder else result.ids,
use_decoder,
Expand All @@ -449,16 +452,21 @@ def _infer(
pass_batch_count = 0
if split_text:
n = len(text) // max_split_batch
if len(text) % max_split_batch: n += 1
if len(text) % max_split_batch:
n += 1
else:
n = 1
max_split_batch = len(text)
for i in range(n):
text_remain = text[i*max_split_batch:]
text_remain = text[i * max_split_batch :]
if len(text_remain) > max_split_batch:
text_remain = text_remain[:max_split_batch]
if split_text:
self.logger.info("infer split %d~%d", i*max_split_batch, i*max_split_batch+len(text_remain))
self.logger.info(
"infer split %d~%d",
i * max_split_batch,
i * max_split_batch + len(text_remain),
)
for result in self._infer_code(
text_remain,
stream,
Expand Down Expand Up @@ -486,7 +494,7 @@ def _infer(
yield wavs
if stream:
new_wavs = wavs[:, length:]
keep_cols = np.sum(np.abs(new_wavs)>1e-5, axis=0) > 0
keep_cols = np.sum(np.abs(new_wavs) > 1e-5, axis=0) > 0
yield new_wavs[:][:, keep_cols]

@torch.inference_mode()
Expand Down

0 comments on commit e6ab5ca

Please sign in to comment.