From a8cad0af35c1d6755b74e7ddfebac8b390e3e0c0 Mon Sep 17 00:00:00 2001 From: Alvaro Moran Date: Wed, 18 Sep 2024 15:37:16 +0000 Subject: [PATCH] fix(jetstream Pt): make Jetstream Pt install more reliable I was previously referencing a given git revision and install from github, but since the Jetstream Pytorch package install its dependencies from its git submodels, these are installed in temporary directories, that can disappear afterwards. This happened on CI, making the installation fail. To work around that, a dedicated install script has been added, and it is now used to install that. --- .gitignore | 4 +++- Makefile | 1 + install-jetstream-pt.sh | 13 +++++++++++++ pyproject.toml | 4 ++-- 4 files changed, 19 insertions(+), 3 deletions(-) create mode 100644 install-jetstream-pt.sh diff --git a/.gitignore b/.gitignore index 7beb0e7f..55b9b3cc 100644 --- a/.gitignore +++ b/.gitignore @@ -133,4 +133,6 @@ dmypy.json *.pt .vscode -.idea/ \ No newline at end of file +.idea/ + +jetstream-pt-deps \ No newline at end of file diff --git a/Makefile b/Makefile index dc1b770e..00d4f043 100644 --- a/Makefile +++ b/Makefile @@ -88,6 +88,7 @@ tgi_server: VERSION=${VERSION} TGI_VERSION=${TGI_VERSION} make -C text-generation-inference/server gen-server jetstream_requirements: + bash install-jetstream-pt.sh python -m pip install .[jetstream-pt] \ -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ diff --git a/install-jetstream-pt.sh b/install-jetstream-pt.sh new file mode 100644 index 00000000..aa5bd621 --- /dev/null +++ b/install-jetstream-pt.sh @@ -0,0 +1,13 @@ +#!/bin/bash +deps_dir=deps +rm -rf $deps_dir +mkdir -p $deps_dir +cd $deps_dir +pwd +git clone https://github.com/google/jetstream-pytorch.git +cd jetstream-pytorch +git checkout ec4ac8f6b180ade059a2284b8b7d843b3cab0921 +git submodule update --init --recursive +# We cannot install in a temporary directory because the directory should not be deleted after the script finishes, +# because it will install its dependendencies from that directory. +pip install -e . diff --git a/pyproject.toml b/pyproject.toml index 01ad1a5d..c57ba05d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,10 +58,10 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] tests = ["pytest", "safetensors"] quality = ["black", "ruff", "isort"] -# Jetstream/Pytorch support is experimental for now, requires installation from fixed commit. +# Jetstream/Pytorch support is experimental for now, it needs to be installed manually. # Pallas is pulled because it will install a compatible version of jax[tpu]. jetstream-pt = [ - "jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git@ec4ac8f6b180ade059a2284b8b7d843b3cab0921", + "jetstream-pt", "torch-xla[pallas] == 2.4.0" ]