Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolving dependency issues #246

Open
rinapch opened this issue Jan 7, 2023 · 7 comments
Open

Resolving dependency issues #246

rinapch opened this issue Jan 7, 2023 · 7 comments

Comments

@rinapch
Copy link

rinapch commented Jan 7, 2023

There has been a number of issues regarding different version conflicts and how to fix them. I've spent some time trying to make this code run, so maybe this instruction will spare someone else their efforts :)

First of all, as per this issue in jax repo jax-ml/jax#13321, TPU VMs no longer work with jax older than 0.2.16. This repo requires jax==0.2.12. I found out that the code still works with jax versions 0.2.18 and 0.2.20

Additionally, since there are a number of dependecies in the requirements file that do not state the needed versions, I rolled back all of them to the lastest versions per January 2022 and used poetry to resolve conflics. Here is the pyproject.toml file in the end:

python = "^3.8"
numpy = ">=1.19.5,<1.20.0"
tqdm = ">=4.45.0,<4.46.0"
wandb = "^0.13.7"
einops = ">=0.3.0,<0.4.0"
requests = ">=2.25.1,<2.26.0"
fabric = ">=2.6.0,<2.7.0"
optax = "0.0.9"
dm-haiku = "0.0.5"
ray = {version = "1.4.1", extras = ["default"]}
jax = "0.2.18"
cloudpickle = ">=1.3.0,<1.4.0"
tensorflow-cpu = ">=2.6.0,<2.7.0"
google-cloud-storage = ">=1.36.2,<1.37.0"
transformers = ">=4.16.2,<4.17.0"
smart-open = {version = ">=5.2.1,<5.3.0", extras = ["gcs"]}
ftfy = ">=6.1,<7.0"
lm-dataformat = "^0.0.20"
pathy = "^0.10.1"
func-timeout = "^4.3.5"
chex = "0.0.5"

After installing all of this with poetry, install jax[tpu] with pip, so that it gets the right libtpu nightly build (pip install "jax[tpu]==0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html)

When starting training you also can experience training being stuck at validation. As was suggested by @versae in this issue #218, it helps to change TPU runtime version to an alpha build. Something like gcloud alpha compute tpus tpu-vm create gptj --accelerator-type v3-8 --version v2-alpha

@JohnnyOpcode
Copy link

With Colab Pro, the default TPU lib (and JAX) is now at 0.3.25. I jumped thru these hoops as well and have run with

!pip install mesh-transformer-jax/ jax==0.3.15 tensorflow==2.8.2 chex==0.1.4 jaxlib==0.3.15

Your mileage may vary..

Johnny

@mosmos6
Copy link

mosmos6 commented Jan 12, 2023

@rinapch Worked perfect for me. Thank you very much..

@mosmos6
Copy link

mosmos6 commented Jan 16, 2023

@rinapch Worked perfect for me. Thank you very much..

That said, it worked perfectly for fine tuning but not to infer on colab. (It caused optax error)
In order to set up the model, I needed to reverse the requirements as

numpy~=1.19.5
tqdm~=4.45.0
wandb>=0.11.2
einops~=0.3.0
requests~=2.25.1
fabric~=2.6.0
optax==0.0.6
git+https://github.com/deepmind/dm-haiku
git+https://github.com/EleutherAI/lm-evaluation-harness@c406a62047
ray[default]==1.4.1
jax~=0.2.12
Flask~=1.1.2
cloudpickle~=1.3.0
tensorflow-cpu~=2.5.0
google-cloud-storage~=1.36.2
transformers
smart_open[gcs]
func_timeout
ftfy
fastapi
uvicorn
lm_dataformat
pathy

and

!pip install chex==0.1.2
!pip install jaxlib==0.1.68
!pip install dm-haiku==0.0.5

Just as a note.

@AidanShipperley
Copy link

Thank you so much for this post, it helped me resolve all of my dependency issues. I have never worked with poetry before, but I was able to get a model training in a conda environment just using install commands.

If anybody is interested, I wrote out the steps I took from scratch that are currently working based on my test run.

-- First, Install conda on the TPU vm

mkdir conda_install
cd conda_install
sudo apt-get update
sudo apt-get install wget
wget https://repo.anaconda.com/archive/Anaconda3-2022.10-Linux-x86_64.sh
bash Anaconda3-2022.10-Linux-x86_64.sh

-- Update path to include conda

export PATH=~/anaconda3/bin:$PATH

-- Create env with mamba and python == 3.8

conda create -n gpt -c conda-forge mamba python==3.8

-- Close and reopen terminal, ressh

gcloud compute tpus tpu-vm ssh YOUR_TPU_NAME --zone YOUR_ZONE_NAME

-- Leave base

conda deactivate 

-- Enter env

conda activate gpt

-- Install requirements available through conda first

mamba install -c conda-forge numpy==1.19.5 tqdm==4.45.0 einops==0.3.0 requests==2.25.1 fabric==2.6.0 optax==0.0.9 dm-haiku==0.0.5 jax==0.2.18 cloudpickle==1.3.0 tensorflow-cpu==2.6.0 google-cloud-storage==1.36.2 transformers==4.16.2 smart_open==5.2.1 ftfy==6.1.1 pathy==0.10.1 func_timeout==4.3.5

-- Install remaining requirements not available through conda with pip

pip install ray[default]==1.4.1 wandb==0.13.7 chex==0.0.5 lm-dataformat==0.0.20 typing-extensions==4.2.0 protobuf==3.19.5

-- NOTE: You will see a typing-extensions error pop up about tensorflow 2.6.0 not being compatible with 4.2.0. This is fine, ignore it.

-- Jax 0.2.12 does NOT WORK with TPUs anymore, but we can use 0.2.18 or 0.2.20

pip install "jax[tpu]==0.2.18" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

-- If you have issues with protobuf (may originate from the import wandb call), run this

python3 -m pip uninstall protobuf
python3 -m pip install protobuf==3.19.5

-- Finally, you can run this and fine-tune your model

cd ./mesh-transformer-jax/
python3 device_train.py --config=./configs/YOUR_CONFIG_NAME.json --tune-model-path=gs://YOUR_BUCKET_NAME/step_383500/

@mosmos6
Copy link

mosmos6 commented Mar 18, 2023

@JohnnyOpcode How did you infer with JAX 0.3.15? I think it runs only with 0.2.12.

@JohnnyOpcode
Copy link

@JohnnyOpcode How did you infer with JAX 0.3.15? I think it runs only with 0.2.12.

I was using Colab Pro (paid) and I experimented with different versions of the libraries and with pip. The key takeaway is compatibility with the TPUv2 ASIC. I'll try and find some time to go thru those motions again and come up with a newer working requirements.txt for everybody.

Python sucks btw. Just like JS and TS. Too many brittle dependencies, but it does create lots of BS positions and salaries.

@indrakalita
Copy link

  1. Uninstall all relevant packages:
    pip uninstall -y jax jaxlib chex
  2. Install a compatible version of JAX for CUDA 12:
    pip install --upgrade "jax[cuda12_pip]<0.4.24" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    (If you encounter dependency warnings, you can safely ignore them.)
  3. Install a compatible version of Chex:
    pip install --upgrade chex==0.1.81
  4. After completing these steps, verify dependencies using:
    pip check

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants