Skip to content

Commit

Permalink
RWKV: now faster and less params
Browse files Browse the repository at this point in the history
  • Loading branch information
www committed Aug 13, 2021
1 parent 546114c commit 3b9005e
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/model.py
Expand Up @@ -68,9 +68,10 @@ def __init__(self, config, layer_id):
self.layer_id = layer_id
self.time_shift = nn.ZeroPad2d((0,0,1,0))

self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
self.value = nn.Linear(config.n_embd, 3 * config.n_embd)
self.weight = nn.Linear(3 * config.n_embd, config.n_embd)
hidden_sz = 5 * config.n_embd // 2 # can use smaller hidden_sz because of R
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)
self.receptance = nn.Linear(config.n_embd, config.n_embd)

def forward(self, x):
Expand Down Expand Up @@ -166,9 +167,10 @@ class GeGLU(torch.nn.Module):
def __init__(self, config, layer_id):
super().__init__()
self.layer_id = layer_id
self.key = nn.Linear(config.n_embd, 3 * config.n_embd)
self.value = nn.Linear(config.n_embd, 3 * config.n_embd)
self.weight = nn.Linear(3 * config.n_embd, config.n_embd)
hidden_sz = 3 * config.n_embd
self.key = nn.Linear(config.n_embd, hidden_sz)
self.value = nn.Linear(config.n_embd, hidden_sz)
self.weight = nn.Linear(hidden_sz, config.n_embd)

def forward(self, x):
k = self.key(x)
Expand Down

0 comments on commit 3b9005e

Please sign in to comment.