Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jax dependency Issue #579

Open
Aradhya-Tripathi opened this issue Feb 29, 2024 · 12 comments
Open

Jax dependency Issue #579

Aradhya-Tripathi opened this issue Feb 29, 2024 · 12 comments

Comments

@Aradhya-Tripathi
Copy link

Jax discontinued linear_util since v0.4.25 (latest) which means that when haiku is imported running jax@latest it crashes and since colabfold's pyproject.toml says its fine with any version of jax which is 0.4.20 above it causes issues running several methods of colabfold.

Please let me know if this makes sense or would you require more info.

@milot-mirdita
Copy link
Collaborator

Not sure how we can fix this. As far as I understand, we should not pin any jaxlib version, only jax, however that doesn't prevent it from installing a newer jaxlib.

Installing through poetry instead of pip should solve most issues, as it will install the versions specified in the lock file.

@milot-mirdita
Copy link
Collaborator

milot-mirdita commented Mar 7, 2024

I guess we have to wait for this to be merged:
google-deepmind/dm-haiku#739

Nevermind, this was already merged, however we pin dm-haiku to 0.0.10 since I had some problem with 0.0.11, however I don't remember why.

@milot-mirdita
Copy link
Collaborator

I pinned dm-haiku 0.0.12, this should hopefully help avoid any issues once google decides to upgrade jax within colab.

Thanks!

@Aradhya-Tripathi
Copy link
Author

Aradhya-Tripathi commented Mar 7, 2024

I think we can update jax to the latest version and change the code to use the linear_utils from the jax.extend module find the same here, I can make a pr for the same, and maybe fix installs for jax on colab as well.

@milot-mirdita
Copy link
Collaborator

dm haiku 0.0.12 does this already. Should be fine now after updating the dependency

@johnjacobpeters
Copy link

I think this is a related/the same error. I returned to a ColabFold v1.5.5: AlphaFold2 w/ MMseqs2 BATCH session that was working yesterday, and now I am getting the below error. Is there a fix for this? I might just be ignorant about what pinning dm-haiku 0.0.12 means.

RuntimeError Traceback (most recent call last)
in <cell line: 5>()
3 import sys
4
----> 5 from colabfold.batch import get_queries, run
6 from colabfold.download import default_data_dir
7 from colabfold.utils import setup_logging

7 frames
/usr/local/lib/python3.10/dist-packages/jax/_src/lib/init.py in check_jaxlib_version(jax_version, jaxlib_version, minimum_jaxlib_version)
62 msg = (f'jaxlib is version {jaxlib_version}, but this version '
63 f'of jax requires version >= {minimum_jaxlib_version}.')
---> 64 raise RuntimeError(msg)
65
66 if _jaxlib_version > _jax_version:

RuntimeError: jaxlib is version 0.3.25, but this version of jax requires version >= 0.4.20.

@milot-mirdita
Copy link
Collaborator

Are you using this notebook:
https://colab.research.google.com/github/sokrypton/ColabFold/blob/main/batch/AlphaFold2_batch.ipynb

The error message sounds like you are using an old version of it.

@johnjacobpeters
Copy link

I tried running things from that notebook, and no change. I did realize that I only seem to get that bug when using TPU. GPU seems fine.

@milot-mirdita
Copy link
Collaborator

I pushed a fix for TPU, should work again

@johnjacobpeters
Copy link

Thanks! Works great.

@Aradhya-Tripathi
Copy link
Author

Since this has been fixed with the upgrade in dm haiku 0.0.12, I think I can close this issue?

@milot-mirdita
Copy link
Collaborator

I still need to make a new pip release, i updated the conda package a few days ago.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants