From 8110e028c7fe496287d9092d2255f3b7fa6bdd2d Mon Sep 17 00:00:00 2001 From: Casper Date: Sat, 4 Nov 2023 13:42:45 +0100 Subject: [PATCH] Create fused LlamaLikeModel (#152) --- awq/models/aquila.py | 147 +++++++++++++++---------------------- awq/models/llama.py | 145 ++++++++++++++++-------------------- awq/models/mistral.py | 147 ++++++++++++++++--------------------- awq/modules/fused/attn.py | 11 --- awq/modules/fused/block.py | 33 +++++++++ awq/modules/fused/model.py | 80 +++++++++++++++----- awq/utils/fused_utils.py | 57 ++++++++++++++ examples/basic_generate.py | 12 +-- 8 files changed, 343 insertions(+), 289 deletions(-) diff --git a/awq/models/aquila.py b/awq/models/aquila.py index c935ca98..e4f3d73d 100644 --- a/awq/models/aquila.py +++ b/awq/models/aquila.py @@ -1,40 +1,41 @@ -## Reference from llama.py +import tqdm +from typing import List, Tuple from .base import BaseAWQForCausalLM +from awq.utils.fused_utils import fuse_qkv +from awq.modules.fused.block import LlamaLikeBlock +from awq.modules.fused.model import LlamaLikeModel from transformers.models.llama.modeling_llama import ( - LlamaDecoderLayer as AquilaDecoderLayer, - LlamaForCausalLM as AquilaForCausalLM, - LlamaAttention as AquilaAttention, - LlamaRMSNorm as AquilaRMSNorm, - LlamaMLP as AquilaMLP + LlamaDecoderLayer as OldAquilaDecoderLayer, + LlamaForCausalLM as OldAquilaForCausalLM ) +from awq.modules.fused.mlp import QuantLlamaMLP +from awq.modules.fused.norm import FasterTransformerRMSNorm class AquilaAWQForCausalLM(BaseAWQForCausalLM): layer_type = "AquilaDecoderLayer" max_new_tokens_key = "max_position_embeddings" @staticmethod - def fuse_layers(model: AquilaForCausalLM): + def fuse_layers(model: OldAquilaForCausalLM): fuser = AquilaFuser(model) - fuser.fuse_attention() - fuser.fuse_rmsnorm() - fuser.fuse_mlp() + fuser.fuse_transformer() @staticmethod - def get_model_layers(model: AquilaForCausalLM): + def get_model_layers(model: OldAquilaForCausalLM): return model.model.layers @staticmethod - def get_act_for_scaling(module: AquilaDecoderLayer): + def get_act_for_scaling(module: OldAquilaDecoderLayer): return dict( is_scalable=False ) @staticmethod - def move_embed(model: AquilaForCausalLM, device: str): + def move_embed(model: OldAquilaForCausalLM, device: str): model.model.embed_tokens = model.model.embed_tokens.to(device) @staticmethod - def get_layers_for_scaling(module: AquilaDecoderLayer, input_feat, module_kwargs): + def get_layers_for_scaling(module: OldAquilaDecoderLayer, input_feat, module_kwargs): layers = [] # attention input @@ -72,85 +73,57 @@ def get_layers_for_scaling(module: AquilaDecoderLayer, input_feat, module_kwargs return layers -import torch -from typing import List, Tuple, Union -from awq.utils.utils import set_module_name -from awq.modules.fused.mlp import QuantLlamaMLP -from awq.modules.fused.attn import QuantAttentionFused -from awq.modules.fused.norm import FasterTransformerRMSNorm -from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV class AquilaFuser: - def __init__(self, model): + def __init__(self, model: OldAquilaForCausalLM): self.model = model - self.attention_modules: List[Tuple[str, AquilaAttention]] = [ - (name, module) for name, module in self.model.named_modules() - if "AquilaAttention".lower() in module.__class__.__name__.lower() - ] - - self.rmsnorm_modules: List[Tuple[str, AquilaRMSNorm]] = [ - (name, module) for name, module in self.model.named_modules() - if "AquilaRMSNorm".lower() in module.__class__.__name__.lower() - ] - - self.mlp_modules: List[Tuple[str, AquilaMLP]] = [ + self.aquila_blocks: List[Tuple[str, OldAquilaDecoderLayer]] = [ (name, module) for name, module in self.model.named_modules() - if "AquilaMLP".lower() in module.__class__.__name__.lower() + if 'AquilaDecoderLayer'.lower() in module.__class__.__name__.lower() ] - def fuse_attention(self): - for name, module in self.attention_modules: - qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module) - attn = QuantAttentionFused( - module.hidden_size, - module.num_heads, - module.num_key_value_heads, - qkv_layer, - module.o_proj, - next(iter(qkv_layer.state_dict().values())).device, - self.model.config.max_new_tokens + def fuse_transformer(self): + blocks = [] + + module: OldAquilaDecoderLayer + for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): + device = next(iter(module.state_dict().values())).device + qkv = fuse_qkv( + module, + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj ) - set_module_name(self.model, name, attn) - - def _fuse_qkv(self, module: AquilaAttention): - q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj - bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None - - if isinstance(q_proj, WQLinear_GEMV): - q_linear = WQLinear_GEMV - else: - q_linear = WQLinear_GEMM - - qkv_layer = q_linear( - q_proj.w_bit, - q_proj.group_size, - q_proj.in_features, - q_proj.out_features + k_proj.out_features + v_proj.out_features, - q_proj.bias is not None, - next(iter(module.state_dict().values())).device - ) - - if isinstance(qkv_layer, WQLinear_GEMV): - qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0) - qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0) - qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) - qkv_layer.split_k_iters = q_proj.split_k_iters - else: - qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) - qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) - qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + mlp = QuantLlamaMLP( + module.mlp.gate_proj, + module.mlp.down_proj, + module.mlp.up_proj + ) + norm_1 = FasterTransformerRMSNorm( + module.input_layernorm.weight, + module.input_layernorm.variance_epsilon + ) + norm_2 = FasterTransformerRMSNorm( + module.post_attention_layernorm.weight, + module.post_attention_layernorm.variance_epsilon + ) + blocks.append(LlamaLikeBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_new_tokens + )) - qkv_layer.bias = bias - - return qkv_layer - - def fuse_rmsnorm(self): - for name, module in self.rmsnorm_modules: - norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon) - set_module_name(self.model, name, norm) - - def fuse_mlp(self): - for name, module in self.mlp_modules: - mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj) - set_module_name(self.model, name, mlp) + self.model.model = LlamaLikeModel( + self.model.config.vocab_size, + blocks, + self.model.model.embed_tokens, + self.model.model.norm, + ) diff --git a/awq/models/llama.py b/awq/models/llama.py index c67f39c7..88be8d3e 100644 --- a/awq/models/llama.py +++ b/awq/models/llama.py @@ -1,33 +1,41 @@ +import tqdm +from typing import List, Tuple from .base import BaseAWQForCausalLM -from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM +from awq.utils.fused_utils import fuse_qkv +from awq.modules.fused.block import LlamaLikeBlock +from awq.modules.fused.model import LlamaLikeModel +from transformers.models.llama.modeling_llama import ( + LlamaDecoderLayer as OldLlamaDecoderLayer, + LlamaForCausalLM as OldLlamaForCausalLM +) +from awq.modules.fused.mlp import QuantLlamaMLP +from awq.modules.fused.norm import FasterTransformerRMSNorm class LlamaAWQForCausalLM(BaseAWQForCausalLM): layer_type = "LlamaDecoderLayer" max_new_tokens_key = "max_position_embeddings" @staticmethod - def fuse_layers(model: LlamaForCausalLM): + def fuse_layers(model: OldLlamaForCausalLM): fuser = LlamaFuser(model) - fuser.fuse_attention() - fuser.fuse_rmsnorm() - fuser.fuse_mlp() + fuser.fuse_transformer() @staticmethod - def get_model_layers(model: LlamaForCausalLM): + def get_model_layers(model: OldLlamaForCausalLM): return model.model.layers @staticmethod - def get_act_for_scaling(module: LlamaDecoderLayer): + def get_act_for_scaling(module: OldLlamaDecoderLayer): return dict( is_scalable=False ) @staticmethod - def move_embed(model: LlamaForCausalLM, device: str): + def move_embed(model: OldLlamaForCausalLM, device: str): model.model.embed_tokens = model.model.embed_tokens.to(device) @staticmethod - def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs): + def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs): layers = [] # attention input @@ -65,86 +73,57 @@ def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs) return layers -import torch -from typing import List, Tuple, Union -from awq.utils.utils import set_module_name -from awq.modules.fused.mlp import QuantLlamaMLP -from awq.modules.fused.attn import QuantAttentionFused -from awq.modules.fused.norm import FasterTransformerRMSNorm -from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV -from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP class LlamaFuser: - def __init__(self, model): + def __init__(self, model: OldLlamaForCausalLM): self.model = model - self.attention_modules: List[Tuple[str, LlamaAttention]] = [ - (name, module) for name, module in self.model.named_modules() - if isinstance(module, LlamaAttention) - ] - - self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [ - (name, module) for name, module in self.model.named_modules() - if isinstance(module, LlamaRMSNorm) - ] - - self.mlp_modules: List[Tuple[str, LlamaMLP]] = [ + self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [ (name, module) for name, module in self.model.named_modules() - if isinstance(module, LlamaMLP) + if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower() ] - def fuse_attention(self): - for name, module in self.attention_modules: - qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module) - attn = QuantAttentionFused( - module.hidden_size, - module.num_heads, - module.num_key_value_heads, - qkv_layer, - module.o_proj, - next(iter(qkv_layer.state_dict().values())).device, - self.model.config.max_new_tokens + def fuse_transformer(self): + blocks = [] + + module: OldLlamaDecoderLayer + for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): + device = next(iter(module.state_dict().values())).device + qkv = fuse_qkv( + module, + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj ) - set_module_name(self.model, name, attn) - - def _fuse_qkv(self, module: LlamaAttention): - q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj - bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None - - if isinstance(q_proj, WQLinear_GEMV): - q_linear = WQLinear_GEMV - else: - q_linear = WQLinear_GEMM - - qkv_layer = q_linear( - q_proj.w_bit, - q_proj.group_size, - q_proj.in_features, - q_proj.out_features + k_proj.out_features + v_proj.out_features, - q_proj.bias is not None, - next(iter(module.state_dict().values())).device - ) - - if isinstance(qkv_layer, WQLinear_GEMV): - qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0) - qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0) - qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) - qkv_layer.split_k_iters = q_proj.split_k_iters - else: - qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) - qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) - qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + mlp = QuantLlamaMLP( + module.mlp.gate_proj, + module.mlp.down_proj, + module.mlp.up_proj + ) + norm_1 = FasterTransformerRMSNorm( + module.input_layernorm.weight, + module.input_layernorm.variance_epsilon + ) + norm_2 = FasterTransformerRMSNorm( + module.post_attention_layernorm.weight, + module.post_attention_layernorm.variance_epsilon + ) + blocks.append(LlamaLikeBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_new_tokens + )) - qkv_layer.bias = bias - - return qkv_layer - - def fuse_rmsnorm(self): - for name, module in self.rmsnorm_modules: - norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon) - set_module_name(self.model, name, norm) - - def fuse_mlp(self): - for name, module in self.mlp_modules: - mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj) - set_module_name(self.model, name, mlp) \ No newline at end of file + self.model.model = LlamaLikeModel( + self.model.config.vocab_size, + blocks, + self.model.model.embed_tokens, + self.model.model.norm, + ) diff --git a/awq/models/mistral.py b/awq/models/mistral.py index 9c9fd76a..0430a81e 100644 --- a/awq/models/mistral.py +++ b/awq/models/mistral.py @@ -1,33 +1,41 @@ +import tqdm +from typing import List, Tuple from .base import BaseAWQForCausalLM -from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralForCausalLM +from awq.utils.fused_utils import fuse_qkv +from awq.modules.fused.block import LlamaLikeBlock +from awq.modules.fused.model import LlamaLikeModel +from transformers.models.mistral.modeling_mistral import ( + MistralDecoderLayer as OldMistralDecoderLayer, + MistralForCausalLM as OldMistralForCausalLM +) +from awq.modules.fused.mlp import QuantLlamaMLP +from awq.modules.fused.norm import FasterTransformerRMSNorm class MistralAWQForCausalLM(BaseAWQForCausalLM): layer_type = "MistralDecoderLayer" max_new_tokens_key = "max_position_embeddings" @staticmethod - def fuse_layers(model: MistralForCausalLM): + def fuse_layers(model: OldMistralForCausalLM): fuser = MistralFuser(model) - fuser.fuse_attention() - fuser.fuse_rmsnorm() - fuser.fuse_mlp() - + fuser.fuse_transformer() + @staticmethod - def get_model_layers(model: MistralForCausalLM): + def get_model_layers(model: OldMistralForCausalLM): return model.model.layers @staticmethod - def get_act_for_scaling(module: MistralDecoderLayer): + def get_act_for_scaling(module: OldMistralDecoderLayer): return dict( is_scalable=False ) @staticmethod - def move_embed(model: MistralForCausalLM, device: str): + def move_embed(model: OldMistralForCausalLM, device: str): model.model.embed_tokens = model.model.embed_tokens.to(device) @staticmethod - def get_layers_for_scaling(module: MistralDecoderLayer, input_feat, module_kwargs): + def get_layers_for_scaling(module: OldMistralDecoderLayer, input_feat, module_kwargs): layers = [] # attention input @@ -65,86 +73,57 @@ def get_layers_for_scaling(module: MistralDecoderLayer, input_feat, module_kwarg return layers -import torch -from typing import List, Tuple, Union -from awq.utils.utils import set_module_name -from awq.modules.fused.mlp import QuantLlamaMLP -from awq.modules.fused.attn import QuantAttentionFused -from awq.modules.fused.norm import FasterTransformerRMSNorm -from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV -from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRMSNorm, MistralMLP class MistralFuser: - def __init__(self, model): + def __init__(self, model: OldMistralForCausalLM): self.model = model - self.attention_modules: List[Tuple[str, MistralAttention]] = [ + self.mistral_blocks: List[Tuple[str, OldMistralDecoderLayer]] = [ (name, module) for name, module in self.model.named_modules() - if isinstance(module, MistralAttention) - ] - - self.rmsnorm_modules: List[Tuple[str, MistralRMSNorm]] = [ - (name, module) for name, module in self.model.named_modules() - if isinstance(module, MistralRMSNorm) - ] - - self.mlp_modules: List[Tuple[str, MistralMLP]] = [ - (name, module) for name, module in self.model.named_modules() - if isinstance(module, MistralMLP) + if 'MistralDecoderLayer'.lower() in module.__class__.__name__.lower() ] - def fuse_attention(self): - for name, module in self.attention_modules: - qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module) - attn = QuantAttentionFused( - module.hidden_size, - module.num_heads, - module.num_key_value_heads, - qkv_layer, - module.o_proj, - next(iter(qkv_layer.state_dict().values())).device, - self.model.config.max_new_tokens + def fuse_transformer(self): + blocks = [] + + module: OldMistralDecoderLayer + for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): + device = next(iter(module.state_dict().values())).device + qkv = fuse_qkv( + module, + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj ) - set_module_name(self.model, name, attn) - - def _fuse_qkv(self, module: MistralAttention): - q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj - bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None - - if isinstance(q_proj, WQLinear_GEMV): - q_linear = WQLinear_GEMV - else: - q_linear = WQLinear_GEMM - - qkv_layer = q_linear( - q_proj.w_bit, - q_proj.group_size, - q_proj.in_features, - q_proj.out_features + k_proj.out_features + v_proj.out_features, - q_proj.bias is not None, - next(iter(module.state_dict().values())).device - ) - - if isinstance(qkv_layer, WQLinear_GEMV): - qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0) - qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0) - qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) - qkv_layer.split_k_iters = q_proj.split_k_iters - else: - qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) - qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) - qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + mlp = QuantLlamaMLP( + module.mlp.gate_proj, + module.mlp.down_proj, + module.mlp.up_proj + ) + norm_1 = FasterTransformerRMSNorm( + module.input_layernorm.weight, + module.input_layernorm.variance_epsilon + ) + norm_2 = FasterTransformerRMSNorm( + module.post_attention_layernorm.weight, + module.post_attention_layernorm.variance_epsilon + ) + blocks.append(LlamaLikeBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + o_proj=module.self_attn.o_proj, + mlp=mlp, + norm_1=norm_1, + norm_2=norm_2, + dev=device, + max_seq_len=self.model.config.max_new_tokens + )) - qkv_layer.bias = bias - - return qkv_layer - - def fuse_rmsnorm(self): - for name, module in self.rmsnorm_modules: - norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon) - set_module_name(self.model, name, norm) - - def fuse_mlp(self): - for name, module in self.mlp_modules: - mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj) - set_module_name(self.model, name, mlp) + self.model.model = LlamaLikeModel( + self.model.config.vocab_size, + blocks, + self.model.model.embed_tokens, + self.model.model.norm, + ) diff --git a/awq/modules/fused/attn.py b/awq/modules/fused/attn.py index d1436ad1..b80754cd 100644 --- a/awq/modules/fused/attn.py +++ b/awq/modules/fused/attn.py @@ -123,17 +123,6 @@ def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs): bsz, seqlen, _ = hidden_states.shape - # Check if we are under transformers caching regime - has_past_key_value = kwargs is not None and "past_key_value" in kwargs and kwargs["past_key_value"] is not None - - if has_past_key_value: - # In newest transformers version, when using caching the input hidden states do not consist of - # the last generated token only, but of the whole sequence - past-kvlength. We need to slice the last token - # and set `seqlen=1` - if seqlen > 1: - seqlen = 1 - hidden_states = hidden_states[:, -1:] - if bsz != self.cache_batch_size: raise RuntimeError( f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. " diff --git a/awq/modules/fused/block.py b/awq/modules/fused/block.py index cbc7256a..25c42f58 100644 --- a/awq/modules/fused/block.py +++ b/awq/modules/fused/block.py @@ -2,6 +2,39 @@ import torch.nn as nn from awq.modules.fused.attn import QuantAttentionFused +class LlamaLikeBlock(nn.Module): + """ + LlamaLikeBlock is intended to be reused across blocks that have + an architecture that closely resembles Llama, e.g. Mistral and Aquila. + """ + def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, mlp, norm_1, norm_2, dev, max_seq_len): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.hidden_size = hidden_size + self.norm_1 = norm_1.to(dev) + self.attn = QuantAttentionFused( + self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, + dev=dev, max_seq_len=max_seq_len, use_alibi=False + ).to(dev) + self.norm_2 = norm_2.to(dev) + self.mlp = mlp.to(dev) + + def forward( + self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None + ): + norm_out = self.norm_1(hidden_states) + attn_output, _, past_key_value = self.attn.forward( + hidden_states=norm_out, + past_key_value=past_key_value, + attention_mask=attention_mask + ) + + h = hidden_states + attn_output + out = h + self.mlp.forward(self.norm_2(h)) + + return out, None, past_key_value + class MPTBlock(nn.Module): def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len): super().__init__() diff --git a/awq/modules/fused/model.py b/awq/modules/fused/model.py index 9a67e956..69d37596 100644 --- a/awq/modules/fused/model.py +++ b/awq/modules/fused/model.py @@ -1,8 +1,45 @@ import torch import torch.nn as nn from typing import List -from awq.modules.fused.block import MPTBlock, FalconDecoderLayer from transformers.modeling_outputs import BaseModelOutputWithPast +from awq.utils.fused_utils import prepare_attention_mask, prepare_input_ids +from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock + +class LlamaLikeModel(nn.Module): + """ + LlamaLikeModel is intended to be reused across models that have + an architecture that closely resembles Llama, e.g. Mistral and Aquila. + """ + def __init__(self, vocab_size, blocks, embedding, norm): + super().__init__() + self.vocab_size = vocab_size + self.embedding = embedding + self.blocks: List[LlamaLikeBlock] = blocks + self.norm = norm + self.last_forward_num_tokens = 0 + + @torch.inference_mode() + def forward(self, input_ids: torch.Tensor, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): + input_ids, self.last_forward_num_tokens = prepare_input_ids( + input_ids, + self.last_forward_num_tokens + ) + + _bsz, seqlen = input_ids.shape + h = self.embedding(input_ids) + + mask = prepare_attention_mask( + seqlen=seqlen, + start_pos=self.blocks[0].attn.start_pos, + device=input_ids.device, + type_as=h + ) + + for layer in self.blocks: + h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) + h = self.norm(h) + + return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) class MPTModel(nn.Module): def __init__(self, vocab_size, blocks, wte, norm_f): @@ -13,18 +50,24 @@ def __init__(self, vocab_size, blocks, wte, norm_f): self.norm_f = norm_f self.attn_uses_sequence_id = False self.prefix_lm = False + self.last_forward_num_tokens = 0 @torch.inference_mode() def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): + input_ids, self.last_forward_num_tokens = prepare_input_ids( + input_ids, + self.last_forward_num_tokens + ) + _bsz, seqlen = input_ids.shape h = self.wte(input_ids) - mask = None - if seqlen > 1: - mask = torch.full( - (1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device - ) - mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h) + mask = prepare_attention_mask( + seqlen=seqlen, + start_pos=self.blocks[0].attn.start_pos, + device=input_ids.device, + type_as=h + ) for layer in self.blocks: h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) @@ -41,23 +84,24 @@ def __init__(self, vocab_size, blocks, word_embeddings, ln_f): self.ln_f = ln_f self.attn_uses_sequence_id = False self.prefix_lm = False + self.last_forward_num_tokens = 0 @torch.inference_mode() def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): - # NOTE: falcon input ids contain full context - # after context is processed, slice to latest token - if self.blocks[0].attn.start_pos != 0 and input_ids.shape[-1] != 1: - input_ids = input_ids[:, self.blocks[0].attn.start_pos:] - + input_ids, self.last_forward_num_tokens = prepare_input_ids( + input_ids, + self.last_forward_num_tokens + ) + _bsz, seqlen = input_ids.shape h = self.word_embeddings(input_ids) - mask = None - if seqlen > 1: - mask = torch.full( - (1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device - ) - mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h) + mask = prepare_attention_mask( + seqlen=seqlen, + start_pos=self.blocks[0].attn.start_pos, + device=input_ids.device, + type_as=h + ) for layer in self.blocks: h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) diff --git a/awq/utils/fused_utils.py b/awq/utils/fused_utils.py index ff5e9ad0..eaf186b4 100644 --- a/awq/utils/fused_utils.py +++ b/awq/utils/fused_utils.py @@ -1,3 +1,60 @@ +import torch +from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV + +def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int): + # NOTE: from transformers 4.35.0, input_ids includes full context during decoding + num_input_tokens = input_ids.shape[-1] + num_new_tokens = num_input_tokens + + if num_input_tokens != 1: + num_new_tokens = num_input_tokens - last_forward_num_tokens + + # after context is processed, slice to latest token + if num_new_tokens in [0,1]: + input_ids = input_ids[:, -1:] + + return input_ids, last_forward_num_tokens + num_new_tokens + +def prepare_attention_mask(seqlen, start_pos, device, type_as: torch.Tensor): + mask = None + if seqlen > 1: + mask = torch.full( + (1, 1, seqlen, seqlen), float("-inf"), device=device + ) + mask = torch.triu(mask, diagonal=start_pos+ 1).type_as(type_as) + + return mask + +def fuse_qkv(module, q_proj, k_proj, v_proj): + bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None + + if isinstance(q_proj, WQLinear_GEMV): + q_linear = WQLinear_GEMV + else: + q_linear = WQLinear_GEMM + + qkv_layer = q_linear( + q_proj.w_bit, + q_proj.group_size, + q_proj.in_features, + q_proj.out_features + k_proj.out_features + v_proj.out_features, + q_proj.bias is not None, + next(iter(module.state_dict().values())).device + ) + + if isinstance(qkv_layer, WQLinear_GEMV): + qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0) + qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0) + qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) + qkv_layer.split_k_iters = q_proj.split_k_iters + else: + qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) + qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) + qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + + qkv_layer.bias = bias + + return qkv_layer def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim): if attention_shapes is not None: diff --git a/examples/basic_generate.py b/examples/basic_generate.py index e9d9cf4f..b20e31a2 100644 --- a/examples/basic_generate.py +++ b/examples/basic_generate.py @@ -1,7 +1,7 @@ from awq import AutoAWQForCausalLM from transformers import AutoTokenizer, TextStreamer -quant_path = "TheBloke/Mistral-7B-OpenOrca-AWQ" +quant_path = "TheBloke/zephyr-7B-beta-AWQ" # Load model model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True) @@ -10,11 +10,11 @@ # Convert prompt to tokens prompt_template = """\ -<|im_start|>system -You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!<|im_end|> -<|im_start|>user -{prompt}<|im_end|> -<|im_start|>assistant""" +<|system|> + +<|user|> +{prompt} +<|assistant|>""" prompt = "You're standing on the surface of the Earth. "\ "You walk one mile south, one mile west and one mile north. "\