Skip to content

[JAX] Collective GEMM custom op + primitive + minimal supporting functions #1846

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

denera
Copy link
Collaborator

@denera denera commented Jun 3, 2025

Description

This PR introduces a new XLA custom op for calling nvte_cublas_gemm or related comm+GEMM overlap algorithms, the accompanying JAX primitive, and bare minimum Python wrappers required to work with the custom call.

FWD/BWD autograd implementation will be tackled in a separate upcoming PR.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Alp Dener <[email protected]>

CollectiveGemm XLA FFI op and FWD+BWD rules done

Signed-off-by: Alp Dener <[email protected]>

added missing custom VJP def

Signed-off-by: Alp Dener <[email protected]>

added comm overlap initializer/destroy

Signed-off-by: Alp Dener <[email protected]>

docstrings

Signed-off-by: Alp Dener <[email protected]>

compile errors resolved

Signed-off-by: Alp Dener <[email protected]>

added example

Signed-off-by: Alp Dener <[email protected]>

fixed compile and import issues

Signed-off-by: Alp Dener <[email protected]>

retooled how casting and transposes are done in FWD/BWD

Signed-off-by: Alp Dener <[email protected]>

some API tweaks/cleanup

Signed-off-by: Alp Dener <[email protected]>

added missing TE DType to pybind helper, and restored missing comm overlap config keys

Signed-off-by: Alp Dener <[email protected]>

fixed incorrect variable assignment order in bootstrapping

Signed-off-by: Alp Dener <[email protected]>

XLA custom op working, cublas unsupported parameter error underneath

Signed-off-by: Alp Dener <[email protected]>

fixed cublas invalid parameter error

Signed-off-by: Alp Dener <[email protected]>

both AG and RS overlaps tested working in BF16

Signed-off-by: Alp Dener <[email protected]>

removed FWD/BWD impl to shrink PR down to the XLA custom op, JAX primitive, and minimum accompanying Python wrappers

Signed-off-by: Alp Dener <[email protected]>

fixed incorrect rowwise/columnwise placement of data and scales in XLA custom op

Signed-off-by: Alp Dener <[email protected]>

reverted some stuff

Signed-off-by: Alp Dener <[email protected]>
@denera denera force-pushed the jax/collective-gemm-api branch from 1a845e9 to e92c81a Compare June 4, 2025 17:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant