Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Embedding committed Oct 19, 2021
1 parent 8a04f2f commit a938704
Show file tree
Hide file tree
Showing 15 changed files with 93 additions and 4 deletions.
9 changes: 9 additions & 0 deletions uer/encoders/dual_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

class DualEncoder(nn.Module):
"""
Dual Encoder which enables siamese models like SBER and CLIP.
"""
def __init__(self, args):
super(DualEncoder, self).__init__()
Expand All @@ -24,6 +25,14 @@ def __init__(self, args):
self.encoder_1 = self.encoder_0

def forward(self, emb, seg):
"""
Args:
emb: ([batch_size x seq_length x emb_size], [batch_size x seq_length x emb_size])
seg: ([batch_size x seq_length], [batch_size x seq_length])
Returns:
features_0: [batch_size x seq_length x hidden_size]
features_1: [batch_size x seq_length x hidden_size]
"""
features_0 = self.get_encode_0(emb[0], seg[0])
features_1 = self.get_encode_1(emb[1], seg[1])

Expand Down
18 changes: 18 additions & 0 deletions uer/encoders/rnn_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@


class RnnEncoder(nn.Module):
"""
RNN encoder.
"""
def __init__(self, args):
super(RnnEncoder, self).__init__()

Expand Down Expand Up @@ -38,6 +41,9 @@ def init_hidden(self, batch_size, device):


class LstmEncoder(RnnEncoder):
"""
LSTM encoder.
"""
def __init__(self, args):
super(LstmEncoder, self).__init__(args)

Expand All @@ -58,6 +64,9 @@ def init_hidden(self, batch_size, device):


class GruEncoder(RnnEncoder):
"""
GRU encoder.
"""
def __init__(self, args):
super(GruEncoder, self).__init__(args)

Expand All @@ -70,6 +79,9 @@ def __init__(self, args):


class BirnnEncoder(nn.Module):
"""
Bi-directional RNN encoder.
"""
def __init__(self, args):
super(BirnnEncoder, self).__init__()

Expand Down Expand Up @@ -112,6 +124,9 @@ def init_hidden(self, batch_size, device):


class BilstmEncoder(BirnnEncoder):
"""
Bi-directional LSTM encoder.
"""
def __init__(self, args):
super(BilstmEncoder, self).__init__(args)

Expand All @@ -133,6 +148,9 @@ def init_hidden(self, batch_size, device):


class BigruEncoder(BirnnEncoder):
"""
Bi-directional GRU encoder.
"""
def __init__(self, args):
super(BigruEncoder, self).__init__(args)

Expand Down
30 changes: 30 additions & 0 deletions uer/layers/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from uer.layers.layer_norm import LayerNorm


class DualEmbedding(nn.Module):
"""
"""
Expand Down Expand Up @@ -34,6 +35,14 @@ def __init__(self, args, vocab_size=None):
self.embedding_0 = self.embedding_1

def forward(self, src, seg):
"""
Args:
src: ([batch_size x seq_length], [batch_size x seq_length])
seg: ([batch_size x seq_length], [batch_size x seq_length])
Returns:
emb_0: [batch_size x seq_length x hidden_size]
emb_1: [batch_size x seq_length x hidden_size]
"""
emb_0 = self.get_embedding_0(src[0], seg[0])
emb_1 = self.get_embedding_1(src[1], seg[1])

Expand Down Expand Up @@ -62,6 +71,13 @@ def __init__(self, args, vocab_size):
self.layer_norm = LayerNorm(args.emb_size)

def forward(self, src, _):
"""
Args:
src: [batch_size x seq_length]
seg: [batch_size x seq_length]
Returns:
emb: [batch_size x seq_length x hidden_size]
"""
emb = self.word_embedding(src)
if not self.remove_embedding_layernorm:
emb = self.layer_norm(emb)
Expand All @@ -86,6 +102,13 @@ def __init__(self, args, vocab_size):
self.layer_norm = LayerNorm(args.emb_size)

def forward(self, src, _):
"""
Args:
src: [batch_size x seq_length]
seg: [batch_size x seq_length]
Returns:
emb: [batch_size x seq_length x hidden_size]
"""
word_emb = self.word_embedding(src)
pos_emb = self.position_embedding(
torch.arange(0, word_emb.size(1), device=word_emb.device, dtype=torch.long)
Expand Down Expand Up @@ -117,6 +140,13 @@ def __init__(self, args, vocab_size):
self.layer_norm = LayerNorm(args.emb_size)

def forward(self, src, seg):
"""
Args:
src: [batch_size x seq_length]
seg: [batch_size x seq_length]
Returns:
emb: [batch_size x seq_length x hidden_size]
"""
word_emb = self.word_embedding(src)
pos_emb = self.position_embedding(
torch.arange(0, word_emb.size(1), device=word_emb.device, dtype=torch.long)
Expand Down
11 changes: 8 additions & 3 deletions uer/layers/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@


class LayerNorm(nn.Module):
"""
Layer Normalization.
https://arxiv.org/abs/1607.06450
"""
def __init__(self, hidden_size, eps=1e-6):
super(LayerNorm, self).__init__()
self.eps = eps
Expand All @@ -18,10 +22,11 @@ def forward(self, x):


class T5LayerNorm(nn.Module):
"""
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
"""
def __init__(self, hidden_size, eps=1e-6):
"""
Construct a layernorm module in the T5 style No bias and no subtraction of mean.
"""

super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
Expand Down
3 changes: 3 additions & 0 deletions uer/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@


def load_model(model, model_path):
"""
Load model from saved weights.
"""
if hasattr(model, "module"):
model.module.load_state_dict(torch.load(model_path, map_location="cpu"), strict=False)
else:
Expand Down
3 changes: 3 additions & 0 deletions uer/model_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@


def save_model(model, model_path):
"""
Save model weights to file.
"""
if hasattr(model, "module"):
torch.save(model.module.state_dict(), model_path)
else:
Expand Down
3 changes: 2 additions & 1 deletion uer/targets/albert_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

class AlbertTarget(MlmTarget):
"""
BERT exploits masked language modeling (MLM)
ALBERT exploits masked language modeling (MLM)
and sentence order prediction (SOP) for pretraining.
https://arxiv.org/abs/1909.11942
"""

def __init__(self, args, vocab_size):
Expand Down
4 changes: 4 additions & 0 deletions uer/targets/bart_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@


class BartTarget(T5Target):
"""
BART exploits seq-to-set target.
https://arxiv.org/abs/1910.13461
"""
pass
1 change: 1 addition & 0 deletions uer/targets/bilm_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

class BilmTarget(LmTarget):
"""
Bi-directional Language Model Target
"""
def __init__(self, args, vocab_size):
args.hidden_size = args.hidden_size // 2
Expand Down
1 change: 1 addition & 0 deletions uer/targets/cls_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

class ClsTarget(nn.Module):
"""
Classification Target
"""
def __init__(self, args, vocab_size):
super(ClsTarget, self).__init__()
Expand Down
4 changes: 4 additions & 0 deletions uer/targets/gsg_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,8 @@


class GsgTarget(T5Target):
"""
Gap Sentences Generation in Pegasus Model
https://arxiv.org/abs/1912.08777
"""
pass
1 change: 1 addition & 0 deletions uer/targets/lm_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

class LmTarget(nn.Module):
"""
Language Model Target
"""

def __init__(self, args, vocab_size):
Expand Down
4 changes: 4 additions & 0 deletions uer/targets/prefixlm_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@


class PrefixlmTarget(LmTarget):
"""
Prefix Language Model Target in UniLM Model
https://arxiv.org/abs/1905.03197
"""
pass
1 change: 1 addition & 0 deletions uer/targets/seq2seq_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

class Seq2seqTarget(LmTarget):
"""
Sequence-to-sequence Target
"""

def __init__(self, args, vocab_size):
Expand Down
4 changes: 4 additions & 0 deletions uer/targets/t5_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,8 @@


class T5Target(Seq2seqTarget):
"""
T5 Target
https://www.jmlr.org/papers/volume21/20-074/20-074.pdf
"""
pass

0 comments on commit a938704

Please sign in to comment.