Kvax: Fast Multi-Document Flash Attention with Context Parallelism #26813
Unanswered
southfreebird
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi everyone!
We've just open-sourced kvax, a custom Flash Attention implementation based on JAX. It was designed for efficient training with long sequences and offers several cool features:
We use this library in-house to train models on very long sequences, e.g., for our agentic research. In the document mask scenario, it outperforms the CuDNN implementation and FlexAttention.
The library is available under the Apache 2.0 license and can be easily integrated into an existing JAX codebase, esp. if you are using Flax.
GitHub: https://github.com/nebius/kvax
Blog post with benchmarks
Hope it will be useful to the community. Would appreciate any feedback!
Beta Was this translation helpful? Give feedback.
All reactions