Open
Description
Currently there are many external APIs related getting the number of devices associate with PyTorch XLA. Those that I could find were:
- "global_runtime_device_count": returns the total number of devices across all processes/hosts, but it has "@functools.lru_cache()"
- "global_device_count": returns the total number of devices across all processes/hosts, but it has "@functools.lru_cache()"
- "addressable_runtime_device_count": Access number of addressable devices visible to a process.
- "addressable_device_count": Access number of addressable devices visible to a process. It specifically returns 1 in case of SPMD.
- "local_device_count": takes the number of addressable devices and multiplies it by the number of local process counts. Equivalent of the answer of the number of devices running on a host.
From these, some existing observations are:
addressable_runtime_device_count
andaddressable_device_count
are extremely similar in implementation and name. Perhaps we should make the distinction more clear. Perhaps there is some context aroundaddressable_device_count
particular I don't fully grasp.local_device_count
terminology can be confusing when compared with JAX's concept for local devices for jax.local_devices.local_device_count
being the number of devices in the host, while JAX's definition is of devices in the process- We should deduplicate
global_runtime_device_count
andglobal_device_count
, just have one reference the other to remove multiple calls