Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dynamic config scope under jit doesn't change partitionable threefry behavior #21061

Closed
froystig opened this issue May 3, 2024 · 0 comments · Fixed by #21410
Closed

dynamic config scope under jit doesn't change partitionable threefry behavior #21061

froystig opened this issue May 3, 2024 · 0 comments · Fixed by #21410
Assignees
Labels
bug Something isn't working

Comments

@froystig
Copy link
Member

froystig commented May 3, 2024

Description

import jax

def f(x):
  return x + jax.random.randint(jax.random.key(72), (), 0, 10)

def g(x):
  with jax.threefry_partitionable(True):  # False by default
    return x + jax.random.randint(jax.random.key(72), (), 0, 10)

h = jax.jit(g)
print('f', f(1))
print('g', g(1))
print('h', h(1))

prints:

f 4
g 8
h 4

System info (python version, jaxlib version, accelerator, etc.)

0.4.27.dev with 0.4.26 jaxlib.

@froystig froystig added the bug Something isn't working label May 3, 2024
@froystig froystig self-assigned this May 3, 2024
copybara-service bot pushed a commit that referenced this issue May 24, 2024
… setting it inside jit.

Before this information was lost in the roundtrip via `mlir.lower_fun` -> `jaxpr_subcomp`. But now since it's on the jaxpr equations, the information is preserved in jaxpr_subcomp as we enter into each eqn's ctx.

Fixes: #21061
PiperOrigin-RevId: 636380603
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant