Skip to content

[mosaic] Indexing into 3rd minor-most+ dimensionwith sub 32-bit dtypes #27532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
oliverdutton-iso opened this issue Mar 27, 2025 · 3 comments
Closed
Labels
enhancement New feature or request

Comments

@oliverdutton-iso
Copy link
Contributor

Could indexing into 3rd minor-most+ dimensions with sub 32-bit dtypes be supported?

[I know indexing on refs are supported, but I was hoping to use jax.vjp on something directly inside a mosaic kernel but this blocks me]

MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: Only 32-bit types supported

The MLIR operation involved:
  %69 = "vector.extract"(%65) <{static_position = array<i64: 0>}> : (vector<7x128x128xbf16>) -> vector<128x128xbf16>
@oliverdutton-iso oliverdutton-iso added the enhancement New feature or request label Mar 27, 2025
@justinjfu
Copy link
Collaborator

I've forwarded this request to the internal Mosaic team.

It seems like this operation should be easy to support since it just involves selecting the correct tiles out of the source array.

I'm not sure if you have control over whatever is generating this op, but as a workaround it's also possible to store to a scratch ref in VMEM and immediately load out the slice you want. The compiler should elide the memory transfer.

@oliverdutton-iso
Copy link
Contributor Author

Ooo, I didn't know a load/store from scratch would get elided. I thought any write to Refs in Mosaic TPU was treated as if it was a barrier before any ops defined later (even if they don't depend on the ref written to)
e.g. for

def kernel(a_ref, b_ref, c_ref):
   a_ref[...] *= 5
   b_ref[...]  = c_ref[...]

can't be rearranged of the c->b before the a*=5
and assumed same rules for scratch refs?

@apaszke
Copy link
Member

apaszke commented Apr 16, 2025

This should be fixed now!

@apaszke apaszke closed this as completed Apr 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants