From 4d838755901ad71695a2042c8ec96cdecacfdd4f Mon Sep 17 00:00:00 2001 From: Casper Date: Wed, 3 Jan 2024 13:04:37 +0100 Subject: [PATCH 1/9] Mixtral Scaling [WIP] --- awq/models/base.py | 18 +++++++++ awq/models/mixtral.py | 8 ++++ awq/modules/moe.py | 90 +++++++++++++++++++++++++++++++++++++++++++ awq/quantize/scale.py | 20 ++++++++++ 4 files changed, 136 insertions(+) create mode 100644 awq/modules/moe.py diff --git a/awq/models/base.py b/awq/models/base.py index 756e8a2f..59904e9b 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -33,6 +33,7 @@ from awq.models._config import AwqConfig from awq.modules.act import ScaledActivation +from awq.modules.moe import ScaledMixtralSparseMoeBlock from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.quantize.quantizer import AwqQuantizer from awq.utils.module import get_named_linears, set_op_by_name @@ -276,6 +277,9 @@ def _load_quantized_modules(self, model, quant_config, version): # Replace activation functions self._scale_activations(self, layer) + # Replace mixture of experts + self._scale_moe(self, layer) + # Replace nn.Linear with WQLinear for name, module in named_linears.items(): if version == 'GEMM': @@ -309,3 +313,17 @@ def _scale_activations(self, layer): # scale activation scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like) set_op_by_name(layer, scale_dict['scale_name'], scaled_act) + + def _scale_moe(self, layer): + if hasattr(self, "get_moe_for_scaling"): + scale_dict: dict = self.get_moe_for_scaling() + + if not isinstance(scale_dict['scale_layer'], ScaledMixtralSparseMoeBlock): + param = next(layer.parameters()) + + # get activation scale + scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device) + + # scale moe + scaled_act = ScaledMixtralSparseMoeBlock(scale_dict['scale_layer'], scale_like) + set_op_by_name(layer, scale_dict['scale_name'], scaled_act) \ No newline at end of file diff --git a/awq/models/mixtral.py b/awq/models/mixtral.py index 87b65cec..35bc7a13 100644 --- a/awq/models/mixtral.py +++ b/awq/models/mixtral.py @@ -30,6 +30,14 @@ def get_act_for_scaling(module): is_scalable=False ) + @staticmethod + def get_moe_for_scaling(module: OldMixtralDecoderLayer): + return dict( + scale_name="block_sparse_moe", + scale_layer=module.block_sparse_moe, + scale_shape=(module.block_sparse_moe.num_experts, module.block_sparse_moe.hidden_dim), + ) + @staticmethod def move_embed(model: OldMixtralForCausalLM, device: str): model.model.embed_tokens = model.model.embed_tokens.to(device) diff --git a/awq/modules/moe.py b/awq/modules/moe.py new file mode 100644 index 00000000..0fc90b82 --- /dev/null +++ b/awq/modules/moe.py @@ -0,0 +1,90 @@ +import torch + +from transformers.models.mixtral.configuration_mixtral import MixtralConfig +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + +class ScaledMixtralSparseMoeBlock(torch.nn.Module): + """ + This is a modified sparse MoE that scales experts individually. + + Modified version of: + transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock + """ + + def __init__(self, prev_op: MixtralSparseMoeBlock, scales: torch.Tensor): + super().__init__() + config: MixtralConfig = prev_op.config + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = torch.nn.Linear(self.hidden_dim, self.num_experts, bias=False) + self.experts = torch.nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + # [expert_num, hidden_dim] + self.scales = torch.nn.Parameter(scales.data) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = torch.nn.functional.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + # in torch it is faster to index using lists than torch tensors + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + ### MODIFICATION START + current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) / self.scales[expert_idx] + ### MODIFICATION END + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + return final_hidden_states, router_logits + +class MixtralBLockSparseTop2MLP(torch.nn.Module): + def __init__(self, config: MixtralConfig): + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = torch.nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.w2 = torch.nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) + self.w3 = torch.nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) + self.act_fn = torch.nn.SiLU + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states \ No newline at end of file diff --git a/awq/quantize/scale.py b/awq/quantize/scale.py index 9b7f5feb..688b4de2 100644 --- a/awq/quantize/scale.py +++ b/awq/quantize/scale.py @@ -2,13 +2,16 @@ import torch.nn as nn from typing import Tuple, List from awq.modules.act import ScaledActivation +from awq.modules.moe import ScaledMixtralSparseMoeBlock from awq.utils.module import get_op_by_name, set_op_by_name from transformers.models.bloom.modeling_bloom import BloomGelu from transformers.models.llama.modeling_llama import LlamaRMSNorm +from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation allowed_norms = [nn.LayerNorm, LlamaRMSNorm] allowed_act_fns = [nn.GELU, BloomGelu, NewGELUActivation, PytorchGELUTanh, GELUActivation] +allowed_moes = [MixtralSparseMoeBlock] @torch.no_grad() def apply_clip(module, clip_list: Tuple[str, torch.Tensor]): @@ -48,6 +51,11 @@ def apply_scale(module, scales_list, input_feat_dict=None): new_module = ScaledActivation(prev_op, scales) set_op_by_name(module, prev_op_name, new_module) scale_gelu_fc(prev_op, layers[0], scales) + + elif any(isinstance(prev_op,t) for t in allowed_moes): + new_module = ScaledMixtralSparseMoeBlock(prev_op, scales) + set_op_by_name(module, prev_op_name, new_module) + scale_moe_experts(prev_op, layers, scales) else: raise NotImplementedError( @@ -133,3 +141,15 @@ def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor): for p in fc.parameters(): assert torch.isnan(p).sum() == 0 + +@torch.no_grad() +def scale_moe_experts(moe, experts: List[nn.Linear], scales: torch.Tensor): + assert any(isinstance(moe,m) for m in allowed_moes) + assert all(isinstance(m, nn.Linear) for m in experts) + + for expert in experts: + expert.weight.mul_(scales.view(1, -1)) + + for expert in experts: + for p in expert.parameters(): + assert torch.isnan(p).sum() == 0 \ No newline at end of file From ff3772061d2dacbe3c57374d1e89955c39eb2226 Mon Sep 17 00:00:00 2001 From: Casper Date: Wed, 3 Jan 2024 17:16:32 +0100 Subject: [PATCH 2/9] Mixtral individual expert scaling --- awq/models/mixtral.py | 18 ++++++------- awq/modules/moe.py | 32 ++++++++++++++--------- awq/quantize/quantizer.py | 53 ++++++++++++++++++++++++++++----------- awq/quantize/scale.py | 30 ++++++++++++++-------- 4 files changed, 88 insertions(+), 45 deletions(-) diff --git a/awq/models/mixtral.py b/awq/models/mixtral.py index 35bc7a13..9a56680a 100644 --- a/awq/models/mixtral.py +++ b/awq/models/mixtral.py @@ -6,7 +6,8 @@ from awq.modules.fused.model import MixtralModel from transformers.models.mixtral.modeling_mixtral import ( MixtralDecoderLayer as OldMixtralDecoderLayer, - MixtralForCausalLM as OldMixtralForCausalLM + MixtralForCausalLM as OldMixtralForCausalLM, + MixtralBLockSparseTop2MLP as OldMixtralBLockSparseTop2MLP, ) from awq.modules.fused.mlp import QuantFusedMLP from awq.modules.fused.norm import FasterTransformerRMSNorm @@ -62,19 +63,18 @@ def get_layers_for_scaling(module: OldMixtralDecoderLayer, input_feat, module_kw layers=[module.self_attn.o_proj], inp=input_feat['self_attn.o_proj'], )) - - # linear in + + # NOTE: Scaled in awq.quantize.scale.scale_moe_experts, awq.modules.moe.ScaledMixtralSparseMoeBlock + # Experts: Not a linear layer, special handling is introduced in awq.quantize.quantizer layers.append(dict( - prev_op=module.post_attention_layernorm, - layers=[ - w for expert in module.block_sparse_moe.experts - for w in [expert.w1, expert.w3] - ], + prev_op=module.block_sparse_moe, + layers=module.block_sparse_moe.experts, inp=input_feat['block_sparse_moe'], module2inspect=module.block_sparse_moe, )) - # linear out + # scaling w2 + expert: OldMixtralBLockSparseTop2MLP for i, expert in enumerate(module.block_sparse_moe.experts): layers.append(dict( prev_op=expert.w3, diff --git a/awq/modules/moe.py b/awq/modules/moe.py index 0fc90b82..c9abc805 100644 --- a/awq/modules/moe.py +++ b/awq/modules/moe.py @@ -13,15 +13,19 @@ class ScaledMixtralSparseMoeBlock(torch.nn.Module): def __init__(self, prev_op: MixtralSparseMoeBlock, scales: torch.Tensor): super().__init__() - config: MixtralConfig = prev_op.config - self.hidden_dim = config.hidden_size - self.ffn_dim = config.intermediate_size - self.num_experts = config.num_local_experts - self.top_k = config.num_experts_per_tok + self.hidden_dim = prev_op.hidden_dim + self.ffn_dim = prev_op.ffn_dim + self.num_experts = prev_op.num_experts + self.top_k = prev_op.top_k # gating self.gate = torch.nn.Linear(self.hidden_dim, self.num_experts, bias=False) - self.experts = torch.nn.ModuleList([MixtralBLockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + # experts + self.experts = torch.nn.ModuleList([ + MixtralBLockSparseTop2MLP(self.ffn_dim, self.hidden_dim) + for _ in range(self.num_experts) + ]) # [expert_num, hidden_dim] self.scales = torch.nn.Parameter(scales.data) @@ -62,9 +66,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Index the correct hidden states and compute the expert hidden state for # the current expert. We need to make sure to multiply the output hidden # states by `routing_weights` on the corresponding tokens (top-1 and top-2) - ### MODIFICATION START + + ### NOTE: We scale weights here, modified from original MoE. current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) / self.scales[expert_idx] - ### MODIFICATION END current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] # However `index_add_` only support torch tensors for indexing so we'll use @@ -74,17 +78,21 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: return final_hidden_states, router_logits class MixtralBLockSparseTop2MLP(torch.nn.Module): - def __init__(self, config: MixtralConfig): + def __init__(self, ffn_dim, hidden_dim): super().__init__() - self.ffn_dim = config.intermediate_size - self.hidden_dim = config.hidden_size + self.ffn_dim = ffn_dim + self.hidden_dim = hidden_dim self.w1 = torch.nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.w2 = torch.nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) self.w3 = torch.nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) self.act_fn = torch.nn.SiLU - def forward(self, hidden_states): + def forward(self, hidden_states, routing_weights=None): current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) current_hidden_states = self.w2(current_hidden_states) + + if routing_weights is not None: + current_hidden_states = current_hidden_states * routing_weights + return current_hidden_states \ No newline at end of file diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index feb563eb..7495786e 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -4,7 +4,7 @@ import functools import torch.nn as nn from tqdm import tqdm -from typing import Dict, List +from typing import Dict, List, Union from collections import defaultdict from awq.utils.utils import clear_memory from awq.utils.calib_data import get_calib_dataset @@ -17,6 +17,7 @@ set_op_by_name, exclude_layers_to_not_quantize ) +from transformers.models.mixtral.modeling_mixtral import MixtralBLockSparseTop2MLP class AwqQuantizer: @@ -145,7 +146,8 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]): clear_memory() @torch.no_grad() - def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torch.Tensor, module2inspect=None, kwargs={}): + def _search_best_scale(self, module, prev_op, layers: Union[List[nn.Linear], List[MixtralBLockSparseTop2MLP]], + inp: torch.Tensor, module2inspect=None, kwargs={}): if module2inspect is None: assert len(layers) == 1 module2inspect = layers[0] @@ -157,13 +159,25 @@ def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torc inp = inp.to(next(module2inspect.parameters()).device) # [STEP 1]: Compute maximum of weight - weight = torch.cat([_m.weight for _m in layers], dim=0) - org_shape = weight.shape - weight = weight.view(-1, self.group_size) - w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) - w_scale = w_scale.view(org_shape) - w_max = w_scale.mean(0) - clear_memory(weight) + if all(isinstance(m, MixtralBLockSparseTop2MLP) for m in layers): + w_max = [] + for expert in layers: + weight = torch.cat([expert.w1.weight, expert.w3.weight], dim=0) + org_shape = weight.shape + weight = weight.view(-1, self.group_size) + w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) + w_scale = w_scale.view(org_shape) + expert_w_max = w_scale.mean(0) + w_max.append(expert_w_max) + clear_memory(weight) + else: + weight = torch.cat([_m.weight for _m in layers], dim=0) + org_shape = weight.shape + weight = weight.view(-1, self.group_size) + w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) + w_scale = w_scale.view(org_shape) + w_max = w_scale.mean(0) + clear_memory(weight) # [STEP 2]: Compute maximum of x x_max = inp.abs().view(-1, inp.shape[-1]).mean(0) @@ -177,12 +191,23 @@ def _search_best_scale(self, module, prev_op, layers: List[nn.Linear], inp: torc fp16_output = fp16_output[0] # [STEP 4]: Compute loss - best_scales = self._compute_best_scale( - inp, w_max, x_max, module2inspect, - layers, fp16_output, module_kwargs - ) + if all(isinstance(m, MixtralBLockSparseTop2MLP) for m in layers): + best_scales = [ + self._compute_best_scale( + inp, w_max[i], x_max, module2inspect, + [expert.w1, expert.w3], fp16_output, module_kwargs + ) for i, expert in enumerate(layers) + ] + else: + best_scales = self._compute_best_scale( + inp, w_max, x_max, module2inspect, + layers, fp16_output, module_kwargs + ) + + prev_op_name = get_op_name(module, prev_op) + layer_names = tuple([get_op_name(module, m) for m in layers]) - return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales) + return (prev_op_name, layer_names, best_scales) def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: List[nn.Linear], fp16_output, kwargs={}): diff --git a/awq/quantize/scale.py b/awq/quantize/scale.py index 688b4de2..af27d45f 100644 --- a/awq/quantize/scale.py +++ b/awq/quantize/scale.py @@ -6,12 +6,15 @@ from awq.utils.module import get_op_by_name, set_op_by_name from transformers.models.bloom.modeling_bloom import BloomGelu from transformers.models.llama.modeling_llama import LlamaRMSNorm -from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation +from transformers.models.mixtral.modeling_mixtral import ( + MixtralSparseMoeBlock, + MixtralBLockSparseTop2MLP, +) allowed_norms = [nn.LayerNorm, LlamaRMSNorm] allowed_act_fns = [nn.GELU, BloomGelu, NewGELUActivation, PytorchGELUTanh, GELUActivation] -allowed_moes = [MixtralSparseMoeBlock] +allowed_moe = [MixtralSparseMoeBlock] @torch.no_grad() def apply_clip(module, clip_list: Tuple[str, torch.Tensor]): @@ -52,7 +55,11 @@ def apply_scale(module, scales_list, input_feat_dict=None): set_op_by_name(module, prev_op_name, new_module) scale_gelu_fc(prev_op, layers[0], scales) - elif any(isinstance(prev_op,t) for t in allowed_moes): + elif any(isinstance(prev_op,t) for t in allowed_moe): + # scales: [best_scale_expert_0, best_scale_expert_1, ...] -> [expert_index, scales] + scales = torch.stack(scales) + + # apply scales new_module = ScaledMixtralSparseMoeBlock(prev_op, scales) set_op_by_name(module, prev_op_name, new_module) scale_moe_experts(prev_op, layers, scales) @@ -143,13 +150,16 @@ def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor): assert torch.isnan(p).sum() == 0 @torch.no_grad() -def scale_moe_experts(moe, experts: List[nn.Linear], scales: torch.Tensor): - assert any(isinstance(moe,m) for m in allowed_moes) - assert all(isinstance(m, nn.Linear) for m in experts) - - for expert in experts: - expert.weight.mul_(scales.view(1, -1)) +def scale_moe_experts(moe: MixtralSparseMoeBlock, experts: List[MixtralBLockSparseTop2MLP], scales: torch.Tensor): + assert any(isinstance(moe, allowed_module) for allowed_module in allowed_moe) + assert all(isinstance(m, MixtralBLockSparseTop2MLP) for m in experts) + + # One scale for each expert, applied to w1 and w3 only + # Not applied to w2 because it does not take hidden_states as input + for i, expert in enumerate(experts): + expert.w1.weight.mul_(scales[i].view(1, -1)) + expert.w3.weight.mul_(scales[i].view(1, -1)) for expert in experts: for p in expert.parameters(): - assert torch.isnan(p).sum() == 0 \ No newline at end of file + assert torch.isnan(p).sum() == 0 From 0a6a6d025346304687ba69872168ef786fcccd95 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Thu, 4 Jan 2024 14:15:10 +0000 Subject: [PATCH 3/9] Various fixes --- awq/models/base.py | 2 +- awq/modules/moe.py | 32 ++++++++------------------------ awq/quantize/scale.py | 16 +++++++++++++--- 3 files changed, 22 insertions(+), 28 deletions(-) diff --git a/awq/models/base.py b/awq/models/base.py index 59904e9b..120cf95f 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -316,7 +316,7 @@ def _scale_activations(self, layer): def _scale_moe(self, layer): if hasattr(self, "get_moe_for_scaling"): - scale_dict: dict = self.get_moe_for_scaling() + scale_dict: dict = self.get_moe_for_scaling(layer) if not isinstance(scale_dict['scale_layer'], ScaledMixtralSparseMoeBlock): param = next(layer.parameters()) diff --git a/awq/modules/moe.py b/awq/modules/moe.py index c9abc805..df861355 100644 --- a/awq/modules/moe.py +++ b/awq/modules/moe.py @@ -19,13 +19,12 @@ def __init__(self, prev_op: MixtralSparseMoeBlock, scales: torch.Tensor): self.top_k = prev_op.top_k # gating - self.gate = torch.nn.Linear(self.hidden_dim, self.num_experts, bias=False) + self.gate = prev_op.gate # experts - self.experts = torch.nn.ModuleList([ - MixtralBLockSparseTop2MLP(self.ffn_dim, self.hidden_dim) - for _ in range(self.num_experts) - ]) + self.experts = prev_op.experts + for expert in self.experts: + expert.forward = expert_forward # [expert_num, hidden_dim] self.scales = torch.nn.Parameter(scales.data) @@ -77,22 +76,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits -class MixtralBLockSparseTop2MLP(torch.nn.Module): - def __init__(self, ffn_dim, hidden_dim): - super().__init__() - self.ffn_dim = ffn_dim - self.hidden_dim = hidden_dim - - self.w1 = torch.nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.w2 = torch.nn.Linear(self.ffn_dim, self.hidden_dim, bias=False) - self.w3 = torch.nn.Linear(self.hidden_dim, self.ffn_dim, bias=False) - self.act_fn = torch.nn.SiLU - - def forward(self, hidden_states, routing_weights=None): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - current_hidden_states = self.w2(current_hidden_states) - - if routing_weights is not None: - current_hidden_states = current_hidden_states * routing_weights - - return current_hidden_states \ No newline at end of file +def expert_forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states diff --git a/awq/quantize/scale.py b/awq/quantize/scale.py index af27d45f..be089b9f 100644 --- a/awq/quantize/scale.py +++ b/awq/quantize/scale.py @@ -37,7 +37,12 @@ def apply_scale(module, scales_list, input_feat_dict=None): prev_op.cuda() for layer in layers: layer.cuda() - scales.cuda() + + if type(scales) == list: + for scale in scales: + scale.cuda() + else: + scales.cuda() if isinstance(prev_op, nn.Linear) and type(layers) == list and isinstance(layers[0], nn.Linear): scale_fc_fcs(prev_op, layers, scales) @@ -57,7 +62,7 @@ def apply_scale(module, scales_list, input_feat_dict=None): elif any(isinstance(prev_op,t) for t in allowed_moe): # scales: [best_scale_expert_0, best_scale_expert_1, ...] -> [expert_index, scales] - scales = torch.stack(scales) + scales = torch.stack(scales).cuda() # apply scales new_module = ScaledMixtralSparseMoeBlock(prev_op, scales) @@ -79,7 +84,12 @@ def apply_scale(module, scales_list, input_feat_dict=None): prev_op.cpu() for layer in layers: layer.cpu() - scales.cpu() + + if type(scales) == list: + for scale in scales: + scale.cpu() + else: + scales.cpu() @torch.no_grad() def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor): From aec278e300a8d199ceaed9c4af86d1438f0ad41b Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Thu, 4 Jan 2024 16:03:47 +0000 Subject: [PATCH 4/9] Refactor quantization --- awq/models/mixtral.py | 10 ++++++++++ awq/modules/moe.py | 2 +- awq/quantize/quantizer.py | 40 +++++++++++++++++++-------------------- 3 files changed, 30 insertions(+), 22 deletions(-) diff --git a/awq/models/mixtral.py b/awq/models/mixtral.py index 9a56680a..fbfd81bb 100644 --- a/awq/models/mixtral.py +++ b/awq/models/mixtral.py @@ -63,6 +63,16 @@ def get_layers_for_scaling(module: OldMixtralDecoderLayer, input_feat, module_kw layers=[module.self_attn.o_proj], inp=input_feat['self_attn.o_proj'], )) + + layers.append(dict( + prev_op=module.post_attention_layernorm, + layers=[ + w for expert in module.block_sparse_moe.experts + for w in [expert.w1, expert.w3] + ], + inp=input_feat['block_sparse_moe'], + module2inspect=module.block_sparse_moe, + )) # NOTE: Scaled in awq.quantize.scale.scale_moe_experts, awq.modules.moe.ScaledMixtralSparseMoeBlock # Experts: Not a linear layer, special handling is introduced in awq.quantize.quantizer diff --git a/awq/modules/moe.py b/awq/modules/moe.py index df861355..48eb6c77 100644 --- a/awq/modules/moe.py +++ b/awq/modules/moe.py @@ -68,7 +68,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: ### NOTE: We scale weights here, modified from original MoE. current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) / self.scales[expert_idx] - current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] + current_hidden_states = expert_layer(expert_layer, current_state) * routing_weights[top_x_list, idx_list, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index 7495786e..af0f18c6 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -159,19 +159,8 @@ def _search_best_scale(self, module, prev_op, layers: Union[List[nn.Linear], Lis inp = inp.to(next(module2inspect.parameters()).device) # [STEP 1]: Compute maximum of weight - if all(isinstance(m, MixtralBLockSparseTop2MLP) for m in layers): - w_max = [] - for expert in layers: - weight = torch.cat([expert.w1.weight, expert.w3.weight], dim=0) - org_shape = weight.shape - weight = weight.view(-1, self.group_size) - w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) - w_scale = w_scale.view(org_shape) - expert_w_max = w_scale.mean(0) - w_max.append(expert_w_max) - clear_memory(weight) - else: - weight = torch.cat([_m.weight for _m in layers], dim=0) + def _get_w_max(layer_weights): + weight = torch.cat([_m.weight for _m in layer_weights], dim=0) org_shape = weight.shape weight = weight.view(-1, self.group_size) w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) @@ -179,6 +168,15 @@ def _search_best_scale(self, module, prev_op, layers: Union[List[nn.Linear], Lis w_max = w_scale.mean(0) clear_memory(weight) + return w_max + + if type(layers[0]) == nn.Linear: + w_max = _get_w_max(layers) + else: + # FIXME: Specific to Mixtral + weights = [[expert.w1, expert.w3] for expert in layers] + w_max = [_get_w_max(weight) for weight in weights] + # [STEP 2]: Compute maximum of x x_max = inp.abs().view(-1, inp.shape[-1]).mean(0) @@ -191,18 +189,18 @@ def _search_best_scale(self, module, prev_op, layers: Union[List[nn.Linear], Lis fp16_output = fp16_output[0] # [STEP 4]: Compute loss - if all(isinstance(m, MixtralBLockSparseTop2MLP) for m in layers): - best_scales = [ - self._compute_best_scale( - inp, w_max[i], x_max, module2inspect, - [expert.w1, expert.w3], fp16_output, module_kwargs - ) for i, expert in enumerate(layers) - ] - else: + if type(layers[0]) == nn.Linear: best_scales = self._compute_best_scale( inp, w_max, x_max, module2inspect, layers, fp16_output, module_kwargs ) + else: + best_scales = [ + self._compute_best_scale( + inp, w_max[i], x_max, module2inspect, + experts, fp16_output, module_kwargs + ) for i, experts in enumerate(weights) + ] prev_op_name = get_op_name(module, prev_op) layer_names = tuple([get_op_name(module, m) for m in layers]) From 3a581a576504333b41f7f3e2f85e9e3db8f45ff4 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Thu, 4 Jan 2024 20:21:04 +0000 Subject: [PATCH 5/9] Minor fixes and type hinting --- awq/models/mixtral.py | 12 ++++++++++++ awq/quantize/quantizer.py | 5 +++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/awq/models/mixtral.py b/awq/models/mixtral.py index fbfd81bb..974985c1 100644 --- a/awq/models/mixtral.py +++ b/awq/models/mixtral.py @@ -12,6 +12,17 @@ from awq.modules.fused.mlp import QuantFusedMLP from awq.modules.fused.norm import FasterTransformerRMSNorm +def _transformers_version_check(): + import transformers + tv = transformers.__version__.split('.') + if len(tv) == 4: + major, minor, patch, dev = tv + else: + major, minor, patch = tv + + if int(major) >= 4 and int(minor) < 37: + raise Exception("Mixtral requires a minimum of 4.37.0.dev0: pip install git+https://github.com/huggingface/transformers.git") + class MixtralAWQForCausalLM(BaseAWQForCausalLM): layer_type = "MixtralDecoderLayer" max_new_tokens_key = "max_position_embeddings" @@ -23,6 +34,7 @@ def fuse_layers(model: OldMixtralForCausalLM): @staticmethod def get_model_layers(model: OldMixtralForCausalLM): + _transformers_version_check() return model.model.layers @staticmethod diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index af0f18c6..e14024ff 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -4,9 +4,10 @@ import functools import torch.nn as nn from tqdm import tqdm -from typing import Dict, List, Union from collections import defaultdict +from typing import Dict, List, Union from awq.utils.utils import clear_memory +from transformers import PreTrainedModel from awq.utils.calib_data import get_calib_dataset from awq.quantize.scale import apply_scale, apply_clip from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV @@ -24,7 +25,7 @@ class AwqQuantizer: def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=None) -> None: self.awq_model = awq_model - self.model = model + self.model: PreTrainedModel = model self.tokenizer = tokenizer self.w_bit = w_bit self.group_size = group_size From 43f21c95a40d7a779572b1279656adf092e93a18 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Tue, 9 Jan 2024 21:44:48 +0000 Subject: [PATCH 6/9] Lower perplexity --- awq/models/mixtral.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/awq/models/mixtral.py b/awq/models/mixtral.py index 974985c1..a3720789 100644 --- a/awq/models/mixtral.py +++ b/awq/models/mixtral.py @@ -75,16 +75,6 @@ def get_layers_for_scaling(module: OldMixtralDecoderLayer, input_feat, module_kw layers=[module.self_attn.o_proj], inp=input_feat['self_attn.o_proj'], )) - - layers.append(dict( - prev_op=module.post_attention_layernorm, - layers=[ - w for expert in module.block_sparse_moe.experts - for w in [expert.w1, expert.w3] - ], - inp=input_feat['block_sparse_moe'], - module2inspect=module.block_sparse_moe, - )) # NOTE: Scaled in awq.quantize.scale.scale_moe_experts, awq.modules.moe.ScaledMixtralSparseMoeBlock # Experts: Not a linear layer, special handling is introduced in awq.quantize.quantizer From 112de418d5f09bcc5bd8441d0822ffa2555c02fb Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Tue, 9 Jan 2024 21:48:31 +0000 Subject: [PATCH 7/9] Minor fixes, code cleaning --- awq/models/mixtral.py | 2 +- awq/modules/moe.py | 15 ++++----------- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/awq/models/mixtral.py b/awq/models/mixtral.py index a3720789..67364083 100644 --- a/awq/models/mixtral.py +++ b/awq/models/mixtral.py @@ -20,7 +20,7 @@ def _transformers_version_check(): else: major, minor, patch = tv - if int(major) >= 4 and int(minor) < 37: + if int(major) == 4 and int(minor) < 37: raise Exception("Mixtral requires a minimum of 4.37.0.dev0: pip install git+https://github.com/huggingface/transformers.git") class MixtralAWQForCausalLM(BaseAWQForCausalLM): diff --git a/awq/modules/moe.py b/awq/modules/moe.py index 48eb6c77..03e586f8 100644 --- a/awq/modules/moe.py +++ b/awq/modules/moe.py @@ -1,6 +1,4 @@ import torch - -from transformers.models.mixtral.configuration_mixtral import MixtralConfig from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock class ScaledMixtralSparseMoeBlock(torch.nn.Module): @@ -23,8 +21,6 @@ def __init__(self, prev_op: MixtralSparseMoeBlock, scales: torch.Tensor): # experts self.experts = prev_op.experts - for expert in self.experts: - expert.forward = expert_forward # [expert_num, hidden_dim] self.scales = torch.nn.Parameter(scales.data) @@ -67,16 +63,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # states by `routing_weights` on the corresponding tokens (top-1 and top-2) ### NOTE: We scale weights here, modified from original MoE. - current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim) / self.scales[expert_idx] - current_hidden_states = expert_layer(expert_layer, current_state) * routing_weights[top_x_list, idx_list, None] + current_state = hidden_states[None, top_x_list].reshape( + -1, hidden_dim) / self.scales[expert_idx] # <-- scales + + current_hidden_states = expert_layer(current_state) * routing_weights[top_x_list, idx_list, None] # However `index_add_` only support torch tensors for indexing so we'll use # the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype)) final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) return final_hidden_states, router_logits - -def expert_forward(self, hidden_states): - current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) - current_hidden_states = self.w2(current_hidden_states) - return current_hidden_states From 7525fb2770629c6770c35e2c7103ab591b1aa28a Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Sun, 21 Jan 2024 19:16:17 +0000 Subject: [PATCH 8/9] Backward compatibility --- awq/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/awq/models/base.py b/awq/models/base.py index 120cf95f..e7eab4d0 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -315,7 +315,7 @@ def _scale_activations(self, layer): set_op_by_name(layer, scale_dict['scale_name'], scaled_act) def _scale_moe(self, layer): - if hasattr(self, "get_moe_for_scaling"): + if hasattr(self, "get_moe_for_scaling") and hasattr(layer.block_sparse_moe, "scales"): scale_dict: dict = self.get_moe_for_scaling(layer) if not isinstance(scale_dict['scale_layer'], ScaledMixtralSparseMoeBlock): From 6c065abcee3ca148b67dfddbe047df824b71e8b3 Mon Sep 17 00:00:00 2001 From: Casper Hansen Date: Sun, 21 Jan 2024 19:46:00 +0000 Subject: [PATCH 9/9] Rework modules_to_not_convert --- awq/models/base.py | 8 ++++++-- awq/models/mixtral.py | 1 + examples/mixtral_quant.py | 9 ++------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/awq/models/base.py b/awq/models/base.py index e7eab4d0..e26a0bdf 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -84,12 +84,16 @@ def generate(self, *args, **kwargs): @torch.no_grad() def quantize(self, tokenizer=None, quant_config={}, calib_data: Union[str, List[str]]="pileval", - split="train", text_column="text", duo_scaling=True, modules_to_not_convert=None): + split="train", text_column="text", duo_scaling=True): self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config) + if hasattr(self, "modules_to_not_convert"): + self.quant_config.modules_to_not_convert = self.modules_to_not_convert + quantizer = AwqQuantizer( self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size, - self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert + self.quant_config.version, calib_data, split, text_column, duo_scaling, + modules_to_not_convert=self.quant_config.modules_to_not_convert ) quantizer.quantize() diff --git a/awq/models/mixtral.py b/awq/models/mixtral.py index 67364083..a6c5e776 100644 --- a/awq/models/mixtral.py +++ b/awq/models/mixtral.py @@ -26,6 +26,7 @@ def _transformers_version_check(): class MixtralAWQForCausalLM(BaseAWQForCausalLM): layer_type = "MixtralDecoderLayer" max_new_tokens_key = "max_position_embeddings" + modules_to_not_convert = ["gate"] @staticmethod def fuse_layers(model: OldMixtralForCausalLM): diff --git a/examples/mixtral_quant.py b/examples/mixtral_quant.py index fea92f60..4cd281f9 100644 --- a/examples/mixtral_quant.py +++ b/examples/mixtral_quant.py @@ -3,11 +3,7 @@ model_path = 'mistralai/Mixtral-8x7B-Instruct-v0.1' quant_path = 'mixtral-instruct-awq' -modules_to_not_convert = ["gate"] -quant_config = { - "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM", - "modules_to_not_convert": modules_to_not_convert -} +quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"} # Load model # NOTE: pass safetensors=True to load safetensors @@ -19,8 +15,7 @@ # Quantize model.quantize( tokenizer, - quant_config=quant_config, - modules_to_not_convert=modules_to_not_convert + quant_config=quant_config ) # Save quantized model