From 7ebf605177fef7c1a716e69cac86c1ba9614052d Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 17 Feb 2024 01:01:10 -0800 Subject: [PATCH] [CLEANUP] --- example.py | 6 +- medpalm/__init__.py | 4 +- medpalm/attend.py | 218 +++++--- medpalm/model.py | 225 +++----- medpalm/transformer.py | 1193 ++++++++++++++++++++++++---------------- 5 files changed, 930 insertions(+), 716 deletions(-) diff --git a/example.py b/example.py index 5d14f15..afab2e6 100644 --- a/example.py +++ b/example.py @@ -1,12 +1,10 @@ import torch from medpalm.model import MedPalm -#usage +# usage img = torch.randn(1, 3, 256, 256) caption = torch.randint(0, 20000, (1, 1024)) model = MedPalm() output = model(img, caption) -print(output.shape) # (1, 1024, 20000) - - +print(output.shape) # (1, 1024, 20000) diff --git a/medpalm/__init__.py b/medpalm/__init__.py index 3a3169c..f5ec0f0 100644 --- a/medpalm/__init__.py +++ b/medpalm/__init__.py @@ -1 +1,3 @@ -from medpalm.model import MedPalm \ No newline at end of file +from medpalm.model import MedPalm + +__all__ = ["MedPalm"] diff --git a/medpalm/attend.py b/medpalm/attend.py index 4b7d902..00d0cf9 100644 --- a/medpalm/attend.py +++ b/medpalm/attend.py @@ -14,7 +14,10 @@ # constants -EfficientAttentionConfig = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) +EfficientAttentionConfig = namedtuple( + "EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"] +) + @dataclass class Intermediates: @@ -25,19 +28,25 @@ class Intermediates: def to_tuple(self): return (self.qk_similarities, self.pre_softmax_attn, self.post_softmax_attn) + # helpers + def exists(val): return val is not None + def default(val, d): return val if exists(val) else d + def compact(arr): return [*filter(exists, arr)] + def once(fn): called = False + @wraps(fn) def inner(x): nonlocal called @@ -45,63 +54,77 @@ def inner(x): return called = True return fn(x) + return inner + print_once = once(print) # functions for creating causal mask # need a special one for onnx cpu (no support for .triu) + def create_causal_mask(i, j, device): - return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) + return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) + def onnx_create_causal_mask(i, j, device): - r = torch.arange(i, device = device) - causal_mask = rearrange(r, 'i -> i 1') < rearrange(r, 'j -> 1 j') - causal_mask = F.pad(causal_mask, (j - i, 0), value = False) + r = torch.arange(i, device=device) + causal_mask = rearrange(r, "i -> i 1") < rearrange(r, "j -> 1 j") + causal_mask = F.pad(causal_mask, (j - i, 0), value=False) return causal_mask + # main class + class Attend(nn.Module): def __init__( self, *, - dropout = 0., - causal = False, - heads = None, - talking_heads = False, - sparse_topk = None, - scale = None, - qk_norm = False, - flash = False, - add_zero_kv = False, - onnxable = False + dropout=0.0, + causal=False, + heads=None, + talking_heads=False, + sparse_topk=None, + scale=None, + qk_norm=False, + flash=False, + add_zero_kv=False, + onnxable=False, ): super().__init__() self.scale = scale self.qk_norm = qk_norm self.causal = causal - self.create_causal_mask = onnx_create_causal_mask if onnxable else create_causal_mask + self.create_causal_mask = ( + onnx_create_causal_mask if onnxable else create_causal_mask + ) - self.attn_fn = partial(F.softmax, dtype = torch.float32) if not qk_norm else F.softmax + self.attn_fn = ( + partial(F.softmax, dtype=torch.float32) if not qk_norm else F.softmax + ) self.dropout = dropout self.attn_dropout = nn.Dropout(dropout) # talking heads - assert not (flash and talking_heads), 'talking heads not compatible with flash attention' + assert not ( + flash and talking_heads + ), "talking heads not compatible with flash attention" self.talking_heads = talking_heads if talking_heads: - self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) - self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias = False) + self.pre_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias=False) + self.post_softmax_talking_heads = nn.Conv2d(heads, heads, 1, bias=False) # sparse topk - assert not (flash and sparse_topk), 'sparse topk not compatible with flash attention' + assert not ( + flash and sparse_topk + ), "sparse topk not compatible with flash attention" self.sparse_topk = sparse_topk # add a key / value token composed of zeros @@ -112,7 +135,9 @@ def __init__( # flash attention self.flash = flash - assert not (flash and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' + assert not ( + flash and version.parse(torch.__version__) < version.parse("2.0.0") + ), "in order to use flash attention, you must be using pytorch 2.0 or above" # determine efficient attention configs for cuda and cpu @@ -122,31 +147,35 @@ def __init__( if not torch.cuda.is_available() or not flash: return - device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + device_properties = torch.cuda.get_device_properties(torch.device("cuda")) if device_properties.major == 8 and device_properties.minor == 0: - print_once('A100 GPU detected, using flash attention if input tensor is on cuda') + print_once( + "A100 GPU detected, using flash attention if input tensor is on cuda" + ) self.cuda_config = EfficientAttentionConfig(True, False, False) else: - print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') + print_once( + "Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda" + ) self.cuda_config = EfficientAttentionConfig(False, True, True) - def flash_attn( - self, - q, k, v, - mask = None, - attn_bias = None - ): - batch, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device + def flash_attn(self, q, k, v, mask=None, attn_bias=None): + batch, heads, q_len, _, k_len, is_cuda, device = ( + *q.shape, + k.shape[-2], + q.is_cuda, + q.device, + ) # Recommended for multi-query single-key-value attention by Tri Dao # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) if k.ndim == 3: - k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) + k = rearrange(k, "b ... -> b 1 ...").expand_as(q) if v.ndim == 3: - v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) + v = rearrange(v, "b ... -> b 1 ...").expand_as(q) # handle scale - by default they scale by dim_head ** -0.5, but need to take care if using cosine sim attention @@ -166,7 +195,7 @@ def flash_attn( # manually handle causal mask, if another mask was given if causal: - causal_mask = self.create_causal_mask(q_len, k_len, device = device) + causal_mask = self.create_causal_mask(q_len, k_len, device=device) mask = mask & ~causal_mask causal = False @@ -174,7 +203,9 @@ def flash_attn( # convert from bool to float if exists(attn_bias): - attn_bias = rearrange(attn_bias, 'h i j -> 1 h i j').expand(batch, heads, -1, -1) + attn_bias = rearrange(attn_bias, "h i j -> 1 h i j").expand( + batch, heads, -1, -1 + ) # if mask given, the mask would already contain the causal mask from above logic # otherwise, if no mask given but still causal, mask out alibi positional bias to a large negative number @@ -184,7 +215,7 @@ def flash_attn( if exists(mask): attn_bias = attn_bias.masked_fill(~mask, mask_value // 2) elif causal: - causal_mask = self.create_causal_mask(q_len, k_len, device = device) + causal_mask = self.create_causal_mask(q_len, k_len, device=device) attn_bias = attn_bias.masked_fill(causal_mask, mask_value // 2) causal = False @@ -198,24 +229,20 @@ def flash_attn( config = self.cuda_config if is_cuda else self.cpu_config # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale - + with torch.backends.cuda.sdp_kernel(**config._asdict()): out = F.scaled_dot_product_attention( - q, k, v, - attn_mask = mask, - dropout_p = self.dropout if self.training else 0., - is_causal = causal + q, + k, + v, + attn_mask=mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=causal, ) return out, Intermediates() - def forward( - self, - q, k, v, - mask = None, - attn_bias = None, - prev_attn = None - ): + def forward(self, q, k, v, mask=None, attn_bias=None, prev_attn=None): """ einstein notation b - batch @@ -229,21 +256,23 @@ def forward( scale = default(self.scale, q.shape[-1] ** -0.5) if self.add_zero_kv: - k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value = 0.), (k, v)) + k, v = map(lambda t: F.pad(t, (0, 0, 1, 0), value=0.0), (k, v)) if exists(mask): - mask = F.pad(mask, (1, 0), value = True) + mask = F.pad(mask, (1, 0), value=True) if exists(attn_bias): - attn_bias = F.pad(attn_bias, (1, 0), value = 0.) + attn_bias = F.pad(attn_bias, (1, 0), value=0.0) if self.flash: - assert not exists(prev_attn), 'residual attention not compatible with flash attention' - return self.flash_attn(q, k, v, mask = mask, attn_bias = attn_bias) + assert not exists( + prev_attn + ), "residual attention not compatible with flash attention" + return self.flash_attn(q, k, v, mask=mask, attn_bias=attn_bias) - kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' + kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d" - dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale + dots = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale if exists(prev_attn): dots = dots + prev_attn @@ -261,7 +290,7 @@ def forward( mask_value = -torch.finfo(dots.dtype).max if exists(self.sparse_topk) and self.sparse_topk < j: - top_values, _ = dots.topk(self.sparse_topk, dim = -1) + top_values, _ = dots.topk(self.sparse_topk, dim=-1) sparse_topk_mask = dots < top_values[..., -1:] mask = (mask & sparse_topk_mask) if exists(mask) else sparse_topk_mask @@ -269,12 +298,12 @@ def forward( dots = dots.masked_fill(~mask, mask_value) if self.causal: - causal_mask = self.create_causal_mask(i, j, device = device) + causal_mask = self.create_causal_mask(i, j, device=device) dots = dots.masked_fill(causal_mask, mask_value) pre_softmax_attn = dots.clone() - attn = self.attn_fn(dots, dim = -1) + attn = self.attn_fn(dots, dim=-1) attn = attn.type(dtype) post_softmax_attn = attn.clone() @@ -284,35 +313,34 @@ def forward( if self.talking_heads: attn = self.post_softmax_talking_heads(attn) - out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v) + out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) intermediates = Intermediates( - qk_similarities = qk_similarities, - pre_softmax_attn = pre_softmax_attn, - post_softmax_attn = post_softmax_attn + qk_similarities=qk_similarities, + pre_softmax_attn=pre_softmax_attn, + post_softmax_attn=post_softmax_attn, ) return out, intermediates + # cascading heads logic -def to_single_heads(t, dim = 1): - heads = t.unbind(dim = dim) + +def to_single_heads(t, dim=1): + heads = t.unbind(dim=dim) return tuple(head.unsqueeze(dim) for head in heads) + class CascadingHeads(nn.Module): def __init__(self, attend: Attend): super().__init__() self.attend = attend - def forward( - self, - q, k, v, - mask = None, - attn_bias = None, - prev_attn = None - ): - assert q.shape[-1] == v.shape[-1], 'cascading heads can only be done if query / key and value head dimensions are the same' + def forward(self, q, k, v, mask=None, attn_bias=None, prev_attn=None): + assert ( + q.shape[-1] == v.shape[-1] + ), "cascading heads can only be done if query / key and value head dimensions are the same" # split inputs into per-head inputs @@ -324,8 +352,14 @@ def forward( mask = (mask,) * heads - attn_bias = to_single_heads(attn_bias, dim = 0) if exists(attn_bias) else ((None,) * heads) - prev_attn = to_single_heads(prev_attn) if exists(prev_attn) else ((None,) * heads) + attn_bias = ( + to_single_heads(attn_bias, dim=0) + if exists(attn_bias) + else ((None,) * heads) + ) + prev_attn = ( + to_single_heads(prev_attn) if exists(prev_attn) else ((None,) * heads) + ) # now loop through each head, without output of previous head summed with the next head # thus cascading @@ -335,16 +369,14 @@ def forward( prev_head_out = None - for h_q, h_k, h_v, h_mask, h_attn_bias, h_prev_attn in zip(queries, keys, values, mask, attn_bias, prev_attn): - + for h_q, h_k, h_v, h_mask, h_attn_bias, h_prev_attn in zip( + queries, keys, values, mask, attn_bias, prev_attn + ): if exists(prev_head_out): h_q = h_q + prev_head_out out, intermediates = self.attend( - h_q, h_k, h_v, - mask = h_mask, - attn_bias = h_attn_bias, - prev_attn = h_prev_attn + h_q, h_k, h_v, mask=h_mask, attn_bias=h_attn_bias, prev_attn=h_prev_attn ) prev_head_out = out @@ -354,18 +386,28 @@ def forward( # cat all output heads - all_outs = torch.cat(all_outs, dim = 1) + all_outs = torch.cat(all_outs, dim=1) # cat all intermediates, if they exist - qk_similarities, pre_softmax_attn, post_softmax_attn = zip(*map(lambda i: i.to_tuple(), all_intermediates)) + qk_similarities, pre_softmax_attn, post_softmax_attn = zip( + *map(lambda i: i.to_tuple(), all_intermediates) + ) - qk_similarities, pre_softmax_attn, post_softmax_attn = map(compact, (qk_similarities, pre_softmax_attn, post_softmax_attn)) + qk_similarities, pre_softmax_attn, post_softmax_attn = map( + compact, (qk_similarities, pre_softmax_attn, post_softmax_attn) + ) aggregated_intermediates = Intermediates( - qk_similarities = torch.cat(qk_similarities, dim = 1) if len(qk_similarities) > 0 else None, - pre_softmax_attn = torch.cat(pre_softmax_attn, dim = 1) if len(pre_softmax_attn) > 0 else None, - post_softmax_attn = torch.cat(post_softmax_attn, dim = 1) if len(post_softmax_attn) > 0 else None + qk_similarities=torch.cat(qk_similarities, dim=1) + if len(qk_similarities) > 0 + else None, + pre_softmax_attn=torch.cat(pre_softmax_attn, dim=1) + if len(pre_softmax_attn) > 0 + else None, + post_softmax_attn=torch.cat(post_softmax_attn, dim=1) + if len(post_softmax_attn) > 0 + else None, ) - return all_outs, aggregated_intermediates \ No newline at end of file + return all_outs, aggregated_intermediates diff --git a/medpalm/model.py b/medpalm/model.py index 1925a9c..5791d93 100644 --- a/medpalm/model.py +++ b/medpalm/model.py @@ -1,6 +1,4 @@ - import torch -import torch.nn as nn from transformers import AutoTokenizer, CLIPProcessor from medpalm.transformer import ( @@ -15,46 +13,49 @@ class MedPalmTokenizer: def __init__(self): try: - - self.processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K") + self.processor = CLIPProcessor.from_pretrained( + "laion/CLIP-ViT-L-14-laion2B-s32B-b82K" + ) self.tokenizer = AutoTokenizer.from_pretrained( "EleutherAI/gpt-neox-20b", additional_special_tokens=["", ""], - eos_token ="", + eos_token="", pad_token="", extra_ids=0, - model_max_length=8192 + model_max_length=8192, ) - self.im_idx, self.im_end_idx = self.tokenizer.convert_tokens_to_ids(["", ""]) + self.im_idx, self.im_end_idx = self.tokenizer.convert_tokens_to_ids( + ["", ""] + ) except Exception as e: print(f"Error init tokenizer: {e}") - def tokenize_texts(self, texts): try: - - texts = self.tokenizer(texts, return_tensors="pt", padding=True, truncation=True).input_ids - image_tokens = torch.tensor([[self.im_idx, self.im_end_idx]] * texts.shape[0]) + texts = self.tokenizer( + texts, return_tensors="pt", padding=True, truncation=True + ).input_ids + image_tokens = torch.tensor( + [[self.im_idx, self.im_end_idx]] * texts.shape[0] + ) return torch.cat([texts[:, 0:1], image_tokens, texts[:, 1:]], dim=1), texts except Exception as e: print(f"Error tokenizing texts: {e}") - - def tokenize_images(self, images): try: - - tokenized_images = self.processor(images=images, return_tensors="pt").pixel_values + tokenized_images = self.processor( + images=images, return_tensors="pt" + ).pixel_values print(f"Tokenized image: {tokenized_images.shape}") return tokenized_images - + except Exception as e: print(f"Error tokenizing texts: {e}") def tokenize(self, sample): try: - text_tokens, only_text_tokens = self.tokenize_texts(sample["target_text"]) attention_mask = text_tokens != self.tokenizer.pad_token_id dummy_image_features = torch.ones((text_tokens.shape[0], 64)) @@ -65,141 +66,63 @@ def tokenize(self, sample): "labels": only_text_tokens, "attention_mask": attention_mask, } - + except Exception as e: print(f"Error during tokenization {e}") - - -# class MedPalm(nn.Module): -# """ -# MedPalm is a transformer-based model architecture. It initializes with -# a Transformer and AutoregressiveWrapper with default or user-specified parameters. - -# Initialize the model with specified or default parameters. -# Args: -# - num_tokens: Number of tokens in the vocabulary -# - max_seq_len: Maximum sequence length -# - dim: Dimension of the model -# - depth: Depth of the model -# - dim_head: Dimension of the model head -# - heads: Number of heads -# - use_abs_pos_emb: Whether to use absolute position embedding -# - alibi_pos_bias: Alibi position bias -# - alibi_num_heads: Number of alibi heads -# - rotary_xpos: Rotary position -# - attn_flash: Attention flash -# - deepnorm: Deep normalization -# - shift_tokens: Number of tokens to shift -# - attn_one_kv_head: Attention one key/value head -# - qk_norm: Query-key normalization -# - attn_qk_norm: Attention query-key normalization -# - attn_qk_norm_dim_scale: Attention query-key normalization dimension scale -# - embedding_provider: Embedding provider module -# """ -# def __init__(self, -# num_tokens=20000, -# max_seq_len=4096, -# dim=2560, -# depth=32, -# dim_head=128, -# heads=24, -# use_abs_pos_emb=False, -# alibi_pos_bias=True, -# alibi_num_heads=12, -# rotary_xpos=True, -# attn_flash=True, -# image_size=256, -# patch_size=32, -# attn_one_kv_head=False, # multiquery attention -# qk_norm=True, -# attn_qk_norm=False, -# attn_qk_norm_dim_scale=False, -# ): -# super(MedPalm, self).__init__() - -# self.encoder = ViTransformerWrapper( -# image_size=image_size, -# patch_size=patch_size, -# attn_layers=Encoder( -# dim=dim, -# depth=depth, -# dim_head=dim_head, -# heads=heads -# ) -# ) - -# self.decoder = Transformer( -# num_tokens=num_tokens, -# max_seq_len=max_seq_len, -# use_abs_pos_emb=use_abs_pos_emb, -# attn_layers=Decoder( -# dim=dim, -# depth=depth, -# dim_head=dim_head, -# heads=heads, -# alibi_pos_bias=alibi_pos_bias, -# alibi_num_heads=alibi_num_heads, -# rotary_xpos=rotary_xpos, -# attn_flash=attn_flash, -# attn_one_kv_head=False, -# qk_norm=qk_norm, -# attn_qk_norm=False, -# attn_qk_norm_dim_scale=False, -# cross_attend=True -# ) -# ) - -# # self.decoder = AutoregressiveWrapper(self.decoder) - -# def forward(self, text_tokens, img, **kwargs): -# """ -# Forward pass through the model. It expects the input text_tokens. -# Args: -# - text_tokens: Input tokens -# - kwargs: Other arguments -# Returns: -# - output from the decoder -# """ -# try: -# print(f"Text tokens shape: {text_tokens.shape}") -# encoded = self.encoder(img, return_embeddings=True) -# print(encoded.shape) -# return self.decoder(text_tokens, context=encoded) -# except Exception as error: -# print(f"Failed in forward method: {error}") -# raise class MedPalm(torch.nn.Module): - def __init__(self, - image_size=256, - patch_size=32, - encoder_dim=512, - encoder_depth=6, - encoder_heads=8, - num_tokens=20000, - max_seq_len=1024, - decoder_dim=512, - decoder_depth=6, - decoder_heads=8, - alibi_num_heads=4, - use_abs_pos_emb=False, - cross_attend=True, - alibi_pos_bias=True, - rotary_xpos=True, - attn_flash=True, - qk_norm=True): - + """ + MedPalm model for medical image and text processing. + + Args: + image_size (int): Size of the input image (default: 256). + patch_size (int): Size of each image patch (default: 32). + encoder_dim (int): Dimensionality of the encoder (default: 512). + encoder_depth (int): Number of encoder layers (default: 6). + encoder_heads (int): Number of attention heads in the encoder (default: 8). + num_tokens (int): Number of tokens in the decoder (default: 20000). + max_seq_len (int): Maximum sequence length in the decoder (default: 1024). + decoder_dim (int): Dimensionality of the decoder (default: 512). + decoder_depth (int): Number of decoder layers (default: 6). + decoder_heads (int): Number of attention heads in the decoder (default: 8). + alibi_num_heads (int): Number of attention heads in the alibi mechanism (default: 4). + use_abs_pos_emb (bool): Whether to use absolute positional embeddings (default: False). + cross_attend (bool): Whether to enable cross-attention in the decoder (default: True). + alibi_pos_bias (bool): Whether to use positional bias in the alibi mechanism (default: True). + rotary_xpos (bool): Whether to use rotary positional embeddings (default: True). + attn_flash (bool): Whether to use attention flash in the decoder (default: True). + qk_norm (bool): Whether to normalize the query-key vectors in attention (default: True). + """ + + def __init__( + self, + image_size=256, + patch_size=32, + encoder_dim=512, + encoder_depth=6, + encoder_heads=8, + num_tokens=20000, + max_seq_len=1024, + decoder_dim=512, + decoder_depth=6, + decoder_heads=8, + alibi_num_heads=4, + use_abs_pos_emb=False, + cross_attend=True, + alibi_pos_bias=True, + rotary_xpos=True, + attn_flash=True, + qk_norm=True, + ): super(MedPalm, self).__init__() - + self.encoder = ViTransformerWrapper( image_size=image_size, patch_size=patch_size, attn_layers=Encoder( - dim=encoder_dim, - depth=encoder_depth, - heads=encoder_heads - ) + dim=encoder_dim, depth=encoder_depth, heads=encoder_heads + ), ) self.decoder = Transformer( @@ -216,15 +139,25 @@ def __init__(self, rotary_xpos=rotary_xpos, attn_flash=attn_flash, qk_norm=qk_norm, - ) + ), ) + self.decoder = AutoregressiveWrapper(self.decoder) + def forward(self, img, text): - try: + """ + Forward pass of the MedPalm model. + + Args: + img (torch.Tensor): Input image tensor. + text (torch.Tensor): Input text tensor. + + Returns: + torch.Tensor: Output tensor from the decoder. + """ + try: encoded = self.encoder(img, return_embeddings=True) return self.decoder(text, context=encoded) except Exception as error: print(f"Failed in forward method: {error}") raise - - diff --git a/medpalm/transformer.py b/medpalm/transformer.py index 525003d..ee4b5fe 100644 --- a/medpalm/transformer.py +++ b/medpalm/transformer.py @@ -19,6 +19,7 @@ def exists(val): return val is not None + def eval_decorator(fn): def inner(self, *args, **kwargs): was_training = self.training @@ -26,11 +27,14 @@ def inner(self, *args, **kwargs): out = fn(self, *args, **kwargs) self.train(was_training) return out + return inner + # nucleus -def top_p(logits, thres = 0.9): + +def top_p(logits, thres=0.9): sorted_logits, sorted_indices = torch.sort(logits, descending=True) cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) @@ -38,37 +42,37 @@ def top_p(logits, thres = 0.9): sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() sorted_indices_to_remove[:, 0] = 0 - sorted_logits[sorted_indices_to_remove] = float('-inf') + sorted_logits[sorted_indices_to_remove] = float("-inf") return sorted_logits.scatter(1, sorted_indices, sorted_logits) + # topk -def top_k(logits, thres = 0.9): + +def top_k(logits, thres=0.9): k = ceil((1 - thres) * logits.shape[-1]) val, ind = torch.topk(logits, k) - probs = torch.full_like(logits, float('-inf')) + probs = torch.full_like(logits, float("-inf")) probs.scatter_(1, ind, val) return probs + # top_a + def top_a(logits, min_p_pow=2.0, min_p_ratio=0.02): probs = F.softmax(logits, dim=-1) limit = torch.pow(torch.max(probs), min_p_pow) * min_p_ratio - logits[probs < limit] = float('-inf') + logits[probs < limit] = float("-inf") logits[probs >= limit] = 1 return logits + # autoregressive wrapper class + class AutoregressiveWrapper(nn.Module): - def __init__( - self, - net, - ignore_index = -100, - pad_value = 0, - mask_prob = 0. - ): + def __init__(self, net, ignore_index=-100, pad_value=0, mask_prob=0.0): super().__init__() self.pad_value = pad_value self.ignore_index = ignore_index @@ -77,7 +81,7 @@ def __init__( self.max_seq_len = net.max_seq_len # paper shows masking (MLM) in conjunction with autoregressive decoder-only training leads to big improvements https://arxiv.org/abs/2210.13432 - assert mask_prob < 1. + assert mask_prob < 1.0 self.mask_prob = mask_prob @torch.no_grad() @@ -86,32 +90,33 @@ def generate( self, start_tokens, seq_len, - eos_token = None, - temperature = 1., - filter_logits_fn = top_k, - filter_thres = 0.9, - min_p_pow = 2.0, - min_p_ratio = 0.02, - **kwargs + eos_token=None, + temperature=1.0, + filter_logits_fn=top_k, + filter_thres=0.9, + min_p_pow=2.0, + min_p_ratio=0.02, + **kwargs, ): - - start_tokens, ps = pack([start_tokens], '* n') + start_tokens, ps = pack([start_tokens], "* n") b, t = start_tokens.shape out = start_tokens for _ in range(seq_len): - x = out[:, -self.max_seq_len:] + x = out[:, -self.max_seq_len :] logits = self.net(x, **kwargs)[:, -1] if filter_logits_fn in {top_k, top_p}: - filtered_logits = filter_logits_fn(logits, thres = filter_thres) + filtered_logits = filter_logits_fn(logits, thres=filter_thres) probs = F.softmax(filtered_logits / temperature, dim=-1) elif filter_logits_fn is top_a: - filtered_logits = filter_logits_fn(logits, min_p_pow = min_p_pow, min_p_ratio= min_p_ratio) + filtered_logits = filter_logits_fn( + logits, min_p_pow=min_p_pow, min_p_ratio=min_p_ratio + ) probs = F.softmax(filtered_logits / temperature, dim=-1) sample = torch.multinomial(probs, 1) @@ -119,18 +124,18 @@ def generate( out = torch.cat((out, sample), dim=-1) if exists(eos_token): - is_eos_tokens = (out == eos_token) + is_eos_tokens = out == eos_token - if is_eos_tokens.any(dim = -1).all(): + if is_eos_tokens.any(dim=-1).all(): # mask out everything after the eos tokens shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) - mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 + mask = shifted_is_eos_tokens.float().cumsum(dim=-1) >= 1 out = out.masked_fill(mask, self.pad_value) break out = out[:, t:] - out, = unpack(out, ps, '* n') + (out,) = unpack(out, ps, "* n") return out @@ -139,20 +144,20 @@ def forward(self, x, return_loss=True, **kwargs): inp, target = x[:, :-1], x[:, 1:] - if self.mask_prob > 0.: - rand = torch.randn(inp.shape, device = x.device) - rand[:, 0] = -torch.finfo(rand.dtype).max # first token should not be masked out + if self.mask_prob > 0.0: + rand = torch.randn(inp.shape, device=x.device) + rand[:, 0] = -torch.finfo( + rand.dtype + ).max # first token should not be masked out num_mask = min(int(seq * self.mask_prob), seq - 1) - indices = rand.topk(num_mask, dim = -1).indices - mask = ~torch.zeros_like(inp).scatter(1, indices, 1.).bool() - kwargs.update(self_attn_context_mask = mask) + indices = rand.topk(num_mask, dim=-1).indices + mask = ~torch.zeros_like(inp).scatter(1, indices, 1.0).bool() + kwargs.update(self_attn_context_mask=mask) logits = self.net(inp, **kwargs) loss = F.cross_entropy( - rearrange(logits, 'b n c -> b c n'), - target, - ignore_index = ignore_index + rearrange(logits, "b n c -> b c n"), target, ignore_index=ignore_index ) if return_loss: @@ -161,9 +166,9 @@ def forward(self, x, return_loss=True, **kwargs): return logits - DEFAULT_DIM_HEAD = 64 + @dataclass class LayerIntermediates: hiddens: Optional[List[Tensor]] = None @@ -171,62 +176,80 @@ class LayerIntermediates: layer_hiddens: Optional[List[Tensor]] = None attn_z_loss: Optional[Tensor] = None + # helpers + def exists(val): return val is not None + def default(val, d): if exists(val): return val return d() if isfunction(d) else d + def cast_tuple(val, depth): return val if isinstance(val, tuple) else (val,) * depth + def maybe(fn): @wraps(fn) def inner(x, *args, **kwargs): if not exists(x): return x return fn(x, *args, **kwargs) + return inner -class always(): + +class always: def __init__(self, val): self.val = val + def __call__(self, *args, **kwargs): return self.val -class not_equals(): + +class not_equals: def __init__(self, val): self.val = val + def __call__(self, x, *args, **kwargs): return x != self.val -class equals(): + +class equals: def __init__(self, val): self.val = val + def __call__(self, x, *args, **kwargs): return x == self.val + def Sequential(*modules): return nn.Sequential(*filter(exists, modules)) + # tensor helpers + def max_neg_value(tensor): return -torch.finfo(tensor.dtype).max -def l2norm(t, groups = 1): - t = rearrange(t, '... (g d) -> ... g d', g = groups) - t = F.normalize(t, p = 2, dim = -1) - return rearrange(t, '... g d -> ... (g d)') -def pad_at_dim(t, pad, dim = -1, value = 0.): - dims_from_right = (- dim - 1) if dim < 0 else (t.ndim - dim - 1) - zeros = ((0, 0) * dims_from_right) - return F.pad(t, (*zeros, *pad), value = value) +def l2norm(t, groups=1): + t = rearrange(t, "... (g d) -> ... g d", g=groups) + t = F.normalize(t, p=2, dim=-1) + return rearrange(t, "... g d -> ... (g d)") + + +def pad_at_dim(t, pad, dim=-1, value=0.0): + dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1) + zeros = (0, 0) * dims_from_right + return F.pad(t, (*zeros, *pad), value=value) + def or_reduce(masks): head, *body = masks @@ -234,119 +257,139 @@ def or_reduce(masks): head = head | rest return head + # auxiliary loss helpers -def calc_z_loss( - pre_softmax_attns: List[Tensor], - mask = None, - weight = 1. -): + +def calc_z_loss(pre_softmax_attns: List[Tensor], mask=None, weight=1.0): # the same loss applied to the mixture of experts router logits in https://arxiv.org/abs/2202.08906 # in the paper, in a tiny footnote, they mention using it on attention logits with stabilizing effects # also used in PaLM as one of the measures - lse = 0. + lse = 0.0 for attn in pre_softmax_attns: - lse = lse + attn.logsumexp(dim = -1) + lse = lse + attn.logsumexp(dim=-1) loss = torch.square(lse) - loss = reduce(loss, 'b h n -> b n', 'sum') + loss = reduce(loss, "b h n -> b n", "sum") if not exists(mask): return loss.mean() * weight - loss = loss[mask].sum() / mask.sum().clamp(min = 1e-5) + loss = loss[mask].sum() / mask.sum().clamp(min=1e-5) return loss * weight + # init helpers + def init_zero_(layer): - nn.init.constant_(layer.weight, 0.) + nn.init.constant_(layer.weight, 0.0) if exists(layer.bias): - nn.init.constant_(layer.bias, 0.) + nn.init.constant_(layer.bias, 0.0) + # keyword argument helpers + def pick_and_pop(keys, d): values = list(map(lambda key: d.pop(key), keys)) return dict(zip(keys, values)) + def group_dict_by_key(cond, d): - return_val = [dict(),dict()] + return_val = [dict(), dict()] for key in d.keys(): match = bool(cond(key)) ind = int(not match) return_val[ind][key] = d[key] return (*return_val,) + def string_begins_with(prefix, str): return str.startswith(prefix) + def group_by_key_prefix(prefix, d): return group_dict_by_key(partial(string_begins_with, prefix), d) + def groupby_prefix_and_trim(prefix, d): - kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) - kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) + kwargs_with_prefix, kwargs = group_dict_by_key( + partial(string_begins_with, prefix), d + ) + kwargs_without_prefix = dict( + map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())) + ) return kwargs_without_prefix, kwargs + # initializations + def deepnorm_init( - transformer, - beta, - module_name_match_list = ['.ff.', '.to_v', '.to_out'] + transformer, beta, module_name_match_list=[".ff.", ".to_v", ".to_out"] ): for name, module in transformer.named_modules(): if type(module) != nn.Linear: continue - needs_beta_gain = any(map(lambda substr: substr in name, module_name_match_list)) + needs_beta_gain = any( + map(lambda substr: substr in name, module_name_match_list) + ) gain = beta if needs_beta_gain else 1 - nn.init.xavier_normal_(module.weight.data, gain = gain) + nn.init.xavier_normal_(module.weight.data, gain=gain) if exists(module.bias): nn.init.constant_(module.bias.data, 0) + # structured dropout, more effective than traditional attention dropouts + def dropout_seq(seq, mask, dropout): b, n, *_, device = *seq.shape, seq.device - logits = torch.randn(b, n, device = device) + logits = torch.randn(b, n, device=device) if exists(mask): mask_value = max_neg_value(logits) logits = logits.masked_fill(~mask, mask_value) - keep_prob = 1. - dropout - num_keep = max(1, int(keep_prob * n)) - keep_indices = logits.topk(num_keep, dim = 1).indices + keep_prob = 1.0 - dropout + num_keep = max(1, int(keep_prob * n)) + keep_indices = logits.topk(num_keep, dim=1).indices - batch_indices = torch.arange(b, device = device) - batch_indices = rearrange(batch_indices, 'b -> b 1') + batch_indices = torch.arange(b, device=device) + batch_indices = rearrange(batch_indices, "b -> b 1") seq = seq[batch_indices, keep_indices] if exists(mask): - seq_counts = mask.sum(dim = -1) + seq_counts = mask.sum(dim=-1) seq_keep_counts = torch.ceil(seq_counts * keep_prob).int() - keep_mask = torch.arange(num_keep, device = device) < rearrange(seq_keep_counts, 'b -> b 1') + keep_mask = torch.arange(num_keep, device=device) < rearrange( + seq_keep_counts, "b -> b 1" + ) mask = mask[batch_indices, keep_indices] & keep_mask return seq, mask + # activations + class ReluSquared(nn.Module): def forward(self, x): return F.relu(x) ** 2 + # embedding + class TokenEmbedding(nn.Module): - def __init__(self, dim, num_tokens, l2norm_embed = False): + def __init__(self, dim, num_tokens, l2norm_embed=False): super().__init__() self.l2norm_embed = l2norm_embed self.emb = nn.Embedding(num_tokens, dim) @@ -355,50 +398,56 @@ def forward(self, x): token_emb = self.emb(x) return l2norm(token_emb) if self.l2norm_embed else token_emb + # positional embeddings + class AbsolutePositionalEmbedding(nn.Module): - def __init__(self, dim, max_seq_len, l2norm_embed = False): + def __init__(self, dim, max_seq_len, l2norm_embed=False): super().__init__() - self.scale = dim ** -0.5 if not l2norm_embed else 1. + self.scale = dim**-0.5 if not l2norm_embed else 1.0 self.max_seq_len = max_seq_len self.l2norm_embed = l2norm_embed self.emb = nn.Embedding(max_seq_len, dim) - def forward(self, x, pos = None): + def forward(self, x, pos=None): seq_len, device = x.shape[1], x.device - assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}' + assert ( + seq_len <= self.max_seq_len + ), f"you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}" if not exists(pos): - pos = torch.arange(seq_len, device = device) + pos = torch.arange(seq_len, device=device) pos_emb = self.emb(pos) pos_emb = pos_emb * self.scale return l2norm(pos_emb) if self.l2norm_embed else pos_emb + class ScaledSinusoidalEmbedding(nn.Module): - def __init__(self, dim, theta = 10000): + def __init__(self, dim, theta=10000): super().__init__() assert (dim % 2) == 0 - self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5) + self.scale = nn.Parameter(torch.ones(1) * dim**-0.5) half_dim = dim // 2 freq_seq = torch.arange(half_dim).float() / half_dim - inv_freq = theta ** -freq_seq - self.register_buffer('inv_freq', inv_freq, persistent = False) + inv_freq = theta**-freq_seq + self.register_buffer("inv_freq", inv_freq, persistent=False) - def forward(self, x, pos = None): + def forward(self, x, pos=None): seq_len, device = x.shape[1], x.device if not exists(pos): - pos = torch.arange(seq_len, device = device) + pos = torch.arange(seq_len, device=device) - emb = einsum('i, j -> i j', pos, self.inv_freq) - emb = torch.cat((emb.sin(), emb.cos()), dim = -1) + emb = einsum("i, j -> i j", pos, self.inv_freq) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb * self.scale + class RelativePositionBias(nn.Module): - def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, heads = 8): + def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8): super().__init__() self.scale = scale self.causal = causal @@ -407,7 +456,9 @@ def __init__(self, scale, causal = False, num_buckets = 32, max_distance = 128, self.relative_attention_bias = nn.Embedding(num_buckets, heads) @staticmethod - def _relative_position_bucket(relative_position, causal = True, num_buckets = 32, max_distance = 128): + def _relative_position_bucket( + relative_position, causal=True, num_buckets=32, max_distance=128 + ): ret = 0 n = -relative_position if not causal: @@ -420,10 +471,17 @@ def _relative_position_bucket(relative_position, causal = True, num_buckets = 32 max_exact = num_buckets // 2 is_small = n < max_exact - val_if_large = max_exact + ( - torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) - ).long() - val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + val_if_large = ( + max_exact + + ( + torch.log(n.float() / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).long() + ) + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1) + ) ret += torch.where(is_small, n, val_if_large) return ret @@ -434,34 +492,42 @@ def device(self): def forward(self, i, j): device = self.device - q_pos = torch.arange(j - i, j, dtype = torch.long, device = device) - k_pos = torch.arange(j, dtype = torch.long, device = device) + q_pos = torch.arange(j - i, j, dtype=torch.long, device=device) + k_pos = torch.arange(j, dtype=torch.long, device=device) rel_pos = k_pos[None, :] - q_pos[:, None] - rp_bucket = self._relative_position_bucket(rel_pos, causal = self.causal, num_buckets = self.num_buckets, max_distance = self.max_distance) + rp_bucket = self._relative_position_bucket( + rel_pos, + causal=self.causal, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) values = self.relative_attention_bias(rp_bucket) - bias = rearrange(values, 'i j h -> h i j') + bias = rearrange(values, "i j h -> h i j") return bias * self.scale + class DynamicPositionBias(nn.Module): - def __init__(self, dim, *, heads, depth, log_distance = False, norm = False): + def __init__(self, dim, *, heads, depth, log_distance=False, norm=False): super().__init__() - assert depth >= 1, 'depth for dynamic position bias MLP must be greater or equal to 1' + assert ( + depth >= 1 + ), "depth for dynamic position bias MLP must be greater or equal to 1" self.log_distance = log_distance self.mlp = nn.ModuleList([]) - self.mlp.append(Sequential( - nn.Linear(1, dim), - nn.LayerNorm(dim) if norm else None, - nn.SiLU() - )) + self.mlp.append( + Sequential( + nn.Linear(1, dim), nn.LayerNorm(dim) if norm else None, nn.SiLU() + ) + ) for _ in range(depth - 1): - self.mlp.append(Sequential( - nn.Linear(dim, dim), - nn.LayerNorm(dim) if norm else None, - nn.SiLU() - )) + self.mlp.append( + Sequential( + nn.Linear(dim, dim), nn.LayerNorm(dim) if norm else None, nn.SiLU() + ) + ) self.mlp.append(nn.Linear(dim, heads)) @@ -474,26 +540,31 @@ def forward(self, i, j): n, device = j, self.device # get the (n x n) matrix of distances - seq_arange = torch.arange(n, device = device) - context_arange = torch.arange(n, device = device) - indices = rearrange(seq_arange, 'i -> i 1') - rearrange(context_arange, 'j -> 1 j') - indices += (n - 1) + seq_arange = torch.arange(n, device=device) + context_arange = torch.arange(n, device=device) + indices = rearrange(seq_arange, "i -> i 1") - rearrange( + context_arange, "j -> 1 j" + ) + indices += n - 1 # input to continuous positions MLP - pos = torch.arange(-n + 1, n, device = device).float() - pos = rearrange(pos, '... -> ... 1') + pos = torch.arange(-n + 1, n, device=device).float() + pos = rearrange(pos, "... -> ... 1") if self.log_distance: - pos = torch.sign(pos) * torch.log(pos.abs() + 1) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1) + pos = torch.sign(pos) * torch.log( + pos.abs() + 1 + ) # log of distance is sign(rel_pos) * log(abs(rel_pos) + 1) for layer in self.mlp: pos = layer(pos) - # get position biases + # get position biases bias = pos[indices] - bias = rearrange(bias, 'i j h -> h i j') + bias = rearrange(bias, "i j h -> h i j") return bias + class AlibiPositionalBias(nn.Module): def __init__(self, heads, total_heads, **kwargs): super().__init__() @@ -501,28 +572,35 @@ def __init__(self, heads, total_heads, **kwargs): self.total_heads = total_heads slopes = Tensor(self._get_slopes(heads)) - slopes = rearrange(slopes, 'h -> h 1 1') - self.register_buffer('slopes', slopes, persistent = False) - self.register_buffer('bias', None, persistent = False) - + slopes = rearrange(slopes, "h -> h 1 1") + self.register_buffer("slopes", slopes, persistent=False) + self.register_buffer("bias", None, persistent=False) + def get_bias(self, i, j, device): - i_arange = torch.arange(j - i, j, device = device) - j_arange = torch.arange(j, device = device) - bias = -torch.abs(rearrange(j_arange, 'j -> 1 1 j') - rearrange(i_arange, 'i -> 1 i 1')) + i_arange = torch.arange(j - i, j, device=device) + j_arange = torch.arange(j, device=device) + bias = -torch.abs( + rearrange(j_arange, "j -> 1 1 j") - rearrange(i_arange, "i -> 1 i 1") + ) return bias @staticmethod def _get_slopes(heads): def get_slopes_power_of_2(n): - start = (2**(-2**-(math.log2(n)-3))) + start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start - return [start*ratio**i for i in range(n)] + return [start * ratio**i for i in range(n)] if math.log2(heads).is_integer(): return get_slopes_power_of_2(heads) closest_power_of_2 = 2 ** math.floor(math.log2(heads)) - return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2] + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][ + : heads - closest_power_of_2 + ] + ) @property def device(self): @@ -538,20 +616,21 @@ def forward(self, i, j): bias = bias * self.slopes num_heads_unalibied = h - bias.shape[0] - bias = pad_at_dim(bias, (0, num_heads_unalibied), dim = 0) - self.register_buffer('bias', bias, persistent = False) + bias = pad_at_dim(bias, (0, num_heads_unalibied), dim=0) + self.register_buffer("bias", bias, persistent=False) return self.bias + class RotaryEmbedding(nn.Module): def __init__( self, dim, - use_xpos = False, - scale_base = 512, - interpolation_factor = 1., - base = 10000, - base_rescale_factor = 1. + use_xpos=False, + scale_base=512, + interpolation_factor=1.0, + base=10000, + base_rescale_factor=1.0, ): super().__init__() # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning @@ -559,50 +638,55 @@ def __init__( # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ base *= base_rescale_factor ** (dim / (dim - 2)) - inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim)) - self.register_buffer('inv_freq', inv_freq) + inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) - assert interpolation_factor >= 1. + assert interpolation_factor >= 1.0 self.interpolation_factor = interpolation_factor if not use_xpos: - self.register_buffer('scale', None) + self.register_buffer("scale", None) return scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) self.scale_base = scale_base - self.register_buffer('scale', scale) + self.register_buffer("scale", scale) def forward(self, seq_len, device): - t = torch.arange(seq_len, device = device).type_as(self.inv_freq) + t = torch.arange(seq_len, device=device).type_as(self.inv_freq) t = t / self.interpolation_factor - freqs = torch.einsum('i , j -> i j', t, self.inv_freq) - freqs = torch.cat((freqs, freqs), dim = -1) + freqs = torch.einsum("i , j -> i j", t, self.inv_freq) + freqs = torch.cat((freqs, freqs), dim=-1) if not exists(self.scale): - return freqs, 1. + return freqs, 1.0 - power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base - scale = self.scale ** rearrange(power, 'n -> n 1') - scale = torch.cat((scale, scale), dim = -1) + power = ( + torch.arange(seq_len, device=device) - (seq_len // 2) + ) / self.scale_base + scale = self.scale ** rearrange(power, "n -> n 1") + scale = torch.cat((scale, scale), dim=-1) return freqs, scale def rotate_half(x): - x = rearrange(x, '... (j d) -> ... j d', j = 2) - x1, x2 = x.unbind(dim = -2) - return torch.cat((-x2, x1), dim = -1) + x = rearrange(x, "... (j d) -> ... j d", j=2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(t, freqs, scale = 1): + +def apply_rotary_pos_emb(t, freqs, scale=1): seq_len = t.shape[-2] freqs = freqs[-seq_len:, :] return (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) + # norms + class Scale(nn.Module): def __init__(self, value, fn): super().__init__() @@ -611,44 +695,51 @@ def __init__(self, value, fn): def forward(self, x, **kwargs): out = self.fn(x, **kwargs) - scale_fn = lambda t: t * self.value + + def scale_fn(t): + return t * self.value if not isinstance(out, tuple): return scale_fn(out) return (scale_fn(out[0]), *out[1:]) + class ScaleNorm(nn.Module): - def __init__(self, dim, eps = 1e-5): + def __init__(self, dim, eps=1e-5): super().__init__() self.eps = eps - self.g = nn.Parameter(torch.ones(1) * (dim ** -0.5)) + self.g = nn.Parameter(torch.ones(1) * (dim**-0.5)) def forward(self, x): - norm = torch.norm(x, dim = -1, keepdim = True) - return x / norm.clamp(min = self.eps) * self.g + norm = torch.norm(x, dim=-1, keepdim=True) + return x / norm.clamp(min=self.eps) * self.g + class RMSNorm(nn.Module): def __init__(self, dim): super().__init__() - self.scale = dim ** 0.5 + self.scale = dim**0.5 self.g = nn.Parameter(torch.ones(dim)) def forward(self, x): - return F.normalize(x, dim = -1) * self.scale * self.g + return F.normalize(x, dim=-1) * self.scale * self.g + class SimpleRMSNorm(nn.Module): def __init__(self, dim): super().__init__() - self.scale = dim ** 0.5 + self.scale = dim**0.5 def forward(self, x): - return F.normalize(x, dim = -1) * self.scale + return F.normalize(x, dim=-1) * self.scale + # residual and residual gates + class Residual(nn.Module): - def __init__(self, dim, scale_residual = False, scale_residual_constant = 1.): + def __init__(self, dim, scale_residual=False, scale_residual_constant=1.0): super().__init__() self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None self.scale_residual_constant = scale_residual_constant @@ -662,8 +753,9 @@ def forward(self, x, residual): return x + residual + class GRUGating(nn.Module): - def __init__(self, dim, scale_residual = False, **kwargs): + def __init__(self, dim, scale_residual=False, **kwargs): super().__init__() self.gru = nn.GRUCell(dim, dim) self.residual_scale = nn.Parameter(torch.ones(dim)) if scale_residual else None @@ -673,24 +765,26 @@ def forward(self, x, residual): residual = residual * self.residual_scale gated_output = self.gru( - rearrange(x, 'b n d -> (b n) d'), - rearrange(residual, 'b n d -> (b n) d') + rearrange(x, "b n d -> (b n) d"), rearrange(residual, "b n d -> (b n) d") ) return gated_output.reshape_as(x) + # token shifting -def shift(t, amount, mask = None): + +def shift(t, amount, mask=None): if amount == 0: return t else: amount = min(amount, t.shape[1]) if exists(mask): - t = t.masked_fill(~mask[..., None], 0.) + t = t.masked_fill(~mask[..., None], 0.0) + + return pad_at_dim(t, (amount, -amount), dim=-2, value=0.0) - return pad_at_dim(t, (amount, -amount), dim = - 2, value = 0.) class ShiftTokens(nn.Module): def __init__(self, shifts, fn): @@ -699,49 +793,48 @@ def __init__(self, shifts, fn): self.shifts = tuple(shifts) def forward(self, x, **kwargs): - mask = kwargs.get('mask', None) + mask = kwargs.get("mask", None) shifts = self.shifts segments = len(shifts) feats_per_shift = x.shape[-1] // segments - splitted = x.split(feats_per_shift, dim = -1) + splitted = x.split(feats_per_shift, dim=-1) segments_to_shift, rest = splitted[:segments], splitted[segments:] - segments_to_shift = list(map(lambda args: shift(*args, mask = mask), zip(segments_to_shift, shifts))) - x = torch.cat((*segments_to_shift, *rest), dim = -1) + segments_to_shift = list( + map(lambda args: shift(*args, mask=mask), zip(segments_to_shift, shifts)) + ) + x = torch.cat((*segments_to_shift, *rest), dim=-1) return self.fn(x, **kwargs) + # feedforward + class GLU(nn.Module): - def __init__( - self, - dim_in, - dim_out, - activation: Callable, - mult_bias = False - ): + def __init__(self, dim_in, dim_out, activation: Callable, mult_bias=False): super().__init__() self.act = activation self.proj = nn.Linear(dim_in, dim_out * 2) - self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1. + self.mult_bias = nn.Parameter(torch.ones(dim_out)) if mult_bias else 1.0 def forward(self, x): - x, gate = self.proj(x).chunk(2, dim = -1) + x, gate = self.proj(x).chunk(2, dim=-1) return x * self.act(gate) * self.mult_bias + class FeedForward(nn.Module): def __init__( self, dim, - dim_out = None, - mult = 4, - glu = False, - glu_mult_bias = False, - swish = False, - relu_squared = False, - post_act_ln = False, - dropout = 0., - no_bias = False, - zero_init_output = False + dim_out=None, + mult=4, + glu=False, + glu_mult_bias=False, + swish=False, + relu_squared=False, + post_act_ln=False, + dropout=0.0, + no_bias=False, + zero_init_output=False, ): super().__init__() inner_dim = int(dim * mult) @@ -755,18 +848,17 @@ def __init__( activation = nn.GELU() if glu: - project_in = GLU(dim, inner_dim, activation, mult_bias = glu_mult_bias) + project_in = GLU(dim, inner_dim, activation, mult_bias=glu_mult_bias) else: project_in = nn.Sequential( - nn.Linear(dim, inner_dim, bias = not no_bias), - activation + nn.Linear(dim, inner_dim, bias=not no_bias), activation ) self.ff = Sequential( project_in, nn.LayerNorm(inner_dim) if post_act_ln else None, nn.Dropout(dropout), - nn.Linear(inner_dim, dim_out, bias = not no_bias) + nn.Linear(inner_dim, dim_out, bias=not no_bias), ) # init last linear layer to 0 @@ -776,39 +868,41 @@ def __init__( def forward(self, x): return self.ff(x) + # attention. it is all we need + class Attention(nn.Module): def __init__( self, dim, - dim_head = DEFAULT_DIM_HEAD, - heads = 8, - causal = False, - flash = False, - talking_heads = False, - head_scale = False, - sparse_topk = None, - num_mem_kv = 0, - dropout = 0., - on_attn = False, - gate_values = False, - zero_init_output = False, - max_attend_past = None, - qk_norm = False, - qk_norm_groups = 1, - qk_norm_scale = 10, - qk_norm_dim_scale = False, - one_kv_head = False, - shared_kv = False, - value_dim_head = None, - tensor_product = False, # https://arxiv.org/abs/2208.06061 - cascading_heads = False, - add_zero_kv = False, # same as add_zero_attn in pytorch - onnxable = False + dim_head=DEFAULT_DIM_HEAD, + heads=8, + causal=False, + flash=False, + talking_heads=False, + head_scale=False, + sparse_topk=None, + num_mem_kv=0, + dropout=0.0, + on_attn=False, + gate_values=False, + zero_init_output=False, + max_attend_past=None, + qk_norm=False, + qk_norm_groups=1, + qk_norm_scale=10, + qk_norm_dim_scale=False, + one_kv_head=False, + shared_kv=False, + value_dim_head=None, + tensor_product=False, # https://arxiv.org/abs/2208.06061 + cascading_heads=False, + add_zero_kv=False, # same as add_zero_attn in pytorch + onnxable=False, ): super().__init__() - self.scale = dim_head ** -0.5 + self.scale = dim_head**-0.5 self.heads = heads self.causal = causal @@ -824,15 +918,17 @@ def __init__( v_dim = value_dim_head out_dim = v_dim * heads - self.to_q = nn.Linear(dim, q_dim, bias = False) - self.to_k = nn.Linear(dim, k_dim, bias = False) + self.to_q = nn.Linear(dim, q_dim, bias=False) + self.to_k = nn.Linear(dim, k_dim, bias=False) # shared key / values, for further memory savings during inference - assert not (shared_kv and value_dim_head != dim_head), 'key and value head dimensions must be equal for shared key / values' - self.to_v = nn.Linear(dim, v_dim, bias = False) if not shared_kv else None + assert not ( + shared_kv and value_dim_head != dim_head + ), "key and value head dimensions must be equal for shared key / values" + self.to_v = nn.Linear(dim, v_dim, bias=False) if not shared_kv else None # relations projection from tp-attention - self.to_r = nn.Linear(dim, v_dim, bias = False) if tensor_product else None + self.to_r = nn.Linear(dim, v_dim, bias=False) if tensor_product else None # add GLU gating for aggregated values, from alphafold2 self.to_v_gate = None @@ -854,22 +950,26 @@ def __init__( self.qk_norm_q_scale = nn.Parameter(torch.ones(dim_head)) self.qk_norm_k_scale = nn.Parameter(torch.ones(dim_head)) - assert (not qk_norm) or (dim_head % qk_norm_groups) == 0, 'dimension per attention head must be divisible by the qk norm groups' - assert not (qk_norm and (dim_head // qk_norm_groups) <= 2), 'the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)' + assert (not qk_norm) or ( + dim_head % qk_norm_groups + ) == 0, "dimension per attention head must be divisible by the qk norm groups" + assert not ( + qk_norm and (dim_head // qk_norm_groups) <= 2 + ), "the group dimension may be too small (2 was too small in my tests, but 4 still works, surprisingly)" # attend class - includes core attention algorithm + talking heads self.attend = Attend( - heads = heads, - causal = causal, - talking_heads = talking_heads, - dropout = dropout, - sparse_topk = sparse_topk, - qk_norm = qk_norm, - scale = qk_norm_scale if qk_norm else self.scale, - add_zero_kv = add_zero_kv, - flash = flash, - onnxable = onnxable + heads=heads, + causal=causal, + talking_heads=talking_heads, + dropout=dropout, + sparse_topk=sparse_topk, + qk_norm=qk_norm, + scale=qk_norm_scale if qk_norm else self.scale, + add_zero_kv=add_zero_kv, + flash=flash, + onnxable=onnxable, ) # head scaling @@ -888,7 +988,11 @@ def __init__( # attention on attention self.attn_on_attn = on_attn - self.to_out = nn.Sequential(nn.Linear(out_dim, dim * 2, bias = False), nn.GLU()) if on_attn else nn.Linear(out_dim, dim, bias = False) + self.to_out = ( + nn.Sequential(nn.Linear(out_dim, dim * 2, bias=False), nn.GLU()) + if on_attn + else nn.Linear(out_dim, dim, bias=False) + ) # init output projection 0 if zero_init_output: @@ -897,16 +1001,22 @@ def __init__( def forward( self, x, - context = None, - mask = None, - context_mask = None, - attn_mask = None, - rel_pos = None, - rotary_pos_emb = None, - prev_attn = None, - mem = None + context=None, + mask=None, + context_mask=None, + attn_mask=None, + rel_pos=None, + rotary_pos_emb=None, + prev_attn=None, + mem=None, ): - b, n, _, h, head_scale, device, has_context = *x.shape, self.heads, self.head_scale, x.device, exists(context) + b, n, _, h, head_scale, device, has_context = ( + *x.shape, + self.heads, + self.head_scale, + x.device, + exists(context), + ) kv_input = default(context, x) q_input = x @@ -915,23 +1025,24 @@ def forward( r_input = x if exists(mem): - k_input = torch.cat((mem, k_input), dim = -2) - v_input = torch.cat((mem, v_input), dim = -2) + k_input = torch.cat((mem, k_input), dim=-2) + v_input = torch.cat((mem, v_input), dim=-2) q = self.to_q(q_input) k = self.to_k(k_input) v = self.to_v(v_input) if exists(self.to_v) else k r = self.to_r(r_input) if exists(self.to_r) else None - q = rearrange(q, 'b n (h d) -> b h n d', h = h) + q = rearrange(q, "b n (h d) -> b h n d", h=h) if not self.one_kv_head: - k, v, r = map(lambda t: maybe(rearrange)(t, 'b n (h d) -> b h n d', h = h), (k, v, r)) + k, v, r = map( + lambda t: maybe(rearrange)(t, "b n (h d) -> b h n d", h=h), (k, v, r) + ) if self.qk_norm: - qk_l2norm = partial(l2norm, groups = self.qk_norm_groups) + qk_l2norm = partial(l2norm, groups=self.qk_norm_groups) q, k = map(qk_l2norm, (q, k)) - scale = self.qk_norm_scale q = q * self.qk_norm_q_scale k = k * self.qk_norm_k_scale @@ -940,51 +1051,68 @@ def forward( freqs, xpos_scale = rotary_pos_emb l = freqs.shape[-1] - q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale ** -1.) if exists(xpos_scale) else (1., 1.) - (ql, qr), (kl, kr), (vl, vr) = map(lambda t: (t[..., :l], t[..., l:]), (q, k, v)) + q_xpos_scale, k_xpos_scale = ( + (xpos_scale, xpos_scale**-1.0) if exists(xpos_scale) else (1.0, 1.0) + ) + (ql, qr), (kl, kr), (vl, vr) = map( + lambda t: (t[..., :l], t[..., l:]), (q, k, v) + ) - ql, kl, vl = map(lambda arg: apply_rotary_pos_emb(arg[0], freqs, arg[1]), ((ql, q_xpos_scale), (kl, k_xpos_scale), (vl, k_xpos_scale))) - q, k, v = map(lambda t: torch.cat(t, dim = -1), ((ql, qr), (kl, kr), (vl, vr))) + ql, kl, vl = map( + lambda arg: apply_rotary_pos_emb(arg[0], freqs, arg[1]), + ((ql, q_xpos_scale), (kl, k_xpos_scale), (vl, k_xpos_scale)), + ) + q, k, v = map( + lambda t: torch.cat(t, dim=-1), ((ql, qr), (kl, kr), (vl, vr)) + ) input_mask = context_mask if has_context else mask if self.num_mem_kv > 0: - mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v)) + mem_k, mem_v = map( + lambda t: repeat(t, "h n d -> b h n d", b=b), (self.mem_k, self.mem_v) + ) if self.qk_norm: mem_k = l2norm(mem_k) mem_k = mem_k * self.qk_norm_k_scale - k = torch.cat((mem_k, k), dim = -2) - v = torch.cat((mem_v, v), dim = -2) + k = torch.cat((mem_k, k), dim=-2) + v = torch.cat((mem_v, v), dim=-2) if exists(input_mask): - input_mask = pad_at_dim(input_mask, (self.num_mem_kv, 0), dim = -1, value = True) + input_mask = pad_at_dim( + input_mask, (self.num_mem_kv, 0), dim=-1, value=True + ) i, j = map(lambda t: t.shape[-2], (q, k)) # determine masking - mask_value = max_neg_value(q) + max_neg_value(q) masks = [] final_attn_mask = None if exists(input_mask): - input_mask = rearrange(input_mask, 'b j -> b 1 1 j') + input_mask = rearrange(input_mask, "b j -> b 1 1 j") masks.append(~input_mask) if exists(attn_mask): - assert 2 <= attn_mask.ndim <= 4, 'attention mask must have greater than 2 dimensions but less than or equal to 4' + assert ( + 2 <= attn_mask.ndim <= 4 + ), "attention mask must have greater than 2 dimensions but less than or equal to 4" if attn_mask.ndim == 2: - attn_mask = rearrange(attn_mask, 'i j -> 1 1 i j') + attn_mask = rearrange(attn_mask, "i j -> 1 1 i j") elif attn_mask.ndim == 3: - attn_mask = rearrange(attn_mask, 'h i j -> 1 h i j') + attn_mask = rearrange(attn_mask, "h i j -> 1 h i j") masks.append(~attn_mask) if exists(self.max_attend_past): - range_q = torch.arange(j - i, j, device = device) - range_k = torch.arange(j, device = device) - dist = rearrange(range_q, 'i -> 1 1 i 1') - rearrange(range_k, 'j -> 1 1 1 j') + range_q = torch.arange(j - i, j, device=device) + range_k = torch.arange(j, device=device) + dist = rearrange(range_q, "i -> 1 1 i 1") - rearrange( + range_k, "j -> 1 1 1 j" + ) max_attend_past_mask = dist > self.max_attend_past masks.append(max_attend_past_mask) @@ -1000,10 +1128,7 @@ def forward( # attention is all we need out, intermediates = self.attend( - q, k, v, - mask = final_attn_mask, - attn_bias = attn_bias, - prev_attn = prev_attn + q, k, v, mask=final_attn_mask, attn_bias=attn_bias, prev_attn=prev_attn ) # https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients @@ -1018,7 +1143,7 @@ def forward( # merge heads - out = rearrange(out, 'b h n d -> b n (h d)') + out = rearrange(out, "b h n d -> b n (h d)") # alphafold2 styled gating of the values @@ -1031,66 +1156,67 @@ def forward( out = self.to_out(out) if exists(mask): - mask = rearrange(mask, 'b n -> b n 1') - out = out.masked_fill(~mask, 0.) + mask = rearrange(mask, "b n -> b n 1") + out = out.masked_fill(~mask, 0.0) return out, intermediates + class AttentionLayers(nn.Module): def __init__( self, dim, depth, - heads = 8, - causal = False, - cross_attend = False, - only_cross = False, - use_scalenorm = False, - use_rmsnorm = False, - use_simple_rmsnorm = False, - alibi_pos_bias = False, - alibi_num_heads = None, - rel_pos_bias = False, - rel_pos_num_buckets = 32, - rel_pos_max_distance = 128, - dynamic_pos_bias = False, - dynamic_pos_bias_log_distance = False, - dynamic_pos_bias_mlp_depth = 2, - dynamic_pos_bias_norm = False, - rotary_pos_emb = False, - rotary_emb_dim = None, - rotary_xpos = False, - rotary_interpolation_factor = 1., - rotary_xpos_scale_base = 512, - rotary_base_rescale_factor = 1., - custom_layers = None, - sandwich_coef = None, - par_ratio = None, - residual_attn = False, - cross_residual_attn = False, - macaron = False, - pre_norm = True, - pre_norm_has_final_norm = True, - gate_residual = False, - scale_residual = False, - scale_residual_constant = 1., - deepnorm = False, - shift_tokens = 0, - sandwich_norm = False, - resi_dual = False, - resi_dual_scale = 1., - zero_init_branch_output = False, - layer_dropout = 0., - cross_attn_tokens_dropout = 0., - **kwargs + heads=8, + causal=False, + cross_attend=False, + only_cross=False, + use_scalenorm=False, + use_rmsnorm=False, + use_simple_rmsnorm=False, + alibi_pos_bias=False, + alibi_num_heads=None, + rel_pos_bias=False, + rel_pos_num_buckets=32, + rel_pos_max_distance=128, + dynamic_pos_bias=False, + dynamic_pos_bias_log_distance=False, + dynamic_pos_bias_mlp_depth=2, + dynamic_pos_bias_norm=False, + rotary_pos_emb=False, + rotary_emb_dim=None, + rotary_xpos=False, + rotary_interpolation_factor=1.0, + rotary_xpos_scale_base=512, + rotary_base_rescale_factor=1.0, + custom_layers=None, + sandwich_coef=None, + par_ratio=None, + residual_attn=False, + cross_residual_attn=False, + macaron=False, + pre_norm=True, + pre_norm_has_final_norm=True, + gate_residual=False, + scale_residual=False, + scale_residual_constant=1.0, + deepnorm=False, + shift_tokens=0, + sandwich_norm=False, + resi_dual=False, + resi_dual_scale=1.0, + zero_init_branch_output=False, + layer_dropout=0.0, + cross_attn_tokens_dropout=0.0, + **kwargs, ): super().__init__() rotary_pos_emb = rotary_pos_emb or rotary_xpos - ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs) - attn_kwargs, kwargs = groupby_prefix_and_trim('attn_', kwargs) + ff_kwargs, kwargs = groupby_prefix_and_trim("ff_", kwargs) + attn_kwargs, kwargs = groupby_prefix_and_trim("attn_", kwargs) - dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD) + dim_head = attn_kwargs.get("dim_head", DEFAULT_DIM_HEAD) self.dim = dim self.depth = depth @@ -1100,39 +1226,81 @@ def __init__( rotary_emb_dim = max(default(rotary_emb_dim, dim_head // 2), 32) - assert not (rotary_xpos and not causal), 'rotary xpos is not compatible with bidirectional attention' - self.rotary_pos_emb = RotaryEmbedding(rotary_emb_dim, use_xpos = rotary_xpos, scale_base = rotary_xpos_scale_base, interpolation_factor = rotary_interpolation_factor, base_rescale_factor = rotary_base_rescale_factor) if rotary_pos_emb else None + assert not ( + rotary_xpos and not causal + ), "rotary xpos is not compatible with bidirectional attention" + self.rotary_pos_emb = ( + RotaryEmbedding( + rotary_emb_dim, + use_xpos=rotary_xpos, + scale_base=rotary_xpos_scale_base, + interpolation_factor=rotary_interpolation_factor, + base_rescale_factor=rotary_base_rescale_factor, + ) + if rotary_pos_emb + else None + ) - assert not (alibi_pos_bias and rel_pos_bias), 'you can only choose Alibi positional bias or T5 relative positional bias, not both' - assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance' + assert not ( + alibi_pos_bias and rel_pos_bias + ), "you can only choose Alibi positional bias or T5 relative positional bias, not both" + assert ( + rel_pos_num_buckets <= rel_pos_max_distance + ), "number of relative position buckets must be less than the relative position max distance" # relative positional bias - flash_attn = attn_kwargs.get('flash', False) - assert (int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias)) <= 1, 'you can only choose up to one of t5, alibi, or dynamic positional bias' + flash_attn = attn_kwargs.get("flash", False) + assert ( + int(rel_pos_bias) + int(dynamic_pos_bias) + int(alibi_pos_bias) + ) <= 1, "you can only choose up to one of t5, alibi, or dynamic positional bias" self.rel_pos = None if rel_pos_bias: - assert not flash_attn, 'flash attention not compatible with t5 relative positional bias' - self.rel_pos = RelativePositionBias(scale = dim_head ** 0.5, causal = causal, heads = heads, num_buckets = rel_pos_num_buckets, max_distance = rel_pos_max_distance) + assert ( + not flash_attn + ), "flash attention not compatible with t5 relative positional bias" + self.rel_pos = RelativePositionBias( + scale=dim_head**0.5, + causal=causal, + heads=heads, + num_buckets=rel_pos_num_buckets, + max_distance=rel_pos_max_distance, + ) elif dynamic_pos_bias: - assert not flash_attn, 'flash attention not compatible with dynamic positional bias' - self.rel_pos = DynamicPositionBias(dim = dim // 4, heads = heads, log_distance = dynamic_pos_bias_log_distance, depth = dynamic_pos_bias_mlp_depth, norm = dynamic_pos_bias_norm) + assert ( + not flash_attn + ), "flash attention not compatible with dynamic positional bias" + self.rel_pos = DynamicPositionBias( + dim=dim // 4, + heads=heads, + log_distance=dynamic_pos_bias_log_distance, + depth=dynamic_pos_bias_mlp_depth, + norm=dynamic_pos_bias_norm, + ) elif alibi_pos_bias: alibi_num_heads = default(alibi_num_heads, heads) - assert alibi_num_heads <= heads, 'number of ALiBi heads must be less than the total number of heads' - self.rel_pos = AlibiPositionalBias(heads = alibi_num_heads, total_heads = heads) + assert ( + alibi_num_heads <= heads + ), "number of ALiBi heads must be less than the total number of heads" + self.rel_pos = AlibiPositionalBias(heads=alibi_num_heads, total_heads=heads) # determine deepnorm and residual scale if deepnorm: - assert scale_residual_constant == 1, 'scale residual constant is being overridden by deep norm settings' + assert ( + scale_residual_constant == 1 + ), "scale residual constant is being overridden by deep norm settings" pre_norm = sandwich_norm = resi_dual = False scale_residual = True scale_residual_constant = (2 * depth) ** 0.25 - assert (int(sandwich_norm) + int(resi_dual)) <= 1, 'either sandwich norm or resiDual is selected, but not both' - assert not (not pre_norm and sandwich_norm), 'sandwich norm cannot be used when not using prenorm' + assert ( + int(sandwich_norm) + int(resi_dual) + ) <= 1, "either sandwich norm or resiDual is selected, but not both" + assert not ( + not pre_norm and sandwich_norm + ), "sandwich norm cannot be used when not using prenorm" if resi_dual: pre_norm = False @@ -1141,16 +1309,22 @@ def __init__( self.sandwich_norm = sandwich_norm self.resi_dual = resi_dual - assert 0 < resi_dual_scale <= 1., 'resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1.' + assert ( + 0 < resi_dual_scale <= 1.0 + ), "resiDual prenorm residual must be scaled by a factor greater than 0 and less than or equal to 1." self.resi_dual_scale = resi_dual_scale self.residual_attn = residual_attn self.cross_residual_attn = cross_residual_attn - assert not (flash_attn and (residual_attn or cross_residual_attn)), 'flash attention is not compatible with residual attention' + assert not ( + flash_attn and (residual_attn or cross_residual_attn) + ), "flash attention is not compatible with residual attention" self.cross_attend = cross_attend - assert (int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm)) <= 1, 'you can only use either scalenorm, rmsnorm, or simple rmsnorm' + assert ( + int(use_scalenorm) + int(use_rmsnorm) + int(use_simple_rmsnorm) + ) <= 1, "you can only use either scalenorm, rmsnorm, or simple rmsnorm" if use_scalenorm: norm_class = ScaleNorm @@ -1164,20 +1338,20 @@ def __init__( norm_fn = partial(norm_class, dim) if cross_attend and not only_cross: - default_block = ('a', 'c', 'f') + default_block = ("a", "c", "f") elif cross_attend and only_cross: - default_block = ('c', 'f') + default_block = ("c", "f") else: - default_block = ('a', 'f') + default_block = ("a", "f") if macaron: - default_block = ('f',) + default_block + default_block = ("f",) + default_block # zero init if zero_init_branch_output: - attn_kwargs = {**attn_kwargs, 'zero_init_output': True} - ff_kwargs = {**ff_kwargs, 'zero_init_output': True} + attn_kwargs = {**attn_kwargs, "zero_init_output": True} + ff_kwargs = {**ff_kwargs, "zero_init_output": True} # calculate layer block order @@ -1185,23 +1359,33 @@ def __init__( layer_types = custom_layers elif exists(par_ratio): par_depth = depth * len(default_block) - assert 1 < par_ratio <= par_depth, 'par ratio out of range' - default_block = tuple(filter(not_equals('f'), default_block)) - par_attn = par_depth // par_ratio - depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper + assert 1 < par_ratio <= par_depth, "par ratio out of range" + default_block = tuple(filter(not_equals("f"), default_block)) + par_attn = par_depth // par_ratio + depth_cut = ( + par_depth * 2 // 3 + ) # 2 / 3 attention layer cutoff suggested by PAR paper par_width = (depth_cut + depth_cut // par_attn) // par_attn - assert len(default_block) <= par_width, 'default block is too large for par_ratio' - par_block = default_block + ('f',) * (par_width - len(default_block)) + assert ( + len(default_block) <= par_width + ), "default block is too large for par_ratio" + par_block = default_block + ("f",) * (par_width - len(default_block)) par_head = par_block * par_attn - layer_types = par_head + ('f',) * (par_depth - len(par_head)) + layer_types = par_head + ("f",) * (par_depth - len(par_head)) elif exists(sandwich_coef): - assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth' - layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef + assert ( + sandwich_coef > 0 and sandwich_coef <= depth + ), "sandwich coefficient should be less than the depth" + layer_types = ( + ("a",) * sandwich_coef + + default_block * (depth - sandwich_coef) + + ("f",) * sandwich_coef + ) else: layer_types = default_block * depth self.layer_types = layer_types - self.num_attn_layers = len(list(filter(equals('a'), layer_types))) + self.num_attn_layers = len(list(filter(equals("a"), layer_types))) # stochastic depth @@ -1221,18 +1405,20 @@ def __init__( # iterate and construct layers - for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)): - is_last_layer = ind == (len(self.layer_types) - 1) + for ind, (layer_type, layer_shift_tokens) in enumerate( + zip(self.layer_types, shift_tokens) + ): + ind == (len(self.layer_types) - 1) - if layer_type == 'a': - layer = Attention(dim, heads = heads, causal = causal, **attn_kwargs) - elif layer_type == 'c': - layer = Attention(dim, heads = heads, **attn_kwargs) - elif layer_type == 'f': + if layer_type == "a": + layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs) + elif layer_type == "c": + layer = Attention(dim, heads=heads, **attn_kwargs) + elif layer_type == "f": layer = FeedForward(dim, **ff_kwargs) layer = layer if not macaron else Scale(0.5, layer) else: - raise Exception(f'invalid layer type {layer_type}') + raise Exception(f"invalid layer type {layer_type}") if layer_shift_tokens > 0: shift_range_upper = layer_shift_tokens + 1 @@ -1240,23 +1426,19 @@ def __init__( layer = ShiftTokens(range(shift_range_lower, shift_range_upper), layer) residual_fn = GRUGating if gate_residual else Residual - residual = residual_fn(dim, scale_residual = scale_residual, scale_residual_constant = scale_residual_constant) + residual = residual_fn( + dim, + scale_residual=scale_residual, + scale_residual_constant=scale_residual_constant, + ) pre_branch_norm = norm_fn() if pre_norm else None post_branch_norm = norm_fn() if sandwich_norm else None post_main_norm = norm_fn() if not pre_norm else None - norms = nn.ModuleList([ - pre_branch_norm, - post_branch_norm, - post_main_norm - ]) + norms = nn.ModuleList([pre_branch_norm, post_branch_norm, post_main_norm]) - self.layers.append(nn.ModuleList([ - norms, - layer, - residual - ])) + self.layers.append(nn.ModuleList([norms, layer, residual])) if deepnorm: init_gain = (8 * depth) ** -0.25 @@ -1265,15 +1447,17 @@ def __init__( def forward( self, x, - context = None, - mask = None, - context_mask = None, - attn_mask = None, - self_attn_context_mask = None, - mems = None, - return_hiddens = False + context=None, + mask=None, + context_mask=None, + attn_mask=None, + self_attn_context_mask=None, + mems=None, + return_hiddens=False, ): - assert not (self.cross_attend ^ exists(context)), 'context must be passed in if cross_attend is set to True' + assert not ( + self.cross_attend ^ exists(context) + ), "context must be passed in if cross_attend is set to True" hiddens = [] layer_hiddens = [] @@ -1286,25 +1470,31 @@ def forward( rotary_pos_emb = None if exists(self.rotary_pos_emb): - max_rotary_emb_length = max(list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems))) + max_rotary_emb_length = max( + list(map(lambda m: (m.shape[1] if exists(m) else 0) + x.shape[1], mems)) + ) rotary_pos_emb = self.rotary_pos_emb(max_rotary_emb_length, x.device) outer_residual = x * self.resi_dual_scale - for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate(zip(self.layer_types, self.layers, self.layer_dropouts)): - is_last = ind == (len(self.layers) - 1) + for ind, (layer_type, (norm, block, residual_fn), layer_dropout) in enumerate( + zip(self.layer_types, self.layers, self.layer_dropouts) + ): + ind == (len(self.layers) - 1) - if self.training and layer_dropout > 0. and random() < layer_dropout: + if self.training and layer_dropout > 0.0 and random() < layer_dropout: continue - if layer_type == 'a': + if layer_type == "a": if return_hiddens: hiddens.append(x) layer_mem = mems.pop(0) if mems else None - if layer_type == 'c': - if self.training and self.cross_attn_tokens_dropout > 0.: - context, context_mask = dropout_seq(context, context_mask, self.cross_attn_tokens_dropout) + if layer_type == "c": + if self.training and self.cross_attn_tokens_dropout > 0.0: + context, context_mask = dropout_seq( + context, context_mask, self.cross_attn_tokens_dropout + ) inner_residual = x @@ -1316,11 +1506,26 @@ def forward( if exists(pre_norm): x = pre_norm(x) - if layer_type == 'a': - out, inter = block(x, mask = mask, context_mask = self_attn_context_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, mem = layer_mem) - elif layer_type == 'c': - out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn) - elif layer_type == 'f': + if layer_type == "a": + out, inter = block( + x, + mask=mask, + context_mask=self_attn_context_mask, + attn_mask=attn_mask, + rel_pos=self.rel_pos, + rotary_pos_emb=rotary_pos_emb, + prev_attn=prev_attn, + mem=layer_mem, + ) + elif layer_type == "c": + out, inter = block( + x, + context=context, + mask=mask, + context_mask=context_mask, + prev_attn=prev_cross_attn, + ) + elif layer_type == "f": out = block(x) if self.resi_dual: @@ -1331,12 +1536,12 @@ def forward( x = residual_fn(out, inner_residual) - if layer_type in ('a', 'c') and return_hiddens: + if layer_type in ("a", "c") and return_hiddens: intermediates.append(inter) - if layer_type == 'a' and self.residual_attn: + if layer_type == "a" and self.residual_attn: prev_attn = inter.pre_softmax_attn - elif layer_type == 'c' and self.cross_residual_attn: + elif layer_type == "c" and self.cross_residual_attn: prev_cross_attn = inter.pre_softmax_attn if exists(post_main_norm): @@ -1352,28 +1557,32 @@ def forward( if return_hiddens: intermediates = LayerIntermediates( - hiddens = hiddens, - attn_intermediates = intermediates, - layer_hiddens = layer_hiddens + hiddens=hiddens, + attn_intermediates=intermediates, + layer_hiddens=layer_hiddens, ) return x, intermediates return x + class Encoder(AttentionLayers): def __init__(self, **kwargs): - assert 'causal' not in kwargs, 'cannot set causality on encoder' - super().__init__(causal = False, **kwargs) + assert "causal" not in kwargs, "cannot set causality on encoder" + super().__init__(causal=False, **kwargs) + class Decoder(AttentionLayers): def __init__(self, **kwargs): - assert 'causal' not in kwargs, 'cannot set causality on decoder' - super().__init__(causal = True, **kwargs) + assert "causal" not in kwargs, "cannot set causality on decoder" + super().__init__(causal=True, **kwargs) + class CrossAttender(AttentionLayers): def __init__(self, **kwargs): - super().__init__(cross_attend = True, only_cross = True, **kwargs) + super().__init__(cross_attend=True, only_cross=True, **kwargs) + class ViTransformerWrapper(nn.Module): def __init__( @@ -1382,26 +1591,26 @@ def __init__( image_size, patch_size, attn_layers, - channels = 3, - num_classes = None, - post_emb_norm = False, - emb_dropout = 0. + channels=3, + num_classes=None, + post_emb_norm=False, + emb_dropout=0.0, ): super().__init__() - assert isinstance(attn_layers, Encoder), 'attention layers must be an Encoder' - assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size' + assert isinstance(attn_layers, Encoder), "attention layers must be an Encoder" + assert ( + image_size % patch_size == 0 + ), "image dimensions must be divisible by the patch size" dim = attn_layers.dim num_patches = (image_size // patch_size) ** 2 - patch_dim = channels * patch_size ** 2 + patch_dim = channels * patch_size**2 self.patch_size = patch_size self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim)) self.patch_to_embedding = nn.Sequential( - nn.LayerNorm(patch_dim), - nn.Linear(patch_dim, dim), - nn.LayerNorm(dim) + nn.LayerNorm(patch_dim), nn.Linear(patch_dim, dim), nn.LayerNorm(dim) ) self.post_emb_norm = nn.LayerNorm(dim) if post_emb_norm else nn.Identity() @@ -1409,16 +1618,14 @@ def __init__( self.attn_layers = attn_layers - self.mlp_head = nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity() + self.mlp_head = ( + nn.Linear(dim, num_classes) if exists(num_classes) else nn.Identity() + ) - def forward( - self, - img, - return_embeddings = False - ): + def forward(self, img, return_embeddings=False): p = self.patch_size - x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p) + x = rearrange(img, "b c (h p1) (w p2) -> b (h w) (p1 p2 c)", p1=p, p2=p) x = self.patch_to_embedding(x) n = x.shape[1] @@ -1432,9 +1639,10 @@ def forward( if not exists(self.mlp_head) or return_embeddings: return x - x = x.mean(dim = -2) + x = x.mean(dim=-2) return self.mlp_head(x) + class Transformer(nn.Module): def __init__( self, @@ -1442,22 +1650,24 @@ def __init__( num_tokens, max_seq_len, attn_layers, - emb_dim = None, - max_mem_len = 0, - shift_mem_down = 0, - emb_dropout = 0., - post_emb_norm = False, - num_memory_tokens = None, - tie_embedding = False, - logits_dim = None, - use_abs_pos_emb = True, - scaled_sinu_pos_emb = False, - l2norm_embed = False, - emb_frac_gradient = 1., # GLM-130B and Cogview successfully used this, set at 0.1 - attn_z_loss_weight = 1e-4 + emb_dim=None, + max_mem_len=0, + shift_mem_down=0, + emb_dropout=0.0, + post_emb_norm=False, + num_memory_tokens=None, + tie_embedding=False, + logits_dim=None, + use_abs_pos_emb=True, + scaled_sinu_pos_emb=False, + l2norm_embed=False, + emb_frac_gradient=1.0, # GLM-130B and Cogview successfully used this, set at 0.1 + attn_z_loss_weight=1e-4, ): super().__init__() - assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder' + assert isinstance( + attn_layers, AttentionLayers + ), "attention layers must be one of Encoder or Decoder" dim = attn_layers.dim emb_dim = default(emb_dim, dim) @@ -1469,16 +1679,18 @@ def __init__( self.shift_mem_down = shift_mem_down self.l2norm_embed = l2norm_embed - self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed = l2norm_embed) + self.token_emb = TokenEmbedding(emb_dim, num_tokens, l2norm_embed=l2norm_embed) if not (use_abs_pos_emb and not attn_layers.has_pos_emb): self.pos_emb = always(0) elif scaled_sinu_pos_emb: self.pos_emb = ScaledSinusoidalEmbedding(emb_dim) else: - self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len, l2norm_embed = l2norm_embed) + self.pos_emb = AbsolutePositionalEmbedding( + emb_dim, max_seq_len, l2norm_embed=l2norm_embed + ) - self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290 + self.emb_frac_gradient = emb_frac_gradient # fraction of the gradient that should go to the embedding, https://arxiv.org/abs/2105.13290 self.post_emb_norm = nn.LayerNorm(emb_dim) if post_emb_norm else nn.Identity() self.emb_dropout = nn.Dropout(emb_dropout) @@ -1489,7 +1701,11 @@ def __init__( self.init_() logits_dim = default(logits_dim, num_tokens) - self.to_logits = nn.Linear(dim, logits_dim) if not tie_embedding else lambda t: t @ self.token_emb.emb.weight.t() + self.to_logits = ( + nn.Linear(dim, logits_dim) + if not tie_embedding + else lambda t: t @ self.token_emb.emb.weight.t() + ) # memory tokens (like [cls]) from Memory Transformers paper num_memory_tokens = default(num_memory_tokens, 0) @@ -1499,9 +1715,9 @@ def __init__( def init_(self): if self.l2norm_embed: - nn.init.normal_(self.token_emb.emb.weight, std = 1e-5) + nn.init.normal_(self.token_emb.emb.weight, std=1e-5) if not isinstance(self.pos_emb, always): - nn.init.normal_(self.pos_emb.emb.weight, std = 1e-5) + nn.init.normal_(self.pos_emb.emb.weight, std=1e-5) return nn.init.kaiming_normal_(self.token_emb.emb.weight) @@ -1509,27 +1725,34 @@ def init_(self): def forward( self, x, - return_embeddings = False, - return_logits_and_embeddings = False, - return_intermediates = False, - mask = None, - return_mems = False, - return_attn = False, - mems = None, - pos = None, - prepend_embeds = None, - sum_embeds = None, - return_attn_z_loss = False, - attn_z_loss_weight = 1e-4, - **kwargs + return_embeddings=False, + return_logits_and_embeddings=False, + return_intermediates=False, + mask=None, + return_mems=False, + return_attn=False, + mems=None, + pos=None, + prepend_embeds=None, + sum_embeds=None, + return_attn_z_loss=False, + attn_z_loss_weight=1e-4, + **kwargs, ): - b, n, device, num_mem, emb_frac_gradient = *x.shape, x.device, self.num_memory_tokens, self.emb_frac_gradient - return_hiddens = return_mems | return_attn | return_intermediates | return_attn_z_loss + b, n, device, num_mem, emb_frac_gradient = ( + *x.shape, + x.device, + self.num_memory_tokens, + self.emb_frac_gradient, + ) + return_hiddens = ( + return_mems | return_attn | return_intermediates | return_attn_z_loss + ) # absolute positional embedding external_pos_emb = exists(pos) and pos.dtype != torch.long - pos_emb = self.pos_emb(x, pos = pos) if not external_pos_emb else pos + pos_emb = self.pos_emb(x, pos=pos) if not external_pos_emb else pos x = self.token_emb(x) + pos_emb # for summing embeddings passed externally - needs this for self-conditioning in non-autoregressive training @@ -1545,9 +1768,11 @@ def forward( if exists(prepend_embeds): prepend_seq, prepend_dim = prepend_embeds.shape[1:] - assert prepend_dim == x.shape[-1], 'prepended embeddings need to have same dimensions as text model dimensions' + assert ( + prepend_dim == x.shape[-1] + ), "prepended embeddings need to have same dimensions as text model dimensions" - x = torch.cat((prepend_embeds, x), dim = -2) + x = torch.cat((prepend_embeds, x), dim=-2) # whether to reduce the gradient going to the embedding, from cogview paper, corroborated by GLM-130B model @@ -1562,21 +1787,23 @@ def forward( x = self.project_emb(x) if num_mem > 0: - mem = repeat(self.memory_tokens, 'n d -> b n d', b = b) - x = torch.cat((mem, x), dim = 1) + mem = repeat(self.memory_tokens, "n d -> b n d", b=b) + x = torch.cat((mem, x), dim=1) # auto-handle masking after appending memory tokens if exists(mask): - mask = pad_at_dim(mask, (num_mem, 0), dim = -1, value = True) + mask = pad_at_dim(mask, (num_mem, 0), dim=-1, value=True) if self.shift_mem_down and exists(mems): - mems_l, mems_r = mems[:self.shift_mem_down], mems[self.shift_mem_down:] + mems_l, mems_r = mems[: self.shift_mem_down], mems[self.shift_mem_down :] mems = [*mems_r, *mems_l] if return_hiddens: - x, intermediates = self.attn_layers(x, mask = mask, mems = mems, return_hiddens = True, **kwargs) + x, intermediates = self.attn_layers( + x, mask=mask, mems=mems, return_hiddens=True, **kwargs + ) else: - x = self.attn_layers(x, mask = mask, mems = mems, **kwargs) + x = self.attn_layers(x, mask=mask, mems=mems, **kwargs) mem, x = x[:, :num_mem], x[:, num_mem:] @@ -1588,8 +1815,12 @@ def forward( out = self.to_logits(x) if return_attn_z_loss: - pre_softmax_attns = list(map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates)) - intermediates.attn_z_loss = calc_z_loss(pre_softmax_attns, weight = attn_z_loss_weight) + pre_softmax_attns = list( + map(lambda t: t.pre_softmax_attn, intermediates.attn_intermediates) + ) + intermediates.attn_z_loss = calc_z_loss( + pre_softmax_attns, weight=attn_z_loss_weight + ) return_intermediates = True if return_intermediates: @@ -1597,12 +1828,20 @@ def forward( if return_mems: hiddens = intermediates.hiddens - new_mems = list(map(lambda pair: torch.cat(pair, dim = -2), zip(mems, hiddens))) if exists(mems) else hiddens - new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems)) + new_mems = ( + list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) + if exists(mems) + else hiddens + ) + new_mems = list( + map(lambda t: t[..., -self.max_mem_len :, :].detach(), new_mems) + ) return out, new_mems if return_attn: - attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates)) + attn_maps = list( + map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates) + ) return out, attn_maps - return out \ No newline at end of file + return out