Skip to content

Commit 1451479

Browse files
Jake VanderPlascopybara-github
authored andcommitted
Replace deprecated jax.tree_* functions with jax.tree.*
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 634149223
1 parent a6a508e commit 1451479

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

trax/layers/research/efficient_attention_test.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,14 @@ def _test_equivalence_to_reference_code(
9494
test_all = self._run_forward_and_backward(test_model, inp, weights, state)
9595
test_out, test_state, test_inp_grad, test_weights_grad = test_all
9696

97-
self.assertEqual(jax.tree_structure(ref_out),
98-
jax.tree_structure(test_out))
99-
self.assertEqual(jax.tree_structure(ref_state),
100-
jax.tree_structure(test_state))
101-
self.assertEqual(jax.tree_structure(ref_inp_grad),
102-
jax.tree_structure(test_inp_grad))
103-
self.assertEqual(jax.tree_structure(ref_weights_grad),
104-
jax.tree_structure(test_weights_grad))
97+
self.assertEqual(jax.tree.structure(ref_out),
98+
jax.tree.structure(test_out))
99+
self.assertEqual(jax.tree.structure(ref_state),
100+
jax.tree.structure(test_state))
101+
self.assertEqual(jax.tree.structure(ref_inp_grad),
102+
jax.tree.structure(test_inp_grad))
103+
self.assertEqual(jax.tree.structure(ref_weights_grad),
104+
jax.tree.structure(test_weights_grad))
105105

106106
check_close = lambda x, y: self.assertAllClose(x, y, rtol=2e-3, atol=2e-3)
107107
fastmath.nested_map_multiarg(check_close, ref_out, test_out)
@@ -168,7 +168,7 @@ def get_slice_for_val(x):
168168
dtype=x.dtype)
169169
else:
170170
return x[:, i:i+1]
171-
return jax.tree_map(get_slice_for_val, pytree)
171+
return jax.tree.map(get_slice_for_val, pytree)
172172

173173
seqlen = x[0].shape[1] if isinstance(x, (tuple, list)) else x.shape[1]
174174

trax/models/research/bert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,9 @@ def reshape_bias(name):
308308
for a, b in zip(fastmath.tree_leaves(self.weights), new_w):
309309
assert a.shape == b.shape, (
310310
f'Expected shape {a.shape}, got shape {b.shape}')
311-
self.weights = jax.tree_unflatten(jax.tree_structure(self.weights), new_w)
311+
self.weights = jax.tree.unflatten(jax.tree.structure(self.weights), new_w)
312312
move_to_device = jax.jit(lambda x: x)
313-
self.weights = jax.tree_map(move_to_device, self.weights)
313+
self.weights = jax.tree.map(move_to_device, self.weights)
314314

315315
def _settable_attrs(self):
316316
"""We allow to set attributes required for loading the model from its checkpoints."""

0 commit comments

Comments
 (0)