Skip to content

robflynnyh/causal-transformer

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

26 Commits
 
 
 
 
 
 

Repository files navigation

causal-transformer

A causal transformer based LM.

Example usage:

model = transformer_lm(
    dim = 512,
    vocab_size = 29,
    depth = 10,
    heads = 8,
    dim_head = 64,
    dropout=0.0,
    causal = True,
    shared_kv = True,
)
# intermeditate_logits are losses from intermediate layers if intermediate losses is enabled (False by default)
logits, interimediate_logits, cached_kvs = model(labels[t], length = length[t])

# cached_kvs can then be passed back into the model for easy recurrent training or inference like the tranformer-xl
logits, interimediate_logits, cached_kvs = model(labels[t+1], length = length[t+1], cache = cached_kvs)
Unlike any other implementations that I have seen this allows for transformer-xl like training with variable length sequences
i.e usually it is assumed there is no padding in the cache, which can make it difficult when working with datasets that provide a series of variable length sentences
as inputs

# see test function caching_test() for more details

Currently has the following features:

Some of the code is taken or adapted from lucidrains/Phil Wangs x-transformers library

About

A causal transformer based LM.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages