Skip to content

Commit

Permalink
Merge pull request #95 from casper-hansen/cache_refactor
Browse files Browse the repository at this point in the history
Refactor cache and embedding modules
  • Loading branch information
casper-hansen authored Oct 6, 2023
2 parents c9e4527 + b13e2a8 commit 1b68975
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 143 deletions.
240 changes: 100 additions & 140 deletions awq/modules/fused/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,69 +2,88 @@
import math
import torch
import torch.nn as nn
import awq_inference_engine
from torch.nn import functional as F
from awq.modules.fused.cache import WindowedCache
from awq.utils.fused_utils import get_attention_shapes

try:
import ft_inference_engine
FT_INSTALLED = True
except:
FT_INSTALLED = False

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)

def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
):
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)

def gen_slopes(n_heads, alibi_bias_max=8):
_n_heads = 2 ** math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m)
if _n_heads != n_heads:
slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1)


def build_alibi_bias(
n_heads, seq_len, full=False, alibi_bias_max=8, dtype=torch.float32
):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
if full:
alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
1, 1, seq_len, 1
class RoPE(nn.Module):
def __init__(self, hidden_size, n_heads, max_seq_len, device):
super(RoPE, self).__init__()

self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis(hidden_size // n_heads, max_seq_len * 2).to(device),
requires_grad=False
)

@staticmethod
def precompute_freqs_cis(dim: int, end: int, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis

@staticmethod
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)

def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
alibi_bias = alibi_bias.abs().mul(-1)
slopes = gen_slopes(n_heads, alibi_bias_max)
alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)

xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)

return xq_out.type_as(xq), xk_out.type_as(xk)

class ALiBi(nn.Module):
def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
super(ALiBi, self).__init__()

# Initialize ALiBi slopes and bias
slopes, bias = self.build_alibi_bias(n_heads, max_seq_len, alibi_bias_max=alibi_bias_max)
self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)

@staticmethod
def gen_slopes(n_heads, alibi_bias_max=8):
_n_heads = 2 ** math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m)

if _n_heads != n_heads:
slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]

return slopes.view(1, n_heads, 1, 1)

@staticmethod
def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)

def forward(self, scores, seqlen):
scores += self.bias[..., :seqlen]
return scores

class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len,
Expand All @@ -81,74 +100,27 @@ def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max
self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.max_seq_len = max_seq_len
self.attention_shapes = self._get_attention_shapes(attention_shapes, max_seq_len)
self.cache_v = ( torch.zeros(self.attention_shapes["cache_v"]).to(dev).half() )
self.cache_k = ( torch.zeros(self.attention_shapes["cache_k"]).to(dev).half() )

# attention shapes for self attention
self.attention_shapes = get_attention_shapes(
attention_shapes, max_seq_len, self.cache_batch_size, n_heads, n_kv_heads, self.head_dim
)
# cache store that rolls cache
self.cache = WindowedCache(
self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], dev
)

if use_alibi:
alibi_slopes, alibi_bias = build_alibi_bias(self.n_heads, max_seq_len)
self.alibi_slopes = alibi_slopes.float().to(dev)
self.alibi_bias = alibi_bias.float().to(dev)
self.alibi = ALiBi(n_heads, max_seq_len, dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.freqs_cis = precompute_freqs_cis(
hidden_size // n_heads,
max_seq_len * 2,
).to(dev)
self.alibi = None
self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev)
self.rotary_dim = self.head_dim
self.alibi_slopes = None
self.is_neox = True

def _get_attention_shapes(self, attention_shapes, max_seq_len):
if attention_shapes is not None:
attention_shapes = attention_shapes

elif self.n_kv_heads == 0:
attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, self.n_heads, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xq_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_heads, self.head_dim),
"xv_view": (self.n_heads, self.head_dim),
"xk_reshape": (self.n_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_heads, self.head_dim),
"single_xv_view": (self.n_heads, self.head_dim)
}

