Skip to content

Commit 2fd8ab0

Browse files
committed
Fix errors
1 parent 06d4b28 commit 2fd8ab0

File tree

8 files changed

+16
-281
lines changed

8 files changed

+16
-281
lines changed

API_GUIDE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ single device snippet. Let's go over then one by one.
197197
- `torch_xla.launch()`
198198
- Creates the processes that each run an XLA device.
199199
- This function is a wrapper of multithreading spawn to allow user run the script with torchrun command line also. Each process will only be able to access the device assigned to the current process. For example on a TPU v4-8, there will be 4 processes being spawn up and each process will own a TPU device.
200-
- Note that if you print the `torch.device('xla')` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
200+
- Note that if you print the `torch_xla.device()` on each process you will see `xla:0` on all devices. This is because each process can only see one device. This does not mean multi-process is not functioning. The only exeption is with PJRT runtime on TPU v2 and TPU v3 since there will be `#devices/2` processes and each process will have 2 threads (check this [doc](https://github.com/pytorch/xla/blob/master/docs/pjrt.md#tpus-v2v3-vs-v4) for more details).
201201
- `MpDeviceLoader`
202202
- Loads the training data onto each device.
203203
- `MpDeviceLoader` can wrap on a torch dataloader. It can preload the data to the device and overlap the dataloading with device execution to improve the performance.

test/pjrt/test_runtime_multi_cpu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_default_cpu_device(self):
2727
os.environ.pop(xenv.PJRT_CPU_ASYNC_CLIENT, None)
2828

2929
expected = {0: torch.device('xla:0')}
30-
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
30+
devices_per_process = pjrt.run_multiprocess(torch_xla.device)
3131
self.assertDictEqual(devices_per_process, expected)
3232

3333
def test_multi_cpu_devices(self):
@@ -38,7 +38,7 @@ def test_multi_cpu_devices(self):
3838
3: torch.device('xla:3'),
3939
}
4040

41-
devices_per_process = pjrt.run_multiprocess(xm.xla_device)
41+
devices_per_process = pjrt.run_multiprocess(torch_xla.device)
4242
self.assertDictEqual(devices_per_process, expected)
4343

4444
def test_global_ordinal(self):
@@ -65,7 +65,7 @@ def forward(ctx, x):
6565
def backward(ctx, grad_output):
6666
results['forward_ordinal'] = ctx.forward_ordinal
6767
results['backward_ordinal'] = xr.global_ordinal()
68-
results['device'] = str(torch.device('xla'))
68+
results['device'] = str(torch_xla.device())
6969
return grad_output
7070

7171
x = torch.ones(1, requires_grad=True, device='xla')

test/pjrt/test_runtime_multi_gpu.py

Lines changed: 0 additions & 266 deletions
This file was deleted.

test/pytorch_test_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def _alt_lookup(d, keys, defval):
559559
def instantiate_test(cls, name, test, *, generic_cls):
560560
test_name = name + '_' + cls.device_type
561561
class_name = cls.__name__
562-
real_device_type = xm.xla_device_hw(str(torch.device('xla')))
562+
real_device_type = xm.xla_device_hw(str(torch.device('xla:0')))
563563
assert real_device_type in DISABLED_TORCH_TESTS, 'Unsupported device type:' + real_device_type
564564
disabled_torch_tests = DISABLED_TORCH_TESTS[real_device_type]
565565

@@ -631,8 +631,8 @@ def get_primary_device(cls):
631631

632632
@classmethod
633633
def setUpClass(cls):
634-
# Sets the primary test device to the xla_device (CPU or TPU)
635-
cls.primary_device = str(torch.device('xla'))
634+
# Sets the primary test device to the torch_xla.device (CPU or TPU)
635+
cls.primary_device = str(torch_xla.device())
636636
torch_xla._XLAC._xla_set_mat_mul_precision('highest')
637637

638638
def setUp(self):

test/spmd/test_xla_spmd_python_api_interaction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_is_master_ordinal(self):
3838
self.assertTrue(xm.is_master_ordinal())
3939

4040
def test_xla_device(self):
41-
device = torch.device('xla')
41+
device = torch_xla.device()
4242
self.assertEqual(device, torch.device('xla:0'))
4343

4444
def test_xla_real_devices(self):

test/test_operations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ class TestOptimizationBarrier(test_utils.XlaTestCase):
442442
def test_optimization_barrier_correctness(self):
443443
device = torch.device('xla')
444444
# only test optimization_barrier on TPU
445-
if xm.xla_device_hw(device) != 'TPU':
445+
if xr.device_type() != 'TPU':
446446
return
447447
x = torch.randn(5, 5, device=device)
448448
y = torch.randn(5, 5, device=device)
@@ -1532,7 +1532,7 @@ def test_deepcopy(self):
15321532
self.assertEqual(x[0], x0)
15331533

15341534
def test_print(self):
1535-
xla_device = torch.device('xla')
1535+
xla_device = torch.device('xla:0')
15361536
x = torch.tensor([5], device=xla_device)
15371537
expected_str = 'tensor([5], device=\'' + str(xla_device) + '\')'
15381538
self.assertEqual(str(x), expected_str)
@@ -2759,7 +2759,7 @@ def test_send_to_device_grad(self):
27592759
self.assertTrue(dt[0].requires_grad)
27602760

27612761
def test_send_to_device_single(self):
2762-
xla_device = torch.device('xla')
2762+
xla_device = torch.device('xla:0')
27632763
t = _gen_tensor(2, 2)
27642764
dt = xm.send_cpu_data_to_device(t, xla_device)
27652765
self.assertEqual(dt[0].device, xla_device)
@@ -2859,7 +2859,7 @@ def from_tensors(self, tensors):
28592859

28602860
wpack = PackWrapper(pack)
28612861

2862-
xla_device = torch.device('xla')
2862+
xla_device = torch.device('xla:0')
28632863
xdata = xm.send_cpu_data_to_device(wpack, xla_device)
28642864
self.assertTrue(isinstance(xdata, nn.utils.rnn.PackedSequence))
28652865
self.assertEqual(xdata.batch_sizes.device, torch.device('cpu'))

torch_xla/_internal/pjrt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def initialize_singleprocess():
104104
plugins.default().configure_single_process()
105105
elif runtime.device_type() == 'TPU':
106106
tpu.configure_one_chip_topology()
107-
xm.set_replication(torch.device('xla'), [])
107+
xm.set_replication(torch_xla.device(), [])
108108

109109

110110
def initialize_multiprocess(local_rank: int, local_world_size: int):
@@ -119,7 +119,7 @@ def initialize_multiprocess(local_rank: int, local_world_size: int):
119119
neuron.initialize_env(local_rank, local_world_size)
120120

121121
devices = xm.get_xla_supported_devices()
122-
xm.set_replication(torch.device('xla'), devices)
122+
xm.set_replication(torch_xla.device(), devices)
123123

124124

125125
def run_multiprocess(fn: Callable[..., R],

torch_xla/runtime.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,8 @@ def local_ordinal() -> int:
156156
Local ordinal is in range [0, local_device_count)."""
157157
local_rank = xu.getenv_as(xenv.PJRT_LOCAL_PROCESS_RANK, int, 0)
158158
devices_per_process = addressable_device_count()
159-
return local_rank * devices_per_process + torch.device('xla').index
159+
return local_rank * devices_per_process + torch.device(
160+
torch_xla._XLAC._xla_get_default_device()).index
160161

161162

162163
def process_index() -> int:

0 commit comments

Comments
 (0)