Skip to content

Simplify device count external API calls #9199

@pgmoka

Description

@pgmoka

Currently there are many external APIs related getting the number of devices associate with PyTorch XLA. Those that I could find were:

  • "global_runtime_device_count": returns the total number of devices across all processes/hosts, but it has "@functools.lru_cache()"
  • "global_device_count": returns the total number of devices across all processes/hosts, but it has "@functools.lru_cache()"
  • "addressable_runtime_device_count": Access number of addressable devices visible to a process.
  • "addressable_device_count": Access number of addressable devices visible to a process. It specifically returns 1 in case of SPMD.
  • "local_device_count": takes the number of addressable devices and multiplies it by the number of local process counts. Equivalent of the answer of the number of devices running on a host.

From these, some existing observations are:

  • addressable_runtime_device_count and addressable_device_count are extremely similar in implementation and name. Perhaps we should make the distinction more clear. Perhaps there is some context around addressable_device_count particular I don't fully grasp.
  • local_device_count terminology can be confusing when compared with JAX's concept for local devices for jax.local_devices. local_device_count being the number of devices in the host, while JAX's definition is of devices in the process
  • We should deduplicate global_runtime_device_count and global_device_count, just have one reference the other to remove multiple calls

Metadata

Metadata

Assignees

No one assigned

    Labels

    documentationusabilityBugs/features related to improving the usability of PyTorch/XLA

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions