From 3b9005ea11788a4dd8b54555d5efee397cb31a24 Mon Sep 17 00:00:00 2001 From: BlinkDL Date: Fri, 13 Aug 2021 18:39:24 +0800 Subject: [PATCH] RWKV: now faster and less params --- src/model.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/model.py b/src/model.py index f1c47194..7cebae9f 100644 --- a/src/model.py +++ b/src/model.py @@ -68,9 +68,10 @@ def __init__(self, config, layer_id): self.layer_id = layer_id self.time_shift = nn.ZeroPad2d((0,0,1,0)) - self.key = nn.Linear(config.n_embd, 3 * config.n_embd) - self.value = nn.Linear(config.n_embd, 3 * config.n_embd) - self.weight = nn.Linear(3 * config.n_embd, config.n_embd) + hidden_sz = 5 * config.n_embd // 2 # can use smaller hidden_sz because of R + self.key = nn.Linear(config.n_embd, hidden_sz) + self.value = nn.Linear(config.n_embd, hidden_sz) + self.weight = nn.Linear(hidden_sz, config.n_embd) self.receptance = nn.Linear(config.n_embd, config.n_embd) def forward(self, x): @@ -166,9 +167,10 @@ class GeGLU(torch.nn.Module): def __init__(self, config, layer_id): super().__init__() self.layer_id = layer_id - self.key = nn.Linear(config.n_embd, 3 * config.n_embd) - self.value = nn.Linear(config.n_embd, 3 * config.n_embd) - self.weight = nn.Linear(3 * config.n_embd, config.n_embd) + hidden_sz = 3 * config.n_embd + self.key = nn.Linear(config.n_embd, hidden_sz) + self.value = nn.Linear(config.n_embd, hidden_sz) + self.weight = nn.Linear(hidden_sz, config.n_embd) def forward(self, x): k = self.key(x)