Skip to content

Commit

Permalink
Replace references to deprecated jax array attributes device_buffer a…
Browse files Browse the repository at this point in the history
…nd device_buffers

PiperOrigin-RevId: 588553845
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Dec 6, 2023
1 parent d72bd65 commit 10a95cc
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions trax/optimizers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,11 @@ def _free_accelerators(self, exceptions=(), keep_constants=True):
logging.info('Deleting %d live buffers.', len(live_buffers))
exceptions_buffers = []
for x in fastmath.tree_flatten(exceptions):
if hasattr(x, 'device_buffer'): # DeviceArray
if hasattr(x, 'addressable_shards'): # Array
exceptions_buffers.extend(shard.data for shard in x.addressable_shards)
elif hasattr(x, 'device_buffer'): # DeviceArray
exceptions_buffers.append(x.device_buffer)
if hasattr(x, 'device_buffers'): # ShardedDeviceArray
elif hasattr(x, 'device_buffers'): # ShardedDeviceArray
exceptions_buffers.extend(x.device_buffers)
for b in live_buffers:
should_delete = True
Expand Down

0 comments on commit 10a95cc

Please sign in to comment.