Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compilation cache does not work with custom partitioning #21159

Open
jaro-sevcik opened this issue May 10, 2024 · 2 comments
Open

Compilation cache does not work with custom partitioning #21159

jaro-sevcik opened this issue May 10, 2024 · 2 comments
Labels
enhancement New feature or request

Comments

@jaro-sevcik
Copy link
Contributor

Description

Compilation cache does not trigger for jitted functions with custom_partitioning ops. After running the JAX program multiple times, there is a separate entry with different hash for each run.

Here is the invocation:

mkdir -p /jaxcache
rm /jaxcache/*
for i in {0..10}; do python3 custom_part_cache.py; done
ls -1 /jaxcache

Here are the contents of the cache afterwards:

jit_iota-6af29e74be0b748c9cbec33fe00007655cc219c4cb4a9c2e7421faeca18eb5bc
pjit_f-1e8207d3cc8af860ea1e72a893f277a3b07aac94e008608b6523f5056d21c2ef
pjit_f-5163b782b60ae09e6082e7fe93733122289bc2fbd24252e05067e5eac4271d6e
pjit_f-550c04c40fce4cf52ab92af870ca9356001d9b0ac74b698e79e0bd7ebed09e73
pjit_f-6b0a81f2b7ba72aee5fe80bb2d99f2841ec6c1d6b762663b7e1ad8bbf40e2d86
pjit_f-77361fcb361bc13000d8492b336406b0d90d29d2aacc7afa0b83f32cfbae37c7
pjit_f-7b4beb49745e864c056f08aa58e52c2252a6cc54b68d6075dead11b870d59a03
pjit_f-aa9cf91afb1e36a26f07da4579353e96f1660faef89c69bab7c8b5aa88d73601
pjit_f-cf1bdb3e8b7a3a0d9cc4f018076e700469b2a7507bc00cd68ab3bb033088bb7f
pjit_f-d3b00a0a01cf56d8dee67f426e7bce38343737f514027a1263773ceca4c4daa0
pjit_f-ecfe91dfe48573323fd140b60ed16a907f2f52bf9ec163fd855eaa86a9d8ebb3

The program with the custom-partitioned op:

# custom_part_cache.py
import jax
import jax.numpy as jnp
from jax.sharding import NamedSharding
from jax.experimental.custom_partitioning import custom_partitioning
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh

jax.experimental.compilation_cache.compilation_cache.set_cache_dir("/jaxcache")

def compute_shard(x):
    return x + 1.0

def infer_sharding_from_operands(mesh, arg_shapes, result_shape):
    return ((NamedSharding(mesh, P('x')),), NamedSharding(mesh, P('x')))

@custom_partitioning
def custom_op(x):
    return x # Only use for putput shape

def partition(mesh, arg_shapes, result_shape):
    return mesh, compute_shard, NamedSharding(mesh, P('x')), (NamedSharding(mesh, P('x')),)

custom_op.def_partition(
    infer_sharding_from_operands=infer_sharding_from_operands,
    partition=partition)

def f(x):
  return custom_op(x)

with Mesh(jax.devices(), ('x',)):
  x = jnp.arange(8, dtype=jnp.float32)
  print(pjit(f, in_shardings=P('x'), out_shardings=P('x'))(x))

I believe the cache misses because the backend_config parameter to the CustomSPMDPartitioning custom call is the address of some descriptor data structure and that is different from run to run. As a result, the MLIR is different in different runs and the cache never triggers. Below is the HLO for the function f, the backend_config here is the address that differs across runs.

HloModule pjit_f, entry_computation_layout={(f32[8]{0})->f32[8]{0}}, num_partitions=8

ENTRY main.5 {
  Arg_0.1 = f32[8]{0} parameter(0), sharding={devices=[8]<=[8]}
  custom-call.2 = f32[8]{0} custom-call(Arg_0.1), custom_call_target="CustomSPMDPartitioning", api_version=..., metadata=..., backend_config="139653345189008"
  tuple.3 = (f32[8]{0}) tuple(custom-call.2)
  ROOT get-tuple-element.4 = f32[8]{0} get-tuple-element(tuple.3), index=0, sharding={devices=[8]<=[8]}
}

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.28.dev20240509+1e88e2f86
jaxlib: 0.4.28.dev20240509
numpy:  1.26.4
python: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0]
jax.devices (8 total, 8 local): [cuda(id=0) cuda(id=1) ... cuda(id=6) cuda(id=7)]
process_count: 1
platform: uname_result(system='Linux', node='...', release='4.15.0-101-generic', version='#102-Ubuntu SMP Mon May 11 10:07:26 UTC 2020', machine='x86_64')


$ nvidia-smi
Fri May 10 08:00:22 2024       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla V100-SXM2...  Off  | 00000000:06:00.0 Off |                    0 |
| N/A   29C    P0    59W / 300W |    311MiB / 32768MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  Off  | 00000000:07:00.0 Off |                    0 |
| N/A   31C    P0    58W / 300W |    311MiB / 32768MiB |      2%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   2  Tesla V100-SXM2...  Off  | 00000000:0A:00.0 Off |                    0 |
| N/A   32C    P0    58W / 300W |    311MiB / 32768MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   3  Tesla V100-SXM2...  Off  | 00000000:0B:00.0 Off |                    0 |
| N/A   28C    P0    59W / 300W |    311MiB / 32768MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   4  Tesla V100-SXM2...  Off  | 00000000:85:00.0 Off |                    0 |
| N/A   31C    P0    59W / 300W |    311MiB / 32768MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   5  Tesla V100-SXM2...  Off  | 00000000:86:00.0 Off |                    0 |
| N/A   32C    P0    59W / 300W |    311MiB / 32768MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   6  Tesla V100-SXM2...  Off  | 00000000:89:00.0 Off |                    0 |
| N/A   33C    P0    59W / 300W |    311MiB / 32768MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   7  Tesla V100-SXM2...  Off  | 00000000:8A:00.0 Off |                    0 |
| N/A   31C    P0    60W / 300W |    311MiB / 32768MiB |      1%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+
@jaro-sevcik jaro-sevcik added the bug Something isn't working label May 10, 2024
@hawkinsp hawkinsp added enhancement New feature or request and removed bug Something isn't working labels May 10, 2024
@hawkinsp
Copy link
Member

Yes. That's correct and how it works at the moment. The custom partitioning ultimately refers to a Python object, which is why it's not stable run to run.

@nouiz
Copy link
Collaborator

nouiz commented May 10, 2024

If this only cover the python callback, could we just remove it from the key?
Not everything is 100% versioned right now (like XLA isn't versioned).
So it would just end up to the end user responsibility to make sure handle the cache correctly?

Or could we get the python ast of the callbacks and hash it? I see that as more work and not sure it is useful.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants