You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
pip install "jax[cuda12]==0.4.31"
Successfully installed jax-0.4.31
pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
Successfully installed whisper_jax-0.0.1
pip list|g jax
246:jax 0.4.32.dev20240828
247:jax-cuda12-pjrt 0.4.31
248:jax-cuda12-plugin 0.4.31
249:jaxlib 0.4.31
701:whisper_jax 0.0.1
python
Python 3.12.5 (main, Aug 7 2024, 00:00:00) [GCC 13.3.1 20240522 (Red Hat 13.3.1-1)] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import whisper_jax
/home/k/.local/lib/python3.12/site-packages/transformers/utils/generic.py:311: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.
torch.utils._pytree._register_pytree_node(
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/k/.local/lib/python3.12/site-packages/whisper_jax/__init__.py", line 18, in <module>
from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
File "/home/k/.local/lib/python3.12/site-packages/whisper_jax/modeling_flax_whisper.py", line 57, in <module>
from whisper_jax import layers
File "/home/k/.local/lib/python3.12/site-packages/whisper_jax/layers.py", line 63, in <module>
def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
^^^^^^^^^^^^^^^^^^^
File "/home/k/.local/lib/python3.12/site-packages/jax/_src/deprecations.py", line 55, in getattr
raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.core' has no attribute 'NamedShape'
Tried pip install -U "jax", same error.
Tried `pip install -U "jax[cuda12]==0.4.31" # so they are all the same version now
Still getting the same message.
Tried installing whisper-jax today
Tried
pip install -U "jax"
, same error.Tried `pip install -U "jax[cuda12]==0.4.31" # so they are all the same version now
Still getting the same message.
pip uninstall jax jaxlib jax-cuda12-pjrt jax-cuda12-plugin
pip install -U "jax[cuda12]"
python
The text was updated successfully, but these errors were encountered: