File tree Expand file tree Collapse file tree 3 files changed +8
-3
lines changed Expand file tree Collapse file tree 3 files changed +8
-3
lines changed Original file line number Diff line number Diff line change 2
2
import time
3
3
from jax .lib import xla_extension
4
4
5
- Device = xla_extension .Device
5
+ Device = jax . Device if hasattr ( jax , "Device" ) else xla_extension .Device
6
6
7
7
8
8
def jax_randn (shape , device , dtype ):
Original file line number Diff line number Diff line change 9
9
from torch import Tensor
10
10
import jax
11
11
from jax import ShapeDtypeStruct
12
- from jax .util import safe_zip
12
+ try :
13
+ from jax .util import safe_zip
14
+ except ImportError :
15
+ safe_zip = zip
16
+
17
+ jax .config .update ('jax_use_shardy_partitioner' , False ) # TODO: temporary workaround for JAX 0.7.0
13
18
14
19
# jax version-friendly way of importing the ffi module in jax
15
20
try :
Original file line number Diff line number Diff line change 12
12
except ImportError :
13
13
from jax .core import ConcretizationTypeError
14
14
15
- JAXDevice = jax .lib .xla_extension .Device
15
+ JAXDevice = jax .Device if hasattr ( jax , "Device" ) else jax . lib .xla_extension .Device
16
16
17
17
18
18
def _transfer (x : Array | Tensor , via : str = "dlpack" , device : str = "cuda" ):
You can’t perform that action at this time.
0 commit comments