From 9f9579371dd54b53c4456176688cf718c97b0454 Mon Sep 17 00:00:00 2001 From: Victor Date: Tue, 11 Jun 2024 00:21:35 +0200 Subject: [PATCH] Phi and Phi-2 support --- awq/models/__init__.py | 1 + awq/models/auto.py | 1 + awq/models/base.py | 1 + awq/models/phi.py | 139 +++++++++++++++++++++++++++++++++++++ awq/modules/fused/block.py | 68 ++++++++++++++++++ awq/modules/fused/model.py | 63 ++++++++++++++++- 6 files changed, 272 insertions(+), 1 deletion(-) create mode 100644 awq/models/phi.py diff --git a/awq/models/__init__.py b/awq/models/__init__.py index dff2fd76..765d7f6f 100644 --- a/awq/models/__init__.py +++ b/awq/models/__init__.py @@ -17,5 +17,6 @@ from .gemma import GemmaAWQForCausalLM from .stablelm import StableLmAWQForCausalLM from .starcoder2 import Starcoder2AWQForCausalLM +from .phi import PhiAWQForCausalLM from .phi3 import Phi3AWQForCausalLM from .cohere import CohereAWQForCausalLM diff --git a/awq/models/auto.py b/awq/models/auto.py index 7c67e899..efd4be43 100644 --- a/awq/models/auto.py +++ b/awq/models/auto.py @@ -26,6 +26,7 @@ "gemma": GemmaAWQForCausalLM, "stablelm": StableLmAWQForCausalLM, "starcoder2": Starcoder2AWQForCausalLM, + "phi": PhiAWQForCausalLM, "phi3": Phi3AWQForCausalLM, "cohere": CohereAWQForCausalLM, } diff --git a/awq/models/base.py b/awq/models/base.py index f6ef9cb7..448f1b37 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -78,6 +78,7 @@ "gemma": "AutoModelForCausalLM", "stablelm": "AutoModelForCausalLM", "starcoder2": "AutoModelForCausalLM", + "phi":"AutoModelForCausalLM", "phi3": "AutoModelForCausalLM", "cohere": "AutoModelForCausalLM", } diff --git a/awq/models/phi.py b/awq/models/phi.py new file mode 100644 index 00000000..9dd5f52a --- /dev/null +++ b/awq/models/phi.py @@ -0,0 +1,139 @@ +import tqdm +from typing import List, Tuple +from .base import BaseAWQForCausalLM +from awq.utils.fused_utils import fuse_qkv +from awq.modules.fused.block import PhiBlock +from awq.modules.fused.model import PhiModel as AWQPhiModel +from transformers.models.phi.modeling_phi import ( + PhiDecoderLayer as OldPhiDecoderLayer, + PhiForCausalLM as OldPhiForCausalLM, +) +from awq.modules.fused.norm import FasterTransformerRMSNorm + + + +class PhiAWQForCausalLM(BaseAWQForCausalLM): + layer_type = "PhiDecoderLayer" + max_seq_len_key = "max_position_embeddings" + + + @staticmethod + def fuse_layers(model: OldPhiForCausalLM): + fuser = PhiFuser(model) + fuser.fuse_transformer() + + @staticmethod + def get_model_layers(model: OldPhiForCausalLM): + return model.model.layers + + @staticmethod + def get_act_for_scaling(module: OldPhiForCausalLM): + return dict(is_scalable=False) + + @staticmethod + def move_embed(model: OldPhiForCausalLM, device: str): + model.model.embed_tokens = model.model.embed_tokens.to(device) + + @staticmethod + def get_layers_for_scaling(module: OldPhiDecoderLayer, input_feat, module_kwargs): + layers = [] + + #Attention: + + # attention input + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[ + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ], + inp=input_feat["self_attn.q_proj"], + module2inspect=module.self_attn, + kwargs=module_kwargs, + ) + ) + + # Similarly to llama and other models, we skip the output projection + # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 + if module.self_attn.v_proj.weight.shape == module.self_attn.dense.weight.shape: + layers.append( + dict( + prev_op=module.self_attn.v_proj, + layers=[module.self_attn.dense], + inp=input_feat["self_attn.dense"], + ) + ) + + # MLP: + + # linear 1 + layers.append( + dict( + prev_op=module.input_layernorm, + layers=[module.mlp.fc1], + inp=input_feat["mlp.fc1"], + module2inspect=module.mlp, + ) + ) + + # linear 2 + layers.append( + dict( + prev_op=module.mlp.fc1, + layers=[module.mlp.fc2], + inp=input_feat["mlp.fc2"], + ) + ) + + return layers + +class PhiFuser: + def __init__(self, model: OldPhiForCausalLM): + self.model = model + + self.phi_blocks: List[Tuple[str, OldPhiDecoderLayer]] = [ + (name, module) + for name, module in self.model.named_modules() + if "PhiDecoderLayer".lower() in module.__class__.__name__.lower() + ] + + def fuse_transformer(self): + blocks = [] + + module: OldPhiDecoderLayer + for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): + device = next(iter(module.state_dict().values())).device + qkv = fuse_qkv( + module, + module.self_attn.q_proj, + module.self_attn.k_proj, + module.self_attn.v_proj, + ) + norm_1 = FasterTransformerRMSNorm( + module.input_layernorm.weight, module.input_layernorm.eps + ) + blocks.append( + PhiBlock( + hidden_size=self.model.config.hidden_size, + n_heads=self.model.config.num_attention_heads, + n_kv_heads=self.model.config.num_key_value_heads, + qkv_layer=qkv, + dense=module.self_attn.dense, + mlp=module.mlp, + norm_1=norm_1, + dev=device, + max_seq_len=self.model.config.max_position_embeddings, + rope_theta=self.model.config.rope_theta, + rope_scaling=self.model.config.rope_scaling + ) + ) + + self.model.model = AWQPhiModel( + self.model.config.vocab_size, + blocks, + self.model.model.embed_tokens, + self.model.model.final_layernorm, + ) + setattr(self.model.model, "blocks", self.model.model.blocks) diff --git a/awq/modules/fused/block.py b/awq/modules/fused/block.py index faefef3a..5a2d240a 100644 --- a/awq/modules/fused/block.py +++ b/awq/modules/fused/block.py @@ -442,4 +442,72 @@ def forward( h = hidden_states.to(attn_output.device) + attn_output out = h + self.mlp.forward(self.norm_2(h)) + return out, None, past_key_value + +class PhiBlock(nn.Module): + def __init__( + self, + hidden_size, + n_heads, + n_kv_heads, + qkv_layer, + dense, + mlp, + norm_1, + dev, + max_seq_len, + rope_theta=10000, + rope_scaling=None, + partial_rotary_factor=1.0 + ): + super().__init__() + self.n_heads = n_heads + self.n_kv_heads = n_kv_heads + self.head_dim = hidden_size // n_heads + + self.hidden_size = hidden_size + self.norm_1 = norm_1.to(dev) + self.attn = QuantAttentionFused( + self.hidden_size, + self.n_heads, + self.n_kv_heads, + qkv_layer, + dense, + dev=dev, + max_seq_len=max_seq_len, + use_alibi=False, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + partial_rotary_factor=partial_rotary_factor, + head_dim=self.head_dim, + ).to(dev) + + self.mlp = mlp.to(dev) + self.device = dev + + def forward( + self, + hidden_states, + past_key_value, + attn_bias=None, + attention_mask=None, + is_causal=None, + ): + # Implemented as per PhiDecoderLayer's `forward` + + if hidden_states.device.type != self.device.type: + print(f"Warning! hidden states device {hidden_states.device.type} not the same as PhiBlock device {self.device.type}.") + hidden_states = hidden_states.to(self.device) + + norm_out = self.norm_1(hidden_states) + + attn_out, _, past_key_value = self.attn( + hidden_states=norm_out, + past_key_value=past_key_value, + attention_mask=attention_mask + ) + + ff_hidden = self.mlp(norm_out) + + out = attn_out + ff_hidden + hidden_states return out, None, past_key_value \ No newline at end of file diff --git a/awq/modules/fused/model.py b/awq/modules/fused/model.py index d1fe5437..64b37674 100644 --- a/awq/modules/fused/model.py +++ b/awq/modules/fused/model.py @@ -11,6 +11,7 @@ FalconDecoderLayer, LlamaLikeBlock, MixtralBlock, + PhiBlock, Phi3Block, CohereBlock, ) @@ -372,4 +373,64 @@ def forward( past_key_values=None, hidden_states=(), attentions=(), - ) \ No newline at end of file + ) + +class PhiModel(nn.Module): + def __init__(self, vocab_size, blocks, embedding, final_layernorm): + super().__init__() + self.vocab_size = vocab_size + self.embedding = embedding + self.blocks: List[PhiBlock] = nn.ModuleList(blocks) + self.final_layernorm = final_layernorm + self.last_forward_num_tokens = 0 + + @property + def embed_tokens(self): + return self.embedding + + @property + def layers(self): + return self.blocks + + @torch.inference_mode() + def forward( + self, + input_ids: torch.Tensor, + attn_bias=None, + attention_mask=None, + is_causal=None, + *args, + **kwargs, + ): + input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids( + input_ids, self.last_forward_num_tokens + ) + _bsz, seqlen = input_ids.shape + + fused_utils.prepare_cache(self.blocks, seqlen) + + h = self.embedding(input_ids) + + mask = fused_utils.prepare_attention_mask( + seqlen=seqlen, + start_pos=self.blocks[0].attn.start_pos, + device=input_ids.device, + type_as=h, + ) + + for layer in self.blocks: + h, mask = fused_utils.prepare_correct_devices( + layer, + h, + mask, + ) + h, _, _ = layer(h, None, attention_mask=mask, is_causal=is_causal) + h = self.final_layernorm(h) + + return BaseModelOutputWithPast( + last_hidden_state=h, + past_key_values=None, + hidden_states=(), + attentions=(), + ) +