You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
More of an XLA problem (rather than Jax). It looks like the latency hiding scheduler is not working correctly in jax 0.4.35. I am running my code on a single DGX node with 8 H100. I am only passing the following flags to XLA:
I also tried to modify the value for config.parallel_collective_overlap_limit to check if it would make any difference by modifying the XLA_FLAGS var to:
Thanks for the report! It's possible that someone on this repo will be able to provide suggestions, but (as you mention!) you might get more milage asking this question on the XLA repository if you haven't already: https://github.com/openxla/xla/issues
Description
More of an XLA problem (rather than Jax). It looks like the latency hiding scheduler is not working correctly in jax 0.4.35. I am running my code on a single DGX node with 8 H100. I am only passing the following flags to XLA:
XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
During compilation of my code, I get the following error:
I also tried to modify the value for
config.parallel_collective_overlap_limit
to check if it would make any difference by modifying theXLA_FLAGS
var to:but in that case I get:
If I downgrade to jax 0.4.34, the error does not occur.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.35
jaxlib: 0.4.34
numpy: 2.0.2
python: 3.11.11 (main, Dec 6 2024, 20:02:44) [Clang 18.1.8 ]
The text was updated successfully, but these errors were encountered: