Skip to content

Commit

Permalink
Fix auto_functionalize (pytorch#121990)
Browse files Browse the repository at this point in the history
Differential Revision: D54964130

When we re-export, auto_functionalize HOP will be in the graph. Therefore, we need to implement proper functionalization rule for it. Since the content inside auto_functionalize is guaranteed be functional, it is ok to just fall through it.

Pull Request resolved: pytorch#121990
Approved by: https://github.com/ydwu4, https://github.com/zou3519
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed Mar 19, 2024
1 parent a2a88f3 commit 0d845f7
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 0 deletions.
40 changes: 40 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,28 @@


torch.library.define("testlib::returns_tensor_symint", "(Tensor x) -> (Tensor, SymInt)")
torch.library.define(
"testlib::foo",
"(Tensor(a!) x, Tensor(b!) z) -> (Tensor, Tensor, Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)

@torch.library.impl("testlib::returns_tensor_symint", "cpu")
@torch.library.impl_abstract("testlib::returns_tensor_symint")
def returns_tensor_symint_impl(x):
return x, x.shape[0]

@torch.library.impl("testlib::foo", "cpu")
@torch._dynamo.disable
def foo_impl(x, z):
x.add_(5)
z.add_(5)
return x, z, x + z

@torch.library.impl_abstract("testlib::foo")
def foo_abstract(x, z):
return x, z, x + z


@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestDynamismExpression(TestCase):
Expand Down Expand Up @@ -3601,6 +3617,30 @@ def forward(self, x):

self._test_export_same_as_eager(Module(), (torch.randn(4, 4),))

def test_custom_op_auto_functionalize(self):
class M(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x, z):
return torch.ops.testlib.foo(x, z)

inps = (torch.ones(5), torch.ones(5))
inps_for_export = (torch.ones(5), torch.ones(5))
inps_for_export_with_decomp = (torch.ones(5), torch.ones(5))

ep = torch.export.export(M(), inps_for_export)
x_new_eager, z_new_eager, legit_eager = M()(*inps)
x_new_export, z_new_export, legit_export = ep.module()(*inps_for_export)
self.assertTrue(torch.allclose(x_new_eager, x_new_export))
self.assertTrue(torch.allclose(z_new_eager, z_new_export))
self.assertTrue(torch.allclose(legit_eager, legit_export))

ep = ep.run_decompositions()
x_new_export, z_new_export, legit_export = ep.module()(*inps_for_export_with_decomp)
self.assertTrue(torch.allclose(x_new_eager, x_new_export))
self.assertTrue(torch.allclose(z_new_eager, z_new_export))
self.assertTrue(torch.allclose(legit_eager, legit_export))

@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOneOffModelExportResult(TestCase):
Expand Down
8 changes: 8 additions & 0 deletions torch/_higher_order_ops/auto_functionalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,3 +251,11 @@ def do_auto_functionalize(
ctx.sync(orig_arg)

return ctx.wrap_tensors(unwrapped_actual_out) # type: ignore[arg-type]


@auto_functionalized.py_functionalize_impl
def auto_functionalized_func(ctx, _mutable_op, **kwargs):
unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
with ctx.redispatch_to_next():
result = auto_functionalized(_mutable_op, **unwrapped_kwargs)
return ctx.wrap_tensors(result)

0 comments on commit 0d845f7

Please sign in to comment.