Skip to content

Commit

Permalink
ENH / FIX: Few enhancements and fix for mixed-precision training (#348)
Browse files Browse the repository at this point in the history
  • Loading branch information
younesbelkada authored Feb 16, 2024
1 parent 2de6092 commit 969b290
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions awq/modules/linear/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,16 @@ def forward(
def backward(ctx, grad_output):
input, qweight, qzeros, scales, bias = ctx.saved_tensors

if awq_ext is None:
raise ValueError(
"auto-awq kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels"
" by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels"
)

# Cast to correct dtype for mixed precision training
weights = awq_ext.dequantize_weights_cuda(
qweight, scales, qzeros, 1, 0, 0, False
)
).to(grad_output.dtype)

if ctx.needs_input_grad[0]:
# 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm
Expand All @@ -75,7 +82,6 @@ def backward(ctx, grad_output):

return grad_input, None, None, None, None, None, None, None


class WQLinear_GEMM(nn.Module):
def __init__(
self, w_bit, group_size, in_features, out_features, bias, dev, training=False
Expand Down

0 comments on commit 969b290

Please sign in to comment.