You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/source/features/torch_distributed.md
+5-2
Original file line number
Diff line number
Diff line change
@@ -1,6 +1,9 @@
1
-
# Support of Torch Distributed API in PyTorch/XLA
2
-
Before the 2.5 release, PyTorch/XLA only supported collective ops through our custom API call `torch_xla.core.xla_model.*`. In the 2.5 release, we adopt `torch.distributed.*` in PyTorch/XLA for both Dynamo and non-Dynamo cases.
1
+
# Support for Torch Distributed
2
+
3
+
Before the 2.5 release, PyTorch/XLA only supported collective ops through the custom API call `torch_xla.core.xla_model.*`. In the 2.5 release, we adopted `torch.distributed.*` in PyTorch/XLA for both Dynamo and non-Dynamo cases.
4
+
3
5
## Collective ops lowering
6
+
4
7
### Collective ops lowering stack
5
8
After introducing the [traceable collective communication APIs](https://github.com/pytorch/pytorch/issues/93173), dynamo can support the collective ops with reimplementing lowering in PyTorch/XLA. The collective op is only traceable through `torch.ops._c10d_functional` call. Below is the figure that shows how the collective op, `all_reduce` in this case, is lowered between torch and torch_xla:
the shape of `out_tensor` depends on the value of `in_tensor` and is bounded by the shape of `in_tensor`. In other words, if you do
10
-
```
9
+
10
+
The shape of `out_tensor` depends on the value of `in_tensor` and is bounded by the shape of `in_tensor`. In other words, if you do
11
+
12
+
```python
11
13
>>>print(out_tensor.shape)
12
14
torch.Size([<=25, 2])
13
15
```
14
-
you can see the first dimension depends on the value of `in_tensor` and its maximum value is 25. We call the first dimension the dynamic dimension. The second dimension does not depend on any upstream tensors so we call it the static dimension.
16
+
the first dimension depends on the value of `in_tensor` and its maximum value is 25. We call the first dimension the dynamic dimension. The second dimension does not depend on any upstream tensors so we call it the static dimension.
15
17
16
18
Dynamic shape can be further categorized into bounded dynamic shape and unbounded dynamic shape.
17
-
-bounded dynamic shape: refers to a shape whose dynamic dimensions are bounded by static values. It works for accelerators that require static memory allocation (e.g. TPU).
18
-
-unbounded dynamic shape: refers to a shape whose dynamic dimensions can be infinitely large. It works for accelerators that don’t require static memory allocation (e.g. GPU).
19
+
-Bounded dynamic shape: refers to a shape whose dynamic dimensions are bounded by static values. It works for accelerators that require static memory allocation (e.g. TPU).
20
+
-Unbounded dynamic shape: refers to a shape whose dynamic dimensions can be infinitely large. It works for accelerators that don’t require static memory allocation (e.g. GPU).
19
21
20
22
Today, only the bounded dynamic shape is supported and it is in the experimental phase.
21
23
22
24
## Bounded dynamic shape
23
25
24
26
Currently, we support multi-layer perceptron models (MLP) with dynamic size input on TPU.
25
27
26
-
This feature is controlled by a flag `XLA_EXPERIMENTAL="nonzero:masked_select"`. To run a model with the feature enabled, you can do:
27
-
```
28
+
This feature is controlled by a flag `XLA_EXPERIMENTAL="nonzero:masked_select"`. To run a model with the feature enabled, launch Python with the following environment variable:
@@ -40,8 +43,8 @@ Here are some numbers we get when we run the MLP model for 100 iterations:
40
43
41
44
One of the motivations of the dynamic shape is to reduce the number of excessive recompilation when the shape keeps changing between iterations. From the figure above, you can see the number of compilations reduced by half which results in the drop of the training time.
For more details on how we plan to expand the dynamic shape support on PyTorch/XLA in the future, feel free to review our [RFC](https://github.com/pytorch/xla/issues/3884).
0 commit comments