Skip to content

Relax assumptions for bcoo_dot_general #27397

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ryan112358 opened this issue Mar 25, 2025 · 4 comments
Open

Relax assumptions for bcoo_dot_general #27397

ryan112358 opened this issue Mar 25, 2025 · 4 comments
Labels
enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR

Comments

@ryan112358
Copy link

ryan112358 commented Mar 25, 2025

I'm trying to do some einsums with some sparse matrices and am hitting an error. Here is a minimal example:

def foo(ab, a):
  return jnp.einsum('ab,a->a', ab, a)

A = jax.random.normal(jax.random.key(0), shape=(2,))
AB = jax.random.normal(jax.random.key(1), shape=(2, 5))
AB = sparse.BCOO.fromdense(AB)

ans = jax.sparsify(foo)(AB, A)
NotImplementedError: bcoo_dot_general batch dimensions must be among the batch dimensions in the sparse representation.
got lhs_batch=(0,), n_batch=0

Would be nice if this type of operatoin could be natively supported. In this simple case, one could just write something like: ab.data *= a[ab.indices[:, 0]] then do ab.sum(axis=1), or slightly better do ab.sum(axis=1) * a

@ryan112358 ryan112358 added the enhancement New feature or request label Mar 25, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Mar 25, 2025

Hi - thanks for the request! We don't have plans to further develop jax.experimental.sparse at this point, but we would consider pull requests in this area if you're willing to work on this feature. Thanks!

@jakevdp jakevdp added the P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR label Mar 25, 2025
@Aniketsy
Copy link

Hi @jakevdp ,
I’m interested in working on this feature to support batch dimensions in bcoo_dot_general for einsum-like operations with sparse matrices. Before proceeding, I’d appreciate some guidance on whether this aligns with JAX’s future plans for sparse support and if there are any design constraints or preferred implementation approaches.

If this is a reasonable contribution, I’d be happy to submit a PR. Could you point me to relevant parts of the codebase or existing discussions that might help?

@ryan112358
Copy link
Author

@Aniketsy personally this would be a very welcome contribution. This seems like a good place to start:

https://github.com/jax-ml/jax/blob/main/jax/experimental/sparse/bcoo.py#L681

@Aniketsy
Copy link

@ryan112358 Thank you for suggesting me . I will work on this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request P3 (no schedule) We have no plan to work on this and, if it is unassigned, we would be happy to review a PR
Projects
None yet
Development

No branches or pull requests

3 participants