Skip to content

Commit 2a50fab

Browse files
Jake VanderPlascopybara-github
authored andcommitted
Avoid use of deprecated xla_bridge.get_backend().live_buffers()
xla_bridge.get_backend is deprecated, and the public API for this is jax.live_arrays(). This is a drop-in replacement with no change of behavior. PiperOrigin-RevId: 723227225
1 parent 2f32eff commit 2a50fab

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

trax/optimizers/trainer.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,7 @@ def _collect_weights(self, layer):
440440

441441
def _free_accelerators(self, exceptions=(), keep_constants=True):
442442
"""Deletes all live buffers from accelerator with no safety guarantees."""
443-
backend = jax.lib.xla_bridge.get_backend()
444-
live_buffers = backend.live_buffers()
443+
live_buffers = jax.live_arrays()
445444
logging.info('Deleting %d live buffers.', len(live_buffers))
446445
exceptions_buffers = []
447446
for x in fastmath.tree_flatten(exceptions):

0 commit comments

Comments
 (0)