Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit b921904

Browse files
Jake VanderPlascopybara-github
Jake VanderPlas
authored andcommittedFeb 4, 2025·
Avoid use of deprecated xla_bridge.get_backend().live_buffers()
xla_bridge.get_backend is deprecated, and the public API for this is jax.live_arrays(). PiperOrigin-RevId: 723227225
1 parent 2f32eff commit b921904

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed
 

‎trax/optimizers/trainer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -440,8 +440,7 @@ def _collect_weights(self, layer):
440440

441441
def _free_accelerators(self, exceptions=(), keep_constants=True):
442442
"""Deletes all live buffers from accelerator with no safety guarantees."""
443-
backend = jax.lib.xla_bridge.get_backend()
444-
live_buffers = backend.live_buffers()
443+
live_buffers = jax.live_arrays()
445444
logging.info('Deleting %d live buffers.', len(live_buffers))
446445
exceptions_buffers = []
447446
for x in fastmath.tree_flatten(exceptions):

0 commit comments

Comments
 (0)
Please sign in to comment.