Skip to content

fawazsammani/mogrifier-lstm-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

41 Commits
 
 
 
 

Repository files navigation

Implementation of Mogrifier LSTM Cell in PyTorch

This follows the implementation of a Mogrifier LSTM proposed here

The Mogrifier LSTM is an LSTM where two inputs x and h_prev modulate one another in an alternating fashion before the LSTM computation.

Capture

You can easily define the Mogrifier LSTMCell just like defining nn.LSTMCell, with an additional parameter of mogrify_steps:

mog_lstm = MogrifierLSTMCell(input_size, hidden_size, mogrify_steps)

Here we provide an example of a model with two-layer Mogrifier LSTM.

from mog_lstm import MogrifierLSTMCell
import torch
import torch.nn as nn
        
class Model(nn.Module):
    def __init__(self, input_size, hidden_size, mogrify_steps, vocab_size, tie_weights, dropout):
        super(Model, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, input_size)
        self.mogrifier_lstm_layer1 = MogrifierLSTMCell(input_size, hidden_size, mogrify_steps)
        self.mogrifier_lstm_layer2 = MogrifierLSTMCell(hidden_size, hidden_size, mogrify_steps)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.drop = nn.Dropout(dropout)
        if tie_weights:
            self.fc.weight = self.embedding.weight
        
    def forward(self, seq, max_len = 10):
        
        embed = self.embedding(seq)
        batch_size = seq.shape[0]
        h1,c1 = [torch.zeros(batch_size,self.hidden_size), torch.zeros(batch_size,self.hidden_size)]
        h2,c2 = [torch.zeros(batch_size,self.hidden_size), torch.zeros(batch_size,self.hidden_size)]
        hidden_states = []
        outputs = []
        for step in range(max_len):
            x = self.drop(embed[:, step])
            h1,c1 = self.mogrifier_lstm_layer1(x, (h1, c1))
            h2,c2 = self.mogrifier_lstm_layer2(h1, (h2, c2))
            out = self.fc(self.drop(h2))
            hidden_states.append(h2.unsqueeze(1))
            outputs.append(out.unsqueeze(1))
            

        hidden_states = torch.cat(hidden_states, dim = 1)   # (batch_size, max_len, hidden_size)
        outputs = torch.cat(outputs, dim = 1)               # (batch_size, max_len, vocab_size)
        
        return outputs, hidden_states 
input_size = 512
hidden_size = 512
vocab_size = 30
batch_size = 4
lr = 3e-3
mogrify_steps = 5        # 5 steps give optimal performance according to the paper
dropout = 0.5            # for simplicity: input dropout and output_dropout are 0.5. See appendix B in the paper for exact values
tie_weights = True       # in the paper, embedding weights and output weights are tied
betas = (0, 0.999)       # in the paper the momentum term in Adam is ignored
weight_decay = 2.5e-4    # weight decay is around this value, see appendix B in the paper
clip_norm = 10           # paper uses cip_norm of 10

model = Model(input_size, hidden_size, mogrify_steps, vocab_size, tie_weights, dropout)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=betas, eps=1e-08, weight_decay=weight_decay)

# seq of shape (batch_size, max_words)
seq = torch.LongTensor([[ 8, 29, 18,  1, 17,  3, 26,  6, 26,  5],
                        [ 8, 28, 15, 12, 13,  2, 26, 16, 20,  0],
                        [15,  4, 27, 14, 29, 28, 14,  1,  0,  0],
                        [20, 22, 29, 22, 23, 29,  0,  0,  0,  0]])
                        
outputs, hidden_states = model(seq)
print(outputs.shape)
print(hidden_states.shape)

Factorization of Q and R as products of low-rank matrices

If you would like to implement the Factorization of Q and R as products of low-rank matrices as done in the paper, you can do as follows:

k = 85  # if set to 85: (512 * 85) + (85 * 512) << (512 * 512)
self.mogrifier_list = nn.ModuleList([torch.nn.Sequential(torch.nn.Linear(hidden_size, k, bias = False), torch.nn.Linear(k, input_size, bias = True))])  # start with q
for i in range(1, mogrify_steps):
    if i % 2 == 0:
        self.mogrifier_list.extend([torch.nn.Sequential(torch.nn.Linear(hidden_size, k, bias = False), torch.nn.Linear(k, input_size, bias = True))])  # q
    else:
        self.mogrifier_list.extend([torch.nn.Sequential(torch.nn.Linear(input_size, k, bias = False), torch.nn.Linear(k, hidden_size, bias = True))])  # r

Thanks to KFrank for his help on the factorization part.

About

Implementation of Mogrifier LSTM in PyTorch

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages