Skip to content

Commit c0aec83

Browse files
committed
Clarify xr.device_type() API and use them
1 parent 78cff03 commit c0aec83

File tree

5 files changed

+7
-8
lines changed

5 files changed

+7
-8
lines changed

test/ds/test_dynamic_shape_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def forward(self, x):
4444

4545

4646
@unittest.skipIf(
47-
xm.xla_device_hw(torch_xla.device()) != 'TPU',
47+
xr.device_type() != 'TPU',
4848
f"The tests fail on CPU. See https://github.com/pytorch/xla/issues/4298 for more detail."
4949
)
5050
class TestDynamicShapeModels(unittest.TestCase):

test/pjrt/test_dynamic_plugin_tpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def setUpClass(cls):
2020
@staticmethod
2121
def _assert_tpus_exist(index=0):
2222
del index
23-
assert xm.xla_device_hw(torch_xla.device()) == 'TPU'
23+
assert xr.device_type() == 'TPU'
2424

2525
def test_single_process(self):
2626
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:

test/test_autocast.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,7 @@ def compare(first, second):
348348
self.assertFalse(self.is_autocast_enabled())
349349

350350

351-
@unittest.skipIf(
352-
xm.xla_device_hw(torch_xla.device()) != 'TPU', f"TPU autocast test.")
351+
@unittest.skipIf(xr.device_type() != 'TPU', f"TPU autocast test.")
353352
class TestAutocastTPU(TestAutocastBase):
354353

355354
@classmethod
@@ -405,7 +404,7 @@ class TestOtherOps(unittest.TestCase):
405404

406405
# On TPU, the input of batch norm is casted into fp32, see torch_xla/csrc/autocast_mode.cpp
407406
@unittest.skipIf(
408-
xm.xla_device_hw(torch_xla.device()) != 'TPU',
407+
xr.device_type() != 'TPU',
409408
"the behavior of batch_norm autocast on TPU is different from others")
410409
def test_batch_norm_tpu(self):
411410
device = torch_xla.device()

torch_xla/distributed/spmd/xla_sharding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ def _get_physical_tpu_mesh(self, devices: np.ndarray) -> np.ndarray:
334334
A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On
335335
v2 and v3, global_z is instead cores_per_chip (i.e., 2).
336336
"""
337-
assert xm.xla_device_hw(torch_xla.device()) == 'TPU'
337+
assert xr.device_type() == 'TPU'
338338
# coords is a 3-dims tuple representing the device in physical mesh
339339
device_coords = [self.device_attributes[d]['coords'] for d in devices]
340340
dims = tuple(d + 1 for d in max(device_coords))

torch_xla/runtime.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ def _maybe_select_default_device():
8282

8383

8484
def device_type() -> Optional[str]:
85-
"""Returns the current PjRt device type.
85+
"""Returns the current PJRT device type.
8686
8787
Selects a default device if none has been configured
8888
8989
Returns:
90-
A string representation of the device.
90+
A string representation of the PJRT device: "CPU", "TPU", etc.
9191
"""
9292
pjrt_device = xu.getenv_as(xenv.PJRT_DEVICE, str)
9393
return pjrt_device.split('_')[0] if pjrt_device else pjrt_device

0 commit comments

Comments
 (0)