Skip to content

Commit

Permalink
optimize(dvae): remove einops (#383)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jun 21, 2024
1 parent e6412b1 commit ace4d0c
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions ChatTTS/model/dvae.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import math
from einops import rearrange
from vector_quantize_pytorch import GroupedResidualFSQ

import torch
Expand Down Expand Up @@ -66,23 +65,32 @@ def __init__(self,
self.G = G
self.R = R

def _embed(self, x):
def _embed(self, x: torch.Tensor):
if self.transpose:
x = x.transpose(1,2)
"""
x = rearrange(
x, "b t (g r) -> g b t r", g = self.G, r = self.R,
)
)
"""
x.view(-1, self.G, self.R).permute(2, 0, 1, 3)
feat = self.quantizer.get_output_from_indices(x)
return feat.transpose(1,2) if self.transpose else feat

def forward(self, x,):
if self.transpose:
x = x.transpose(1,2)
feat, ind = self.quantizer(x)
"""
ind = rearrange(
ind, "g b t r ->b t (g r)",
)
embed_onehot = F.one_hot(ind.long(), self.n_ind).to(x.dtype)
)
"""
ind = ind.permute(1, 2, 0, 3).contiguous()
ind = ind.view(ind.size(0), ind.size(1), -1)
embed_onehot_tmp = F.one_hot(ind.long(), self.n_ind)
embed_onehot = embed_onehot_tmp.to(x.dtype)
del embed_onehot_tmp
e_mean = torch.mean(embed_onehot, dim=[0,1])
e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1)
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1))
Expand Down

0 comments on commit ace4d0c

Please sign in to comment.