Replies: 2 comments
-
This question might be better asked on the openxla/xla repo because JAX just lowers the FFT to whatever XLA provides. I believe (but don't quote me!) that on CPU XLA will use the multithreaded FFT from DUCC, and on NVIDIA GPUs it'll use cuFFT. I'm not sure what decompositions those implementations use! |
Beta Was this translation helpful? Give feedback.
-
Hello, I can say for sure that they do not parallelize on multiple GPUs since cufft is single GPU only There is cufftMP but it is not very easy to install and work with I would suggest that you use my implementation of parallel FFTs in Jax https://github.com/differentiableuniverseinitiative/jaxdecomp It is also available on PyPi It supports both pencil and slab decompositions |
Beta Was this translation helpful? Give feedback.
-
Does Jax ffts parallelize across cores automatically? Do they use pencil or slab decompositions?
Beta Was this translation helpful? Give feedback.
All reactions