Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
xmfan committed Sep 6, 2024
1 parent 50a6978 commit 271b8f2
Showing 1 changed file with 43 additions and 37 deletions.
80 changes: 43 additions & 37 deletions intermediate_source/compiled_autograd_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,37 +11,35 @@
.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
:class-card: card-prerequisites
* How compiled autograd interacts with torch.compile
* How compiled autograd interacts with ``torch.compile``
* How to use the compiled autograd API
* How to inspect logs using TORCH_LOGS
* How to inspect logs using ``TORCH_LOGS``
.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
:class-card: card-prerequisites
* PyTorch 2.4
* `torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_ familiarity
* Complete the `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`_
"""

######################################################################
# Overview
# ------------
# Compiled Autograd is a torch.compile extension introduced in PyTorch 2.4
# Compiled Autograd is a ``torch.compile`` extension introduced in PyTorch 2.4
# that allows the capture of a larger backward graph.
#
# Doesn't torch.compile already capture the backward graph?
# ------------
# And it does, **partially**. AOTAutograd captures the backward graph ahead-of-time, but with certain limitations:
# 1. Graph breaks in the forward lead to graph breaks in the backward
# 2. `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
# While ``torch.compile`` does capture the backward graph, it does so **partially**. The AOTAutograd component captures the backward graph ahead-of-time, with certain limitations:
# * Graph breaks in the forward lead to graph breaks in the backward
# * `Backward hooks <https://pytorch.org/docs/stable/notes/autograd.html#backward-hooks-execution>`_ are not captured
#
# Compiled Autograd addresses these limitations by directly integrating with the autograd engine, allowing
# it to capture the full backward graph at runtime. Models with these two characteristics should try
# Compiled Autograd, and potentially observe better performance.
#
# However, Compiled Autograd has its own limitations:
# 1. Additional runtime overhead at the start of the backward
# 2. Dynamic autograd structure leads to recompiles
# However, Compiled Autograd introduces its own limitations:
# * Added runtime overhead at the start of the backward for cache lookup
# * More prone to recompiles and graph breaks in dynamo due to the larger capture
#
# .. note:: Compiled Autograd is under active development and is not yet compatible with all existing PyTorch features. For the latest status on a particular feature, refer to `Compiled Autograd Landing Page <https://docs.google.com/document/d/11VucFBEewzqgkABIjebZIzMvrXr3BtcY1aGKpX61pJY>`_.
#
Expand All @@ -50,8 +48,9 @@
######################################################################
# Setup
# ------------
# In this tutorial, we'll base our examples on this toy model.
#
# In this tutorial, we will base our examples on this simple neural network model.
# It takes a a 10-dimensional input vector, processes it through a single linear layer, and outputs another 10-dimensional vector.
#

import torch

Expand All @@ -67,7 +66,7 @@ def forward(self, x):
######################################################################
# Basic usage
# ------------
# .. note:: The ``torch._dynamo.config.compiled_autograd = True`` config must be enabled before calling the torch.compile API.
# Before calling the torch.compile API, make sure to set ``torch._dynamo.config.compiled_autograd`` to ``True``:
#

model = Model()
Expand All @@ -82,23 +81,30 @@ def train(model, x):
train(model, x)

######################################################################
# Inspecting the compiled autograd logs
# ------------
# Run the script with the TORCH_LOGS environment variables:
# - To only print the compiled autograd graph, use ``TORCH_LOGS="compiled_autograd" python example.py``
# - To print the graph with more tensor medata and recompile reasons, at the cost of performance, use ``TORCH_LOGS="compiled_autograd_verbose" python example.py``
# In the code above, we create an instance of the ``Model`` class and generate a random 10-dimensional tensor ``x`` by using torch.randn(10).
# We define the training loop function ``train`` and decorate it with @torch.compile to optimize its execution.
#
# When ``train(model, x)`` is called:
# * Python Interpreter calls Dynamo, since this call was decorated with ``@torch.compile``
# * Dynamo intercepts the python bytecode, simulates their execution and records the operations into a graph
# * AOTDispatcher disables hooks and calls the autograd engine to compute gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph. Using ``torch.autograd.Function``, AOTDispatcher rewrites the forward and backward implementation of ``train``.
# * Inductor generates a function corresponding to an optimized implementation of the AOTDispatcher forward and backward
# * Dynamo sets the optimized function to be evaluated next by Python Interpreter
# * Python Interpreter executes the optimized function, which basically executes ``loss = model(x).sum()``
# * Python Interpreter executes ``loss.backward()``, calling into the autograd engine, which routes to the Compiled Autograd engine since we enabled the config: ``torch._dynamo.config.compiled_autograd = True``
# * Compiled Autograd computes the gradients for ``model.linear.weight`` and ``model.linear.bias``, and records the operations into a graph, including any hooks it encounters. During this, it will record the backward previously rewritten by AOTDispatcher. Compiled Autograd then generates a new function which corresponds to a fully traced implementation of ``loss.backward()``, and executes it with ``torch.compile`` in inference mode
# * The same steps recursively apply to the Compiled Autograd graph, but this time AOTDispatcher does not need to partition this graph into a forward and backward
#

@torch.compile
def train(model, x):
loss = model(x).sum()
loss.backward()

train(model, x)

######################################################################
# The compiled autograd graph should now be logged to stderr. Certain graph nodes will have names that are prefixed by ``aot0_``,
# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0 e.g. ``aot0_view_2`` corresponds to ``view_2`` of the AOT backward graph with id=0.
# Inspecting the compiled autograd logs
# -------------------------------------
# Run the script with the ``TORCH_LOGS`` environment variables:
# - To only print the compiled autograd graph, use ``TORCH_LOGS="compiled_autograd" python example.py``
# - To print the graph with more tensor metadata and recompile reasons, at the cost of performance, use ``TORCH_LOGS="compiled_autograd_verbose" python example.py``
#
# Rerun the snippet above, the compiled autograd graph should now be logged to ``stderr``. Certain graph nodes will have names that are prefixed by ``aot0_``,
# these correspond to the nodes previously compiled ahead of time in AOTAutograd backward graph 0, for example, ``aot0_view_2`` corresponds to ``view_2`` of the AOT backward graph with id=0.
#

stderr_output = """
Expand Down Expand Up @@ -156,17 +162,19 @@ def forward(self, inputs, sizes, scalars, hooks):
"""

######################################################################
# .. note:: This is the graph that we will call torch.compile on, NOT the optimized graph. Compiled Autograd generates some python code to represent the entire C++ autograd execution.
# .. note:: This is the graph on which we will call ``torch.compile``, **NOT** the optimized graph. Compiled Autograd essentially generates some unoptimized Python code to represent the entire C++ autograd execution.
#

######################################################################
# Compiling the forward and backward pass using different flags
# ------------
#
# -------------------------------------------------------------
# You can use different compiler configs for the two compilations, for example, the backward may be a fullgraph even if there are graph breaks in the forward.
#

def train(model, x):
model = torch.compile(model)
loss = model(x).sum()
torch._dynamo.config.compiled_autograd = True
torch.compile(lambda: loss.backward(), fullgraph=True)()

######################################################################
Expand All @@ -182,7 +190,7 @@ def train(model, x):

######################################################################
# Compiled Autograd addresses certain limitations of AOTAutograd
# ------------
# --------------------------------------------------------------
# 1. Graph breaks in the forward lead to graph breaks in the backward
#

Expand Down Expand Up @@ -252,7 +260,7 @@ def forward(self, inputs, sizes, scalars, hooks):

######################################################################
# Common recompilation reasons for Compiled Autograd
# ------------
# --------------------------------------------------
# 1. Due to change in autograd structure

torch._dynamo.config.compiled_autograd = True
Expand Down Expand Up @@ -302,7 +310,5 @@ def forward(self, inputs, sizes, scalars, hooks):
######################################################################
# Conclusion
# ----------
# In this tutorial, we went over the high-level ecosystem of torch.compile with compiled autograd, the basics of compiled autograd and a few common recompilation reasons.
#
# For feedback on this tutorial, please file an issue on https://github.com/pytorch/tutorials.
# In this tutorial, we went over the high-level ecosystem of ``torch.compile`` with compiled autograd, the basics of compiled autograd and a few common recompilation reasons.
#

0 comments on commit 271b8f2

Please sign in to comment.