Skip to content

Latency Hiding Scheduler not working with jax 0.4.35 #25404

Closed
@lgiacomoni

Description

@lgiacomoni

Description

More of an XLA problem (rather than Jax). It looks like the latency hiding scheduler is not working correctly in jax 0.4.35. I am running my code on a single DGX node with 8 H100. I am only passing the following flags to XLA:

XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false

During compilation of my code, I get the following error:

F external/xla/xla/service/gpu/gpu_hlo_schedule.cc:475] Check failed: (config.collective_broadcast_overlap_limit <= config.parallel_collective_overlap_limit) && (config.all_to_all_overlap_limit <= config.parallel_collective_overlap_limit) && (config.all_gather_overlap_limit <= config.parallel_collective_overlap_limit) && (config.all_reduce_overlap_limit <= config.parallel_collective_overlap_limit) && (config.reduce_scatter_overlap_limit <= config.parallel_collective_overlap_limit)

I also tried to modify the value for config.parallel_collective_overlap_limit to check if it would make any difference by modifying the XLA_FLAGS var to:

ENV XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_experimental_parallel_collective_overlap_limit=2"

but in that case I get:

Unknown flags in XLA_FLAGS: --xla_gpu_experimental_parallel_collective_overlap_limit=2 --xla_gpu_experimental_parallel_collective_overlap_limit=2

If I downgrade to jax 0.4.34, the error does not occur.

System info (python version, jaxlib version, accelerator, etc.)

jax: 0.4.35
jaxlib: 0.4.34
numpy: 2.0.2
python: 3.11.11 (main, Dec 6 2024, 20:02:44) [Clang 18.1.8 ]

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions