You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: README.md
+24-246
Original file line number
Diff line number
Diff line change
@@ -19,6 +19,13 @@ This repository still contains:
19
19
20
20
# TorchDynamo
21
21
22
+
> TorchDynamo makes it easy to experiment with different compiler backends to make PyTorch code faster with a single line decorator `torch._dynamo.optimize()`
23
+
24
+
TorchDynamo supports arbitrary PyTorch code, control flow, mutation and dynamic shapes.
25
+
26
+
You can follow our nightly benchmarks [here](https://github.com/pytorch/torchdynamo/issues/681)
27
+
28
+
22
29
TorchDynamo is a Python-level JIT compiler designed to make unmodified
23
30
PyTorch programs faster. TorchDynamo hooks into the frame evaluation API
24
31
in CPython ([PEP 523]) to dynamically modify Python bytecode right before
@@ -90,279 +97,50 @@ cd tools/dynamo
90
97
python verify_dynamo.py
91
98
```
92
99
93
-
## Usage Example
100
+
## Getting started
94
101
95
-
Here is a basic example of how to use TorchDynamo. One can decorate a function
96
-
or a method using `torch._dynamo.optimize` to enable TorchDynamo optimization.
102
+
Here is a basic example of how to use TorchDynamo. You can decorate a function
103
+
or a method using `torch._dynamo.optimize()` and pass in the name of a compiler e.g: inductor and your code will run faster.
call_function mul <built-in function mul> (x, b) {}
206
-
output output output ((mul,),) {}
207
-
```
113
+
It's also easy to define your own compiler backends in pure python [custom backend](./documentation/custom-backend.md)
208
114
209
-
Note that the order of the last two graphs is nondeterministic depending
210
-
on which one is encountered first by the just-in-time compiler.
211
115
212
116
### Existing Backends
213
117
214
-
TorchDynamo has a growing list of backends, which can be found in [backends.py]
215
-
or `torchdynamo.list_backends()`. Note many backends require installing
216
-
additional packages. Some of the most commonly used backends are
118
+
TorchDynamo has a growing list of backends, which can be found in [backends.py](https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/optimizations/backends.py)
119
+
or `torchdynamo.list_backends()` each of which with its optional dependencies.
120
+
121
+
Some of the most commonly used backends are
217
122
218
-
Debugging backends:
123
+
**Debugging backends**:
219
124
*`dynamo.optimize("eager")` - Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo issues.
220
125
*`dynamo.optimize("aot_eager")` - Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups.
221
126
222
-
Training & inference backends:
223
-
*`dynamo.optimize("inductor")` - Uses TorchInductor backend with AotAutograd and cudagraphs.[Read more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747)
127
+
**Training & inference backends**:
128
+
*`dynamo.optimize("inductor")` - Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton kernels[Read more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747)
224
129
*`dynamo.optimize("nvfuser")` - nvFuser with TorchScript. [Read more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)
225
130
*`dynamo.optimize("aot_nvfuser")` - nvFuser with AotAutograd. [Read more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)
226
131
*`dynamo.optimize("aot_cudagraphs")` - cudagraphs with AotAutograd. [Read more](https://github.com/pytorch/torchdynamo/pull/757)
*`dynamo.optimize("fx2trt")` - Uses Nvidia TensorRT for inference optimizations. [Read more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst)
231
136
*`dynamo.optimize("onnxrt")` - Uses ONNXRT for inference on CPU/GPU. [Read more](https://onnxruntime.ai/)
232
137
*`dynamo.optimize("ipex")` - Uses IPEX for inference on CPU. [Read more](https://github.com/intel/intel-extension-for-pytorch)
233
138
234
-
### Training and AotAutograd
235
-
236
-
Torchdynamo supports training, using AotAutograd to capture backwards:
237
-
* the .forward() graph and optimizer.step() is captured by torchdynamo's python evalframe frontend
238
-
* for each segment of .forward() that torchdynamo captures, it uses AotAutograd to generate a backward graph segment
239
-
* each pair of forward, backward graph are (optionally) min-cut partitioned to save the minimal state between forward/backward
240
-
* the forward, backward pairs are wrapped in autograd.function modules
241
-
* usercode calling .backward() still triggers eager's autograd engine, which runs each 'compiled backward' graph as if it were one op, also running any non-compiled eager ops' .backward() functions
242
-
243
-
Current limitations:
244
-
* DDP and FSDP, which rely on autograd 'hooks' firing between backward ops to schedule communications ops, may be pessimized by having all communication ops scheduled _after_ whole compiled regions of backwards ops (WIP to fix this)
245
-
246
-
Example
247
-
```py
248
-
model =...
249
-
optimizer =...
250
-
251
-
@dynamo.optimize("inductor")
252
-
deftraining_iter_fn(...):
253
-
outputs = model(...)
254
-
loss = outputs.loss
255
-
loss.backward()
256
-
optimizer.step()
257
-
optimizer.zero_grad()
258
-
return loss
259
-
260
-
for _ inrange(100):
261
-
loss = training_iter_fn(...)
262
-
```
263
-
For more details, you can follow our [E2E model training benchmark](./benchmarks/training_loss.py) to onboard your own model training and evaluation. It's running the popular [hugging face Bert model](https://huggingface.co/docs/transformers/training) on [Yelp Reviews datasets](https://huggingface.co/datasets/yelp_review_full). It also prints our if the loss converged and performance speedup comparing to native PyTorch at the end.
264
-
265
-
266
-
## Troubleshooting
267
-
See [Troubleshooting](./documentation/TROUBLESHOOTING.md).
268
-
269
-
## Adding Backends
270
-
271
-
One could replace `my_compiler()` in the examples above with something that generates faster
272
-
code, for example one using [optimize_for_inference]:
Copy file name to clipboardExpand all lines: documentation/FAQ.md
+20
Original file line number
Diff line number
Diff line change
@@ -5,6 +5,15 @@ Below is the TorchDynamo compiler stack.
5
5
6
6
At a high level, the TorchDynamo stack consists of a graph capture from Python code using dynamo and a backend compiler. In this example the backend compiler consists of backward graph tracing using AOTAutograd and graph lowering using TorchInductor. There are of course many more compilers available here https://github.com/pytorch/torchdynamo/blob/0b8aaf340dad4777a080ef24bf09623f1aa6f3dd/README.md#existing-backend but for this document we will focus on inductor as a motivating example
7
7
8
+
Torchdynamo supports training, using AotAutograd to capture backwards:
9
+
1. the `.forward()` graph and `optimizer.step()` is captured by torchdynamo's python evalframe frontend
10
+
2. for each segment of `.forward()` that torchdynamo captures, it uses AotAutograd to generate a backward graph segment
11
+
3. each pair of forward, backward graph are (optionally) min-cut partitioned to save the minimal state between forward/backward
12
+
4. the forward, backward pairs are wrapped in autograd.function modules
13
+
5. usercode calling` .backward()` still triggers eager's autograd engine, which runs each 'compiled backward' graph as if it were one op, also running any non-compiled eager ops' .backward() functions
14
+
15
+
Current limitations:
16
+
* DDP and FSDP, which rely on autograd 'hooks' firing between backward ops to schedule communications ops, may be pessimized by having all communication ops scheduled _after_ whole compiled regions of backwards ops (WIP to fix this)
8
17
9
18
## Why is my code crashing?
10
19
@@ -69,6 +78,17 @@ print(prof.report())
69
78
70
79
Many of the reasons for graph breaks and excessive recompilation will be fixed with upcoming support for [tracing dynamic tensor shapes](https://docs.google.com/document/d/1QJB-GOnbv-9PygGlOMXwiO9K6vVNm8sNg_olixJ9koc/edit?usp=sharing), more careful choices for guards and better tuned heuristics.
71
80
81
+
### Why are you recompiling in production?
82
+
83
+
In some cases, you may not want unexpected compiles after a program
84
+
has warmed up. For example, if you are serving production traffic in a
85
+
latency critical application. For this, TorchDynamo provides an alternate
86
+
mode where prior compiled graphs are used, but no new ones are generated:
0 commit comments