Skip to content

Commit

Permalink
PEFT compatible GEMM (#324)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Feb 3, 2024
1 parent ebe8fc3 commit 29ee66d
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 27 deletions.
132 changes: 105 additions & 27 deletions awq/modules/linear/gemm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
from torch.autograd import Function
from awq.utils.utils import get_best_device
from awq.utils.packing_utils import dequantize_gemm

Expand All @@ -10,9 +11,94 @@
except:
AWQ_INSTALLED = False

# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
class WQLinearMMFunction(Function):
@staticmethod
# ctx is the first argument to forward
def forward(
ctx,
x,
qweight,
qzeros,
scales,
w_bit=4,
group_size=128,
bias=None,
out_features=0
):
# The forward pass can use ctx.
ctx.save_for_backward(x, qweight, qzeros, scales, bias)
ctx.out_features = out_features

out_shape = x.shape[:-1] + (out_features, )
x = x.to(torch.float16)

if AWQ_INSTALLED:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024

if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda(
qweight,
scales,
qzeros,
0,
0,
0,
False
)
out = torch.matmul(x, out)
else:
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]),
qweight,
scales,
qzeros,
8
)
else:
out = dequantize_gemm(
qweight,
qzeros,
scales,
w_bit,
group_size
)
out = torch.matmul(x, out)

out = out + bias if bias is not None else out
out = out.reshape(out_shape)

# always want 3D tensor if tensor is 2D
if len(out.shape) == 2:
out = out.unsqueeze(0)

return out

@staticmethod
def backward(ctx, grad_output):
input, qweight, qzeros, scales, bias = ctx.saved_tensors

weights = awq_ext.dequantize_weights_cuda(
qweight,
scales,
qzeros,
1,
0,
0,
False
)

if ctx.needs_input_grad[0]:
# 2D matrix multiplication, unsqueeze to 3D
grad_input = grad_output.squeeze(0).mm(
weights.transpose(0, 1)
).unsqueeze(0)

return grad_input, None, None, None, None, None, None, None


class WQLinear_GEMM(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False):
super().__init__()

if w_bit not in [4]:
Expand All @@ -22,6 +108,7 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.training = training

# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
Expand Down Expand Up @@ -145,45 +232,36 @@ def from_linear(

return awq_linear

@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)

input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()

if AWQ_INSTALLED:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024

if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda(
self.qweight,
self.scales,
self.qzeros,
0,
0,
0,
False,
)
out = torch.matmul(x, out)
else:
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
8,
)
else:
out = dequantize_gemm(
if self.training:
out = WQLinearMMFunction.apply(
x,
self.qweight,
self.qzeros,
self.scales,
self.w_bit,
self.group_size,
self.bias,
self.out_features,
)
out = torch.matmul(x, out)
else:
with torch.no_grad():
out = WQLinearMMFunction.apply(
x,
self.qweight,
self.qzeros,
self.scales,
self.w_bit,
self.group_size,
self.bias,
self.out_features,
)

if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
Expand Down
76 changes: 76 additions & 0 deletions examples/awq_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import datasets
from awq import AutoAWQForCausalLM
from transformers import (
AutoTokenizer,
TrainingArguments,
Trainer,
DataCollatorForLanguageModeling
)
from peft import get_peft_model, LoraConfig, TaskType

def prepare_split(tokenizer):
data = datasets.load_dataset("mhenrichsen/alpaca_2k_test", split="train")
prompt_template = "<s>[INST] {system} {prompt} [/INST] {output}</s>"

def format_prompt(x):
return prompt_template.format(
system="",
prompt=x["instruction"],
output=x["output"]
)

data = data.map(
lambda x: {"text": format_prompt(x)},
).select_columns(["text"])
data = data.map(lambda x: tokenizer(x["text"]), batched=True)

return data

model_path = "ybelkada/opt-125m-awq"

# Load model
model = AutoAWQForCausalLM.from_quantized(model_path, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

# Prepare data
data_train = prepare_split(tokenizer)

# Config Lora
lora_config = LoraConfig(
r=4,
lora_alpha=8,
lora_dropout=0.5,
bias="none",
task_type=TaskType.CAUSAL_LM,
inference_mode=False
)

model = get_peft_model(model.model, lora_config)

model.print_trainable_parameters()

training_arguments = TrainingArguments(
output_dir="./output",
per_device_train_batch_size=1,
optim="adamw_torch",
num_train_epochs=1,
learning_rate=1e-4,
# fp16=True,
evaluation_strategy="no",
save_strategy="epoch",
save_steps=100,
logging_steps=50,
eval_steps=None,
load_best_model_at_end=False
)

trainer = Trainer(
model=model,
train_dataset=data_train,
args=training_arguments,
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)

trainer.train()
trainer.save_model("output")

0 comments on commit 29ee66d

Please sign in to comment.