Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latency Hiding Scheduler not working with jax 0.4.35 #25404

Open
lgiacomoni opened this issue Dec 11, 2024 · 1 comment
Open

Latency Hiding Scheduler not working with jax 0.4.35 #25404

lgiacomoni opened this issue Dec 11, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@lgiacomoni
Copy link

Description

More of an XLA problem (rather than Jax). It looks like the latency hiding scheduler is not working correctly in jax 0.4.35. I am running my code on a single DGX node with 8 H100. I am only passing the following flags to XLA:

XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false

During compilation of my code, I get the following error:

F external/xla/xla/service/gpu/gpu_hlo_schedule.cc:475] Check failed: (config.collective_broadcast_overlap_limit <= config.parallel_collective_overlap_limit) && (config.all_to_all_overlap_limit <= config.parallel_collective_overlap_limit) && (config.all_gather_overlap_limit <= config.parallel_collective_overlap_limit) && (config.all_reduce_overlap_limit <= config.parallel_collective_overlap_limit) && (config.reduce_scatter_overlap_limit <= config.parallel_collective_overlap_limit)

I also tried to modify the value for config.parallel_collective_overlap_limit to check if it would make any difference by modifying the XLA_FLAGS var to:

ENV XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_experimental_parallel_collective_overlap_limit=2"

but in that case I get:

Unknown flags in XLA_FLAGS: --xla_gpu_experimental_parallel_collective_overlap_limit=2 --xla_gpu_experimental_parallel_collective_overlap_limit=2

If I downgrade to jax 0.4.34, the error does not occur.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.35
jaxlib: 0.4.34
numpy: 2.0.2
python: 3.11.11 (main, Dec 6 2024, 20:02:44) [Clang 18.1.8 ]

@lgiacomoni lgiacomoni added the bug Something isn't working label Dec 11, 2024
@dfm
Copy link
Collaborator

dfm commented Dec 11, 2024

Thanks for the report! It's possible that someone on this repo will be able to provide suggestions, but (as you mention!) you might get more milage asking this question on the XLA repository if you haven't already: https://github.com/openxla/xla/issues

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants