Closed
Description
Description
More of an XLA problem (rather than Jax). It looks like the latency hiding scheduler is not working correctly in jax 0.4.35. I am running my code on a single DGX node with 8 H100. I am only passing the following flags to XLA:
XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false
During compilation of my code, I get the following error:
F external/xla/xla/service/gpu/gpu_hlo_schedule.cc:475] Check failed: (config.collective_broadcast_overlap_limit <= config.parallel_collective_overlap_limit) && (config.all_to_all_overlap_limit <= config.parallel_collective_overlap_limit) && (config.all_gather_overlap_limit <= config.parallel_collective_overlap_limit) && (config.all_reduce_overlap_limit <= config.parallel_collective_overlap_limit) && (config.reduce_scatter_overlap_limit <= config.parallel_collective_overlap_limit)
I also tried to modify the value for config.parallel_collective_overlap_limit
to check if it would make any difference by modifying the XLA_FLAGS
var to:
ENV XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_experimental_parallel_collective_overlap_limit=2"
but in that case I get:
Unknown flags in XLA_FLAGS: --xla_gpu_experimental_parallel_collective_overlap_limit=2 --xla_gpu_experimental_parallel_collective_overlap_limit=2
If I downgrade to jax 0.4.34, the error does not occur.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.35
jaxlib: 0.4.34
numpy: 2.0.2
python: 3.11.11 (main, Dec 6 2024, 20:02:44) [Clang 18.1.8 ]