You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[Doc] highlight some features as experimental (#2152)
* generic python
* Update feature list in release note
* fine tune, add experimental to horovod, simple trace and profiler_legacy
* Update CPU part in release note
* add cpu to OS matrix
* DDP doc: Add torch-ccl source build command for cpu (#2159)
---------
Co-authored-by: Ye Ting <[email protected]>
Co-authored-by: zhuhong61 <[email protected]>
Copy file name to clipboardExpand all lines: docs/tutorials/features.rst
+6-6Lines changed: 6 additions & 6 deletions
Original file line number
Diff line number
Diff line change
@@ -66,9 +66,9 @@ On Intel® GPUs, quantization usages follow PyTorch default quantization APIs. C
66
66
Distributed Training
67
67
--------------------
68
68
69
-
To meet demands of large scale model training over multiple devices, distributed training on Intel® GPUs and CPUs are supported. Two alternative methodologies are available. Users can choose either to use PyTorch native distributed training module, `Distributed Data Parallel (DDP) <https://pytorch.org/docs/stable/notes/ddp.html>`_, with `Intel® oneAPI Collective Communications Library (oneCCL) <https://www.intel.com/content/www/us/en/developer/tools/oneapi/oneccl.html>`_ support via `Intel® oneCCL Bindings for PyTorch (formerly known as torch_ccl) <https://github.com/intel/torch-ccl>`_ or use Horovod with `Intel® oneAPI Collective Communications Library (oneCCL) <https://www.intel.com/content/www/us/en/developer/tools/oneapi/oneccl.html>`_ support.
69
+
To meet demands of large scale model training over multiple devices, distributed training on Intel® GPUs and CPUs are supported. Two alternative methodologies are available. Users can choose either to use PyTorch native distributed training module, `Distributed Data Parallel (DDP) <https://pytorch.org/docs/stable/notes/ddp.html>`_, with `Intel® oneAPI Collective Communications Library (oneCCL) <https://www.intel.com/content/www/us/en/developer/tools/oneapi/oneccl.html>`_ support via `Intel® oneCCL Bindings for PyTorch (formerly known as torch_ccl) <https://github.com/intel/torch-ccl>`_ or use Horovod with `Intel® oneAPI Collective Communications Library (oneCCL) <https://www.intel.com/content/www/us/en/developer/tools/oneapi/oneccl.html>`_ support (Experimental).
70
70
71
-
For more detailed information, check `DDP <features/DDP.md>`_ and `Horovod <features/horovod.md>`_.
71
+
For more detailed information, check `DDP <features/DDP.md>`_ and `Horovod (Experimental) <features/horovod.md>`_.
72
72
73
73
.. toctree::
74
74
:hidden:
@@ -122,8 +122,8 @@ For more detailed information, check `Advanced Configuration <features/advanced_
122
122
features/advanced_configuration
123
123
124
124
125
-
Legacy Profiler Tool
126
-
--------------------
125
+
Legacy Profiler Tool (Experimental)
126
+
-----------------------------------
127
127
128
128
The legacy profiler tool is an extension of PyTorch* legacy profiler for profiling operators' overhead on XPU devices. With this tool, users can get the information in many fields of the run models or code scripts. User should build Intel® Extension for PyTorch* with profiler support as default and enable this tool by a `with` statement before the code segment.
129
129
@@ -135,8 +135,8 @@ For more detailed information, check `Legacy Profiler Tool <features/profiler_le
135
135
136
136
features/profiler_legacy
137
137
138
-
Simple Trace Tool
139
-
-----------------
138
+
Simple Trace Tool (Experimental)
139
+
--------------------------------
140
140
141
141
Simple Trace is a built-in debugging tool that lets you control printing out the call stack for a piece of code. Once enabled, it can automatically print out verbose messages of called operators in a stack format with indenting to distinguish the context.
Copy file name to clipboardExpand all lines: docs/tutorials/features/DDP.md
+29-11Lines changed: 29 additions & 11 deletions
Original file line number
Diff line number
Diff line change
@@ -1,14 +1,15 @@
1
-
# DistributedDataParallel (DDP)
1
+
DistributedDataParallel (DDP)
2
+
=============================
2
3
3
4
## Introduction
4
5
5
6
`DistributedDataParallel (DDP)` is a PyTorch\* module that implements multi-process data parallelism across multiple GPUs and machines. With DDP, the model is replicated on every process, and each model replica is fed a different set of input data samples. DDP enables overlapping between gradient communication and gradient computations to speed up training. Please refer to [DDP Tutorial](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html) for an introduction to DDP.
6
7
7
-
The PyTorch `Collective Communication (c10d)` library supports communication across processes. To run DDP on XPU, we use Intel® oneCCL Bindings for Pytorch\* (formerly known as torch-ccl) to implement the PyTorch c10d ProcessGroup API (https://github.com/intel/torch-ccl). It holds PyTorch bindings maintained by Intel for the Intel® oneAPI Collective Communications Library\* (oneCCL), a library for efficient distributed deep learning training implementing such collectives as `allreduce`, `allgather`, and `alltoall`. Refer to [oneCCL Github page](https://github.com/oneapi-src/oneCCL) for more information about oneCCL.
8
+
The PyTorch `Collective Communication (c10d)` library supports communication across processes. To run DDP on GPU, we use Intel® oneCCL Bindings for Pytorch\* (formerly known as torch-ccl) to implement the PyTorch c10d ProcessGroup API (https://github.com/intel/torch-ccl). It holds PyTorch bindings maintained by Intel for the Intel® oneAPI Collective Communications Library\* (oneCCL), a library for efficient distributed deep learning training implementing such collectives as `allreduce`, `allgather`, and `alltoall`. Refer to [oneCCL Github page](https://github.com/oneapi-src/oneCCL) for more information about oneCCL.
8
9
9
10
## Installation of Intel® oneCCL Bindings for Pytorch\*
10
11
11
-
To use PyTorch DDP on XPU, install Intel® oneCCL Bindings for Pytorch\* as described below.
12
+
To use PyTorch DDP on GPU, install Intel® oneCCL Bindings for Pytorch\* as described below.
12
13
13
14
### Install PyTorch and Intel® Extension for PyTorch\*
14
15
@@ -19,6 +20,18 @@ For more detailed information, check [installation guide](../installation.md).
**Note:** Make sure you have installed basekit from https://www.intel.com/content/www/us/en/developer/tools/oneapi/toolkits.html#base-kit
62
+
**Note:** Make sure you have installed [basekit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/toolkits.html#base-kit) when using Intel® oneCCL Bindings for Pytorch\* on Intel® GPUs.
45
63
46
64
```bash
47
65
source$basekit_root/ccl/latest/env/vars.sh
@@ -165,7 +183,7 @@ For using one GPU card with multiple tiles, each tile could be regarded as a dev
165
183
166
184
### Usage of DDP scaling API
167
185
168
-
Note: This API supports XPU devices on one card.
186
+
Note: This API supports GPU devices on one card.
169
187
170
188
```python
171
189
Args:
@@ -221,5 +239,5 @@ print("DDP Use XPU: {} for training".format(xpu))
Copy file name to clipboardExpand all lines: docs/tutorials/features/advanced_configuration.md
+2-1Lines changed: 2 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -1,4 +1,5 @@
1
-
# Advanced Configuration
1
+
Advanced Configuration
2
+
======================
2
3
3
4
The default settings for Intel® Extension for PyTorch\* are sufficient for most use cases. However, if users want to customize Intel® Extension for PyTorch\*, advanced configuration is available at build time and runtime.
Copy file name to clipboardExpand all lines: docs/tutorials/features/horovod.md
+2-1Lines changed: 2 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -1,4 +1,5 @@
1
-
# Horovod with PyTorch
1
+
Horovod with PyTorch (Experimental)
2
+
===================================
2
3
3
4
Horovod is a distributed deep learning training framework for TensorFlow, Keras, PyTorch, and Apache MXNet. The goal of Horovod is to make distributed deep learning fast and easy to use. Horovod core principles are based on MPI concepts such as size, rank, local rank, allreduce, allgather, broadcast, and alltoall. To use Horovod with PyTorch, you need to install Horovod with Pytorch first, and make specific change for Horovod in your training script.
0 commit comments