You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Following the most recent commit, the jax.tree module is now being used instead of jax.tree_*. This is change requires jax >= 0.4.25. However, the are still many parts of the repository that are still old/deprecated. For instance, if you install 0.4.25 you might get something like:
AttributeError: module 'jax.random' has no attribute 'KeyArray'
This is because this module has been removed in jax 0.4.24, meaning that in order to not run into this problem you need jax <=0.4.23. Obviously this goes against the requirement above.
Lastly, there is still the issue with DeviceArray and ShardedDeviceArray. They have all been changed to somply jax.Array back in jax=0.4.0! At the current state of the repo you basically need to add lines like:
jax.interpreters.xla.DeviceArray = jax.Array in order to be able to even import acme...
The text was updated successfully, but these errors were encountered:
Following the most recent commit, the jax.tree module is now being used instead of jax.tree_*. This is change requires jax >= 0.4.25. However, the are still many parts of the repository that are still old/deprecated. For instance, if you install 0.4.25 you might get something like:
AttributeError: module 'jax.random' has no attribute 'KeyArray'
This is because this module has been removed in jax 0.4.24, meaning that in order to not run into this problem you need jax <=0.4.23. Obviously this goes against the requirement above.
Lastly, there is still the issue with DeviceArray and ShardedDeviceArray. They have all been changed to somply jax.Array back in jax=0.4.0! At the current state of the repo you basically need to add lines like:
jax.interpreters.xla.DeviceArray = jax.Array
in order to be able to even import acme...The text was updated successfully, but these errors were encountered: