Open
Description
In jax._src.xla_bridge
we have a function named is_gpu
to determine whether a platform is a GPU platform:
Lines 846 to 847 in 7dd401c
And it seems it just hardcoded two platforms cuda
and rocm
inside this function. However I notice that there are some utilities for such check in xla_bridge
, like _platform_aliases
, _alias_to_platforms
and expand_platform_alias
:
Lines 838 to 844 in 7dd401c
I think maybe we can change this function to:
def is_gpu(platform):
return platform in expand_platform_alias("gpu")
so that it can be extensible and consistent with _platform_alias
. We can just change the alias list if we want to add a new GPU platform, instead of taking care of this function.
WDYT? I'm glad to submit a patch if it looks good to people in the JAX community : )