Skip to content

Commit a82e1df

Browse files
IrisSallydragon-hellogithub-actions[bot]
authored
fix: device mismatch error in embedding loading on branch dev (#737)
* fix device mismatch bug * fix additional device mismatch bug when zero-shot * chore(format): run black on dev --------- Co-authored-by: dragon <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 024f93e commit a82e1df

File tree

3 files changed

+10
-4
lines changed

3 files changed

+10
-4
lines changed

ChatTTS/core.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def _load(
272272
vq_config=asdict(self.config.dvae.vq),
273273
dim=self.config.dvae.decoder.idim,
274274
coef=coef,
275+
device=self.device,
275276
)
276277
.to(device)
277278
.eval()
@@ -288,8 +289,8 @@ def _load(
288289
self.config.embed.num_text_tokens,
289290
self.config.embed.num_vq,
290291
)
291-
embed.from_pretrained(embed_path)
292-
self.embed = embed
292+
embed.from_pretrained(embed_path, device=self.device)
293+
self.embed = embed.to(self.device)
293294
self.logger.log(logging.INFO, "embed loaded.")
294295

295296
gpt = GPT(

ChatTTS/model/dvae.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,10 @@ def __init__(
179179
hop_length=256,
180180
n_mels=100,
181181
padding: Literal["center", "same"] = "center",
182+
device: torch.device = torch.device("cuda"),
182183
):
183184
super().__init__()
185+
self.device = device
184186
if padding not in ["center", "same"]:
185187
raise ValueError("Padding must be 'center' or 'same'.")
186188
self.padding = padding
@@ -197,6 +199,7 @@ def __call__(self, audio: torch.Tensor) -> torch.Tensor:
197199
return super().__call__(audio)
198200

199201
def forward(self, audio: torch.Tensor) -> torch.Tensor:
202+
audio = audio.to(self.device)
200203
mel: torch.Tensor = self.mel_spec(audio)
201204
features = torch.log(torch.clip(mel, min=1e-5))
202205
return features
@@ -210,6 +213,7 @@ def __init__(
210213
vq_config: Optional[dict] = None,
211214
dim=512,
212215
coef: Optional[str] = None,
216+
device: torch.device = torch.device("cuda"),
213217
):
214218
super().__init__()
215219
if coef is None:
@@ -227,7 +231,7 @@ def __init__(
227231
nn.Conv1d(dim, dim, 4, 2, 1),
228232
nn.GELU(),
229233
)
230-
self.preprocessor_mel = MelSpectrogramFeatures()
234+
self.preprocessor_mel = MelSpectrogramFeatures(device=device)
231235
self.encoder: Optional[DVAEDecoder] = DVAEDecoder(**encoder_config)
232236

233237
self.decoder = DVAEDecoder(**decoder_config)

ChatTTS/model/embed.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,13 @@ def __init__(
3434
)
3535

3636
@torch.inference_mode()
37-
def from_pretrained(self, filename: str):
37+
def from_pretrained(self, filename: str, device: torch.device):
3838
state_dict_tensors = {}
3939
with safe_open(filename, framework="pt") as f:
4040
for k in f.keys():
4141
state_dict_tensors[k] = f.get_tensor(k)
4242
self.load_state_dict(state_dict_tensors)
43+
self.to(device)
4344

4445
def __call__(
4546
self, input_ids: torch.Tensor, text_mask: torch.Tensor

0 commit comments

Comments
 (0)