-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat (equalize): enable parametrized scales #1175
base: dev
Are you sure you want to change the base?
Conversation
20837b7
to
99bbb4f
Compare
src/brevitas/utils/torch_utils.py
Outdated
|
||
|
||
def update_module_tensor(module: nn.Module, tensor: torch.Tensor, tensor_name: str): | ||
setattr(module, tensor_name, torch.nn.Parameter(tensor)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not convinced by this. Either the name of the function must be changed, or we need to control if the new tensor is a parameter or not. Also, how general is this to have it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm fine with removing it, it is mainly a leftover of the WeightBiasWrapper.
src/brevitas/graph/base.py
Outdated
@@ -308,7 +309,7 @@ def apply(self, model: GraphModule) -> GraphModule: | |||
tensor = getattr(module, self.tensor_name).data | |||
tensor = self.transform_module(tensor) | |||
# Modify the weights in-place | |||
setattr(module, self.tensor_name, torch.nn.Parameter(tensor)) | |||
update_module_tensor(module=module, tensor=tensor, tensor_name=self.tensor_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like we only do it once, so let's not create a new func
src/brevitas/graph/equalize.py
Outdated
""" | ||
Given two adjacent tensors', the weights are scaled such that | ||
the ranges of the first tensors' output channel are equal to the | ||
ranges of the second tensors' input channel | ||
""" | ||
# The names of the attributes containing the tensors to equalize, as well as the axis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this change needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I like this solution better than the previous one. Let's see if we can find a compromise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The motivation was two-fold:
- Remove the WeightBiasWrapper, whose only functionality was to make sure that weights are under the attribute "weight". In general, reducing, as much as possible, the places in which we do checks on the type of module to get the attribute of the weights (e.g. "in_proj_weight" in MHA)
- Make the loop in which the parametrizations are added as similar as possible to that of the rotations: might make it easier to remove duplications in future PRs.
That being said, I'm not a big fan of having those constants either, so I'm open to any proposal.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we expand the WeightBiasWrapper class to also keep track of the original tensor name?
So we can keep doing module.weight to get the weights (instead of having weight_pos and bias_pos), and the name/axis, all in their own attribute
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've expanded the "wrapper" to handle additional logic, so it is more meaningful to keep it.
src/brevitas/graph/equalize.py
Outdated
@@ -500,20 +528,35 @@ def _no_equalize(): | |||
|
|||
if isinstance(module, nn.MultiheadAttention): | |||
module = module.out_proj | |||
src_axes[name] = (module, axis) | |||
# Bias, if present, needs to be rotated for sources |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rotated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo
src/brevitas/graph/equalize.py
Outdated
@@ -1021,6 +1064,29 @@ def apply(self, | |||
return graph_model | |||
|
|||
|
|||
class ScaleBiasMul(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if I want to have a new class for what's basically a 99% overlap with ScaledBias
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we switch scaling
with inverse_scaling
, i.e.
scaling_factors = sinks_range / srcs_range
We only need the reciprocal for the weights, and maybe we can get rid of this class?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to make sure that the order of the operations was the same as before, but given that. empirically, the change in the output is negligible, I'll do that change to remove the ScaleBiasMul module.
self.axis = axis | ||
self.start_end_idxs = start_end_idxs | ||
self.slice_idxs = slice_idxs | ||
self.use_inverse_scaling = use_inverse_scaling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to have the inverse both here and for the activations?
Only one of the two needs the inverse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is needed in weight equalization: sources are scaled by the scaling factor, and sinks by its inverse.
src/brevitas/graph/equalize.py
Outdated
sink_broadcast_size = [1] * module.weight.ndim | ||
sink_broadcast_size[axis] = module.weight.size(axis) | ||
insert_mul_node_fn(scaling_factors, act_val_shape, act_axis) | ||
for name, (module, tensor_names_axis) in src_axes.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens in weight equalization with multiple iterations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parametrizations are added on top of each other.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't we discuss to update the parametrization? In case of big models with 100 iterations, does it mean we will have 100 scales per weight?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unlikely that weight equalization will be used with parametrized scaling at the moment. If needed, some logic will be incorporated in the future to fuse the scaling parameters appropriately, so there's only a scaling factor Parameter per region, irrespective of the number of iterations.
021f2e9
to
19fac39
Compare
b597820
to
e1e489c
Compare
Reason for this PR
Enable parametrized scaling, similarly to rotations (see #1148).
Changes Made in this PR
Refactored the function _cross_layer_equalization and incorporated rewriters for handling fused/unfused scaling.
Testing Summary
Made equalization tests also run with fuse_scaling=False.
Risk Highlight
Checklist
dev
branch.