Skip to content

Commit

Permalink
[dynamo] reset grad state in aotdispatch test, add failing trace func…
Browse files Browse the repository at this point in the history
…tional tensor test to dynamo (pytorch#126113)

Workaround for pytorch#125568.

We could add additional global state to reset (e.g. autocast?) or move this setup/teardown to a more general place.

Also added a minimal repro for the linked issue - will investigate in a followup PR.

Pull Request resolved: pytorch#126113
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
  • Loading branch information
williamwen42 authored and pytorchmergebot committed May 14, 2024
1 parent f6a00a8 commit 4a8db9d
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 0 deletions.
34 changes: 34 additions & 0 deletions test/dynamo/test_repros.py
Original file line number Diff line number Diff line change
Expand Up @@ -4915,6 +4915,40 @@ def ladder(x):
opt_ladder = torch.compile(ladder, fullgraph=True, backend="eager")
self.assertEqual(opt_ladder(data), ladder(data))

@unittest.expectedFailure
def test_trace_functional_tensor_with_error(self):
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._subclasses.functional_tensor import (
FunctionalTensor,
FunctionalTensorMode,
)

def f(a, tmp):
a_view = a.view(-1)
with torch.no_grad():
a.set_(tmp)
a_view.mul_(2)
return a + tmp

fake_mode = FakeTensorMode()
with FunctionalTensorMode():
inp = torch.ones(3, 3, requires_grad=True)
inp = fake_mode.from_tensor(inp, static_shapes=True)
inp = FunctionalTensor.to_functional(inp)

tmp = torch.ones(3, 3, requires_grad=True)
tmp = fake_mode.from_tensor(tmp, static_shapes=True)
tmp = FunctionalTensor.to_functional(tmp)

opt_f = torch.compile(f, backend="eager")
with self.assertRaisesRegex(
RuntimeError, "cannot mutate tensors with frozen storage"
):
opt_f(inp, tmp)

# grad state may not be properly reset after the error
self.assertTrue(torch.is_grad_enabled())


instantiate_parametrized_tests(ReproTests)

Expand Down
Empty file.
5 changes: 5 additions & 0 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,13 @@

class AOTTestCase(TestCase):
def setUp(self):
self.prev_grad_state = torch.is_grad_enabled()
super().setUp()

def tearDown(self):
torch.set_grad_enabled(self.prev_grad_state)
super().tearDown()


class TestPythonKey(AOTTestCase):
def test_make_fx(self, device):
Expand Down

0 comments on commit 4a8db9d

Please sign in to comment.