Skip to content

Commit

Permalink
fix(jetstream Pt): make Jetstream Pt install more reliable
Browse files Browse the repository at this point in the history
I was previously referencing a given git revision and install from
github, but since the Jetstream Pytorch package install its dependencies
from its git submodels, these are installed in temporary directories,
that can disappear afterwards. This happened on CI, making the
installation fail.

To work around that, a dedicated install script has been added, and it
is now used to install that.
  • Loading branch information
tengomucho committed Sep 18, 2024
1 parent 4b48358 commit a8cad0a
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 3 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -133,4 +133,6 @@ dmypy.json
*.pt

.vscode
.idea/
.idea/

jetstream-pt-deps
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ tgi_server:
VERSION=${VERSION} TGI_VERSION=${TGI_VERSION} make -C text-generation-inference/server gen-server

jetstream_requirements:
bash install-jetstream-pt.sh
python -m pip install .[jetstream-pt] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \
Expand Down
13 changes: 13 additions & 0 deletions install-jetstream-pt.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
deps_dir=deps
rm -rf $deps_dir
mkdir -p $deps_dir
cd $deps_dir
pwd
git clone https://github.com/google/jetstream-pytorch.git
cd jetstream-pytorch
git checkout ec4ac8f6b180ade059a2284b8b7d843b3cab0921
git submodule update --init --recursive
# We cannot install in a temporary directory because the directory should not be deleted after the script finishes,
# because it will install its dependendencies from that directory.
pip install -e .
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ build-backend = "setuptools.build_meta"
[project.optional-dependencies]
tests = ["pytest", "safetensors"]
quality = ["black", "ruff", "isort"]
# Jetstream/Pytorch support is experimental for now, requires installation from fixed commit.
# Jetstream/Pytorch support is experimental for now, it needs to be installed manually.
# Pallas is pulled because it will install a compatible version of jax[tpu].
jetstream-pt = [
"jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git@ec4ac8f6b180ade059a2284b8b7d843b3cab0921",
"jetstream-pt",
"torch-xla[pallas] == 2.4.0"
]

Expand Down

0 comments on commit a8cad0a

Please sign in to comment.