Skip to content

Commit

Permalink
chore(format): run black on main (#501)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and github-actions[bot] authored Jun 29, 2024
1 parent c98fd16 commit 7db7d08
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
24 changes: 19 additions & 5 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def _infer(
):
wav = self._decode_to_wavs(result, length, use_decoder)
yield wav

def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
if "mps" in str(self.device):
return self.vocos.decode(spec.cpu()).cpu().numpy()
Expand Down Expand Up @@ -460,13 +460,27 @@ def _text_to_token(
attn_sz = attention_mask_lst[-1].size(0)
if attn_sz > max_attention_mask_len:
max_attention_mask_len = attn_sz
input_ids = torch.zeros(len(input_ids_lst), max_input_ids_len, device=device, dtype=input_ids_lst[0].dtype)
input_ids = torch.zeros(
len(input_ids_lst),
max_input_ids_len,
device=device,
dtype=input_ids_lst[0].dtype,
)
for i in range(len(input_ids_lst)):
input_ids.narrow(0, i, 1).narrow(1, 0, input_ids_lst[i].size(0)).copy_(input_ids_lst[i])
input_ids.narrow(0, i, 1).narrow(1, 0, input_ids_lst[i].size(0)).copy_(
input_ids_lst[i]
)
del_all(input_ids_lst)
attention_mask = torch.zeros(len(attention_mask_lst), max_attention_mask_len, device=device, dtype=attention_mask_lst[0].dtype)
attention_mask = torch.zeros(
len(attention_mask_lst),
max_attention_mask_len,
device=device,
dtype=attention_mask_lst[0].dtype,
)
for i in range(len(attention_mask_lst)):
attention_mask.narrow(0, i, 1).narrow(1, 0, attention_mask_lst[i].size(0)).copy_(attention_mask_lst[i])
attention_mask.narrow(0, i, 1).narrow(
1, 0, attention_mask_lst[i].size(0)
).copy_(attention_mask_lst[i])
del_all(attention_mask_lst)

text_mask = torch.ones(input_ids.shape, dtype=bool, device=device)
Expand Down
21 changes: 15 additions & 6 deletions examples/cmd/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,19 @@ def main(texts: List[str], spk: Optional[str] = None):
else:
logger.error("Models load failed.")
sys.exit(1)

if spk is None:
spk = chat.sample_random_speaker()
logger.info("Use speaker:")
print(spk)

logger.info("Start inference.")
wavs = chat.infer(texts, params_infer_code=ChatTTS.Chat.InferCodeParams(
spk_emb=spk,
))
wavs = chat.infer(
texts,
params_infer_code=ChatTTS.Chat.InferCodeParams(
spk_emb=spk,
),
)
logger.info("Inference completed.")
# Save each generated wav file to a local file
for index, wav in enumerate(wavs):
Expand All @@ -55,9 +58,15 @@ def main(texts: List[str], spk: Optional[str] = None):
if __name__ == "__main__":
logger.info("Starting ChatTTS commandline demo...")
parser = argparse.ArgumentParser(
description="ChatTTS Command", usage='[--spk xxx] "Your text 1." " Your text 2."'
description="ChatTTS Command",
usage='[--spk xxx] "Your text 1." " Your text 2."',
)
parser.add_argument(
"--spk",
help="Speaker (empty to sample a random one)",
type=Optional[str],
default=None,
)
parser.add_argument("--spk", help="Speaker (empty to sample a random one)", type=Optional[str], default=None)
parser.add_argument(
"texts", help="Original text", default="YOUR TEXT HERE", nargs="*"
)
Expand Down

0 comments on commit 7db7d08

Please sign in to comment.