Skip to content

Commit 0e2265e

Browse files
authored
Upgrade JAX (#1072)
- Use new pip install syntax. - The proper version of jaxlib is installed based on the version of jax.
1 parent 7cd9929 commit 0e2265e

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ RUN pip install flashtext && \
400400
pip install pycrypto && \
401401
pip install easyocr && \
402402
# Keep JAX version in sync with GPU image.
403-
pip install jax==0.2.16 jaxlib==0.1.68 && \
403+
pip install jax[cpu]==0.2.19 && \
404404
# ipympl adds interactive widget support for matplotlib
405405
pip install ipympl==0.7.0 && \
406406
pip install pandarallel && \

gpu.Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ RUN pip uninstall -y lightgbm && \
7777
/tmp/clean-layer.sh
7878

7979
# Install JAX (Keep JAX version in sync with CPU image)
80-
RUN pip install jax==0.2.16 jaxlib==0.1.68+cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION -f https://storage.googleapis.com/jax-releases/jax_releases.html && \
80+
RUN pip install jax[cuda$CUDA_MAJOR_VERSION$CUDA_MINOR_VERSION]==0.2.19 -f https://storage.googleapis.com/jax-releases/jax_releases.html && \
8181
/tmp/clean-layer.sh
8282

8383
# Reinstall packages with a separate version for GPU support.

0 commit comments

Comments
 (0)