Skip to content

Make xla_bridge.is_gpu more extensible #25521

Open
@PragmaTwice

Description

@PragmaTwice

In jax._src.xla_bridge we have a function named is_gpu to determine whether a platform is a GPU platform:

jax/jax/_src/xla_bridge.py

Lines 846 to 847 in 7dd401c

def is_gpu(platform):
return platform in ("cuda", "rocm")

And it seems it just hardcoded two platforms cuda and rocm inside this function. However I notice that there are some utilities for such check in xla_bridge, like _platform_aliases, _alias_to_platforms and expand_platform_alias:

jax/jax/_src/xla_bridge.py

Lines 838 to 844 in 7dd401c

def expand_platform_alias(platform: str) -> list[str]:
"""Expands, e.g., "gpu" to ["cuda", "rocm"].
This is used for convenience reasons: we expect cuda and rocm to act similarly
in many respects since they share most of the same code.
"""
return _alias_to_platforms.get(platform, [platform])

I think maybe we can change this function to:

def is_gpu(platform):
  return platform in expand_platform_alias("gpu")

so that it can be extensible and consistent with _platform_alias. We can just change the alias list if we want to add a new GPU platform, instead of taking care of this function.

WDYT? I'm glad to submit a patch if it looks good to people in the JAX community : )

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions