Skip to content

Commit caa40cc

Browse files
authored
refactor. (#9040)
1 parent 047de24 commit caa40cc

22 files changed

+177
-143
lines changed

docs/source/contribute/bazel.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Bazel in Pytorch/XLA
1+
# Building with Bazel
22

33
[Bazel](https://bazel.build/) is a free software tool used for the
44
automation of building and testing software.

docs/source/contribute/codegen_migration.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Codegen migration Guide
1+
# Codegen Migration Guide
22

33
As PyTorch/XLA migrates to the LTC (Lazy Tensor Core), we need to clean
44
up the existing stub code (which spans over 6+ files) that were used to

docs/source/contribute/configure-environment.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Configure a development environment
1+
# Configure A Development Environment
22

33
The goal of this guide is to set up an interactive development
44
environment on a Cloud TPU with PyTorch/XLA installed. If this is your

docs/source/contribute/op_lowering.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# OP Lowering Guide
1+
# Op Lowering Guide
22

33
PyTorch wraps the C++ ATen tensor library that offers a wide range of
44
operations implemented on GPU and CPU. Pytorch/XLA is a PyTorch

docs/source/contribute/plugins.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Custom Hardware Plugins
22

33
PyTorch/XLA supports custom hardware through OpenXLA's PJRT C API. The
4-
PyTorch/XLA team direclty supports plugins for Cloud TPU (`libtpu`) and
4+
PyTorch/XLA team directly supports plugins for Cloud TPU (`libtpu`) and
55
GPU ([OpenXLA](https://github.com/openxla/xla/tree/main/xla/pjrt/gpu)).
66
The same plugins may also be used by JAX and TF.
77

docs/source/features/scan.md

+16-17
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,26 @@
1-
# Guide for using `scan` and `scan_layers`
1+
# Optimizing Repeated Layers with `scan` and `scan_layers`
22

33
This is a guide for using `scan` and `scan_layers` in PyTorch/XLA.
44

55
## When should you use this
66

7-
You should consider using [`scan_layers`][scan_layers] if you have a model with
7+
Consider using [`scan_layers`][scan_layers] if you have a model with
88
many homogenous (same shape, same logic) layers, for example LLMs. These models
99
can be slow to compile. `scan_layers` is a drop-in replacement for a for loop over
1010
homogenous layers, such as a bunch of decoder layers. `scan_layers` traces the
1111
first layer and reuses the compiled result for all subsequent layers, significantly
1212
reducing the model compile time.
1313

1414
[`scan`][scan] on the other hand is a lower level higher-order-op modeled after
15-
[`jax.lax.scan`][jax-lax-scan]. Its primary purpose is to help implement
16-
`scan_layers` under the hood. However, you may find it useful if you would like
17-
to program some sort of loop logic where the loop itself has a first-class
18-
representation in the compiler (specifically, an XLA `While` op).
15+
[`jax.lax.scan`][jax-lax-scan]. Its primary purpose is to implement
16+
`scan_layers` under the hood. However, you may find it useful
17+
to program loop logic where the loop itself has a first-class
18+
representation in the compiler (specifically, the XLA `while` op).
1919

2020
## `scan_layers` example
2121

2222
Typically, a transformer model passes the input embedding through a sequence of
23-
homogenous decoder layers like the following:
23+
homogenous decoder layers:
2424

2525
```python
2626
def run_decoder_layers(self, hidden_states):
@@ -31,7 +31,7 @@ def run_decoder_layers(self, hidden_states):
3131

3232
When this function is lowered into an HLO graph, the for loop is unrolled into a
3333
flat list of operations, resulting in long compile times. To reduce compile
34-
times, you can replace the for loop with a call to `scan_layers`, as shown in
34+
times, replace the for loop with `scan_layers`, as shown in
3535
[`decoder_with_scan.py`][decoder_with_scan]:
3636

3737
```python
@@ -61,7 +61,7 @@ def scan(
6161
...
6262
```
6363

64-
You can use it to loop over the leading dimension of tensors efficiently. If `xs`
64+
Use it to loop over the leading dimension of tensors efficiently. If `xs`
6565
is a single tensor, this function is roughly equal to the following Python code:
6666

6767
```python
@@ -74,8 +74,8 @@ def scan(fn, init, xs):
7474
return carry, torch.stack(ys, dim=0)
7575
```
7676

77-
Under the hood, `scan` is implemented much more efficiently by lowering the loop
78-
into an XLA `While` operation. This ensures that only one iteration of the loop
77+
Under the hood, `scan` is implemented efficiently by lowering the loop
78+
into an XLA `while` operation. This ensures that only one iteration of the loop
7979
is compiled by XLA.
8080

8181
[`scan_examples.py`][scan_examples] contains some example code showing how to use
@@ -114,19 +114,18 @@ Means over time: tensor([[1.0000],
114114
The functions/modules passed to `scan` and `scan_layers` must be AOTAutograd
115115
traceable. In particular, as of PyTorch/XLA 2.6, `scan` and `scan_layers` cannot
116116
trace functions with custom Pallas kernels. That means if your decoder uses,
117-
for example flash attention, then it's incompatible with `scan`. We are working on
118-
[supporting this important use case][flash-attn-issue] in nightly and the next
119-
releases.
117+
for example flash attention, then it is incompatible with `scan`. We are working on
118+
[supporting this important use case][flash-attn-issue].
120119

121120
### AOTAutograd overhead
122121

123122
Because `scan` uses AOTAutograd to figure out the backward pass of the input
124-
function/module on every iteration, it's easy to become tracing bound compared to
123+
function/module on every iteration, it is easy to become tracing-bound compared to
125124
a for loop implementation. In fact, the `train_decoder_only_base.py` example runs
126125
slower under `scan` than with for loop as of PyTorch/XLA 2.6 due to this overhead.
127126
We are working on [improving tracing speed][retracing-issue]. This is less of a
128127
problem when your model is very large or has many layers, which are the situations
129-
you would want to use `scan` anyways.
128+
you would want to use `scan`.
130129

131130
## Compile time experiments
132131

@@ -180,7 +179,7 @@ Metric: CompileTime
180179
99%=18s995ms301.667us
181180
```
182181

183-
We can see that the maximum compile time dropped from `1m03s` to `19s` by
182+
The maximum compile time dropped from `1m03s` to `19s` by
184183
switching to `scan_layers`.
185184

186185
## References

docs/source/features/distop.md renamed to docs/source/features/torch_distributed.md

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
# Support of Torch Distributed API in PyTorch/XLA
2-
Before the 2.5 release, PyTorch/XLA only supported collective ops through our custom API call `torch_xla.core.xla_model.*`. In the 2.5 release, we adopt `torch.distributed.*` in PyTorch/XLA for both Dynamo and non-Dynamo cases.
1+
# Support for Torch Distributed
2+
3+
Before the 2.5 release, PyTorch/XLA only supported collective ops through the custom API call `torch_xla.core.xla_model.*`. In the 2.5 release, we adopted `torch.distributed.*` in PyTorch/XLA for both Dynamo and non-Dynamo cases.
4+
35
## Collective ops lowering
6+
47
### Collective ops lowering stack
58
After introducing the [traceable collective communication APIs](https://github.com/pytorch/pytorch/issues/93173), dynamo can support the collective ops with reimplementing lowering in PyTorch/XLA. The collective op is only traceable through `torch.ops._c10d_functional` call. Below is the figure that shows how the collective op, `all_reduce` in this case, is lowered between torch and torch_xla:
69

docs/source/index.rst

+70-34
Original file line numberDiff line numberDiff line change
@@ -2,72 +2,108 @@
22

33
PyTorch/XLA documentation
44
===================================
5-
PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs.
5+
``torch_xla`` is a Python package that implements \
6+
`XLA <https://openxla.org/xla>`_ as a backend for PyTorch.
7+
8+
+------------------------------------------------+------------------------------------------------+------------------------------------------------+
9+
| **Familiar APIs** | **High Performance** | **Cost Efficient** |
10+
| | | |
11+
| Create and train PyTorch models on TPUs, | Scale training jobs across thousands of | TPU hardware and the XLA compiler are optimized|
12+
| with only minimal changes required. | TPU cores while maintaining high MFU. | for cost-efficient training and inference. |
13+
+------------------------------------------------+------------------------------------------------+------------------------------------------------+
14+
15+
Getting Started
16+
---------------
17+
18+
Install with pip.
19+
20+
.. code-block:: sh
21+
22+
pip install torch torch_xla[tpu]
23+
24+
Verify the installation:
25+
26+
.. code-block:: sh
27+
28+
python -c "import torch_xla; print(torch_xla.__version__)"
29+
python -c "import torch; import torch_xla; print(torch.tensor(1.0, device='xla').device)"
30+
31+
Tutorials
32+
---------
633

734
.. toctree::
835
:glob:
936
:maxdepth: 1
10-
:caption: Learn about Pytorch/XLA
37+
:caption: Learn the Basics
1138

12-
learn/xla-overview
1339
learn/pytorch-on-xla-devices
14-
learn/api-guide
15-
learn/dynamic_shape
16-
learn/eager
17-
learn/pjrt
18-
learn/troubleshoot
40+
learn/xla-overview
1941

2042
.. toctree::
2143
:glob:
2244
:maxdepth: 1
23-
:caption: Learn about accelerators
45+
:caption: Distributed Training on TPU
2446

2547
accelerators/tpu
26-
accelerators/gpu
48+
perf/spmd_basic
49+
perf/spmd_advanced
50+
perf/spmd_distributed_checkpoint
51+
features/torch_distributed
52+
perf/ddp
53+
perf/fsdp_collectives
54+
perf/fsdp_spmd
2755

2856
.. toctree::
2957
:glob:
3058
:maxdepth: 1
31-
:caption: Run ML workloads with Pytorch/XLA
59+
:caption: Advanced Techniques
3260

33-
workloads/kubernetes
61+
features/pallas
62+
features/stablehlo
63+
perf/amp
64+
learn/dynamic_shape
65+
perf/dynamo
66+
perf/quantized_ops
67+
features/scan
68+
perf/fori_loop
69+
perf/assume_pure
3470

3571
.. toctree::
3672
:glob:
3773
:maxdepth: 1
38-
:caption: PyTorch/XLA features
74+
:caption: Troubleshooting
3975

40-
features/pallas.md
41-
features/stablehlo.md
42-
features/triton.md
43-
features/scan.md
76+
learn/troubleshoot
77+
learn/eager
78+
notes/source_of_recompilation
79+
perf/recompilation
4480

4581
.. toctree::
4682
:glob:
4783
:maxdepth: 1
48-
:caption: Improve Pytorch/XLA workload performance
84+
:caption: Training on GPU
4985

50-
perf/amp
51-
perf/spmd_basic
52-
perf/spmd_advanced
53-
perf/spmd_distributed_checkpoint
86+
accelerators/gpu
87+
features/triton
5488
perf/spmd_gpu
55-
perf/ddp
56-
perf/dynamo
57-
perf/fori_loop
58-
perf/fsdp
59-
perf/fsdpv2
60-
perf/quantized_ops
61-
perf/recompilation
62-
89+
6390
.. toctree::
6491
:glob:
6592
:maxdepth: 1
66-
:caption: Contribute to Pytorch/XLA
93+
:caption: Contributing
6794

95+
contribute/bazel
6896
contribute/configure-environment
69-
contribute/codegen_migration
97+
contribute/cpp_debugger
7098
contribute/op_lowering
99+
contribute/codegen_migration
71100
contribute/plugins
72-
contribute/bazel
73-
contribute/recompilation
101+
102+
API Reference
103+
-------------
104+
105+
.. toctree::
106+
:glob:
107+
:maxdepth: 2
108+
109+
learn/api-guide
File renamed without changes.

docs/source/learn/dynamic_shape.md

+17-14
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,33 @@
1-
# Dynamic shape
1+
# Dynamic Shapes
22

3-
Dynamic shape refers to the variable nature of a tensor shape where its shape depends on the value of another upstream tensor. For example:
4-
```
3+
Dynamic shapes means a tensor's shape depends on the value of another tensor. For example:
4+
```python
55
>>> import torch, torch_xla
66
>>> in_tensor = torch.randint(low=0, high=2, size=(5,5), device='xla:0')
77
>>> out_tensor = torch.nonzero(in_tensor)
88
```
9-
the shape of `out_tensor` depends on the value of `in_tensor` and is bounded by the shape of `in_tensor`. In other words, if you do
10-
```
9+
10+
The shape of `out_tensor` depends on the value of `in_tensor` and is bounded by the shape of `in_tensor`. In other words, if you do
11+
12+
```python
1113
>>> print(out_tensor.shape)
1214
torch.Size([<=25, 2])
1315
```
14-
you can see the first dimension depends on the value of `in_tensor` and its maximum value is 25. We call the first dimension the dynamic dimension. The second dimension does not depend on any upstream tensors so we call it the static dimension.
16+
the first dimension depends on the value of `in_tensor` and its maximum value is 25. We call the first dimension the dynamic dimension. The second dimension does not depend on any upstream tensors so we call it the static dimension.
1517

1618
Dynamic shape can be further categorized into bounded dynamic shape and unbounded dynamic shape.
17-
- bounded dynamic shape: refers to a shape whose dynamic dimensions are bounded by static values. It works for accelerators that require static memory allocation (e.g. TPU).
18-
- unbounded dynamic shape: refers to a shape whose dynamic dimensions can be infinitely large. It works for accelerators that don’t require static memory allocation (e.g. GPU).
19+
- Bounded dynamic shape: refers to a shape whose dynamic dimensions are bounded by static values. It works for accelerators that require static memory allocation (e.g. TPU).
20+
- Unbounded dynamic shape: refers to a shape whose dynamic dimensions can be infinitely large. It works for accelerators that don’t require static memory allocation (e.g. GPU).
1921

2022
Today, only the bounded dynamic shape is supported and it is in the experimental phase.
2123

2224
## Bounded dynamic shape
2325

2426
Currently, we support multi-layer perceptron models (MLP) with dynamic size input on TPU.
2527

26-
This feature is controlled by a flag `XLA_EXPERIMENTAL="nonzero:masked_select"`. To run a model with the feature enabled, you can do:
27-
```
28+
This feature is controlled by a flag `XLA_EXPERIMENTAL="nonzero:masked_select"`. To run a model with the feature enabled, launch Python with the following environment variable:
29+
30+
```sh
2831
XLA_EXPERIMENTAL="nonzero:masked_select:masked_scatter" python your_scripts.py
2932
```
3033

@@ -40,8 +43,8 @@ Here are some numbers we get when we run the MLP model for 100 iterations:
4043

4144
One of the motivations of the dynamic shape is to reduce the number of excessive recompilation when the shape keeps changing between iterations. From the figure above, you can see the number of compilations reduced by half which results in the drop of the training time.
4245

43-
To try it out, run
44-
```
46+
To try it:
47+
48+
```sh
4549
XLA_EXPERIMENTAL="nonzero:masked_select" PJRT_DEVICE=TPU python3 pytorch/xla/test/ds/test_dynamic_shape_models.py TestDynamicShapeModels.test_backward_pass_with_dynamic_input
46-
```
47-
For more details on how we plan to expand the dynamic shape support on PyTorch/XLA in the future, feel free to review our [RFC](https://github.com/pytorch/xla/issues/3884).
50+
```

docs/source/learn/troubleshoot.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Troubleshoot
1+
# Troubleshooting Basics
22

33
Note that the information in this section is subject to be removed in
44
future releases of the *PyTorch/XLA* software, since many of them are

docs/source/learn/xla-overview.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Pytorch/XLA overview
1+
# Pytorch/XLA Overview
22

33
This section provides a brief overview of the basic details of PyTorch
44
XLA, which should help readers better understand the required

docs/source/perf/assume_pure.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Use `@assume_pure` to speed up lazy tensor tracing
1+
# Speed Up Tracing with `@assume_pure`
22

33
This document explains how to use `torch_xla.experimental.assume_pure` to
44
eliminate lazy tensor tracing overhead. See [this blog post][lazy-tensor] for a

docs/source/perf/ddp.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# How to do DistributedDataParallel(DDP)
1+
# Distributed Data Parallel (DDP)
22

33
This document shows how to use torch.nn.parallel.DistributedDataParallel
44
in xla, and further describes its difference against the native xla data

0 commit comments

Comments
 (0)