Skip to content

Commit b367dda

Browse files
authored
Merge branch '2noise:dev' into dev
2 parents b625783 + 0ec82fe commit b367dda

File tree

27 files changed

+465
-96
lines changed

27 files changed

+465
-96
lines changed

.github/workflows/checksum.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ on:
44

55
jobs:
66
checksum:
7-
runs-on: ubuntu-latest
7+
runs-on: ubuntu-24.04
88
steps:
99
- uses: actions/checkout@v4
1010

@@ -13,7 +13,7 @@ jobs:
1313

1414
- name: Run RVC-Models-Downloader
1515
run: |
16-
wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.8/rvcmd_linux_amd64.deb
16+
wget https://github.com/fumiama/RVC-Models-Downloader/releases/download/v0.2.9/rvcmd_linux_amd64.deb
1717
sudo apt -y install ./rvcmd_linux_amd64.deb
1818
rm -f ./rvcmd_linux_amd64.deb
1919
rvcmd -notrs -w 1 -notui assets/chtts

.github/workflows/close-issue.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@ on:
55

66
jobs:
77
close-issues:
8-
runs-on: ubuntu-latest
8+
runs-on: ubuntu-24.04
99
permissions:
1010
issues: write
1111
pull-requests: write
1212
steps:
1313
- uses: actions/stale@v5
1414
with:
15-
exempt-issue-labels: "help wanted,good first issue,documentation,following up,todo list"
15+
exempt-issue-labels: "help wanted,following up,todo list,enhancement,algorithm,delayed,performance"
1616
days-before-issue-stale: 30
1717
days-before-issue-close: 15
1818
stale-issue-label: "stale"

.github/workflows/pull-format.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
# This workflow closes invalid PR
99
change-or-close-pr:
1010
# The type of runner that the job will run on
11-
runs-on: ubuntu-latest
11+
runs-on: ubuntu-24.04
1212
permissions: write-all
1313

1414
# Steps represent a sequence of tasks that will be executed as part of the job
@@ -63,6 +63,14 @@ jobs:
6363
- name: Set up Python
6464
uses: actions/setup-python@v5
6565

66+
- name: Create venv
67+
run: python3 -m venv .venv
68+
69+
- name: Activate venv
70+
run: |
71+
. .venv/bin/activate
72+
echo PATH=$PATH >> $GITHUB_ENV
73+
6674
- name: Install Black
6775
run: pip install "black[jupyter]"
6876

.github/workflows/push-format.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ jobs:
2424
- name: Set up Python
2525
uses: actions/setup-python@v5
2626

27+
- name: Create venv
28+
run: python3 -m venv .venv
29+
30+
- name: Activate venv
31+
run: |
32+
. .venv/bin/activate
33+
echo PATH=$PATH >> $GITHUB_ENV
34+
2735
- name: Install Black
2836
run: pip install "black[jupyter]"
2937

.github/workflows/unitest.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ jobs:
2525
run: |
2626
sudo apt-get install -y portaudio19-dev python3-pyaudio
2727
28+
- name: Create venv
29+
run: python3 -m venv .venv
30+
31+
- name: Activate venv
32+
run: |
33+
. .venv/bin/activate
34+
echo PATH=$PATH >> $GITHUB_ENV
35+
2836
- name: Test Install
2937
run: pip install .
3038

ChatTTS/config/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33

44
@dataclass(repr=False, eq=False)
55
class Path:
6-
vocos_ckpt_path: str = "asset/Vocos.pt"
7-
dvae_ckpt_path: str = "asset/DVAE_full.pt"
6+
vocos_ckpt_path: str = "asset/Vocos.safetensors"
7+
dvae_ckpt_path: str = "asset/DVAE.safetensors"
88
gpt_ckpt_path: str = "asset/gpt"
9-
decoder_ckpt_path: str = "asset/Decoder.pt"
9+
decoder_ckpt_path: str = "asset/Decoder.safetensors"
1010
tokenizer_path: str = "asset/tokenizer"
1111
embed_path: str = "asset/Embed.safetensors"
1212

ChatTTS/core.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .config import Config
1616
from .model import DVAE, Embed, GPT, gen_logits, Tokenizer, Speaker
1717
from .utils import (
18+
load_safetensors,
1819
check_all_assets,
1920
download_all_assets,
2021
select_device,
@@ -97,7 +98,7 @@ def download_models(
9798
try:
9899
download_path = snapshot_download(
99100
repo_id="2Noise/ChatTTS",
100-
allow_patterns=["*.pt", "*.yaml", "*.json", "*.safetensors"],
101+
allow_patterns=["*.yaml", "*.json", "*.safetensors"],
101102
)
102103
except:
103104
download_path = None
@@ -253,34 +254,34 @@ def _load(
253254
vocos = (
254255
Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head)
255256
.to(
256-
# vocos on mps will crash, use cpu fallback
257+
# Vocos on mps will crash, use cpu fallback.
258+
# Plus, complex dtype used in the decode process of Vocos is not supported in torch_npu now,
259+
# so we put this calculation of data on CPU instead of NPU.
257260
"cpu"
258-
if "mps" in str(device)
261+
if "mps" in str(device) or "npu" in str(device)
259262
else device
260263
)
261264
.eval()
262265
)
263266
assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
264-
vocos.load_state_dict(torch.load(vocos_ckpt_path, weights_only=True, mmap=True))
267+
vocos.load_state_dict(load_safetensors(vocos_ckpt_path))
265268
self.vocos = vocos
266269
self.logger.log(logging.INFO, "vocos loaded.")
267270

268-
dvae = (
269-
DVAE(
270-
decoder_config=asdict(self.config.dvae.decoder),
271-
encoder_config=asdict(self.config.dvae.encoder),
272-
vq_config=asdict(self.config.dvae.vq),
273-
dim=self.config.dvae.decoder.idim,
274-
coef=coef,
275-
device=device,
276-
)
277-
.to(device)
278-
.eval()
271+
# computation of MelSpectrogram on npu is not support now, use cpu fallback.
272+
dvae_device = torch.device("cpu") if "npu" in str(self.device) else device
273+
dvae = DVAE(
274+
decoder_config=asdict(self.config.dvae.decoder),
275+
encoder_config=asdict(self.config.dvae.encoder),
276+
vq_config=asdict(self.config.dvae.vq),
277+
dim=self.config.dvae.decoder.idim,
278+
coef=coef,
279+
device=dvae_device,
279280
)
280281
coef = str(dvae)
281282
assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
282-
dvae.load_state_dict(torch.load(dvae_ckpt_path, weights_only=True, mmap=True))
283-
self.dvae = dvae
283+
dvae.load_pretrained(dvae_ckpt_path, dvae_device)
284+
self.dvae = dvae.eval()
284285
self.logger.log(logging.INFO, "dvae loaded.")
285286

286287
embed = Embed(
@@ -289,7 +290,7 @@ def _load(
289290
self.config.embed.num_text_tokens,
290291
self.config.embed.num_vq,
291292
)
292-
embed.from_pretrained(embed_path, device=device)
293+
embed.load_pretrained(embed_path, device=device)
293294
self.embed = embed.to(device)
294295
self.logger.log(logging.INFO, "embed loaded.")
295296

@@ -303,7 +304,7 @@ def _load(
303304
logger=self.logger,
304305
).eval()
305306
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
306-
gpt.from_pretrained(gpt_ckpt_path, embed_path, experimental=experimental)
307+
gpt.load_pretrained(gpt_ckpt_path, embed_path, experimental=experimental)
307308
gpt.prepare(compile=compile and "cuda" in str(device))
308309
self.gpt = gpt
309310
self.logger.log(logging.INFO, "gpt loaded.")
@@ -313,22 +314,16 @@ def _load(
313314
)
314315
self.logger.log(logging.INFO, "speaker loaded.")
315316

316-
decoder = (
317-
DVAE(
318-
decoder_config=asdict(self.config.decoder),
319-
dim=self.config.decoder.idim,
320-
coef=coef,
321-
device=device,
322-
)
323-
.to(device)
324-
.eval()
317+
decoder = DVAE(
318+
decoder_config=asdict(self.config.decoder),
319+
dim=self.config.decoder.idim,
320+
coef=coef,
321+
device=device,
325322
)
326323
coef = str(decoder)
327324
assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
328-
decoder.load_state_dict(
329-
torch.load(decoder_ckpt_path, weights_only=True, mmap=True)
330-
)
331-
self.decoder = decoder
325+
decoder.load_pretrained(decoder_ckpt_path, device)
326+
self.decoder = decoder.eval()
332327
self.logger.log(logging.INFO, "decoder loaded.")
333328

334329
if tokenizer_path:
@@ -422,7 +417,7 @@ def _infer(
422417

423418
@torch.inference_mode()
424419
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
425-
if "mps" in str(self.device):
420+
if "mps" in str(self.device) or "npu" in str(self.device):
426421
return self.vocos.decode(spec.cpu()).cpu().numpy()
427422
else:
428423
return self.vocos.decode(spec).cpu().numpy()

ChatTTS/model/dvae.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import pybase16384 as b14
66
import torch
77
import torch.nn as nn
8-
import torch.nn.functional as F
98
import torchaudio
109
from vector_quantize_pytorch import GroupedResidualFSQ
1110

11+
from ..utils import load_safetensors
12+
1213

1314
class ConvNeXtBlock(nn.Module):
1415
def __init__(
@@ -36,7 +37,7 @@ def __init__(
3637
) # pointwise/1x1 convs, implemented with linear layers
3738
self.act = nn.GELU()
3839
self.pwconv2 = nn.Linear(intermediate_dim, dim)
39-
self.gamma = (
40+
self.weight = (
4041
nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
4142
if layer_scale_init_value > 0
4243
else None
@@ -55,8 +56,8 @@ def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor:
5556
del y
5657
y = self.pwconv2(x)
5758
del x
58-
if self.gamma is not None:
59-
y *= self.gamma
59+
if self.weight is not None:
60+
y *= self.weight
6061
y.transpose_(1, 2) # (B, T, C) -> (B, C, T)
6162

6263
x = y + residual
@@ -251,6 +252,12 @@ def __call__(
251252
) -> torch.Tensor:
252253
return super().__call__(inp, mode)
253254

255+
@torch.inference_mode()
256+
def load_pretrained(self, filename: str, device: torch.device):
257+
state_dict_tensors = load_safetensors(filename)
258+
self.load_state_dict(state_dict_tensors)
259+
self.to(device)
260+
254261
@torch.inference_mode()
255262
def forward(
256263
self, inp: torch.Tensor, mode: Literal["encode", "decode"] = "decode"

ChatTTS/model/embed.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from safetensors.torch import safe_open
21
import torch
32
import torch.nn as nn
43
from torch.nn.utils.parametrizations import weight_norm
54

5+
from ..utils import load_safetensors
6+
67

78
class Embed(nn.Module):
89
def __init__(
@@ -34,11 +35,8 @@ def __init__(
3435
)
3536

3637
@torch.inference_mode()
37-
def from_pretrained(self, filename: str, device: torch.device):
38-
state_dict_tensors = {}
39-
with safe_open(filename, framework="pt") as f:
40-
for k in f.keys():
41-
state_dict_tensors[k] = f.get_tensor(k)
38+
def load_pretrained(self, filename: str, device: torch.device):
39+
state_dict_tensors = load_safetensors(filename)
4240
self.load_state_dict(state_dict_tensors)
4341
self.to(device)
4442

ChatTTS/model/gpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def __init__(
5656
self.head_text = embed.head_text.__call__
5757
self.head_code = [hc.__call__ for hc in embed.head_code]
5858

59-
def from_pretrained(
59+
def load_pretrained(
6060
self, gpt_folder: str, embed_file_path: str, experimental=False
6161
):
6262
if self.is_vllm and platform.system().lower() == "linux":

0 commit comments

Comments
 (0)