-
Notifications
You must be signed in to change notification settings - Fork 287
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ssm_enhancement #689
base: main
Are you sure you want to change the base?
ssm_enhancement #689
Conversation
These enhancements provide additional flexibility and options for implementing and experimenting with different recurrence methods in the Mamba and Jamba models, potentially improving performance and accuracy for various tasks.
Hey @markblee , Could you please take a look at my PR when you get a chance? Thanks! |
Hey @swiseman , Could you please take a look at my PR when you get a chance? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
|
||
|
||
|
||
class HybridMambaRecurrence(BaseMambaRecurrence): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for these new classes. Do people use either the hybrid recurrences or alternative recurrences defined below? Is there evidence that they are useful empirically? If not, I think it would be simpler to leave these classes out for now, and if necessary let people define them in downstream experiment files which import axlearn.common.ssm
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @swiseman , Thank you for your valuable input. I've reviewed the hybrid recurrences and alternative recurrences, and it seems that they haven't been used extensively in practice. Based on your benchmarking results, it appears that the AssociativeScanMambaRecurrence is more efficient than the HybridMambaRecurrence.
Given the lack of empirical evidence and the performance advantage of the AssociativeScanMambaRecurrence, I agree that it's reasonable to remove the HybridMambaRecurrence and other less-used recurrences from the core axlearn.common.ssm module for now.
This will simplify the codebase and make it easier for users to understand and use. If there's a strong need for these recurrences in the future, they can be defined in downstream experiment files as you suggested.
-Vishesh
- fixed functions redundant definitions - fixed Incorrect Module Import in layers.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, shall we close the PR or turn it into draft? Thanks.
Hey @ruomingp we can close this PR, I am working on these hybrid structures and it will take a while... Thanks for reading; |
Pull Request: Enhancements to Mamba and Jamba State-space Models
Summary
This pull request introduces several enhancements to the Mamba and Jamba state-space models (SSMs) implementation, including new recurrence methods, hybrid approaches, and comprehensive testing.
Changes
1. New Recurrence Methods
2. Enhancements to
ssm.py
HybridMambaRecurrence
andAlternativeMambaRecurrence
classes.MambaMixerLayer
andJambaMixerLayer
to integrate the new recurrence methods.3. Comprehensive Testing in
ssm_test.py
HybridMambaRecurrence
andAlternativeMambaRecurrence
inMambaMixerLayerTest
.StackedMambaTest
.StackedMixedSSMTransformerTest
.(Documentation and Examples : Updated docstrings and comments to reflect the new features and changes.)
Testing
All new features have been thoroughly tested with the following configurations:
jnp.float32
,jnp.bfloat16
).MambaBlock
,JambaMambaBlock
, andStackedMixedSSMTransformerLayer
.Conclusion
These enhancements provide additional flexibility and options for implementing and experimenting with different recurrence methods in the Mamba and Jamba models, potentially improving performance and accuracy for various tasks.