diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index 1bf89fd3..37b1dbdc 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -244,17 +244,23 @@ def _search_best_scale( # Put x on the right device inp = inp.to(next(module2inspect.parameters()).device) - # [STEP 1]: Compute maximum of weight + # [STEP 1]: Compute per-channel mean of normalised weights + # All layer weights are concatted together weight = torch.cat([_m.weight for _m in layers], dim=0) org_shape = weight.shape + # The weights are reshaped to be organised by quantization group weight = weight.view(-1, self.group_size) + # Calculates the relative magnitude of the weights within each of the quantization groups, + # and rescales each group individually so that each group has weights on a 0-1 scale. w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) + # Resizes the rescaled weight matrix back up to its original dimensions w_scale = w_scale.view(org_shape) - w_max = w_scale.mean(0) + # Gets the average rescaled magnitude for each output channel + w_mean = w_scale.mean(0) clear_memory(weight) - # [STEP 2]: Compute maximum of x - x_max = inp.abs().view(-1, inp.shape[-1]).mean(0) + # [STEP 2]: Compute per-channel mean of the input activation + x_mean = inp.abs().view(-1, inp.shape[-1]).mean(0) # [STEP 3]: Compute output of module with torch.no_grad(): @@ -266,7 +272,7 @@ def _search_best_scale( # [STEP 4]: Compute loss best_scales = self._compute_best_scale( - inp, w_max, x_max, module2inspect, layers, fp16_output, module_kwargs + inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs ) return ( @@ -278,8 +284,8 @@ def _search_best_scale( def _compute_best_scale( self, x, - w_max, - x_max, + w_mean, + x_mean, module2inspect, linears2scale: List[nn.Linear], fp16_output, @@ -303,8 +309,8 @@ def _compute_best_scale( org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()} device = x.device - x_max = x_max.view(-1).to(device) - w_max = w_max.view(-1).to(device) + x_mean = x_mean.view(-1).to(device) + w_mean = w_mean.view(-1).to(device) for ratio in range(n_grid): # create new scales @@ -312,9 +318,9 @@ def _compute_best_scale( # NOTE: s^-1 * x is fused here, according to paper if self.duo_scaling: - scales = (x_max.pow(ratio) / w_max.pow(1 - ratio)).clamp(min=1e-4) + scales = (x_mean.pow(ratio) / w_mean.pow(1 - ratio)).clamp(min=1e-4) else: - scales = x_max.pow(ratio).clamp(min=1e-4).view(-1) + scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1) scales = scales / (scales.max() * scales.min()).sqrt() scales_view = scales.view(1, -1).to(device)