forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 4
CI: 06/05/25 upstream sync #454
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rocm-repo-management-api-2
wants to merge
1,880
commits into
rocm-main
Choose a base branch
from
ci-upstream-sync-209_1
base: rocm-main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
PiperOrigin-RevId: 762092464
PiperOrigin-RevId: 762099728
PiperOrigin-RevId: 762116819
Continuous and Nightly/Release workflows will now run the CUDA 12.8 test runs by using the Nvidia CUDA packages from PyPI instead of those on the system. PiperOrigin-RevId: 762125356
…built wheels. NCCC wheel is used in `[with-cuda]` requirements as stated [here](https://github.com/jax-ml/jax/blob/0dc70b93f2e13fae5b097837760bd621e746dae7/jax_plugins/cuda/plugin_setup.py#L58). PiperOrigin-RevId: 762126511
Hit this while working with sharding in types -- passing a sharding that had an empty mesh. (I think this was in a test). This failed trying to acces with `with_spec` attribute on None -- so just catching this case early. PiperOrigin-RevId: 762135310
…kernel. Also prevents an out-of-bounds read of SMEM. And re-enables tests for the TPU paged_attention_kernel. @apaszke confirmed the presence of data races using the race detector in the new TPU interpret mode. With the additional semaphores, the race detector no longer detects any races in the this kernel and I no longer see any test failures in 20+ test runs on a TPU. Details on the data races: - In each iteration, the kernel: (a) Starts copying data for `k` and `v` for the next iteration. (b) Waits for the copy of `k` for the current iteration to finish. (c) Waits for the copy of `v` for the current iteration to finish. - It is possible for these copies to happen out of order -- that is: (a) The copies for the next iteration can finish before the copies for the current iteration. (b) And the copies for `v` for the current iteration can finish before the copies for `k` for the current iteration. - If the same DMA semaphore is used for everything, then out-of-order copies can lead to: (a) `k = async_copy_k.wait_and_get_loaded()` returns but the data isn't all available because the underlying semaphore was signaled by the completion of copies of `v` for the current iteration or copies of `k` or `v` for the next iteration. (a) `v = async_copy_v.wait_and_get_loaded()` returns but the data isn't all available because the underlying semaphore was signaled by the completion of copies of `k` or `v` for the next iteration. PiperOrigin-RevId: 762136079
Without this handling, in explicit sharding mode vmap of a function with an internal shmap can introduce unnecessary replication. PiperOrigin-RevId: 762175189
PiperOrigin-RevId: 762213931
PiperOrigin-RevId: 762220186
…`ad_util.Zero` which lead to errors like this: `TypeError: Argument 'Zero(float32[1,1,512])' of type '<class 'jax._src.ad_util.Zero'>' is not a valid JAX type` PiperOrigin-RevId: 762264425
This issue was discovered when enabling the ragged dot example kernel to run using warpgroup semantics. The new test requires `-UNDEBUG`. PiperOrigin-RevId: 762365352
PiperOrigin-RevId: 762370447
http://github.com/openxla/xla/commit/c361fc2992e8d674636e7870992e95658b1be792. PiperOrigin-RevId: 762376295
Apparently we never checked it and it's been quite easy to get this wrong. PiperOrigin-RevId: 762394139
…nline asm. Before this change, in the int32 case, we pass two extra immediate args compared to the number of parameters in the ASM string. Running tests with `-UNDEBUG` detects the error. I think this likely broke with the special-casing of `int32` in cl/761489756. I've added an assert that should prevent mismatches in the future. PiperOrigin-RevId: 762394530
`ValueError: Sharding rule has 1 operands, but the operation has 2 operands` PiperOrigin-RevId: 762412744
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. PiperOrigin-RevId: 762414305
This is an attempt to re-land jax-ml#28650, fixing build failures. **Motivation** jax-ml#28650 was motivated by recent changes in setuptools, which caused test failures with editable installs, but it also exposes a potentially larger issue with our approach to testing with pytest. As far as I understand it, the default pytest behavior is to prepend the working directory to `sys.path`, meaning that `jax` is imported from the source directory, rather than the installed version. Switching to the `importlib` import mode means that we correctly test against the installed version of `jax`, which seems like what we typically want to do. The catch is that then we need to explicitly package any test utilities into the distribution. We don't currently package test-specific utilities like `internal_test_util` with JAX, but these utilities were still available to tests since they live within the `jax` source tree. This breaks when using `importlib` import mode, and a non-editable install of JAX. **Solutions** The approach that I've taken here is to explicitly package everything needed by the tests into the `jax` distribution. This means that we can correctly test against the _installed_ version of JAX when using pytest. This solution isn't ideal because it means that we're distributing `jax` submodules that aren't actually required except when running the test suite, but this seems like a small price to pay to me. **Alternatives** One different approach that we could take would be to only support using pytest with _editable_ installs of JAX. This would work because the required files would still be discoverable in an editable install because they live within the source tree. In fact, before this change, most of our CI jobs actually did install an editable distribution (which is why the failures in jax-ml#28650 weren't caught in pre-submit!). The problem with this approach is that we're not actually testing JAX as it is used when installed from a distribution, and it wouldn't catch things like missing submodules. I think it's much better to test against the installed distribution! A more extreme approach would be to switch JAX to a `src/jax` and `src/jaxlib` layout (i.e. moving `jax` and `jaxlib` out of the root directory) as recommended by the Python packaging docs. Unfortunately this would be complicated with the way JAX is distributed internally at Google, so I think that's probably a non-starter. PiperOrigin-RevId: 762419160
* We now re-export a restricted version of `debug_check` under `pl`. Unlike the original, the `pl` version only allows a static message, i.e. string interpolation is not supported. * Only debug checks are supported, which means that by default no checking is done -- `debug_check` is lowered to a noop. * The context manager enabling debug checks is called `enable_debug_checks`. I would very much like to drop the `enable_` prefix, but without it the context manager reads too similar to `debug_check`. PiperOrigin-RevId: 762433258
We have users of CompileOnlyPyClient that use `backend.compile` as we eventually intend it (i.e., return `ExecutableRef`, possibly `PyExecutable` eventually, instead of `PyLoadedExectuable`). PiperOrigin-RevId: 762440439
PiperOrigin-RevId: 762441171
Creating smaller build rules enforces better organized dependency graphs in the JAX project, helps pytype propagate annotations correctly, and leads to improved build and iteration times. Unfortunately this is not a clean build refactor, because batching depends on jax.lax, which in turn depends on batching. However, the problematic functions are only called within contexts where jax.lax is available for import. We have a few options here: 1. Continue to bundle the batching.py source with the main build. 2. Build separately, but do the local import workaround in this CL (a pattern we use elsewhere). 3. Build this separately, but move some batching definitions into jax.lax for a more strict dependency graph. Or pass the `lax` namespace explicitly to the function at the call site. I opted for (2) here because I judged the benefits of a refactored build to be worth the cost of localized impure dependencies, and the kind of refactoring in (3) would affect some downstream users. PiperOrigin-RevId: 762447323
PiperOrigin-RevId: 762456274
Part of a larger refactor. Today, `compile` returns a loaded executable i.e., fuses the compile and load functions. Eventually, `compile` should return an unloaded executable and `load` should return a loaded exectuable; the default jit path will still return a loaded executable. PiperOrigin-RevId: 762457830
PiperOrigin-RevId: 762475094
PiperOrigin-RevId: 762490961
PiperOrigin-RevId: 766874092
PiperOrigin-RevId: 766895212
PiperOrigin-RevId: 766941330
…put's sharding is concrete i.e. does not contain an AbstractMesh PiperOrigin-RevId: 766962130
PiperOrigin-RevId: 766962750
http://github.com/openxla/xla/commit/a3cb8a0de31a1984a56802981ed3987f63879ce5. PiperOrigin-RevId: 767027087
… for GSPMD. The final module that will be created by JAX export will contain a bit of Shardy and GSPMD ops. What we then do during compilation is detect whether there is a mix of these ops. If there is, we override the build option and instead use GSPMD for propagation (we have well tested code to export Shardy->GSPMD, but not vice versa). PiperOrigin-RevId: 767064075
`TCGEN05_ROW` is to `TCGEN05` what `WGMMA_ROW` is to `WGMMA`. PiperOrigin-RevId: 767068597
The lowering b/w Shardy and GSPMD is slightly different with the custom calls, so I needed to choose different test data based on whether or not Shardy was enabled. PiperOrigin-RevId: 767074094
…antics. PiperOrigin-RevId: 767125470
… without sharding rule. PiperOrigin-RevId: 767131346
…:PyClient::CompileAndLoad`. - Remove redundant `xla::PyClient` `compile` bindings. - Remove host_callback arguments to `compile`. PiperOrigin-RevId: 767135320
…mpile`. Currently, we just forward any calls to `compiler.backend_compile_and_load`, which returns an `xla::PyLoadedExecutable` whereas we'd like `compiler.backend_compile` to return an unloaded `xla::PyExecutable`. PiperOrigin-RevId: 767142396
PiperOrigin-RevId: 767160894
…used_attention_stablehlo.py`. PiperOrigin-RevId: 767166345
…ead of a `tuple`. This is because the order of axes in `unreduced` doesn't matter. While lowering, `unreduced` is sorted wrt the mesh axis names so in McJAX all hosts lower to the same thing. PiperOrigin-RevId: 767178835
…tively tests `lax.axis_index` and `lax.axis_size` on clusters axes. PiperOrigin-RevId: 767185068
PiperOrigin-RevId: 767199650
PiperOrigin-RevId: 767277572
PiperOrigin-RevId: 767333363
PiperOrigin-RevId: 767341856
…vice transfers for TFRT TPU. Reverts 8d8cc2b PiperOrigin-RevId: 767348122
PiperOrigin-RevId: 767365438
Co-authored-by: Yash Katariya <[email protected]>
PiperOrigin-RevId: 767380947
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Daily sync with upstream