diff --git a/tpu/requirements.txt b/tpu/requirements.txt index ec6fb273..72e61d1d 100644 --- a/tpu/requirements.txt +++ b/tpu/requirements.txt @@ -12,7 +12,7 @@ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TOR torchaudio==${TORCHAUDIO_VERSION} torchvision==${TORCHVISION_VERSION} # Jax packages -jax[tpu]>=0.4.34 +jax[tpu]>=0.5.3 --find-links https://storage.googleapis.com/jax-releases/libtpu_releases.html distrax flax