Skip to content

Commit

Permalink
Support auto_functionalize in pre-dispatch (pytorch#122177)
Browse files Browse the repository at this point in the history
Summary: Title

Test Plan: CI

Differential Revision: D55042061

Pull Request resolved: pytorch#122177
Approved by: https://github.com/zou3519
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Mar 20, 2024
1 parent dc89d8b commit 5b7ceab
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 5 deletions.
75 changes: 75 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,16 @@
"(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)
torch.library.define(
"testlib::foo_mutated",
"(Tensor(a!) x) -> (Tensor, Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)
torch.library.define(
"testlib::foo_functional",
"(Tensor x) -> (Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)

@torch.library.impl("testlib::returns_tensor_symint", "cpu")
@torch.library.impl_abstract("testlib::returns_tensor_symint")
Expand All @@ -94,6 +104,16 @@ def foo_impl(x, z):
def foo_abstract(x, z):
return x, z, x + z

@torch.library.impl("testlib::foo_mutated", "CompositeImplicitAutograd")
def foo_mutated(x):
a, b, c = torch.ops.testlib.foo(x, x.cos())
return a, a.cos()

@torch.library.impl("testlib::foo_functional", "CompositeImplicitAutograd")
def foo_functional(x):
a, b, c = torch.ops.testlib.foo(x.cos(), x.cos())
return a.cos()


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestDynamismExpression(TestCase):
Expand Down Expand Up @@ -3681,6 +3701,61 @@ def forward(self, x, z):
self.assertTrue(torch.allclose(z_new_eager, z_new_export))
self.assertTrue(torch.allclose(legit_eager, legit_export))

def test_custom_op_auto_functionalize_pre_dispatch(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ops.testlib.foo_mutated(x)

inps = (torch.ones(5),)

ep = torch.export.export(M(), inps)
self.assertExpectedInline(str(ep.graph_module.code.strip()), """\
def forward(self, arg0_1):
cos = torch.ops.aten.cos.default(arg0_1)
auto_functionalized = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.testlib.foo.default, x = arg0_1, z = cos); arg0_1 = cos = None
getitem_3 = auto_functionalized[3]; auto_functionalized = None
cos_1 = torch.ops.aten.cos.default(getitem_3)
return (getitem_3, getitem_3, cos_1)""")

ep = torch.export._trace._export(M(), inps, pre_dispatch=True)
self.assertExpectedInline(str(ep.graph_module.code.strip()), """\
def forward(self, arg0_1):
cos = torch.ops.aten.cos.default(arg0_1)
auto_functionalized = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.testlib.foo.default, x = arg0_1, z = cos); arg0_1 = cos = None
getitem_3 = auto_functionalized[3]; auto_functionalized = None
cos_1 = torch.ops.aten.cos.default(getitem_3)
return (getitem_3, getitem_3, cos_1)""")


def test_custom_op_auto_warn_pre_dispatch(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ops.testlib.foo_functional(x)

inps = (torch.ones(5),)

ep = torch.export.export(M(), inps)
self.assertExpectedInline(str(ep.graph_module.code.strip()), """\
def forward(self, arg0_1):
cos = torch.ops.aten.cos.default(arg0_1)
cos_1 = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
auto_functionalized = torch._higher_order_ops.auto_functionalize.auto_functionalized(torch.ops.testlib.foo.default, x = cos, z = cos_1); cos = cos_1 = None
getitem_3 = auto_functionalized[3]; auto_functionalized = None
cos_2 = torch.ops.aten.cos.default(getitem_3); getitem_3 = None
return (cos_2,)""")

ep = torch.export._trace._export(M(), inps, pre_dispatch=True)
self.assertExpectedInline(str(ep.graph_module.code.strip()), """\
def forward(self, arg0_1):
foo_functional = torch.ops.testlib.foo_functional.default(arg0_1); arg0_1 = None
return (foo_functional,)""")

@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase):
def test_scaled_dot_product_attention_cpu(self):
Expand Down
19 changes: 14 additions & 5 deletions torch/_subclasses/functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import warnings
from abc import ABC, abstractmethod
from typing import Any, Callable, ContextManager, Dict, Optional, Tuple

Expand Down Expand Up @@ -307,7 +308,16 @@ def _can_decompose(func):
alias_info = len(
[i for i in func._schema.arguments if i.alias_info is not None]
)
return alias_info != 0 or func._schema.is_mutable
should_decompose = alias_info != 0 or func._schema.is_mutable
if not should_decompose:
if func.namespace not in ["aten", "prim"]:
warnings.warn(
f"At pre-dispatch tracing, we will assume that any "
f"custom op that is marked with CompositeImplicitAutograd "
f"and functional are safe to not decompose. We found {func}"
f" to be one such op."
)
return should_decompose
return True

if (
Expand Down Expand Up @@ -347,10 +357,9 @@ def unwrap(x):
) and not torch._C._dispatch_has_kernel_for_dispatch_key(
func.name(), torch._C.DispatchKey.Functionalize
):
if self.pre_dispatch:
raise NotImplementedError(
"Auto functionalization is not supported on pre-dispatch tracing"
)
# it doesn't matter what mode we use here because
# the implementation of do_auto_functionalize doesn't
# interact with FunctionalTensorMode at all
return do_auto_functionalize(func, args, kwargs)

from torch._higher_order_ops.effects import handle_effects, has_effects
Expand Down

0 comments on commit 5b7ceab

Please sign in to comment.