diff --git a/awq/models/base.py b/awq/models/base.py index 8144e572..99a1c42c 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -35,6 +35,7 @@ from awq.models._config import AwqConfig from awq.modules.act import ScaledActivation +from awq.modules.moe import ScaledMixtralSparseMoeBlock from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.exllama import WQLinear_Exllama from awq.modules.exllamav2 import WQLinear_ExllamaV2 @@ -96,11 +97,13 @@ def quantize( split="train", text_column="text", duo_scaling=True, - modules_to_not_convert=None, export_compatible=False, ): self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config) + if hasattr(self, "modules_to_not_convert"): + self.quant_config.modules_to_not_convert = self.modules_to_not_convert + self.quantizer = AwqQuantizer( self, self.model, @@ -112,7 +115,7 @@ def quantize( split, text_column, duo_scaling, - modules_to_not_convert=modules_to_not_convert, + modules_to_not_convert=self.quant_config.modules_to_not_convert, export_compatible=export_compatible, ) self.quantizer.quantize() @@ -402,6 +405,9 @@ def _load_quantized_modules( # Replace activation functions self._scale_activations(self, layer) + # Replace mixture of experts + self._scale_moe(self, layer) + # Replace nn.Linear with WQLinear for name, module in named_linears.items(): if use_exllama: @@ -436,5 +442,19 @@ def _scale_activations(self, layer): ) # scale activation - scaled_act = ScaledActivation(scale_dict["scale_layer"], scale_like) - set_op_by_name(layer, scale_dict["scale_name"], scaled_act) + scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like) + set_op_by_name(layer, scale_dict['scale_name'], scaled_act) + + def _scale_moe(self, layer): + if hasattr(self, "get_moe_for_scaling") and hasattr(layer.block_sparse_moe, "scales"): + scale_dict: dict = self.get_moe_for_scaling(layer) + + if not isinstance(scale_dict['scale_layer'], ScaledMixtralSparseMoeBlock): + param = next(layer.parameters()) + + # get activation scale + scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device) + + # scale moe + scaled_act = ScaledMixtralSparseMoeBlock(scale_dict['scale_layer'], scale_like) + set_op_by_name(layer, scale_dict['scale_name'], scaled_act) diff --git a/awq/models/mixtral.py b/awq/models/mixtral.py index 8ca8c515..7d2512cb 100644 --- a/awq/models/mixtral.py +++ b/awq/models/mixtral.py @@ -6,13 +6,26 @@ from awq.modules.fused.model import MixtralModel from transformers.models.mixtral.modeling_mixtral import ( MixtralDecoderLayer as OldMixtralDecoderLayer, - MixtralForCausalLM as OldMixtralForCausalLM + MixtralForCausalLM as OldMixtralForCausalLM, + MixtralBLockSparseTop2MLP as OldMixtralBLockSparseTop2MLP, ) from awq.modules.fused.norm import FasterTransformerRMSNorm +def _transformers_version_check(): + import transformers + tv = transformers.__version__.split('.') + if len(tv) == 4: + major, minor, patch, dev = tv + else: + major, minor, patch = tv + + if int(major) == 4 and int(minor) < 37: + raise Exception("Mixtral requires a minimum of 4.37.0.dev0: pip install git+https://github.com/huggingface/transformers.git") + class MixtralAWQForCausalLM(BaseAWQForCausalLM): layer_type = "MixtralDecoderLayer" max_new_tokens_key = "max_position_embeddings" + modules_to_not_convert = ["gate"] @staticmethod def fuse_layers(model: OldMixtralForCausalLM): @@ -21,6 +34,7 @@ def fuse_layers(model: OldMixtralForCausalLM): @staticmethod def get_model_layers(model: OldMixtralForCausalLM): + _transformers_version_check() return model.model.layers @staticmethod @@ -29,6 +43,14 @@ def get_act_for_scaling(module): is_scalable=False ) + @staticmethod + def get_moe_for_scaling(module: OldMixtralDecoderLayer): + return dict( + scale_name="block_sparse_moe", + scale_layer=module.block_sparse_moe, + scale_shape=(module.block_sparse_moe.num_experts, module.block_sparse_moe.hidden_dim), + ) + @staticmethod def move_embed(model: OldMixtralForCausalLM, device: str): model.model.embed_tokens = model.model.embed_tokens.to(device) @@ -53,19 +75,18 @@ def get_layers_for_scaling(module: OldMixtralDecoderLayer, input_feat, module_kw layers=[module.self_attn.o_proj], inp=input_feat['self_attn.o_proj'], )) - - # linear in + + # NOTE: Scaled in awq.quantize.scale.scale_moe_experts, awq.modules.moe.ScaledMixtralSparseMoeBlock + # Experts: Not a linear layer, special handling is introduced in awq.quantize.quantizer layers.append(dict( - prev_op=module.post_attention_layernorm, - layers=[ - w for expert in module.block_sparse_moe.experts - for w in [expert.w1, expert.w3] - ], + prev_op=module.block_sparse_moe, + layers=module.block_sparse_moe.experts, inp=input_feat['block_sparse_moe'], module2inspect=module.block_sparse_moe, )) - # linear out + # scaling w2 + expert: OldMixtralBLockSparseTop2MLP for i, expert in enumerate(module.block_sparse_moe.experts): layers.append(dict( prev_op=expert.w3, diff --git a/awq/modules/moe.py b/awq/modules/moe.py new file mode 100644 index 00000000..03e586f8 --- /dev/null +++ b/awq/modules/moe.py @@ -0,0 +1,75 @@ +import torch +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +class ScaledMixtralSparseMoeBlock(torch.nn.Module): + """ + This is a modified sparse MoE that scales experts individually. + + Modified version of: + transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock + """ + + def __init__(self, prev_op: MixtralSparseMoeBlock, scales: torch.Tensor): + super().__init__() + self.hidden_dim = prev_op.hidden_dim + self.ffn_dim = prev_op.ffn_dim + self.num_experts = prev_op.num_experts + self.top_k = prev_op.top_k + + # gating + self.gate = prev_op.gate + + # experts + self.experts = prev_op.experts + + # [expert_num, hidden_dim] + self.scales = torch.nn.Parameter(scales.data) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + + ### NOTE: We scale weights here, modified from original MoE. + current_state = hidden_states[None, top_x_list].reshape( + -1, hidden_dim) / self.scales[expert_idx] # <-- scales + + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index a4263f70..00ce9320 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -4,9 +4,10 @@ import functools import torch.nn as nn from tqdm import tqdm -from typing import Dict, List from collections import defaultdict +from typing import Dict, List, Union from awq.utils.utils import clear_memory +from transformers import PreTrainedModel from awq.utils.calib_data import get_calib_dataset from awq.quantize.scale import apply_scale, apply_clip from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV @@ -17,6 +18,7 @@ set_op_by_name, exclude_layers_to_not_quantize ) +from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP class AwqQuantizer: @@ -24,7 +26,7 @@ def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=None, export_compatible=False) -> None: self.awq_model = awq_model - self.model = model + self.model: PreTrainedModel = model self.tokenizer = tokenizer self.w_bit = w_bit self.group_size = group_size @@ -162,7 +164,8 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]): clear_memory() @torch.no_grad() - def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torch.Tensor, module2inspect=None, kwargs={}): + def _search_best_scale(self, module, prev_op, layers: Union[List[nn.Linear], List[MixtralBLockSparseTop2MLP]], + inp: torch.Tensor, module2inspect=None, kwargs={}): if module2inspect is None: assert len(layers) == 1 module2inspect = layers[0] @@ -174,13 +177,23 @@ def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torc inp = inp.to(next(module2inspect.parameters()).device) # [STEP 1]: Compute maximum of weight - weight = torch.cat([_m.weight for _m in layers], dim=0) - org_shape = weight.shape - weight = weight.view(-1, self.group_size) - w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) - w_scale = w_scale.view(org_shape) - w_max = w_scale.mean(0) - clear_memory(weight) + def _get_w_max(layer_weights): + weight = torch.cat([_m.weight for _m in layer_weights], dim=0) + org_shape = weight.shape + weight = weight.view(-1, self.group_size) + w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) + w_scale = w_scale.view(org_shape) + w_max = w_scale.mean(0) + clear_memory(weight) + + return w_max + + if type(layers[0]) == nn.Linear: + w_max = _get_w_max(layers) + else: + # FIXME: Specific to Mixtral + weights = [[expert.w1, expert.w3] for expert in layers] + w_max = [_get_w_max(weight) for weight in weights] # [STEP 2]: Compute maximum of x x_max = inp.abs().view(-1, inp.shape[-1]).mean(0) @@ -194,12 +207,23 @@ def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torc fp16_output = fp16_output[0] # [STEP 4]: Compute loss - best_scales = self._compute_best_scale( - inp, w_max, x_max, module2inspect, - layers, fp16_output, module_kwargs - ) + if type(layers[0]) == nn.Linear: + best_scales = self._compute_best_scale( + inp, w_max, x_max, module2inspect, + layers, fp16_output, module_kwargs + ) + else: + best_scales = [ + self._compute_best_scale( + inp, w_max[i], x_max, module2inspect, + experts, fp16_output, module_kwargs + ) for i, experts in enumerate(weights) + ] + + prev_op_name = get_op_name(module, prev_op) + layer_names = tuple([get_op_name(module, m) for m in layers]) - return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales) + return (prev_op_name, layer_names, best_scales) def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: List[nn.Linear], fp16_output, kwargs={}): diff --git a/awq/quantize/scale.py b/awq/quantize/scale.py index 9b7f5feb..be089b9f 100644 --- a/awq/quantize/scale.py +++ b/awq/quantize/scale.py @@ -2,13 +2,19 @@ import torch.nn as nn from typing import Tuple, List from awq.modules.act import ScaledActivation +from awq.modules.moe import ScaledMixtralSparseMoeBlock from awq.utils.module import get_op_by_name, set_op_by_name from transformers.models.bloom.modeling_bloom import BloomGelu from transformers.models.llama.modeling_llama import LlamaRMSNorm from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation +from transformers.models.mixtral.modeling_mixtral import ( + MixtralSparseMoeBlock, + MixtralBLockSparseTop2MLP, +) allowed_norms = [nn.LayerNorm, LlamaRMSNorm] allowed_act_fns = [nn.GELU, BloomGelu, NewGELUActivation, PytorchGELUTanh, GELUActivation] +allowed_moe = [MixtralSparseMoeBlock] @torch.no_grad() def apply_clip(module, clip_list: Tuple[str, torch.Tensor]): @@ -31,7 +37,12 @@ def apply_scale(module, scales_list, input_feat_dict=None): prev_op.cuda() for layer in layers: layer.cuda() - scales.cuda() + + if type(scales) == list: + for scale in scales: + scale.cuda() + else: + scales.cuda() if isinstance(prev_op, nn.Linear) and type(layers) == list and isinstance(layers[0], nn.Linear): scale_fc_fcs(prev_op, layers, scales) @@ -48,6 +59,15 @@ def apply_scale(module, scales_list, input_feat_dict=None): new_module = ScaledActivation(prev_op, scales) set_op_by_name(module, prev_op_name, new_module) scale_gelu_fc(prev_op, layers[0], scales) + + elif any(isinstance(prev_op,t) for t in allowed_moe): + # scales: [best_scale_expert_0, best_scale_expert_1, ...] -> [expert_index, scales] + scales = torch.stack(scales).cuda() + + # apply scales + new_module = ScaledMixtralSparseMoeBlock(prev_op, scales) + set_op_by_name(module, prev_op_name, new_module) + scale_moe_experts(prev_op, layers, scales) else: raise NotImplementedError( @@ -64,7 +84,12 @@ def apply_scale(module, scales_list, input_feat_dict=None): prev_op.cpu() for layer in layers: layer.cpu() - scales.cpu() + + if type(scales) == list: + for scale in scales: + scale.cpu() + else: + scales.cpu() @torch.no_grad() def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor): @@ -133,3 +158,18 @@ def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor): for p in fc.parameters(): assert torch.isnan(p).sum() == 0 + +@torch.no_grad() +def scale_moe_experts(moe: MixtralSparseMoeBlock, experts: List[MixtralBLockSparseTop2MLP], scales: torch.Tensor): + assert any(isinstance(moe, allowed_module) for allowed_module in allowed_moe) + assert all(isinstance(m, MixtralBLockSparseTop2MLP) for m in experts) + + # One scale for each expert, applied to w1 and w3 only + # Not applied to w2 because it does not take hidden_states as input + for i, expert in enumerate(experts): + expert.w1.weight.mul_(scales[i].view(1, -1)) + expert.w3.weight.mul_(scales[i].view(1, -1)) + + for expert in experts: + for p in expert.parameters(): + assert torch.isnan(p).sum() == 0 diff --git a/examples/mixtral_quant.py b/examples/mixtral_quant.py index fea92f60..4cd281f9 100644 --- a/examples/mixtral_quant.py +++ b/examples/mixtral_quant.py @@ -3,11 +3,7 @@ model_path = 'mistralai/Mixtral-8x7B-Instruct-v0.1' quant_path = 'mixtral-instruct-awq' -modules_to_not_convert = ["gate"] -quant_config = { - "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM", - "modules_to_not_convert": modules_to_not_convert -} +quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} # Load model # NOTE: pass safetensors=True to load safetensors @@ -19,8 +15,7 @@ # Quantize model.quantize( tokenizer, - quant_config=quant_config, - modules_to_not_convert=modules_to_not_convert + quant_config=quant_config ) # Save quantized model