-
Notifications
You must be signed in to change notification settings - Fork 205
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat (hadamard): support region expansion #1178
Conversation
# Only "weight" is rotated | ||
tensor_names_axis = [("weight", _get_input_axis(module))] | ||
if region.expand_region: | ||
assert isinstance(module, nn.Linear), "Currently only Linear layers support expanded hadamard" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code just below can support any layer type, the issue is more on the activation side since we need to define also the dimension along which the input needs padding. It is not complicated to do tbh, but not necessary atm
model: nn.Module, | ||
regions: List[Region], | ||
prefix: str = '', | ||
blacklist_layers: Optional[List[str]] = None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exclusion list needs to be passed since we want to exclude all the layers that might be expanded
@@ -350,20 +351,26 @@ def quantize_llm(args, extra_args=None): | |||
model(**calibration_loader[0]) | |||
remove_hooks(model) | |||
|
|||
layers_to_expand = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The idea is that if the user passes q_proj
, this piece of code will catch all the layers layer.module...q_proj
# Only "weight" is rotated | ||
tensor_names_axis = [("weight", _get_input_axis(module))] | ||
if region.expand_region: | ||
assert isinstance(module, nn.Linear), "Currently only Linear layers support expanded hadamard" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Save convs for a new PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think so, yeah. I'll open an issue to keep track of that as good first issue
src/brevitas/graph/equalize.py
Outdated
if region.expand_region: | ||
assert isinstance(module, nn.Linear), "Currently only Linear layers support expanded hadamard" | ||
hidden_dim = module.weight.shape[1] | ||
new_hidden = find_closest_hadamard_number(hidden_dim, hidden_dim + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
here? Maybe the previous comment should live here?
src/brevitas/nn/equalized_layer.py
Outdated
def input_pad(self, inp): | ||
# TODO: This only works for Linear layers. We have an assert in equalize.py to check for this | ||
hidden_dim = inp.shape[-1] | ||
new_hidden = find_closest_hadamard_number(hidden_dim, hidden_dim + 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar here - the +1
hardcoding...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Only one note about verbosity flags. I guess this can also be achieved by setting the log level higher globally.
@@ -12,6 +13,25 @@ | |||
from brevitas.utils.quant_utils import * | |||
|
|||
|
|||
def setup_logger(name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
:O!
parameter_number_pre = 0 | ||
for m in graph_model.parameters(): | ||
parameter_number_pre += m.numel() | ||
logging.info(f"{len(expanded_regions)} layers will be expanded during rotation") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Personally I think all logging should be protected with a verbose
flag. Sometimes we will want to turn these things off (for some specific experiments)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will go the pytorch way and have an env variable to set the level.
Default will be "INFO" and user can select DEBUG with BREVITAS_LOGGING=DEBUG
src/brevitas/utils/logging.py
Outdated
def setup_logger(name): | ||
|
||
logger = logging.getLogger(name) | ||
logger.setLevel(logging.DEBUG) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe have the logging level as an optional parameter? You might want to turn off all 'INFO' messages at some point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See the comment above. I think this piece of code will change a lot as we figure out what to do/how to do it, I want to have something reasonable to start and I believe your comment/my proposal above should accomplish just that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@@ -20,6 +20,7 @@ def env_to_bool(name, default): | |||
JIT_ENABLED = env_to_bool('BREVITAS_JIT', False) and _enabled | |||
NATIVE_STE_BACKEND_ENABLED = env_to_bool('BREVITAS_NATIVE_STE_BACKEND', False) | |||
VERBOSE = env_to_bool('BREVITAS_VERBOSE', False) | |||
LOGGING_LEVEL = os.environ.get('BREVITAS_LOGGING', 'CRITICAL') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love it! <3
Reason for this PR
Changes Made in this PR
Testing Summary
Risk Highlight
Checklist
dev
branch.