We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
from whisper_jax import FlaxWhisperPipline import jax.numpy as jnp
pipeline = FlaxWhisperPipline("openai/whisper-small", dtype=jnp.float16, batch_size=16)
text = pipeline("10m.mp3") print(text)
The text was updated successfully, but these errors were encountered:
You might need to check your jax version.
I am use jax 0.4.19, it works.
!pip install jax==0.4.19 !pip install -U "jax[cuda12_pip]==0.4.19" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Sorry, something went wrong.
I tried to use windows11(settings/installed apps, NVIDIA driver 546.33, NVIDIA cuda11.8)(nvidia-smi, CUDA 12.3)-wsl2-Ubuntu-docker container-nvidia/cuda image, (https://hub.docker.com/r/nvidia/cuda/tags?page=2&name=11.8), no luck.
Do I need "jax[cuda12_pip]==0.4.19" or "jax[cuda11_pip]==0.4.19"? and how about container version?
You might need to check your jax version. I am use jax 0.4.19, it works. !pip install jax==0.4.19 !pip install -U "jax[cuda12_pip]==0.4.19" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
now, gpu got working. but it's slow. a mp3 of 10 minutes spent more than 300s.
No branches or pull requests
from whisper_jax import FlaxWhisperPipline
import jax.numpy as jnp
instantiate pipeline in bfloat16
pipeline = FlaxWhisperPipline("openai/whisper-small", dtype=jnp.float16, batch_size=16)
text = pipeline("10m.mp3")
print(text)
The text was updated successfully, but these errors were encountered: