Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (equalize): enable parametrized scales #1175

Open
wants to merge 19 commits into
base: dev
Choose a base branch
from

Conversation

pablomlago
Copy link
Collaborator

@pablomlago pablomlago commented Feb 5, 2025

Reason for this PR

Enable parametrized scaling, similarly to rotations (see #1148).

Changes Made in this PR

Refactored the function _cross_layer_equalization and incorporated rewriters for handling fused/unfused scaling.

Testing Summary

Made equalization tests also run with fuse_scaling=False.

Risk Highlight

  • This PR includes code from another work (please detail).
  • This PR contains API-breaking changes.
  • This PR depends on work in another PR (please provide links/details).
  • This PR introduces new dependencies (please detail).
  • There are coverage gaps not covered by tests.
  • Documentation updates required in subsequent PR.

Checklist

  • Code comments added to any hard-to-understand areas, if applicable.
  • Changes generate no new warnings.
  • Updated any relevant tests, if applicable.
  • No conflicts with destination dev branch.
  • I reviewed my own code changes.
  • Initial CI/CD passing.
  • 1+ reviews given, and any review issues addressed and approved.
  • Post-review full CI/CD passing.

@pablomlago pablomlago changed the base branch from master to dev February 5, 2025 18:35
@pablomlago pablomlago changed the title Feat(equalize): enable parametrized scales Feat (equalize): enable parametrized scales Feb 5, 2025
@pablomlago pablomlago requested a review from Giuseppe5 February 10, 2025 11:50


def update_module_tensor(module: nn.Module, tensor: torch.Tensor, tensor_name: str):
setattr(module, tensor_name, torch.nn.Parameter(tensor))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not convinced by this. Either the name of the function must be changed, or we need to control if the new tensor is a parameter or not. Also, how general is this to have it here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with removing it, it is mainly a leftover of the WeightBiasWrapper.

@@ -308,7 +309,7 @@ def apply(self, model: GraphModule) -> GraphModule:
tensor = getattr(module, self.tensor_name).data
tensor = self.transform_module(tensor)
# Modify the weights in-place
setattr(module, self.tensor_name, torch.nn.Parameter(tensor))
update_module_tensor(module=module, tensor=tensor, tensor_name=self.tensor_name)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like we only do it once, so let's not create a new func

"""
Given two adjacent tensors', the weights are scaled such that
the ranges of the first tensors' output channel are equal to the
ranges of the second tensors' input channel
"""
# The names of the attributes containing the tensors to equalize, as well as the axis
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this change needed?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I like this solution better than the previous one. Let's see if we can find a compromise

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation was two-fold:

  • Remove the WeightBiasWrapper, whose only functionality was to make sure that weights are under the attribute "weight". In general, reducing, as much as possible, the places in which we do checks on the type of module to get the attribute of the weights (e.g. "in_proj_weight" in MHA)
  • Make the loop in which the parametrizations are added as similar as possible to that of the rotations: might make it easier to remove duplications in future PRs.
    That being said, I'm not a big fan of having those constants either, so I'm open to any proposal.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we expand the WeightBiasWrapper class to also keep track of the original tensor name?
So we can keep doing module.weight to get the weights (instead of having weight_pos and bias_pos), and the name/axis, all in their own attribute

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've expanded the "wrapper" to handle additional logic, so it is more meaningful to keep it.

@@ -500,20 +528,35 @@ def _no_equalize():

if isinstance(module, nn.MultiheadAttention):
module = module.out_proj
src_axes[name] = (module, axis)
# Bias, if present, needs to be rotated for sources
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rotated?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo

@@ -1021,6 +1064,29 @@ def apply(self,
return graph_model


class ScaleBiasMul(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if I want to have a new class for what's basically a 99% overlap with ScaledBias

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we switch scaling with inverse_scaling, i.e.

scaling_factors = sinks_range / srcs_range

We only need the reciprocal for the weights, and maybe we can get rid of this class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to make sure that the order of the operations was the same as before, but given that. empirically, the change in the output is negligible, I'll do that change to remove the ScaleBiasMul module.

self.axis = axis
self.start_end_idxs = start_end_idxs
self.slice_idxs = slice_idxs
self.use_inverse_scaling = use_inverse_scaling
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to have the inverse both here and for the activations?
Only one of the two needs the inverse

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is needed in weight equalization: sources are scaled by the scaling factor, and sinks by its inverse.

sink_broadcast_size = [1] * module.weight.ndim
sink_broadcast_size[axis] = module.weight.size(axis)
insert_mul_node_fn(scaling_factors, act_val_shape, act_axis)
for name, (module, tensor_names_axis) in src_axes.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens in weight equalization with multiple iterations?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Parametrizations are added on top of each other.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't we discuss to update the parametrization? In case of big models with 100 iterations, does it mean we will have 100 scales per weight?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unlikely that weight equalization will be used with parametrized scaling at the moment. If needed, some logic will be incorporated in the future to fuse the scaling parameters appropriately, so there's only a scaling factor Parameter per region, irrespective of the number of iterations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants