Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kernel always restarting when JIT compiling the forward call on MacBook Pro M3 Max #195

Open
Gdesau opened this issue May 22, 2024 · 0 comments

Comments

@Gdesau
Copy link

Gdesau commented May 22, 2024

Hello,

I'm trying to use whisper-jax on a MacBook Pro M3 Max (with CPU 16 cores, GPU 40 cores, Neural Engine 16 cores and 128Gb unified memory). I'm working on a Jupyter Notebook and I followed the Apple Dev Doc to install JAX. I just had to add one more step to the doc: pip install jax==0.4.26 jaxlib==0.4.26 because I ran into an error(see the Apple Dev forum. However every time I try to JIT compile the forward call as in the doc, my kernel is restarting. Here is my code:

# Imports
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
from jax.experimental.compilation_cache import compilation_cache as cc
from IPython.display import Audio

# instantiate pipeline
pipeline = FlaxWhisperPipline("openai/whisper-small", dtype=jnp.bfloat16, batch_size=16)

cc.set_cache_dir("./jax_cache")

audio = './test_court.mp3'
Audio(audio)

%%time
# JIT compile the forward call - slow, but we only do once
text = pipeline(audio)

I've already tried with other versions of the model and I've tried in a .py file on VsCode and no error but nothing happened when printing the text.

Any idea?

Thanks in advance for the help

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant