Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (hadamard): support region expansion #1178

Merged
merged 10 commits into from
Feb 12, 2025
Merged

Conversation

Giuseppe5
Copy link
Collaborator

Reason for this PR

Changes Made in this PR

Testing Summary

Risk Highlight

  • This PR includes code from another work (please detail).
  • This PR contains API-breaking changes.
  • This PR depends on work in another PR (please provide links/details).
  • This PR introduces new dependencies (please detail).
  • There are coverage gaps not covered by tests.
  • Documentation updates required in subsequent PR.

Checklist

  • Code comments added to any hard-to-understand areas, if applicable.
  • Changes generate no new warnings.
  • Updated any relevant tests, if applicable.
  • No conflicts with destination dev branch.
  • I reviewed my own code changes.
  • Initial CI/CD passing.
  • 1+ reviews given, and any review issues addressed and approved.
  • Post-review full CI/CD passing.

# Only "weight" is rotated
tensor_names_axis = [("weight", _get_input_axis(module))]
if region.expand_region:
assert isinstance(module, nn.Linear), "Currently only Linear layers support expanded hadamard"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code just below can support any layer type, the issue is more on the activation side since we need to define also the dimension along which the input needs padding. It is not complicated to do tbh, but not necessary atm

model: nn.Module,
regions: List[Region],
prefix: str = '',
blacklist_layers: Optional[List[str]] = None):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The exclusion list needs to be passed since we want to exclude all the layers that might be expanded

@@ -350,20 +351,26 @@ def quantize_llm(args, extra_args=None):
model(**calibration_loader[0])
remove_hooks(model)

layers_to_expand = []
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea is that if the user passes q_proj, this piece of code will catch all the layers layer.module...q_proj

src/brevitas/graph/equalize.py Show resolved Hide resolved
src/brevitas/graph/equalize.py Outdated Show resolved Hide resolved
# Only "weight" is rotated
tensor_names_axis = [("weight", _get_input_axis(module))]
if region.expand_region:
assert isinstance(module, nn.Linear), "Currently only Linear layers support expanded hadamard"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Save convs for a new PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, yeah. I'll open an issue to keep track of that as good first issue

if region.expand_region:
assert isinstance(module, nn.Linear), "Currently only Linear layers support expanded hadamard"
hidden_dim = module.weight.shape[1]
new_hidden = find_closest_hadamard_number(hidden_dim, hidden_dim + 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 here? Maybe the previous comment should live here?

def input_pad(self, inp):
# TODO: This only works for Linear layers. We have an assert in equalize.py to check for this
hidden_dim = inp.shape[-1]
new_hidden = find_closest_hadamard_number(hidden_dim, hidden_dim + 1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar here - the +1 hardcoding...

src/brevitas/nn/equalized_layer.py Outdated Show resolved Hide resolved
src/brevitas_examples/llm/main.py Outdated Show resolved Hide resolved
src/brevitas_examples/llm/main.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@nickfraser nickfraser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only one note about verbosity flags. I guess this can also be achieved by setting the log level higher globally.

@@ -12,6 +13,25 @@
from brevitas.utils.quant_utils import *


def setup_logger(name):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

:O!

parameter_number_pre = 0
for m in graph_model.parameters():
parameter_number_pre += m.numel()
logging.info(f"{len(expanded_regions)} layers will be expanded during rotation")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personally I think all logging should be protected with a verbose flag. Sometimes we will want to turn these things off (for some specific experiments)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will go the pytorch way and have an env variable to set the level.
Default will be "INFO" and user can select DEBUG with BREVITAS_LOGGING=DEBUG

def setup_logger(name):

logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe have the logging level as an optional parameter? You might want to turn off all 'INFO' messages at some point.

Copy link
Collaborator Author

@Giuseppe5 Giuseppe5 Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the comment above. I think this piece of code will change a lot as we figure out what to do/how to do it, I want to have something reasonable to start and I believe your comment/my proposal above should accomplish just that

Copy link
Collaborator

@nickfraser nickfraser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@@ -20,6 +20,7 @@ def env_to_bool(name, default):
JIT_ENABLED = env_to_bool('BREVITAS_JIT', False) and _enabled
NATIVE_STE_BACKEND_ENABLED = env_to_bool('BREVITAS_NATIVE_STE_BACKEND', False)
VERBOSE = env_to_bool('BREVITAS_VERBOSE', False)
LOGGING_LEVEL = os.environ.get('BREVITAS_LOGGING', 'CRITICAL')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love it! <3

@Giuseppe5 Giuseppe5 requested a review from nickfraser February 12, 2025 15:34
@Giuseppe5 Giuseppe5 merged commit 15135bc into Xilinx:dev Feb 12, 2025
23 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants