Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Persistent compilation cache does not work #21067

Open
neel04 opened this issue May 3, 2024 · 2 comments
Open

Persistent compilation cache does not work #21067

neel04 opened this issue May 3, 2024 · 2 comments
Labels
bug Something isn't working

Comments

@neel04
Copy link

neel04 commented May 3, 2024

Description

The persistent compilation cache simply doesn't work - it used to work well with older versions of jax but it seems some breaking changes have occurred in the past weeks.

The problem is that the compilation_cache folder is never created, and I can confirm from the lack of speedup that jax is definitely not using the persistent cache.

Reproduce:

import jax
jax.config.update("jax_compilation_cache_dir", './jax-cache')

import os

@jax.jit
def some_op(A, B):
    return (A @ B) * (A + B)

A = jax.numpy.ones((128,))
B = jax.numpy.zeros((128,))

some_op(A, B)

print(os.listdir('./')) # no folder gets created

I don't think its an environment issue - I have a docker image as well that can reproduce it.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.26
jaxlib: 0.4.25
numpy: 1.26.2
python: 3.10.12 (main, Jul 5 2023, 15:02:25) [Clang 14.0.6 ]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='Neels-MacBook-Air.local', release='23.2.0', version='Darwin Kernel Version 23.2.0: Wed Nov 15 21:59:33 PST 2023; root:xnu-10002.61.3~2/RELEASE_ARM64_T8112', machine='arm64')

@neel04 neel04 added the bug Something isn't working label May 3, 2024
@rajasekharporeddy
Copy link
Contributor

Hi @neel04

I tested the provided reproducible code on my Mac-book Pro with M1 Pro chip using jax versions 0.4.26 and 0.4.27.dev20240503 and the corresponding jaxlib versions 0.4.26 and 0.4.27.dev20240503, respectively. In both the cases, a folder named 'jax-cache' was created. Please find the below screenshot for reference.

image

jax.print_environment_info():
jax: 0.4.27.dev20240503
jaxlib: 0.4.27.dev20240503
numpy: 1.26.4
python: 3.11.6 (v3.11.6:8b6ee5ba3b, Oct 2 2023, 11:18:21) [Clang 13.0.0 (clang-1300.0.29.30)]
jax.devices (1 total, 1 local): [CpuDevice(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='rajasekharp-macbookpro.roam.internal', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:10:42 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6000', machine='arm64')

Could you please verify with jaxlib version 0.4.26 along with jax 0.4.26 or with JAX nightly version and let us know.

Thank you.

@neel04
Copy link
Author

neel04 commented May 4, 2024

Yep, upgrading jaxlib from 0.4.25 -> 0.4.26 works locally. So this is definitely a version issue.

However, on my docker image on a TPU v3-8 I have libs:

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.11.9 (main, Apr 24 2024, 11:58:32) [GCC 12.2.0]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1) ... TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
process_count: 1
platform: uname_result(system='Linux', node='f996f75a635a', release='5.13.0-1027-gcp', version='#32~20.04.1-Ubuntu SMP Thu May 26 10:53:08 UTC 2022', machine='x86_64')

and it doesn't work. I can rebuild it - this image hasn't been updated in a few weeks, but I'm not really sure where exactly the problem lies.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants