Skip to content

Commit d80e893

Browse files
committed
a quick patch for JAX 0.7.0 compatibility
1 parent a896e70 commit d80e893

File tree

3 files changed

+8
-3
lines changed

3 files changed

+8
-3
lines changed

tests/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
from jax.lib import xla_extension
44

5-
Device = xla_extension.Device
5+
Device = jax.Device if hasattr(jax, "Device") else xla_extension.Device
66

77

88
def jax_randn(shape, device, dtype):

torch2jax/api.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
from torch import Tensor
1010
import jax
1111
from jax import ShapeDtypeStruct
12-
from jax.util import safe_zip
12+
try:
13+
from jax.util import safe_zip
14+
except ImportError:
15+
safe_zip = zip
16+
17+
jax.config.update('jax_use_shardy_partitioner', False) # TODO: temporary workaround for JAX 0.7.0
1318

1419
# jax version-friendly way of importing the ffi module in jax
1520
try:

torch2jax/dlpack_passing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
except ImportError:
1313
from jax.core import ConcretizationTypeError
1414

15-
JAXDevice = jax.lib.xla_extension.Device
15+
JAXDevice = jax.Device if hasattr(jax, "Device") else jax.lib.xla_extension.Device
1616

1717

1818
def _transfer(x: Array | Tensor, via: str = "dlpack", device: str = "cuda"):

0 commit comments

Comments
 (0)