From b3d511b9f63a98dbde4e072129ef2c8fedf71ef3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90?= <110042431+zly-idleness@users.noreply.github.com> Date: Tue, 15 Oct 2024 23:56:18 +0800 Subject: [PATCH] compatibility: change gamma to weight (#733) --- ChatTTS/model/dvae.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index 7e6b62a83..0f745cee4 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -36,7 +36,7 @@ def __init__( ) # pointwise/1x1 convs, implemented with linear layers self.act = nn.GELU() self.pwconv2 = nn.Linear(intermediate_dim, dim) - self.gamma = ( + self.weight = ( nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True) if layer_scale_init_value > 0 else None @@ -55,8 +55,8 @@ def forward(self, x: torch.Tensor, cond=None) -> torch.Tensor: del y y = self.pwconv2(x) del x - if self.gamma is not None: - y *= self.gamma + if self.weight is not None: + y *= self.weight y.transpose_(1, 2) # (B, T, C) -> (B, C, T) x = y + residual