Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AttributeError: module 'jax.core' has no attribute 'NamedShape' #199

Closed
themanyone opened this issue Aug 28, 2024 · 2 comments
Closed

AttributeError: module 'jax.core' has no attribute 'NamedShape' #199

themanyone opened this issue Aug 28, 2024 · 2 comments

Comments

@themanyone
Copy link

themanyone commented Aug 28, 2024

Tried installing whisper-jax today

pip install "jax[cuda12]==0.4.31"
Successfully installed jax-0.4.31
pip install --upgrade --no-deps --force-reinstall git+https://github.com/sanchit-gandhi/whisper-jax.git
Successfully installed whisper_jax-0.0.1
pip list|g jax
246:jax                                             0.4.32.dev20240828
247:jax-cuda12-pjrt                                 0.4.31
248:jax-cuda12-plugin                               0.4.31
249:jaxlib                                          0.4.31
701:whisper_jax                                     0.0.1

python
Python 3.12.5 (main, Aug  7 2024, 00:00:00) [GCC 13.3.1 20240522 (Red Hat 13.3.1-1)] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import whisper_jax
/home/k/.local/lib/python3.12/site-packages/transformers/utils/generic.py:311: FutureWarning: `torch.utils._pytree._register_pytree_node` is deprecated. Please use `torch.utils._pytree.register_pytree_node` instead.
  torch.utils._pytree._register_pytree_node(
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/k/.local/lib/python3.12/site-packages/whisper_jax/__init__.py", line 18, in <module>
    from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
  File "/home/k/.local/lib/python3.12/site-packages/whisper_jax/modeling_flax_whisper.py", line 57, in <module>
    from whisper_jax import layers
  File "/home/k/.local/lib/python3.12/site-packages/whisper_jax/layers.py", line 63, in <module>
    def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
                             ^^^^^^^^^^^^^^^^^^^
  File "/home/k/.local/lib/python3.12/site-packages/jax/_src/deprecations.py", line 55, in getattr
    raise AttributeError(f"module {module!r} has no attribute {name!r}")
AttributeError: module 'jax.core' has no attribute 'NamedShape'

Tried pip install -U "jax", same error.
Tried `pip install -U "jax[cuda12]==0.4.31" # so they are all the same version now
Still getting the same message.

pip uninstall jax jaxlib jax-cuda12-pjrt jax-cuda12-plugin
pip install -U "jax[cuda12]"
python

from jax import core
dc = dir(core)
[x for x in dc if 'shap' in x]
['UnshapedArray', 'is_constant_shape', 'raise_to_shaped', 'raise_to_shaped_mappings']
[x for x in dc if 'Shap' in x]
['DShapedArray', 'ShapedArray']
@themanyone
Copy link
Author

Changing NamedShape to DShapedArray in "whisper_jax/layers.py", line 63 gets it working.

@GO0108
Copy link

GO0108 commented Sep 1, 2024

I had this problem recently, and all I did was change the jax and jaxlib version to 0.4.26.

jimburtoft added a commit to jimburtoft/whisper-jax that referenced this issue Oct 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants