Skip to content

Migrate torch_xla.device() to torch.device('xla') #9253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions API_GUIDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ print(t)

This code should look familiar. PyTorch/XLA uses the same interface as regular
PyTorch with a few additions. Importing `torch_xla` initializes PyTorch/XLA, and
`torch_xla.device()` returns the current XLA device. This may be a CPU or TPU
`torch.device('xla')` returns the current XLA device. This may be a CPU or TPU
depending on your environment.

## XLA Tensors are PyTorch Tensors
Expand Down Expand Up @@ -112,7 +112,7 @@ train_loader = xu.SampleGenerator(
torch.zeros(batch_size, dtype=torch.int64)),
sample_count=60000 // batch_size // xr.world_size())

device = torch_xla.device() # Get the XLA device (TPU).
device = torch.device('xla') # Get the XLA device (TPU).
model = MNIST().train().to(device) # Create a model and move it to the device.
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
Expand Down Expand Up @@ -169,7 +169,7 @@ def _mp_fn(index):
index: Index of the process.
"""

device = torch_xla.device() # Get the device assigned to this process.
device = torch.device('xla') # Get the device assigned to this process.
# Wrap the loader for multi-device.
mp_device_loader = pl.MpDeviceLoader(train_loader, device)

Expand Down Expand Up @@ -290,7 +290,7 @@ import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = torch_xla.device()
device = torch.device('xla')

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def _default_iter_fn(self, benchmark_experiment: BenchmarkExperiment,

def _pure_wall_time_iter_fn(self, benchmark_experiment: BenchmarkExperiment,
benchmark_model: BenchmarkModel, input_tensor):
device = torch_xla.device() if benchmark_experiment.xla else 'cuda'
device = torch.device('xla') if benchmark_experiment.xla else 'cuda'
sync_fn = xm.wait_device_ops if benchmark_experiment.xla else torch.cuda.synchronize
timing, output = bench.do_bench(
lambda: benchmark_model.model_iter_fn(
Expand Down
6 changes: 3 additions & 3 deletions contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": null,
"metadata": {
"execution": {
"iopub.execute_input": "2024-01-10T19:30:28.607393Z",
Expand All @@ -210,7 +210,7 @@
"lock = mp.Manager().Lock()\n",
"\n",
"def print_device(i, lock):\n",
" device = torch_xla.device()\n",
" device = torch.device('xla')\n",
" with lock:\n",
" print('process', i, device)"
]
Expand Down Expand Up @@ -454,7 +454,7 @@
"import torch_xla.experimental.pjrt_backend # Required for torch.distributed on TPU v2 and v3\n",
"\n",
"def toy_model(index, lock):\n",
" device = torch_xla.device()\n",
" device = torch.device('xla')\n",
" dist.init_process_group('xla', init_method='xla://')\n",
"\n",
" # Initialize a basic toy model\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/source/learn/_pjrt.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ import torch_xla.distributed.xla_backend


def _mp_fn(index):
device = torch_xla.device()
device = torch.device('xla')
- dist.init_process_group('xla', rank=xr.global_ordinal(), world_size=xr.world_size())
+ dist.init_process_group('xla', init_method='xla://')

Expand Down
4 changes: 2 additions & 2 deletions docs/source/learn/eager.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import torch
import torch_xla
import torchvision

device = torch_xla.device()
device = torch.device('xla')
model = torchvision.models.resnet18().to(device)
input = torch.randn(64, 3, 224, 224).to(device)

Expand Down Expand Up @@ -71,7 +71,7 @@ import torchvision
# Run ops eagerly by default
torch_xla.experimental.eager_mode(True)

device = torch_xla.device()
device = torch.device('xla')
model = torchvision.models.resnet18().to(device)

# Mark the function to be compiled
Expand Down
10 changes: 5 additions & 5 deletions docs/source/learn/pytorch-on-xla-devices.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ print(t)

This code should look familiar. PyTorch/XLA uses the same interface as
regular PyTorch with a few additions. Importing `torch_xla` initializes
PyTorch/XLA, and `torch_xla.device()` returns the current XLA device. This
PyTorch/XLA, and `torch.device('xla')` returns the current XLA device. This
may be a CPU or TPU depending on your environment.

## XLA Tensors are PyTorch Tensors
Expand Down Expand Up @@ -81,7 +81,7 @@ The following snippet shows a network training on a single XLA device:
``` python
import torch_xla.core.xla_model as xm

device = torch_xla.device()
device = torch.device('xla')
model = MNIST().train().to(device)
loss_fn = nn.NLLLoss()
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
Expand Down Expand Up @@ -120,7 +120,7 @@ import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl

def _mp_fn(index):
device = torch_xla.device()
device = torch.device('xla')
mp_device_loader = pl.MpDeviceLoader(train_loader, device)

model = MNIST().train().to(device)
Expand Down Expand Up @@ -148,7 +148,7 @@ previous single device snippet. Let's go over then one by one.
will only be able to access the device assigned to the current
process. For example on a TPU v4-8, there will be 4 processes
being spawn up and each process will own a TPU device.
- Note that if you print the `torch_xla.device()` on each process you
- Note that if you print the `torch.device('xla')` on each process you
will see `xla:0` on all devices. This is because each process
can only see one device. This does not mean multi-process is not
functioning. The only execution is with PJRT runtime on TPU v2
Expand Down Expand Up @@ -283,7 +283,7 @@ import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = torch_xla.device()
device = torch.device('xla')

t0 = torch.randn(2, 2, device=device)
t1 = torch.randn(2, 2, device=device)
Expand Down
8 changes: 4 additions & 4 deletions docs/source/learn/xla-overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ repo. contains examples for training and serving many LLM and diffusion models.

General guidelines to modify your code:

- Replace `cuda` with `torch_xla.device()`
- Replace `cuda` with `torch.device('xla')`
- Remove progress bar, printing that would access the XLA tensor
values
- Reduce logging and callbacks that would access the XLA tensor values
Expand Down Expand Up @@ -227,7 +227,7 @@ tutorial, but you can pass the `device` value to the function as well.

``` python
import torch_xla.core.xla_model as xm
self.device = torch_xla.device()
self.device = torch.device('xla')
```

Another place in the code that has cuda specific code is DDIM scheduler.
Expand All @@ -244,7 +244,7 @@ if attr.device != torch.device("cuda"):
with

``` python
device = torch_xla.device()
device = torch.device('xla')
attr = attr.to(torch.device(device))
```

Expand Down Expand Up @@ -339,7 +339,7 @@ with the following lines:

``` python
import torch_xla.core.xla_model as xm
device = torch_xla.device()
device = torch.device('xla')
pipe.to(device)
```

Expand Down
10 changes: 5 additions & 5 deletions docs/source/perf/amp.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ for input, target in data:
optimizer.zero_grad()

# Enables autocasting for the forward pass
with autocast(torch_xla.device()):
with autocast(torch.device('xla')):
output = model(input)
loss = loss_fn(output, target)

Expand All @@ -36,7 +36,7 @@ for input, target in data:
xm.optimizer_step.(optimizer)
```

`autocast(torch_xla.device())` aliases `torch.autocast('xla')` when the XLA
`autocast(torch.device('xla'))` aliases `torch.autocast('xla')` when the XLA
Device is a TPU. Alternatively, if a script is only used with TPUs, then
`torch.autocast('xla', dtype=torch.bfloat16)` can be directly used.

Expand Down Expand Up @@ -115,7 +115,7 @@ for input, target in data:
optimizer.zero_grad()

# Enables autocasting for the forward pass
with autocast(torch_xla.device()):
with autocast(torch.device('xla')):
output = model(input)
loss = loss_fn(output, target)

Expand All @@ -127,12 +127,12 @@ for input, target in data:
scaler.update()
```

`autocast(torch_xla.device())` aliases `torch.cuda.amp.autocast()` when the
`autocast(torch.device('xla'))` aliases `torch.cuda.amp.autocast()` when the
XLA Device is a CUDA device (XLA:GPU). Alternatively, if a script is
only used with CUDA devices, then `torch.cuda.amp.autocast` can be
directly used, but requires `torch` is compiled with `cuda` support for
datatype of `torch.bfloat16`. We recommend using
`autocast(torch_xla.device())` on XLA:GPU as it does not require
`autocast(torch.device('xla'))` on XLA:GPU as it does not require
`torch.cuda` support for any datatypes, including `torch.bfloat16`.

### AMP for XLA:GPU Best Practices
Expand Down
2 changes: 1 addition & 1 deletion docs/source/perf/ddp.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def demo_basic(rank):
setup(rank, world_size)

# create model and move it to XLA device
device = torch_xla.device()
device = torch.device('xla')
model = ToyModel().to(device)
ddp_model = DDP(model, gradient_as_bucket_view=True)

Expand Down
4 changes: 2 additions & 2 deletions docs/source/perf/dynamo.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import torchvision
import torch_xla.core.xla_model as xm

def eval_model(loader):
device = torch_xla.device()
device = torch.device('xla')
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.eval()
dynamo_resnet18 = torch.compile(
Expand Down Expand Up @@ -129,7 +129,7 @@ def train_model(model, data, target, optimizer):
return pred

def train_model_main(loader):
device = torch_xla.device()
device = torch.device('xla')
xla_resnet18 = torchvision.models.resnet18().to(device)
xla_resnet18.train()
dynamo_train_model = torch.compile(
Expand Down
4 changes: 2 additions & 2 deletions docs/source/perf/fori_loop.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ result = while_loop(cond_fn, body_fn, init)
>>> from torch._higher_order_ops.while_loop import while_loop
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = torch_xla.device()
>>> device = torch.device('xla')
>>>
>>> def cond_fn(iteri, x):
... return iteri > 0
Expand Down Expand Up @@ -60,7 +60,7 @@ with similar logic: cumulative plus 1 for ten times:
>>> import torch_xla
>>> import torch_xla.core.xla_model as xm
>>>
>>> device = torch_xla.device()
>>> device = torch.device('xla')
>>>
>>> init_val = torch.tensor(1, device=device)
>>> iteri = torch.tensor(50, device=device)
Expand Down
2 changes: 1 addition & 1 deletion docs/source/perf/quantized_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ scaler = torch.randn((N_OUTPUT_FEATURES,), dtype=torch.bfloat16)
# Call with torch CPU tensor (For debugging purpose)
matmul_output = torch.ops.xla.quantized_matmul(x, w_int, scaler)

device = torch_xla.device()
device = torch.device('xla')
x_xla = x.to(device)
w_int_xla = w_int.to(device)
scaler_xla = scaler.to(device)
Expand Down
2 changes: 1 addition & 1 deletion examples/train_decoder_only_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self,
torch.zeros(self.batch_size, self.seq_len, dtype=torch.int64)),
sample_count=self.train_dataset_len // self.batch_size)

self.device = torch_xla.device()
self.device = torch.device('xla')
self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
self.model = decoder_cls(self.config).to(self.device)
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=0.0001)
Expand Down
2 changes: 1 addition & 1 deletion examples/train_resnet_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def train_loop_fn(self, loader, epoch):
for step, (data, target) in enumerate(loader):
self.optimizer.zero_grad()
# Enables autocasting for the forward pass
with autocast(torch_xla.device()):
with autocast(torch.device('xla')):
output = self.model(data)
loss = self.loss_fn(output, target)
# TPU amp uses bf16 hence gradient scaling is not necessary. If runnign with XLA:GPU
Expand Down
2 changes: 1 addition & 1 deletion examples/train_resnet_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self):
sample_count=self.train_dataset_len // self.batch_size //
xr.world_size())

self.device = torch_xla.device()
self.device = torch.device('xla')
self.train_device_loader = pl.MpDeviceLoader(train_loader, self.device)
self.model = torchvision.models.resnet50().to(self.device)
self.optimizer = optim.SGD(self.model.parameters(), weight_decay=1e-4)
Expand Down
2 changes: 1 addition & 1 deletion plugins/cpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,5 @@ plugins.use_dynamic_plugins()
plugins.register_plugin('CPU', torch_xla_cpu_plugin.CpuPlugin())
xr.set_device_type('CPU')

print(torch_xla.device())
print(torch.device('xla'))
```
2 changes: 1 addition & 1 deletion plugins/cuda/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ plugins.use_dynamic_plugins()
plugins.register_plugin('CUDA', torch_xla_cuda_plugin.CudaPlugin())
xr.set_device_type('CUDA')

print(torch_xla.device())
print(torch.device('xla'))
```
2 changes: 1 addition & 1 deletion test/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class BaseBench(object):

def __init__(self, args):
self.args = args
self.device = torch_xla.device()
self.device = torch.device('xla')
self.test_time = xu.getenv_as('BENCH_TEST_TIME', float, 5.0)
torch.manual_seed(42)

Expand Down
2 changes: 1 addition & 1 deletion test/debug_tool/test_mp_pt_xla_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def _mp_fn(index):
assert False, "This test should be run with PT_XLA_DEBUG_FILE"
if index == 0:
open(debug_file_name, 'w').close()
device = torch_xla.device()
device = torch.device('xla')
t1 = torch.randn(10, 10, device=device)
t2 = t1 * 100
torch_xla.sync()
Expand Down
Loading
Loading