Skip to content

Commit

Permalink
Merge branch '2noise:dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
ain-soph authored Oct 5, 2024
2 parents 1119066 + 71b42e0 commit 41f5e70
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 7 deletions.
7 changes: 4 additions & 3 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def _load(
vq_config=asdict(self.config.dvae.vq),
dim=self.config.dvae.decoder.idim,
coef=coef,
device=self.device,
device=device,
)
.to(device)
.eval()
Expand All @@ -289,8 +289,8 @@ def _load(
self.config.embed.num_text_tokens,
self.config.embed.num_vq,
)
embed.from_pretrained(embed_path, device=self.device)
self.embed = embed.to(self.device)
embed.from_pretrained(embed_path, device=device)
self.embed = embed.to(device)
self.logger.log(logging.INFO, "embed loaded.")

gpt = GPT(
Expand Down Expand Up @@ -318,6 +318,7 @@ def _load(
decoder_config=asdict(self.config.decoder),
dim=self.config.decoder.idim,
coef=coef,
device=device,
)
.to(device)
.eval()
Expand Down
4 changes: 2 additions & 2 deletions ChatTTS/model/dvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def __init__(
hop_length=256,
n_mels=100,
padding: Literal["center", "same"] = "center",
device: torch.device = torch.device("cuda"),
device: torch.device = torch.device("cpu"),
):
super().__init__()
self.device = device
Expand Down Expand Up @@ -213,7 +213,7 @@ def __init__(
vq_config: Optional[dict] = None,
dim=512,
coef: Optional[str] = None,
device: torch.device = torch.device("cuda"),
device: torch.device = torch.device("cpu"),
):
super().__init__()
if coef is None:
Expand Down
6 changes: 5 additions & 1 deletion ChatTTS/model/velocity/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,11 @@ def execute_model(
for i in range(self.post_model.num_vq)
]
input_emb = torch.stack(code_emb, 3).sum(3)
start_idx = input_tokens_history.shape[-2] - 1 if input_tokens_history.shape[-2] > 0 else 0
start_idx = (
input_tokens_history.shape[-2] - 1
if input_tokens_history.shape[-2] > 0
else 0
)
else:
input_emb = self.post_model(input_tokens, text_mask)
# print(input_emb.shape)
Expand Down
1 change: 1 addition & 0 deletions ChatTTS/utils/dl.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def download_all_assets(tmpdir: str, version="0.2.8"):
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/128.0.0.0 Safari/537.36 Edg/128.0.0.0"
},
)
cmdfile = os.path.join(tmpdir, "rvcmd")
if is_win:
download_and_extract_zip(RVCMD_URL, tmpdir)
cmdfile += ".exe"
Expand Down
1 change: 1 addition & 0 deletions docs/cn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ ChatTTS 是一款专门为对话场景(例如 LLM 助手)设计的文本转

- [x] 开源 4 万小时基础模型和 spk_stats 文件。
- [x] 支持流式语音输出。
- [x] 开源 DVAE 编码器和零样本推理代码
- [ ] 开源具有多情感控制功能的 4 万小时版本。
- [ ] ChatTTS.cpp (欢迎在 2noise 组织中新建仓库)。

Expand Down
2 changes: 1 addition & 1 deletion examples/ipynb/colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@
"metadata": {},
"outputs": [],
"source": [
"from tools.audio import load_audio\n",
"from ChatTTS.tools.audio import load_audio\n",
"\n",
"spk_smp = chat.sample_audio_speaker(load_audio(\"sample.mp3\", 24000))\n",
"print(spk_smp) # save it in order to load the speaker without sample audio next time\n",
Expand Down

0 comments on commit 41f5e70

Please sign in to comment.