diff --git a/.gitignore b/.gitignore index 7beb0e7..55b9b3c 100644 --- a/.gitignore +++ b/.gitignore @@ -133,4 +133,6 @@ dmypy.json *.pt .vscode -.idea/ \ No newline at end of file +.idea/ + +jetstream-pt-deps \ No newline at end of file diff --git a/Makefile b/Makefile index dc1b770..00d4f04 100644 --- a/Makefile +++ b/Makefile @@ -88,6 +88,7 @@ tgi_server: VERSION=${VERSION} TGI_VERSION=${TGI_VERSION} make -C text-generation-inference/server gen-server jetstream_requirements: + bash install-jetstream-pt.sh python -m pip install .[jetstream-pt] \ -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \ -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html \ diff --git a/install-jetstream-pt.sh b/install-jetstream-pt.sh new file mode 100644 index 0000000..aa5bd62 --- /dev/null +++ b/install-jetstream-pt.sh @@ -0,0 +1,13 @@ +#!/bin/bash +deps_dir=deps +rm -rf $deps_dir +mkdir -p $deps_dir +cd $deps_dir +pwd +git clone https://github.com/google/jetstream-pytorch.git +cd jetstream-pytorch +git checkout ec4ac8f6b180ade059a2284b8b7d843b3cab0921 +git submodule update --init --recursive +# We cannot install in a temporary directory because the directory should not be deleted after the script finishes, +# because it will install its dependendencies from that directory. +pip install -e . diff --git a/pyproject.toml b/pyproject.toml index 01ad1a5..c57ba05 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,10 +58,10 @@ build-backend = "setuptools.build_meta" [project.optional-dependencies] tests = ["pytest", "safetensors"] quality = ["black", "ruff", "isort"] -# Jetstream/Pytorch support is experimental for now, requires installation from fixed commit. +# Jetstream/Pytorch support is experimental for now, it needs to be installed manually. # Pallas is pulled because it will install a compatible version of jax[tpu]. jetstream-pt = [ - "jetstream-pt @ git+https://github.com/google/jetstream-pytorch.git@ec4ac8f6b180ade059a2284b8b7d843b3cab0921", + "jetstream-pt", "torch-xla[pallas] == 2.4.0" ]