Skip to content

Commit

Permalink
Update instruction to install jax on GPU (pyro-ppl#1470)
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi authored Aug 14, 2022
1 parent f48e341 commit 589b352
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ pip install numpyro[cpu]

To use **NumPyro on the GPU**, you need to install CUDA first and then use the following pip command:
```
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_releases.html
pip install numpyro[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
If you need further guidance, please have a look at the [JAX GPU installation instructions](https://github.com/google/jax#pip-installation-gpu-cuda).

Expand Down
10 changes: 2 additions & 8 deletions docker/dev/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ FROM nvidia/cuda:11.2.2-cudnn8-devel-ubuntu20.04

# declare the image name
# note that this image uses Python 3.8
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04 \
# declare the cuda version for pulling appropriate jaxlib wheel
JAXLIB_CUDA=111
ENV IMG_NAME=11.2.2-cudnn8-devel-ubuntu20.04

# install python3 and pip on top of the base Ubuntu image
# unlike for release, we need to install git and setuptools too
Expand All @@ -22,11 +20,7 @@ RUN apt update && \
ENV PATH=/root/.local/bin:$PATH

# install python packages via pip
# install pip-versions to detect the latest version of jax and jaxlib
RUN pip3 install pip-versions
# this uses latest version of jax and jaxlib available from pypi
RUN pip-versions latest jaxlib | xargs -I{} pip3 install jaxlib=={}+cuda${JAXLIB_CUDA} -f https://storage.googleapis.com/jax-releases/jax_releases.html \
jax
RUN pip3 install jax[cuda] -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# clone the numpyro git repository and run pip install
RUN git clone https://github.com/pyro-ppl/numpyro.git && \
Expand Down

0 comments on commit 589b352

Please sign in to comment.