Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,11 @@ def is_jax_array(x: object) -> TypeIs[jax.Array]:
is_pydata_sparse_array
"""
cls = cast(Hashable, type(x))
return _issubclass_fast(cls, "jax", "Array") or _is_jax_zero_gradient_array(x)
return (
_issubclass_fast(cls, "jax", "Array")
or _issubclass_fast(cls, "jax.core", "Tracer")
Copy link
Contributor

@jakevdp jakevdp Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main reason for the change in v0.8.2 is that tracers now can represent more than just arrays, and so returning True for any tracer may lead to false positives.

The logic in Array.__instancecheck__ is what is required to accurately check in all contexts whether x is an array: https://github.com/jax-ml/jax/blob/82ae1b1cde42a5b93e00d8c3376cde627c2d83bb/jaxlib/py_array.cc#L2187-L2218

The easiest way to accomplish this would be to check isinstance(x, jax.Array) rather than recreating that logic here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That will force us to use a non-cachable operation, which is going to slow things down. But I don't think we have a choice given that the Tracer type itself no longer holds information on whether or not it's an Array.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jakevdp Can you elaborate a bit more on which kinds of non-array objects now create tracers? I.e. we use an _is_writable_cls and _is_lazy_cls. Even if tracers are not arrays, these functions could still be decidable based on the type only. Are tracers still always lazy and always immutable? I realize that these questions might be ill-defined since tracers do not represent real objects and can disappear from the final computation graph, but for our purposes that's not an issue.

Also, could you show an example of a tracer that does not wrap an array? E.g. are bools in the input now traced as bools and not as arrays? This would be very helpful for testing.

@crusaderky Current helper methods such as _is_writable_cls are designed to return None for non-array API objects. It seems we cannot make that decision based off of type information only on jax>=0.8.2. Are you fine with relaxing the None strategy and returning True for Tracers in general, or do you want to be strict here? The former still fits into our current setup, the latter must use non-cachable isinstance checks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, could you show an example of a tracer that does not wrap an array?

An example is the new hijax Box type. There are no public APIs for this (yet), but here's how you can construct it using currently-private APIs at head:

import jax
from jax._src import hijax

box = hijax.new_box()
hijax.box_set(box, (jnp.arange(4), jnp.ones((3, 3)), 2.0, None))

@jax.jit
def f(box):
  print(type(box))  # <class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
  print(box.aval)  # BoxTy()
  print(hijax.box_get(box))  # (JitTracer(int32[4]), JitTracer(float32[3,3]), JitTracer(~float32[]), None)
  # print(box.dtype)  # fails with AttributeError
  # print(box.shape)  # fails with AttributeError

f(box)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current design is that Tracer subclass reflects the type of transformation being traced (e.g. jit, vmap, grad, jaxpr, etc.) while the aval attribute can be inspected to see what kind of object is being traced.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's very helpful. At this point I think we need a decision by the array-api-compat team. Both versions shouldn't be hard to implement.

@crusaderky @lucascolley what are your thoughts?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the false positives seem like a concern from SciPy's side. Maybe we go with this, but add a note into the code comments about the false positives in case anyone complains in the future?

or _is_jax_zero_gradient_array(x)
)
Comment on lines +244 to +248
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The check for jax.core.Tracer has been added to is_jax_array, but several other helper functions in this file may also need similar updates for consistency and correctness. Specifically:

  1. _is_array_api_cls (line 302) - checks for jax.Array but not Tracer
  2. _cls_to_namespace (line 550) - checks for jax.Array but not Tracer
  3. _is_writeable_cls (line 940) - checks for jax.Array but not Tracer (JAX tracers should also be non-writeable)
  4. _is_lazy_cls (line 979) - checks for jax.Array but not Tracer (JAX tracers should also be lazy)

If is_jax_array now returns True for Tracers, these other functions should be updated to handle Tracers consistently. Otherwise, a jitted JAX array might pass is_jax_array but fail in array_namespace or behave incorrectly with is_writeable_array and is_lazy_array.

Copilot uses AI. Check for mistakes.


def is_pydata_sparse_array(x: object) -> TypeIs[sparse.SparseArray]:
Expand Down
9 changes: 9 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def test_is_xp_array(library, func):
assert is_array_api_obj(x)


def test_is_jax_array_jitted():
jax = pytest.importorskip("jax")
import jax.numpy as jnp

x = jnp.asarray([1, 2, 3])
assert is_jax_array(x)
assert jax.jit(lambda y: is_jax_array(y))(x)
Copy link

Copilot AI Dec 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This 'lambda' is just a simple wrapper around a callable object. Use that object directly.

Suggested change
assert jax.jit(lambda y: is_jax_array(y))(x)
assert jax.jit(is_jax_array)(x)

Copilot uses AI. Check for mistakes.


@pytest.mark.parametrize('library', is_namespace_functions.keys())
@pytest.mark.parametrize('func', is_namespace_functions.values())
def test_is_xp_namespace(library, func):
Expand Down
Loading