Skip to content

Commit

Permalink
transformer pytorch
Browse files Browse the repository at this point in the history
  • Loading branch information
ruifanxu committed May 23, 2021
1 parent c239082 commit 259aa07
Show file tree
Hide file tree
Showing 3 changed files with 249 additions and 1 deletion.
242 changes: 242 additions & 0 deletions pytorch/transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
import torch.nn as nn
from torch.nn.functional import cross_entropy,softmax, relu
import numpy as np
import torch
import utils
from torch.utils.data import DataLoader

MAX_LEN = 12

class MultiHead(nn.Module):
def __init__(self, n_head, model_dim, drop_rate):
super().__init__()
self.head_dim = model_dim // n_head
self.n_head = n_head
self.model_dim = model_dim
self.wq = nn.Linear(model_dim, n_head * self.head_dim)
self.wk = nn.Linear(model_dim, n_head * self.head_dim)
self.wv = nn.Linear(model_dim, n_head * self.head_dim)

self.o_dense = nn.Linear(model_dim, model_dim)
self.o_drop = nn.Dropout(drop_rate)
self.layer_norm = nn.LayerNorm(model_dim)

def forward(self,q,k,v,mask,training):
# residual connect
residual = q
dim_per_head= self.head_dim
num_heads = self.n_head
batch_size = q.size(0)

# linear projection
key = self.wk(k) # [n, step, num_heads * head_dim]
value = self.wv(v) # [n, step, num_heads * head_dim]
query = self.wq(q) # [n, step, num_heads * head_dim]

# split by head
query = self.split_heads(query) # [n, n_head, q_step, h_dim]
key = self.split_heads(key)
value = self.split_heads(value) # [n, h, step, h_dim]
context = self.scaled_dot_product_attention(query,key, value, mask) # [n, q_step, h*dv]
o = self.o_dense(context) # [n, step, dim]
o = self.o_drop(o)

o = self.layer_norm(residual+o)
return o

def split_heads(self, x):
x = torch.reshape(x,(x.shape[0], x.shape[1], self.n_head, self.head_dim))
return x.permute(0,2,1,3)

def scaled_dot_product_attention(self, q, k, v, mask=None):
dk = torch.tensor(k.shape[-1]).type(torch.float)
score = torch.matmul(q,k.permute(0,1,3,2)) / (torch.sqrt(dk) + 1e-8) # [n, n_head, step, step]
if mask is not None:
score = score.masked_fill_(mask,-np.inf)
self.attention = softmax(score,dim=-1)
context = torch.matmul(self.attention,v)
context = context.permute(0,2,1,3)
context = context.reshape((context.shape[0], context.shape[1],-1))
return context

class PositionWiseFFN(nn.Module):
def __init__(self,model_dim, dropout = 0.0):
super().__init__()
dff = model_dim*4
self.l = nn.Linear(model_dim,dff)
self.o = nn.Linear(dff,model_dim)
self.dropout = nn.Dropout(dropout)
self.layer_norm = nn.LayerNorm(model_dim)

def forward(self,x):
o = relu(self.l(x))
o = self.o(o)
o = self.dropout(o)

o = self.layer_norm(x + o)
return o # [n, step, dim]



class EncoderLayer(nn.Module):

def __init__(self, n_head, emb_dim, drop_rate):
super().__init__()
self.mh = MultiHead(n_head, emb_dim, drop_rate)
self.ffn = PositionWiseFFN(emb_dim)
self.drop = nn.Dropout(drop_rate)

def forward(self, xz, training, mask):
# xz: [n, step, emb_dim]
context = self.mh(xz, xz, xz, mask, training) # [n, step, emb_dim]
o = self.ffn(context)
return o

class Encoder(nn.Module):
def __init__(self, n_head, emb_dim, drop_rate, n_layer):
super().__init__()
self.encoder_layers = nn.ModuleList(
[EncoderLayer(n_head, emb_dim, drop_rate) for _ in range(n_layer)]
)
def forward(self, xz, training, mask):

for encoder in self.encoder_layers:
xz = encoder(xz,training,mask)
return xz # [n, step, model_dim]

class DecoderLayer(nn.Module):
def __init__(self,n_head,model_dim,drop_rate):
super().__init__()
self.mh = [MultiHead(n_head, model_dim, drop_rate) for _ in range(2)]
self.ffn = PositionWiseFFN(model_dim,drop_rate)

def forward(self,yz, xz, training, yz_look_ahead_mask,xz_pad_mask):
dec_output = self.mh[0](yz, yz, yz, yz_look_ahead_mask, training)
dec_output = self.mh[1](dec_output, xz, xz, xz_pad_mask, training)

dec_output = self.ffn(dec_output)

return dec_output

class Decoder(nn.Module):
def __init__(self, n_head, model_dim, drop_rate, n_layer):
super().__init__()

self.num_layers = n_layer

self.decoder_layers = nn.ModuleList(
[DecoderLayer(n_head, model_dim, drop_rate) for _ in range(n_layer)]
)

def forward(self, yz, xz, training, yz_look_ahead_mask, xz_pad_mask):
for decoder in self.decoder_layers:
yz = decoder(yz, xz, training, yz_look_ahead_mask, xz_pad_mask)
return yz

class PositionEmbedding(nn.Module):
def __init__(self, max_len, emb_dim, n_vocab):
super().__init__()
pos = np.expand_dims(np.arange(max_len),1) # [max_len, 1]
pe = pos / np.power(1000, 2*np.expand_dims(np.arange(emb_dim),0)/emb_dim) # [max_len, emb_dim]
pe[:, 0::2] = np.sin(pe[:, 0::2])
pe[:, 1::2] = np.cos(pe[:, 1::2])
pe = np.expand_dims(pe,0) # [1, max_len, emb_dim]
self.pe = torch.from_numpy(pe).type(torch.float32)
self.embeddings = nn.Embedding(n_vocab,emb_dim)
self.embeddings.weight.data.normal_(0,0.1)

def forward(self, x):
x_embed = self.embeddings(x) + self.pe # [n, step, dim]
return x_embed

class Transformer(nn.Module):
def __init__(self, n_vocab, max_len, n_layer = 6, emb_dim=512, n_head = 8, drop_rate=0.1, padding_idx=0):
super().__init__()
self.max_len = max_len
self.padding_idx = torch.tensor(padding_idx)
self.dec_v_emb = n_vocab

self.embed = PositionEmbedding(max_len, emb_dim, n_vocab)
self.encoder = Encoder(n_head, emb_dim, drop_rate, n_layer)
self.decoder = Decoder(n_head, emb_dim, drop_rate, n_layer)
self.o = nn.Linear(emb_dim,n_vocab)
self.opt = torch.optim.Adam(self.parameters(),lr=0.001)

def forward(self,x,y,training= None):
x_embed, y_embed = self.embed(x), self.embed(y) # [n, step, emb_dim] * 2
pad_mask = self._pad_mask(x)
encoded_z = self.encoder(x_embed,training,pad_mask) # [n, step, emb_dim]
yz_look_ahead_mask = self._look_ahead_mask(y)
decoded_z = self.decoder(y_embed,encoded_z, training, yz_look_ahead_mask, pad_mask)
o = self.o(decoded_z) # [n, step, n_vocab]
return o

def step(self, x, y):
self.opt.zero_grad()
logits = self(x,y[:, :-1],training=True)
pad_mask = ~torch.eq(y[:,1:],self.padding_idx) # [n, seq_len]
loss = cross_entropy(logits.reshape(-1, self.dec_v_emb),y[:,1:].reshape(-1))
loss.backward()
self.opt.step()
return loss.detach().numpy(), logits

def _pad_bool(self, seqs):
return torch.eq(seqs,self.padding_idx)

def _pad_mask(self, seqs):
mask = self._pad_bool(seqs)
return mask[:, None, None, :]

def _look_ahead_mask(self,seqs):
batch_size, seq_len = seqs.shape
mask = torch.triu(torch.ones((seq_len,seq_len), dtype=torch.long), diagonal=1) # [seq_len ,seq_len]
mask = torch.where(self._pad_bool(seqs)[:,None,None,:],1,mask[None,None,:,:]) # [n, 1, seq_len, seq_len]
return mask

def translate(self, src, v2i, i2v):
src_pad = utils.pad_zero(src, self.max_len)
target = utils.pad_zero(np.array([[v2i["<GO>"], ] for _ in range(len(src))]), self.max_len+1)
x_embed = self.embed(src_pad)
encoded_z = self.encoder(x_embed,False,mask=self._pad_mask(src_pad))
for i in range(1,self.max_len):
y = target[:,:-1]
y_embed = self.embed(y)
decoded_z = self.decoder(y_embed,encoded_z,False,self._look_ahead_mask(y),self._pad_mask(src_pad))
logits = self.o(decoded_z)[:,i,:].data.numpy()
idx = np.argmax(logits,axis = 1)
target[:,i] = idx
return ["".join([i2v[i] for i in target[j,1:]]) for j in range(len(src))]




def train():
dataset = utils.DateData(4000)
print("Chinese time order: yy/mm/dd ",dataset.date_cn[:3],"\nEnglish time order: dd/M/yyyy", dataset.date_en[:3])
print("Vocabularies: ", dataset.vocab)
print(f"x index sample: \n{dataset.idx2str(dataset.x[0])}\n{dataset.x[0]}",
f"\ny index sample: \n{dataset.idx2str(dataset.y[0])}\n{dataset.y[0]}")
loader = DataLoader(dataset,batch_size=32,shuffle=True)
model = Transformer(emb_dim=16,max_len=MAX_LEN, n_layer=3, n_head=4, n_vocab=dataset.num_word, drop_rate=0.1)
for i in range(100):
for batch_idx , batch in enumerate(loader):
bx, by, decoder_len = batch
bx, by = torch.from_numpy(utils.pad_zero(bx,max_len = MAX_LEN)).type(torch.LongTensor), torch.from_numpy(utils.pad_zero(by,MAX_LEN+1)).type(torch.LongTensor)
loss, logits = model.step(bx,by)
if batch_idx%50 == 0:
logits = logits[0]
target = dataset.idx2str(by[0, 1:-1].data.numpy())
pred = logits.argmax(dim=1)
res = dataset.idx2str(pred.data.numpy())
src = dataset.idx2str(bx[0].data.numpy())
print(
"Epoch: ",i,
"| t: ", batch_idx,
"| loss: %.3f" % loss,
"| input: ", src,
"| target: ", target,
"| inference: ", res,
)

if __name__ == "__main__":
train()
6 changes: 6 additions & 0 deletions pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def idx2str(self,idx):
break
return "".join(x)

def pad_zero(seqs, max_len):
padded = np.full((len(seqs), max_len), fill_value=PAD_ID, dtype=np.int32)
for i, seq in enumerate(seqs):
padded[i, :len(seq)] = seq
return padded

class Dataset:
def __init__(self,x,y,v2i,i2v):
self.x,self.y = x,y
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def num_word(self):


def pad_zero(seqs, max_len):
padded = np.full((len(seqs), max_len), fill_value=PAD_ID, dtype=np.int32)
padded = np.full((len(seqs), max_len), fill_value=PAD_ID, dtype=np.long)
for i, seq in enumerate(seqs):
padded[i, :len(seq)] = seq
return padded
Expand Down

0 comments on commit 259aa07

Please sign in to comment.