diff --git a/awq/models/base.py b/awq/models/base.py index 53ee2f50..8ef243ab 100644 --- a/awq/models/base.py +++ b/awq/models/base.py @@ -12,11 +12,17 @@ from huggingface_hub import snapshot_download from transformers.modeling_utils import shard_checkpoint -from awq.modules.linear.gemm import WQLinear_GEMM -from awq.modules.linear.gemv import WQLinear_GEMV -from awq.modules.linear.marlin import WQLinear_Marlin, marlin_post_init -from awq.modules.linear.exllama import WQLinear_Exllama, exllama_post_init -from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init +from awq.modules.linear import ( + WQLinear_GEMM, + WQLinear_GEMV, + WQLinear_Marlin, + WQLinear_Exllama, + WQLinear_ExllamaV2, + WQLinear_GEMVFast, + marlin_post_init, + exllama_post_init, + exllamav2_post_init, +) from awq.utils.module import ( get_named_linears, set_op_by_name, @@ -541,6 +547,8 @@ def _load_quantized_modules( q_linear_module = WQLinear_GEMM elif version == "gemv": q_linear_module = WQLinear_GEMV + elif version == "gemv_fast": + q_linear_module = WQLinear_GEMVFast q_linear = q_linear_module.from_linear( module, quant_config.w_bit, quant_config.q_group_size, True diff --git a/awq/modules/linear/__init__.py b/awq/modules/linear/__init__.py index 41996f22..aa341fae 100644 --- a/awq/modules/linear/__init__.py +++ b/awq/modules/linear/__init__.py @@ -1,5 +1,6 @@ -from .exllama import WQLinear_Exllama -from .exllamav2 import WQLinear_ExllamaV2 +from .exllama import WQLinear_Exllama, exllama_post_init +from .exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init from .gemm import WQLinear_GEMM from .gemv import WQLinear_GEMV -from .marlin import WQLinear_Marlin +from .marlin import WQLinear_Marlin, marlin_post_init +from .gemv_fast import WQLinear_GEMVFast diff --git a/awq/modules/linear/gemv_fast.py b/awq/modules/linear/gemv_fast.py new file mode 100644 index 00000000..6e75bd06 --- /dev/null +++ b/awq/modules/linear/gemv_fast.py @@ -0,0 +1,209 @@ +import torch + +try: + import awq_v2_ext # with CUDA kernels (AutoAWQ_kernels) + + AWQ_INSTALLED = True +except: + AWQ_INSTALLED = False + + +def make_divisible(c, divisor): + return (c + divisor - 1) // divisor + + +def calculate_zeros_width(in_features, group_size=128, pack_num=8): + if group_size >= 128: + size_multiplier = 1 + elif group_size == 64: + size_multiplier = 2 + elif group_size == 32: + size_multiplier = 4 + else: + raise NotImplementedError + + base_width = make_divisible(in_features // group_size, pack_num) + base_width = make_divisible(base_width, size_multiplier) * size_multiplier + return base_width + + +def pack_intweight(unpacked_qweight, interleave, kstride): + # unpacked_qweight: [N, K] + N = unpacked_qweight.shape[0] + K = unpacked_qweight.shape[1] + + Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32) + # np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...] + Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4) + Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32) + + # reorder each 8 weights for fast dequantization + # [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] + Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8) + Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3) + Packed_Kernel = Packed_Kernel.reshape(N, K) + + # interleaving every four rows + Packed_Kernel = Packed_Kernel.reshape( + N // interleave, interleave, K // kstride, kstride + ) + # N // 4, K // 64, 4, 64 + Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3) + Packed_Kernel = Packed_Kernel.reshape( + N // interleave, K // kstride, kstride, interleave + ) + # Packing -> (N // 4, K // 64, 64) + Packed_Kernel = ( + Packed_Kernel[..., 0] + | (Packed_Kernel[..., 1] << 4) + | (Packed_Kernel[..., 2] << 8) + | (Packed_Kernel[..., 3] << 12) + ) + # reshape to (N // 4, K), FP16 format + Packed_Kernel = Packed_Kernel.reshape(N // interleave, K) + qweight = ( + torch.tensor(Packed_Kernel.astype("int16")) + .to(unpacked_qweight.device) + .contiguous() + ) + return qweight + + +class WQLinear_GEMVFast(torch.nn.Module): + def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): + super().__init__() + + self.in_features = in_features + self.out_features = out_features + self.w_bit = w_bit + self.group_size = group_size if group_size != -1 else in_features + self.split_k_iters = 8 + self.interleave = 4 + + # quick sanity check (make sure aligment) + assert self.in_features % self.group_size == 0 + assert out_features % (32 // self.w_bit) == 0 + pack_num = 32 // self.w_bit + int16_pack_num = 16 // self.w_bit + + assert out_features % (self.interleave) == 0 + self.register_buffer( + "qweight", + torch.zeros( + ( + out_features // self.interleave, + in_features // int16_pack_num * self.interleave, + ), + dtype=torch.int16, + device=dev, + ), + ) + self.register_buffer( + "scales", + torch.zeros( + ( + calculate_zeros_width(in_features, self.group_size) * pack_num, + out_features, + ), + dtype=torch.float16, + device=dev, + ), + ) + self.register_buffer( + "qzeros", + torch.zeros( + ( + calculate_zeros_width(in_features, self.group_size) * pack_num, + out_features, + ), + dtype=torch.float16, + device=dev, + ), + ) + + if bias: + self.register_buffer( + "bias", torch.zeros((out_features), dtype=torch.float16, device=dev) + ) + else: + self.bias = None + + @classmethod + def from_linear( + cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None + ): + awq_linear = cls( + w_bit, + group_size, + linear.in_features, + linear.out_features, + linear.bias is not None, + linear.weight.device, + ) + if init_only: + return awq_linear + + # need scales and zeros info for real quantization + assert scales is not None and zeros is not None + scale_zeros = zeros * scales + + pack_num = 32 // awq_linear.w_bit + qscales = torch.zeros( + ( + scales.shape[0], + calculate_zeros_width(linear.in_features, group_size) * pack_num, + ), + dtype=torch.float16, + device=scales.device, + ) + qscales[:, : scales.shape[1]] = scales + # awq_linear.scales = scales.clone().half() + awq_linear.scales = qscales.transpose(1, 0).contiguous() + if linear.bias is not None: + awq_linear.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(awq_linear.in_features): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[:, idx // group_size]) + / qscales[:, idx // group_size] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.to(dtype=torch.int32) + awq_linear.qweight = pack_intweight( + intweight.contiguous(), interleave=4, kstride=64 + ) + + zeros = zeros.to(dtype=torch.int32) + qzeros = torch.zeros_like(qscales) + + qzeros[:, : scales.shape[1]] = -( + qscales[:, : scales.shape[1]] * (zeros.to(torch.float32)) + ).to(torch.float16) + awq_linear.qzeros = qzeros.transpose(1, 0).contiguous() + + return awq_linear + + @torch.no_grad() + def forward(self, x): + inputs = x + if inputs.numel() / inputs.shape[-1] < 8: + out = awq_v2_ext.gemv_forward_cuda_decode( + inputs, + self.qweight, + self.scales, + self.qzeros, + inputs.numel() // inputs.shape[-1], + self.out_features, + self.in_features, + self.group_size, + ) + else: + out = awq_v2_ext.gemm_forward_cuda_prefill( + inputs, self.qweight, self.scales, self.qzeros + ) + out = out + self.bias if self.bias is not None else out + + return out diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index aa82cd6b..1bf89fd3 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -9,9 +9,12 @@ from awq.utils.calib_data import get_calib_dataset from awq.quantize.scale import apply_scale, apply_clip from awq.utils.utils import clear_memory, get_best_device -from awq.modules.linear.gemm import WQLinear_GEMM -from awq.modules.linear.gemv import WQLinear_GEMV -from awq.modules.linear.marlin import WQLinear_Marlin +from awq.modules.linear import ( + WQLinear_GEMM, + WQLinear_GEMV, + WQLinear_Marlin, + WQLinear_GEMVFast, +) from awq.utils.module import ( append_str_prefix, get_op_name, @@ -200,6 +203,9 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]): elif self.version == "marlin": q_linear_module = WQLinear_Marlin + + elif self.version == "gemv_fast": + q_linear_module = WQLinear_GEMVFast else: raise ValueError(f"Unknown version {self.version}") @@ -466,6 +472,7 @@ def forward(self, *args, **kwargs): self.model(samples.to(next(self.model.parameters()).device)) except ValueError: # work with early exit pass + modules[0] = modules[0].module # restore # Update the layer kwargs with `prepare_inputs_for_generation` method # that takes care of everything to avoid unexpected errors. @@ -474,7 +481,6 @@ def forward(self, *args, **kwargs): layer_kwargs.pop("input_ids") del samples - modules[0] = modules[0].module # restore inps = inps[0] modules[0] = modules[0].cpu() diff --git a/awq/utils/fused_utils.py b/awq/utils/fused_utils.py index 64d63947..6f0a091d 100644 --- a/awq/utils/fused_utils.py +++ b/awq/utils/fused_utils.py @@ -1,10 +1,13 @@ import torch -from awq.modules.linear.gemm import WQLinear_GEMM -from awq.modules.linear.gemv import WQLinear_GEMV -from awq.modules.linear.marlin import WQLinear_Marlin -from awq.modules.linear.exllama import WQLinear_Exllama -from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2 +from awq.modules.linear import ( + WQLinear_GEMM, + WQLinear_GEMV, + WQLinear_Marlin, + WQLinear_Exllama, + WQLinear_ExllamaV2, + WQLinear_GEMVFast, +) def prepare_correct_devices(next_layer, hidden_states, mask): @@ -73,6 +76,8 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): q_linear = WQLinear_ExllamaV2 elif isinstance(q_proj, WQLinear_Marlin): q_linear = WQLinear_Marlin + elif isinstance(q_proj, WQLinear_GEMVFast): + q_linear = WQLinear_GEMVFast qkv_layer = q_linear( q_proj.w_bit, @@ -132,6 +137,17 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): [q_proj.scales, k_proj.scales, v_proj.scales], dim=1 ) # workspace is created in post_init + elif isinstance(q_proj, WQLinear_GEMVFast): + qkv_layer.qweight = torch.cat( + [q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0 + ) + qkv_layer.qzeros = torch.cat( + [q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1 + ).contiguous() + qkv_layer.scales = torch.cat( + [q_proj.scales, k_proj.scales, v_proj.scales], dim=1 + ).contiguous() + qkv_layer.split_k_iters = q_proj.split_k_iters qkv_layer.bias = bias diff --git a/examples/benchmark.py b/examples/benchmark.py index ed71af45..df84ecb1 100644 --- a/examples/benchmark.py +++ b/examples/benchmark.py @@ -117,11 +117,12 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si raise RuntimeError(ex) total_memory_used = 0 + memory_pct = 100 if successful_generate: # number of tokens in context / time for processing context * batch size - prefill_tokens_per_second = input_ids.shape[1] / context_time * batch_size + prefill_tokens_per_second = round(input_ids.shape[1] / context_time * batch_size, 2) # 1 second / median time per token in seconds * batch size - decode_tokens_per_second = 1 / np.median(generate_time) * batch_size + decode_tokens_per_second = round(1 / np.median(generate_time) * batch_size, 2) print(f" ** Speed (Prefill): {prefill_tokens_per_second:.2f} tokens/second") print(f" ** Speed (Decode): {decode_tokens_per_second:.2f} tokens/second") diff --git a/setup.py b/setup.py index b5de72b0..a580dc86 100644 --- a/setup.py +++ b/setup.py @@ -89,7 +89,6 @@ def get_kernels_whl_url( "transformers>=4.35.0", "tokenizers>=0.12.1", "typing_extensions>=4.8.0", - "triton", "accelerate", "datasets", "zstandard",