else:
attention_shapes = {
# following fastertransformer definition
"cache_v": (self.cache_batch_size, self.n_kv_heads, max_seq_len, self.head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (self.cache_batch_size, self.n_kv_heads, self.head_dim // 8, max_seq_len, 8,),
"xqkv_view": (self.n_heads + self.n_kv_heads * 2, self.head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0 : self.n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, self.n_heads : (self.n_heads + self.n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -self.n_kv_heads :],
"xq_view": (self.n_heads, self.head_dim),
"xk_view": (self.n_kv_heads, self.head_dim),
"xv_view": (self.n_kv_heads, self.head_dim),
"xk_reshape": (self.n_kv_heads, self.head_dim // 8, 8),
"single_xq_view": (self.n_heads, self.head_dim),
"single_xk_view": (self.n_kv_heads, self.head_dim),
"single_xv_view": (self.n_kv_heads, self.head_dim)
}

return attention_shapes

def forward(
self,
hidden_states:torch.Tensor, past_key_value=None, attention_mask=None, position_ids=None,
output_attentions=False, use_cache=False, *args, **kwargs
):
def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
bsz, seqlen, _ = hidden_states.shape
if bsz != self.cache_batch_size:
raise RuntimeError(
Expand All @@ -157,14 +129,8 @@ def forward(
)

if self.start_pos > self.max_seq_len or self.start_pos + seqlen > self.max_seq_len:
# Roll cache to the left
roll_len = self.start_pos
self.cache_v = torch.roll(self.cache_v, shifts=-roll_len, dims=2)
self.cache_k = torch.roll(self.cache_k, shifts=-roll_len, dims=3)
# Zero out the new part
self.cache_v[:, :, -roll_len:, :] = 0
self.cache_k[:, :, :, -roll_len:, :] = 0
self.start_pos = 0
excess_length = self.start_pos + seqlen - self.max_seq_len
self.start_pos = self.cache.roll_kv(excess_length, self.start_pos)

xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
Expand All @@ -179,10 +145,9 @@ def forward(
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])

if not self.use_alibi:
xq, xk = apply_rotary_emb(xq, xk, freqs_cis=self.freqs_cis[self.start_pos : self.start_pos + seqlen])
xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)

self.cache_k = self.cache_k.to(xq)
self.cache_v = self.cache_v.to(xq)
self.cache.to(xq)

values_store = xv.transpose(2, 1)
keys_store = (
Expand All @@ -191,13 +156,10 @@ def forward(
.contiguous()
)

self.cache_v[:bsz, :, self.start_pos : self.start_pos + seqlen, :] = values_store
self.cache_k[:bsz, :, :, self.start_pos : self.start_pos + seqlen, :] = keys_store
self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)

if seqlen == 1:
xv = self.cache_v[:bsz, :, : self.start_pos + seqlen, :].transpose(1, 2).contiguous()
xk = self.cache_k[:bsz, :, :, : self.start_pos + seqlen, :].transpose(2, 3).contiguous()
xk = xk.reshape(xk.shape[:-2] + (self.head_dim,)).transpose(1, 2).contiguous()
xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)

keys = xk
values = xv
Expand All @@ -212,7 +174,7 @@ def forward(
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)

if self.use_alibi:
scores += self.alibi_bias[..., :seqlen]
scores = self.alibi.forward(scores, seqlen)

if attention_mask is not None:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
Expand All @@ -225,14 +187,15 @@ def forward(
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])

alibi_slopes = self.alibi.slopes if self.alibi is not None else None
attention_weight = ft_inference_engine.single_query_attention(
xq, # query
xk, # key
xv, # value
self.cache_k, # key cache
self.cache_v, # value cache
self.cache.k, # key cache
self.cache.v, # value cache
None, # length per sample
self.alibi_slopes, # alibi slopes
alibi_slopes, # alibi slopes
self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension
10000, # rotary embedding base
Expand All @@ -241,11 +204,8 @@ def forward(
attention_weight = attention_weight.reshape(bsz, 1, -1)

attn_output = self.o_proj(attention_weight)

if use_cache:
self.start_pos += seqlen
else:
self.start_pos = 0
self.start_pos += seqlen

# past_key_value is replaced with cache_v, cache_k, returning None
return attn_output, attention_weight, None
# past_key_value is replaced with cache_v, cache_k, returning empty data
past_key_value = [torch.Tensor([ [ [[0]], [[0]], [[0]] ] ])]
return attn_output, attention_weight, past_key_value
39 changes: 39 additions & 0 deletions awq/modules/fused/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import torch

class WindowedCache:
def __init__(self, cache_v_shape, cache_k_shape, device):
"""
The window size is the same as the max_new_tokens. The window will
automatically roll once max_new_tokens is exceeded.
"""
# [batch_size, n_kv_heads, max_seq_len, head_dim]
self.v = torch.zeros(cache_v_shape).to(device).half()
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
self.k = torch.zeros(cache_k_shape).to(device).half()

def get_kv(self, batch_size, start_pos, seqlen, head_dim):
xv = self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous()
xk = self.k[:batch_size, :, :, : start_pos + seqlen, :].transpose(2, 3).contiguous()
xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous()

return xv, xk

def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen):
self.v[:batch_size, :, start_pos : start_pos + seqlen, :] = values_store
self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store

def roll_kv(self, roll_len, start_pos):
# Roll only the necessary part of the cache to the left
self.v[:, :, :-roll_len, :] = self.v[:, :, roll_len:, :]
self.k[:, :, :, :-roll_len, :] = self.k[:, :, :, roll_len:, :]

# Zero out the new part
self.v[:, :, -roll_len:, :] = 0
self.k[:, :, :, -roll_len:, :] = 0

return start_pos - roll_len

def to(self, device):
self.k = self.k.to(device)
self.v = self.v.to(device)

Loading

0 comments on commit 1b68975

Please sign in to comment.