-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Flax needs to be upgraded in the tensorflow/jax image #489
Comments
The latest version of We could either wait for the updates on |
OK, looks like
and got this traceback:
It seems like Edit: |
Looks like we might need to upgrade from CUDA 11.8 to 12 to get a newer version of |
@yuvipanda and @jbusecke - just getting your attention here, since my hacky workflows seems to have stopped working today. Recently to work with jax and appropriate version of flax, I would do 2 steps on the leap hub:
However, today morning this leads to a new error. After following these steps, I get the error: |
Hey @dhruvbalwada, I suspect this is due to the recent update of the pangeo-docker-image on the LEAP hub. To unblock you for now, I recommend you manually run from an older image (the LEAP docs provide instructions). But this does not change the core problem here I think. Anything I could help/test to contribute here @weiji14? |
Right, looks like we'll need to expedite the upgrade to CUDA 12 then as mentioned at #489 (comment). Let me open a PR for that (got some free time today), and then we'll be able to upgrade to newer tenforflow/flax versions. |
Ok, not as simple as I thought. I tried running
It looks like we'll need to wait for |
Do you think for now a slightly older version may be enough? maybe flax>=0.7? |
Nope,
I've also tried older version combinations with CUDA 11.2 and tensorflow 2.13.x last year (see all my crossed out links in #489 (comment)), but they all don't work. We really need to get all the tensorflow/jax libraries to align on the correct version of libprotobuf in conda-forge. |
The new hack that is working is :
Hopefully alignment will come in near future. |
Describe the bug
Current version of flax (0.6.1) on the image does not work properly with the jax version (0.4.13).
T
To Reproduce
Issue can be reproduced by doing
from flax.training import checkpoints
, which will give the errorModuleNotFoundError: No module named 'jax.experimental.global_device_array'
.This has been discussed in google/flax#3087.
Expected behavior
Flax should be importable.
Infrastructure (Where you are running this image):
Solution
At the moment I solve this doing
pip install flax==0.6.10
The text was updated successfully, but these errors were encountered: