Skip to content

Commit

Permalink
[JAX] Replace uses of jax.devices("cpu") with jax.local_devices(backe…
Browse files Browse the repository at this point in the history
…nd="cpu").

An upcoming change to JAX will include non-local (addressable) CPU devices in jax.devices() when JAX is used multicontroller-style, where there are multiple Python processes.

This change preserves the current behavior by replacing uses of jax.devices("cpu"), which previously only returned local devices, with jax.local_devices("cpu"), which will return local devices both now and in the future.

This change is always be safe (i.e., it should always preserve the previous behavior) but it may sometimes be unnecessary if code is never used in a multicontroller setting.

PiperOrigin-RevId: 582427745
  • Loading branch information
hawkinsp authored and copybara-github committed Nov 15, 2023
1 parent 38adb83 commit bda6deb
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions trax/layers/acceleration.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def f(x):
def _accelerate(f, n_devices):
"""Returns an accelerated version of ``f`` running on ``n_devices``."""
if n_devices == 0: # no accelerators - run on CPU
return fastmath.jit(f, device=jax.devices('cpu')[0])
return fastmath.jit(f, device=jax.local_devices(backend='cpu')[0])

if n_devices == 1:
return fastmath.jit(f)
Expand Down Expand Up @@ -248,7 +248,7 @@ def f(x):
def on_cpu(x):
"""Puts ``x`` in CPU memory in JAX."""
if fastmath.is_backend(fastmath.Backend.JAX):
return jax.device_put(x, jax.devices('cpu')[0])
return jax.device_put(x, jax.local_devices(backend='cpu')[0])
else:
return x

Expand Down

0 comments on commit bda6deb

Please sign in to comment.