From a90fe644ade00b537e75870971c336ef0e499ffa Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 6 Dec 2023 15:01:57 -0800 Subject: [PATCH] Replace references to deprecated device_buffer attributes `jax.Array.device_buffer` and `jax.Array.device_buffers` will be deprecated as of jax version 0.4.22; see https://github.com/google/jax/pull/18844. PiperOrigin-RevId: 588553845 --- trax/optimizers/trainer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/trax/optimizers/trainer.py b/trax/optimizers/trainer.py index 41707701f..4d4c4bd9d 100644 --- a/trax/optimizers/trainer.py +++ b/trax/optimizers/trainer.py @@ -445,9 +445,11 @@ def _free_accelerators(self, exceptions=(), keep_constants=True): logging.info('Deleting %d live buffers.', len(live_buffers)) exceptions_buffers = [] for x in fastmath.tree_flatten(exceptions): - if hasattr(x, 'device_buffer'): # DeviceArray + if hasattr(x, 'addressable_shards'): # Array + exceptions_buffers.extend(shard.data for shard in x.addressable_shards) + elif hasattr(x, 'device_buffer'): # DeviceArray exceptions_buffers.append(x.device_buffer) - if hasattr(x, 'device_buffers'): # ShardedDeviceArray + elif hasattr(x, 'device_buffers'): # ShardedDeviceArray exceptions_buffers.extend(x.device_buffers) for b in live_buffers: should_delete = True