Skip to content

CI: 06/05/25 upstream sync #454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1,880 commits into
base: rocm-main
Choose a base branch
from

Conversation

rocm-repo-management-api-2[bot]
Copy link

Daily sync with upstream

Google-ML-Automation and others added 30 commits May 22, 2025 13:05
Continuous and Nightly/Release workflows will now run the CUDA 12.8 test runs by using the Nvidia CUDA packages from PyPI instead of those on the system.

PiperOrigin-RevId: 762125356
Hit this while working with sharding in types -- passing a sharding that had an empty mesh. (I think this was in a test). This failed trying to acces with `with_spec` attribute on None -- so just catching this case early.

PiperOrigin-RevId: 762135310
…kernel.

Also prevents an out-of-bounds read of SMEM.  And re-enables tests for the TPU paged_attention_kernel.

@apaszke confirmed the presence of data races using the race detector in the new TPU interpret mode.  With the additional semaphores, the race detector no longer detects any races in the this kernel and I no longer see any test failures in 20+ test runs on a TPU.

Details on the data races:

 - In each iteration, the kernel:
   (a) Starts copying data for `k` and `v` for the next iteration.
   (b) Waits for the copy of `k` for the current iteration to finish.
   (c) Waits for the copy of `v` for the current iteration to finish.

 - It is possible for these copies to happen out of order -- that is:
   (a) The copies for the next iteration can finish before the copies
       for the current iteration.
   (b) And the copies for `v` for the current iteration can finish
       before the copies for `k` for the current iteration.

 - If the same DMA semaphore is used for everything, then out-of-order
   copies can lead to:
   (a) `k = async_copy_k.wait_and_get_loaded()` returns but the data
       isn't all available because the underlying semaphore was
       signaled by the completion of copies of `v` for the current
       iteration or copies of `k` or `v` for the next iteration.
   (a) `v = async_copy_v.wait_and_get_loaded()` returns but the data
       isn't all available because the underlying semaphore was
       signaled by the completion of copies of `k` or `v` for the
       next iteration.

PiperOrigin-RevId: 762136079
Without this handling, in explicit sharding mode vmap of a function with an internal shmap can introduce unnecessary replication.

PiperOrigin-RevId: 762175189
…`ad_util.Zero` which lead to errors like this: `TypeError: Argument 'Zero(float32[1,1,512])' of type '<class 'jax._src.ad_util.Zero'>' is not a valid JAX type`

PiperOrigin-RevId: 762264425
This issue was discovered when enabling the ragged dot example kernel to run using warpgroup semantics.

The new test requires `-UNDEBUG`.

PiperOrigin-RevId: 762365352
Apparently we never checked it and it's been quite easy to
get this wrong.

PiperOrigin-RevId: 762394139
…nline asm.

Before this change, in the int32 case, we pass two extra immediate args compared to the number of parameters in the ASM string. Running tests with `-UNDEBUG` detects the error.

I think this likely broke with the special-casing of `int32` in cl/761489756.

I've added an assert that should prevent mismatches in the future.

PiperOrigin-RevId: 762394530
`ValueError: Sharding rule has 1 operands, but the operation has 2 operands`

PiperOrigin-RevId: 762412744
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times.

PiperOrigin-RevId: 762414305
This is an attempt to re-land jax-ml#28650, fixing build failures.

**Motivation**

jax-ml#28650 was motivated by recent changes in setuptools, which caused test failures with editable installs, but it also exposes a potentially larger issue with our approach to testing with pytest. As far as I understand it, the default pytest behavior is to prepend the working directory to `sys.path`, meaning that `jax` is imported from the source directory, rather than the installed version. Switching to the `importlib` import mode means that we correctly test against the installed version of `jax`, which seems like what we typically want to do.

The catch is that then we need to explicitly package any test utilities into the distribution. We don't currently package test-specific utilities like `internal_test_util` with JAX, but these utilities were still available to tests since they live within the `jax` source tree. This breaks when using `importlib` import mode, and a non-editable install of JAX.

**Solutions**

The approach that I've taken here is to explicitly package everything needed by the tests into the `jax` distribution. This means that we can correctly test against the _installed_ version of JAX when using pytest.

This solution isn't ideal because it means that we're distributing `jax` submodules that aren't actually required except when running the test suite, but this seems like a small price to pay to me.

**Alternatives**

One different approach that we could take would be to only support using pytest with _editable_ installs of JAX. This would work because the required files would still be discoverable in an editable install because they live within the source tree. In fact, before this change, most of our CI jobs actually did install an editable distribution (which is why the failures in jax-ml#28650 weren't caught in pre-submit!). The problem with this approach is that we're not actually testing JAX as it is used when installed from a distribution, and it wouldn't catch things like missing submodules. I think it's much better to test against the installed distribution!

A more extreme approach would be to switch JAX to a `src/jax` and `src/jaxlib` layout (i.e. moving `jax` and `jaxlib` out of the root directory) as recommended by the Python packaging docs. Unfortunately this would be complicated with the way JAX is distributed internally at Google, so I think that's probably a non-starter.

PiperOrigin-RevId: 762419160
* We now re-export a restricted version of `debug_check` under `pl`.
  Unlike the original, the `pl` version only allows a static message, i.e.
  string interpolation is not supported.
* Only debug checks are supported, which means that by default no checking
  is done -- `debug_check` is lowered to a noop.
* The context manager enabling debug checks is called `enable_debug_checks`.
  I would very much like to drop the `enable_` prefix, but without it the
  context manager reads too similar to `debug_check`.

PiperOrigin-RevId: 762433258
We have users of CompileOnlyPyClient that use `backend.compile` as we eventually intend it (i.e., return `ExecutableRef`, possibly `PyExecutable` eventually, instead of `PyLoadedExectuable`).

PiperOrigin-RevId: 762440439
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times.

Unfortunately this is not a clean build refactor, because batching depends on jax.lax, which in turn depends on batching. However, the problematic functions are only called within contexts where jax.lax is available for import.

We have a few options here:

1. Continue to bundle the batching.py source with the main build.
2. Build separately, but do the local import workaround in this CL (a pattern we use elsewhere).
3. Build this separately, but move some batching definitions into jax.lax for a more strict dependency graph. Or pass the `lax` namespace explicitly to the function at the call site.

I opted for (2) here because I judged the benefits of a refactored build to be worth the cost of localized impure dependencies, and the kind of refactoring in (3) would affect some downstream users.

PiperOrigin-RevId: 762447323
Part of a larger refactor. Today, `compile` returns a loaded executable i.e., fuses the compile and load functions. Eventually, `compile` should return an unloaded executable and `load` should return a loaded exectuable; the default jit path will still return a loaded executable.

PiperOrigin-RevId: 762457830
danielsuo and others added 28 commits June 3, 2025 16:48
PiperOrigin-RevId: 766895212
…put's sharding is concrete i.e. does not contain an AbstractMesh

PiperOrigin-RevId: 766962130
… for GSPMD.

The final module that will be created by JAX export will contain a bit of Shardy and GSPMD ops. What we then do during compilation is detect whether there is a mix of these ops. If there is, we override the build option and instead use GSPMD for propagation (we have well tested code to export Shardy->GSPMD, but not vice versa).

PiperOrigin-RevId: 767064075
`TCGEN05_ROW` is to `TCGEN05` what `WGMMA_ROW` is to `WGMMA`.

PiperOrigin-RevId: 767068597
The lowering b/w Shardy and GSPMD is slightly different with the custom calls, so I needed to choose different test data based on whether or not Shardy was enabled.

PiperOrigin-RevId: 767074094
… without sharding rule.

PiperOrigin-RevId: 767131346
…:PyClient::CompileAndLoad`.

- Remove redundant `xla::PyClient` `compile` bindings.
- Remove host_callback arguments to `compile`.

PiperOrigin-RevId: 767135320
…mpile`.

Currently, we just forward any calls to `compiler.backend_compile_and_load`, which returns an `xla::PyLoadedExecutable` whereas we'd like `compiler.backend_compile` to return an unloaded `xla::PyExecutable`.

PiperOrigin-RevId: 767142396
PiperOrigin-RevId: 767149635
PiperOrigin-RevId: 767160894
…used_attention_stablehlo.py`.

PiperOrigin-RevId: 767166345
…ead of a `tuple`. This is because the order of axes in `unreduced` doesn't matter.

While lowering, `unreduced` is sorted wrt the mesh axis names so in McJAX all hosts lower to the same thing.

PiperOrigin-RevId: 767178835
…tively tests `lax.axis_index` and `lax.axis_size` on clusters axes.

PiperOrigin-RevId: 767185068
PiperOrigin-RevId: 767333363
PiperOrigin-RevId: 767341856
…vice transfers for TFRT TPU.

Reverts 8d8cc2b

PiperOrigin-RevId: 767348122
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot requested a review from a team as a code owner June 5, 2025 06:02
@rocm-repo-management-api-2 rocm-repo-management-api-2 bot enabled auto-merge (rebase) June 5, 2025 06:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.