diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index 04e89ab7a..3da975962 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -1,5 +1,4 @@ import math -from einops import rearrange from vector_quantize_pytorch import GroupedResidualFSQ import torch @@ -66,23 +65,32 @@ def __init__(self, self.G = G self.R = R - def _embed(self, x): + def _embed(self, x: torch.Tensor): if self.transpose: x = x.transpose(1,2) + """ x = rearrange( x, "b t (g r) -> g b t r", g = self.G, r = self.R, - ) + ) + """ + x.view(-1, self.G, self.R).permute(2, 0, 1, 3) feat = self.quantizer.get_output_from_indices(x) return feat.transpose(1,2) if self.transpose else feat - + def forward(self, x,): if self.transpose: x = x.transpose(1,2) feat, ind = self.quantizer(x) + """ ind = rearrange( ind, "g b t r ->b t (g r)", - ) - embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype) + ) + """ + ind = ind.permute(1, 2, 0, 3).contiguous() + ind = ind.view(ind.size(0), ind.size(1), -1) + embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind) + embed_onehot = embed_onehot_tmp.to(x.dtype) + del embed_onehot_tmp e_mean = torch.mean(embed_onehot, dim=[0,1]) e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1) perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))