Skip to content

Clarify xr.device_type() API and use them #9343

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/ds/test_dynamic_shape_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def forward(self, x):


@unittest.skipIf(
xm.xla_device_hw(torch_xla.device()) != 'TPU',
xr.device_type() != 'TPU',
f"The tests fail on CPU. See https://github.com/pytorch/xla/issues/4298 for more detail."
)
class TestDynamicShapeModels(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion test/pjrt/test_dynamic_plugin_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def setUpClass(cls):
@staticmethod
def _assert_tpus_exist(index=0):
del index
assert xm.xla_device_hw(torch_xla.device()) == 'TPU'
assert xr.device_type() == 'TPU'

def test_single_process(self):
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
Expand Down
5 changes: 2 additions & 3 deletions test/test_autocast.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,8 +348,7 @@ def compare(first, second):
self.assertFalse(self.is_autocast_enabled())


@unittest.skipIf(
xm.xla_device_hw(torch_xla.device()) != 'TPU', f"TPU autocast test.")
@unittest.skipIf(xr.device_type() != 'TPU', f"TPU autocast test.")
class TestAutocastTPU(TestAutocastBase):

@classmethod
Expand Down Expand Up @@ -405,7 +404,7 @@ class TestOtherOps(unittest.TestCase):

# On TPU, the input of batch norm is casted into fp32, see torch_xla/csrc/autocast_mode.cpp
@unittest.skipIf(
xm.xla_device_hw(torch_xla.device()) != 'TPU',
xr.device_type() != 'TPU',
"the behavior of batch_norm autocast on TPU is different from others")
def test_batch_norm_tpu(self):
device = torch_xla.device()
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def _get_physical_tpu_mesh(self, devices: np.ndarray) -> np.ndarray:
A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On
v2 and v3, global_z is instead cores_per_chip (i.e., 2).
"""
assert xm.xla_device_hw(torch_xla.device()) == 'TPU'
assert xr.device_type() == 'TPU'
# coords is a 3-dims tuple representing the device in physical mesh
device_coords = [self.device_attributes[d]['coords'] for d in devices]
dims = tuple(d + 1 for d in max(device_coords))
Expand Down
4 changes: 2 additions & 2 deletions torch_xla/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def _maybe_select_default_device():


def device_type() -> Optional[str]:
"""Returns the current PjRt device type.
"""Returns the current PJRT device type.

Selects a default device if none has been configured

Returns:
A string representation of the device.
A string representation of the PJRT device: "CPU", "TPU", etc.
"""
pjrt_device = xu.getenv_as(xenv.PJRT_DEVICE, str)
return pjrt_device.split('_')[0] if pjrt_device else pjrt_device
Expand Down
Loading