We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 2f32eff commit b921904Copy full SHA for b921904
trax/optimizers/trainer.py
@@ -440,8 +440,7 @@ def _collect_weights(self, layer):
440
441
def _free_accelerators(self, exceptions=(), keep_constants=True):
442
"""Deletes all live buffers from accelerator with no safety guarantees."""
443
- backend = jax.lib.xla_bridge.get_backend()
444
- live_buffers = backend.live_buffers()
+ live_buffers = jax.live_arrays()
445
logging.info('Deleting %d live buffers.', len(live_buffers))
446
exceptions_buffers = []
447
for x in fastmath.tree_flatten(exceptions):
0 commit comments