-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Importing jax-moseq
produces deprecation error
#32
Comments
jax-moseq
produces an error on import because chex needs to be updated.jax-moseq
produces deprecation error
So are you saying that |
Yes |
I'm a little worried about this given that keypoint-moseq install currently works for most people, and upgrading chex/jax might open a whole can of worms. When you say "a fresh install of jax-moseq installs jax==0.4.20", what do you mean exactly? Like if you update jax-moseq it tries to update jax too? |
google-deepmind/chex#251 Merged to chex master in march, so probably 0.1.8. I'm having the exact same issue on Ubuntu 22 with GPU. Installing a newer chex gives dependency error from jax-moseq. |
Is the dependency error in pip? If so, I think it can just be ignored. |
Yeah, just edited my comment as you replied. I did ignore it, thanks Caleb :) |
A fresh install of
jax-moseq
installsjax==0.4.20
, which deprecates some classes thatchex==0.1.6
depends on, and produces the following error upon import:AttributeError: module 'jax.interpreters.pxla' has no attribute 'ShardedDeviceArray'
The solution is to update the version of
chex
to the latest (0.1.84
), which no longer referencesShardedDeviceArray
.The text was updated successfully, but these errors were encountered: