Skip to content

[Mosaic] Add none memory space. Add a test for copying a (1,) shaped array to SMEM. #28306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ def TPU_MemorySpace : I32EnumAttr<"MemorySpace", "Memory space", [
I32EnumAttrCase<"smem", 1, "smem">,
I32EnumAttrCase<"kHbm", 2, "hbm">,
I32EnumAttrCase<"kCmem", 3, "cmem">,
I32EnumAttrCase<"kSemaphoreMem", 4, "semaphore_mem">
I32EnumAttrCase<"kSemaphoreMem", 4, "semaphore_mem">,
I32EnumAttrCase<"kNone", 6, "none">
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tpu";
Expand Down
22 changes: 22 additions & 0 deletions tests/pallas/tpu_pallas_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,28 @@ def kernel(index, x, y, sem):
np.testing.assert_array_equal(y, i)
del y

def test_copy_to_smem(self):

def kernel(src, dst, sem):
pltpu.async_copy(src, dst, sem).wait()

def run(src):
return pl.pallas_call(
kernel,
out_shape=jax.ShapeDtypeStruct(src.shape, jnp.float32),
in_specs=[
pl.BlockSpec(memory_space=pltpu.ANY),
],
scratch_shapes=[
pltpu.SemaphoreType.DMA,
],
out_specs=pl.BlockSpec(memory_space=pltpu.SMEM),
)(src)

src = jnp.full((1,), 3.1415, dtype=jnp.float32)
expected = jnp.full((1,), 3.1415, dtype=jnp.float32)
np.testing.assert_array_equal(run(src), expected)

def test_dynamic_dma_on_2nd_minor(self):
def kernel(array, data, index, size, _, sem):
pltpu.async_copy(
Expand Down
Loading