Skip to content

Latest commit

 

History

History
39 lines (33 loc) · 2.12 KB

README.md

File metadata and controls

39 lines (33 loc) · 2.12 KB

causal-transformer

A causal transformer based LM.

Example usage:

model = transformer_lm(
    dim = 512,
    vocab_size = 29,
    depth = 10,
    heads = 8,
    dim_head = 64,
    dropout=0.0,
    causal = True,
    shared_kv = True,
)
# intermeditate_logits are losses from intermediate layers if intermediate losses is enabled (False by default)
logits, interimediate_logits, cached_kvs = model(labels[t], length = length[t])

# cached_kvs can then be passed back into the model for easy recurrent training or inference like the tranformer-xl
logits, interimediate_logits, cached_kvs = model(labels[t+1], length = length[t+1], cache = cached_kvs)
Unlike any other implementations that I have seen this allows for transformer-xl like training with variable length sequences
i.e usually it is assumed there is no padding in the cache, which can make it difficult when working with datasets that provide a series of variable length sentences
as inputs

# see test function caching_test() for more details

Currently has the following features:

Some of the code is taken or adapted from lucidrains/Phil Wangs x-transformers library