Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Crash in eval_jaxpr with 0.4.27 #21116

Open
patrick-kidger opened this issue May 7, 2024 · 17 comments
Open

Crash in eval_jaxpr with 0.4.27 #21116

patrick-kidger opened this issue May 7, 2024 · 17 comments
Assignees
Labels
bug Something isn't working

Comments

@patrick-kidger
Copy link
Collaborator

Description

import jax
import jax.numpy as jnp

def run(some_tracer):
    def f(x, y):
        return x + y

    g = lambda x: f(x, some_tracer)
    jaxpr = jax.make_jaxpr(g)(1)
    jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)

jax.vmap(run)(jnp.arange(2))

produces:

Traceback (most recent call last):
  File ".../file.py", line 13, in <module>
    jax.vmap(run)(jnp.arange(2))
  File ".../file.py", line 11, in run
    jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)
ValueError: safe_map() argument 2 is shorter than argument 1

System info (python version, jaxlib version, accelerator, etc.)

JAX 0.4.27

@yashk2810
Copy link
Member

This is probably my bug, I'll look into it. I am pretty sure it's happening because the tracer is being converted into an argument instead of embedding it into consts in ClosedJaxpr since it is being closed over in f

@yashk2810
Copy link
Member

This is because before make_jaxpr embedded Tracer into consts which was wrong. I changed to unify the codepath of make_jaxpr and jit(f).specialize(*args).jaxpr. The behavior is correct but we should fix this bug. Thanks for the concise repro!

@yashk2810
Copy link
Member

Actually thinking more, tracers should be passed as argument so if you change g above and pass tracer as the argument, the error should be gone.

So maybe the fix should be in diffrax?

@patrick-kidger
Copy link
Collaborator Author

patrick-kidger commented May 8, 2024

How should that be done?

g is a user-provided function. I don't think there is any way to extract the closed-over tracer. Previously this was in jax.make_jaxpr(g).consts, but that is now an empty list.

@yashk2810
Copy link
Member

Now it's in jaxpr.invars. Would it be enough if I gave you a count as to how many closed over tracers are there and how many real args exist and then you can make things work in diffrax?

Note that we have this assert that disabled but we want to enable it:

# assert not any(isinstance(c, Tracer) for c in consts) # TODO(mattjj): enable

@patrick-kidger
Copy link
Collaborator Author

patrick-kidger commented May 8, 2024

Hmm, I think I'm only see the Var objects:

jaxpr.jaxpr.invars  # [Var(id=4790894848):int32[], Var(id=4790895168):int32[]]

but not any way of grabbing some_tracer out of jaxpr.


To approach this a different way: in <=0.4.26, it is the case that for all functions g, the following code will work:

jaxpr = jax.make_jaxpr(g)(*args)
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)

Which I think is quite a nice invariant, actually!
Now that this is no longer the case, what is the equivalent invariant for >=0.4.27? :)

@yashk2810
Copy link
Member

jaxpr.jaxpr.invars

The first Var is the closed over tracer. So if I can give you that count of how many tracers are there before the actual args start, would that be enough?

@patrick-kidger
Copy link
Collaborator Author

patrick-kidger commented May 8, 2024

I don't think so. At least right now, I don't see a way to grab the tracer out of the Var:

(Pdb) jaxpr.jaxpr.invars[0]
Var(id=4790894848):int32[]
(Pdb) dir(jaxpr.jaxpr.invars[0])
['__annotations__', '__class__', '__delattr__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__slots__', '__str__', '__subclasshook__', 'aval', 'count', 'suffix']
(Pdb) jaxpr.jaxpr.invars[0].aval                                                                                                                   
ShapedArray(int32[])
(Pdb) jaxpr.jaxpr.invars[0].count
14
(Pdb) jaxpr.jaxpr.invars[0].suffix
''
# no tracers here!

What am I missing?

@yashk2810
Copy link
Member

yashk2810 commented May 8, 2024

Sorry, I meant that the first Var is the tracer (which would have been in consts before). There is no indication of that being a tracer right now, unless I give you more information from JAX (which is what my question was) and if that would be enough for you to fix diffrax

Before:

In [1]: import jax
   ...: import jax.numpy as jnp
   ...:
   ...: def run(some_tracer):
   ...:     def f(x, y):
   ...:         return x + y
   ...:
   ...:     g = lambda x: f(x, some_tracer)
   ...:     jaxpr = jax.make_jaxpr(g)(1)
   ...:     print(jaxpr.in_avals, jaxpr.consts)
   ...:     jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)
   ...:
   ...: jax.vmap(run)(jnp.arange(2))

[ShapedArray(int32[], weak_type=True)] [Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
  val = Array([0, 1], dtype=int32)
  batch_dim = 0]

Now:

In [2]: import jax
   ...: import jax.numpy as jnp
   ...:
   ...: def run(some_tracer):
   ...:     def f(x, y):
   ...:         return x + y
   ...:
   ...:     g = lambda x: f(x, some_tracer)
   ...:     jaxpr = jax.make_jaxpr(g)(1)
   ...:     print(jaxpr.in_avals, jaxpr.consts)
   ...:     jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 1)
   ...:
   ...: jax.vmap(run)(jnp.arange(2))
[ShapedArray(int32[]), ShapedArray(int32[], weak_type=True)] []

@patrick-kidger
Copy link
Collaborator Author

patrick-kidger commented May 8, 2024

Right! But I don't need to know which Vars correspond to tracers, I need the actual Tracer object itself. I don't think this is available anywhere any more.

Sorry, I think I'm realising where the confusion is coming from: I'm not calling make_jaxpr to get just the jaxpr. My goal here is to take an arbitrary user-provided function, which may include closed-over tracers, and to perform closure conversion to get both a pure function (.jaxpr) and all the closed-over values (.consts). Both are needed for the subsequent eval_jaxpr call. This closure-conversion is what is no longer possible at all.

(Note that jax.closure_convert isn't suitable, at least right now -- it's branded as an AD-specific tool, and as such only closure-converts with respect to floating point arrays. In general we actually need more than that, e.g. vmap'd integer arrays. IIRC this was deliberately removed from jax.closure_convert a couple of years ago for some performance-related reason. I think Matt might recall?)

Taking a step back from what's needed to solve this one particular bug in the short term: in general, abstract evaluation can return four things: the jaxpr, the output pytree/avals, the effects, and the closed-over values. Right now, these pieces are incompletely mix-and-match'd across the public API. The state of affairs is:

                      | Jaxpr | Output Pytree/Avals | Effects | Closed-Over Values
----------------------|-------|---------------------|---------|--------------------
`jax.make_jaxpr`      |   x   |         x           |         |      (x)
`jax.eval_shape`      |       |         x           |         |
`jax.closure_convert` |   ~   |                     |         |       ~

Where an x means it gives you everything and a ~ means you only get some of it, and (x) indicates that this is the capability we lost with the 0.4.27 release.
Speculating: perhaps this is a state of affairs that could be simplified?

(Side note, credit to Gemini for kindly making this table for me ^^ )

@yashk2810
Copy link
Member

I need the actual Tracer object itself

Why do you need the actual tracer object? Embedding tracers into consts was a mistake to begin with. That should have never happened and this assert needs to be enabled:

# assert not any(isinstance(c, Tracer) for c in consts) # TODO(mattjj): enable
which would make it so.

Speculating: perhaps this is a state of affairs that could be simplified?

Eventually, we are going to merge make_jaxpr and eval_shape into jit(f).specialize(*args).jaxpr | .out_shapes. Currently both eval_shape and jaxpr are available on jitted functions.

@yashk2810
Copy link
Member

#21140 should roll it back. We are going to do another release tomorrow, so this should be fixed in 0.4.28

@hawkinsp
Copy link
Member

hawkinsp commented May 9, 2024

That said I think we're still trying to figure out the path forward here, this rollback is mostly to unbreak you.

@patrick-kidger
Copy link
Collaborator Author

Eventually, we are going to merge make_jaxpr and eval_shape into jit(f).specialize(*args).jaxpr | .out_shapes

Ah, interesting! I like the unifying of this. Is this something where we'll be able to grab the out_shapes without recording the jaxpr? (Which IIUC is the advantage of eval_shape over make_jaxpr(..., return_shape=True) today?)

Why do you need the actual tracer object? Embedding tracers into consts was a mistake to begin with.

I have no preference on whether tracers are placed in jaxprs at all. E.g. I would be equally happy to get them via something looking likejaxpr, closed_over_values = jax.make_jaxpr(g, closed_over_values=True).

The actual tracer object is required to perform closure conversion prior to crossing the boundaries of higher order primitives, custom AD etc.

That said I think we're still trying to figure out the path forward here, this rollback is mostly to unbreak you.

Thank you! I really appreciate this.

I think if you're looking for a concrete suggestion on a path forward from me:

  • short term: adjust jax.closure_convert to grab all constants, by having this unconditionally returning True:

    def _maybe_perturbed(x: Any) -> bool:

    IIUC the original rationale for this filtering was performance, so that custom_vjp wouldn't need to compute unnecessary cotangents (Hiding constants from closure_convert and the AD system #6415). However for that purpose we now have symbolic zeros available in custom AD, so this should no longer be a concern.

    Then I can rewrite my code to perform closure-conversion via this function instead.

  • long term: have the jit(f).specialize interface also provide effects and closed-over values. Then many other things (including jax.closure_convert!) could be implemented through this.

@yashk2810
Copy link
Member

What if we expose jex.trace_to_jaxpr and you can use that to get the behavior you want and we can roll forward with the new behavior where tracers won't be in consts?

cc @froystig @mattjj

@hawkinsp
Copy link
Member

hawkinsp commented May 9, 2024

We just released jax 0.4.28, which has the rollback.

@patrick-kidger
Copy link
Collaborator Author

What if we expose jex.trace_to_jaxpr and you can use that to get the behavior you want and we can roll forward with the new behavior where tracers won't be in consts?

Do you mean jax.interpreters.partial_eval.trace_to_jaxpr? This acts on a lot of JAX-internal types though. (WrappedFun, PartialVal, Value.)

If you like, I could write a quick PR adding a flag to jax.closure_convert that hoists integer-dtyped tracers as well?

  • This should be tiny: only about a ~5 LOC change.
  • I'm suggesting placing it behind a flag in case you're concerned about the change.
  • It matches the semantics of what I'm actually trying to do, namely closure conversion.

We just released jax 0.4.28, which has the rollback.

Thank you! I appreciate it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants