-
Notifications
You must be signed in to change notification settings - Fork 2.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Compilation cache does not work with custom partitioning #21159
Comments
Yes. That's correct and how it works at the moment. The custom partitioning ultimately refers to a Python object, which is why it's not stable run to run. |
If this only cover the python callback, could we just remove it from the key? Or could we get the python ast of the callbacks and hash it? I see that as more work and not sure it is useful. |
Description
Compilation cache does not trigger for jitted functions with custom_partitioning ops. After running the JAX program multiple times, there is a separate entry with different hash for each run.
Here is the invocation:
Here are the contents of the cache afterwards:
The program with the custom-partitioned op:
I believe the cache misses because the
backend_config
parameter to theCustomSPMDPartitioning
custom call is the address of some descriptor data structure and that is different from run to run. As a result, the MLIR is different in different runs and the cache never triggers. Below is the HLO for the functionf
, thebackend_config
here is the address that differs across runs.System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: