diff --git a/awq/modules/linear/gemm.py b/awq/modules/linear/gemm.py index bddb3ced..2af6689c 100644 --- a/awq/modules/linear/gemm.py +++ b/awq/modules/linear/gemm.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +from torch.autograd import Function from awq.utils.utils import get_best_device from awq.utils.packing_utils import dequantize_gemm @@ -10,9 +11,94 @@ except: AWQ_INSTALLED = False +# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev +class WQLinearMMFunction(Function): + @staticmethod + # ctx is the first argument to forward + def forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0 + ): + # The forward pass can use ctx. + ctx.save_for_backward(x, qweight, qzeros, scales, bias) + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features, ) + x = x.to(torch.float16) + + if AWQ_INSTALLED: + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = awq_ext.dequantize_weights_cuda( + qweight, + scales, + qzeros, + 0, + 0, + 0, + False + ) + out = torch.matmul(x, out) + else: + out = awq_ext.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), + qweight, + scales, + qzeros, + 8 + ) + else: + out = dequantize_gemm( + qweight, + qzeros, + scales, + w_bit, + group_size + ) + out = torch.matmul(x, out) + + out = out + bias if bias is not None else out + out = out.reshape(out_shape) + + # always want 3D tensor if tensor is 2D + if len(out.shape) == 2: + out = out.unsqueeze(0) + + return out + + @staticmethod + def backward(ctx, grad_output): + input, qweight, qzeros, scales, bias = ctx.saved_tensors + + weights = awq_ext.dequantize_weights_cuda( + qweight, + scales, + qzeros, + 1, + 0, + 0, + False + ) + + if ctx.needs_input_grad[0]: + # 2D matrix multiplication, unsqueeze to 3D + grad_input = grad_output.squeeze(0).mm( + weights.transpose(0, 1) + ).unsqueeze(0) + + return grad_input, None, None, None, None, None, None, None + class WQLinear_GEMM(nn.Module): - def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): + def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False): super().__init__() if w_bit not in [4]: @@ -22,6 +108,7 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): self.out_features = out_features self.w_bit = w_bit self.group_size = group_size if group_size != -1 else in_features + self.training = training # quick sanity check (make sure aligment) assert self.in_features % self.group_size == 0 @@ -145,7 +232,6 @@ def from_linear( return awq_linear - @torch.no_grad() def forward(self, x): out_shape = x.shape[:-1] + (self.out_features,) @@ -153,37 +239,29 @@ def forward(self, x): if input_dtype != torch.float16: x = x.half() - if AWQ_INSTALLED: - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 - - if FP16_MATMUL_HEURISTIC_CONDITION: - out = awq_ext.dequantize_weights_cuda( - self.qweight, - self.scales, - self.qzeros, - 0, - 0, - 0, - False, - ) - out = torch.matmul(x, out) - else: - out = awq_ext.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), - self.qweight, - self.scales, - self.qzeros, - 8, - ) - else: - out = dequantize_gemm( + if self.training: + out = WQLinearMMFunction.apply( + x, self.qweight, self.qzeros, self.scales, self.w_bit, self.group_size, + self.bias, + self.out_features, ) - out = torch.matmul(x, out) + else: + with torch.no_grad(): + out = WQLinearMMFunction.apply( + x, + self.qweight, + self.qzeros, + self.scales, + self.w_bit, + self.group_size, + self.bias, + self.out_features, + ) if input_dtype != torch.float16: out = out.to(dtype=input_dtype) diff --git a/examples/awq_train.py b/examples/awq_train.py new file mode 100644 index 00000000..5e8fd0f5 --- /dev/null +++ b/examples/awq_train.py @@ -0,0 +1,76 @@ +import datasets +from awq import AutoAWQForCausalLM +from transformers import ( + AutoTokenizer, + TrainingArguments, + Trainer, + DataCollatorForLanguageModeling +) +from peft import get_peft_model, LoraConfig, TaskType + +def prepare_split(tokenizer): + data = datasets.load_dataset("mhenrichsen/alpaca_2k_test", split="train") + prompt_template = "[INST] {system} {prompt} [/INST] {output}" + + def format_prompt(x): + return prompt_template.format( + system="", + prompt=x["instruction"], + output=x["output"] + ) + + data = data.map( + lambda x: {"text": format_prompt(x)}, + ).select_columns(["text"]) + data = data.map(lambda x: tokenizer(x["text"]), batched=True) + + return data + +model_path = "ybelkada/opt-125m-awq" + +# Load model +model = AutoAWQForCausalLM.from_quantized(model_path, fuse_layers=False) +tokenizer = AutoTokenizer.from_pretrained(model_path) +tokenizer.pad_token = tokenizer.eos_token + +# Prepare data +data_train = prepare_split(tokenizer) + +# Config Lora +lora_config = LoraConfig( + r=4, + lora_alpha=8, + lora_dropout=0.5, + bias="none", + task_type=TaskType.CAUSAL_LM, + inference_mode=False +) + +model = get_peft_model(model.model, lora_config) + +model.print_trainable_parameters() + +training_arguments = TrainingArguments( + output_dir="./output", + per_device_train_batch_size=1, + optim="adamw_torch", + num_train_epochs=1, + learning_rate=1e-4, + # fp16=True, + evaluation_strategy="no", + save_strategy="epoch", + save_steps=100, + logging_steps=50, + eval_steps=None, + load_best_model_at_end=False +) + +trainer = Trainer( + model=model, + train_dataset=data_train, + args=training_arguments, + data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), +) + +trainer.train() +trainer.save_model("output") \ No newline at end of file