diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 4ed0661a469ca..33f8d10a7b71b 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -4915,6 +4915,40 @@ def ladder(x): opt_ladder = torch.compile(ladder, fullgraph=True, backend="eager") self.assertEqual(opt_ladder(data), ladder(data)) + @unittest.expectedFailure + def test_trace_functional_tensor_with_error(self): + from torch._subclasses.fake_tensor import FakeTensorMode + from torch._subclasses.functional_tensor import ( + FunctionalTensor, + FunctionalTensorMode, + ) + + def f(a, tmp): + a_view = a.view(-1) + with torch.no_grad(): + a.set_(tmp) + a_view.mul_(2) + return a + tmp + + fake_mode = FakeTensorMode() + with FunctionalTensorMode(): + inp = torch.ones(3, 3, requires_grad=True) + inp = fake_mode.from_tensor(inp, static_shapes=True) + inp = FunctionalTensor.to_functional(inp) + + tmp = torch.ones(3, 3, requires_grad=True) + tmp = fake_mode.from_tensor(tmp, static_shapes=True) + tmp = FunctionalTensor.to_functional(tmp) + + opt_f = torch.compile(f, backend="eager") + with self.assertRaisesRegex( + RuntimeError, "cannot mutate tensors with frozen storage" + ): + opt_f(inp, tmp) + + # grad state may not be properly reset after the error + self.assertTrue(torch.is_grad_enabled()) + instantiate_parametrized_tests(ReproTests) diff --git a/test/dynamo_expected_failures/TestAOTAutograd.test_set__and_data_mutation_good b/test/dynamo_expected_failures/TestAOTAutograd.test_set__and_data_mutation_good deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 5c17b7f84d0d4..ffa71a7e905b5 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -103,8 +103,13 @@ class AOTTestCase(TestCase): def setUp(self): + self.prev_grad_state = torch.is_grad_enabled() super().setUp() + def tearDown(self): + torch.set_grad_enabled(self.prev_grad_state) + super().tearDown() + class TestPythonKey(AOTTestCase): def test_make_fx(self, device):