diff --git a/src/lightning/pytorch/plugins/precision/amp.py b/src/lightning/pytorch/plugins/precision/amp.py index 75e792af46b90..bd34ca83d5e2b 100644 --- a/src/lightning/pytorch/plugins/precision/amp.py +++ b/src/lightning/pytorch/plugins/precision/amp.py @@ -112,7 +112,9 @@ def clip_gradients( super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm) def autocast_context_manager(self) -> torch.autocast: - return torch.autocast(self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half)) + return torch.autocast( + self.device, dtype=(torch.bfloat16 if self.precision == "bf16-mixed" else torch.half), cache_enabled=False + ) @override @contextmanager diff --git a/tests/tests_pytorch/plugins/precision/test_amp.py b/tests/tests_pytorch/plugins/precision/test_amp.py index cb061c540b2be..3894c4256e0b8 100644 --- a/tests/tests_pytorch/plugins/precision/test_amp.py +++ b/tests/tests_pytorch/plugins/precision/test_amp.py @@ -14,6 +14,8 @@ from unittest.mock import Mock import pytest +import torch +from torch import nn from torch.optim import Optimizer from lightning.pytorch.plugins import MixedPrecision @@ -51,3 +53,19 @@ def test_optimizer_amp_scaling_support_in_step_method(): with pytest.raises(RuntimeError, match="The current optimizer.*does not allow for gradient clipping"): precision.clip_gradients(optimizer, clip_val=1.0) + + +def test_amp_with_no_grad(): + """Test that asserts using `no_grad` context wrapper with a persistent AMP context wrapper does not break gradient + tracking.""" + layer = nn.Linear(2, 1) + x = torch.randn(1, 2) + amp = MixedPrecision(precision="bf16-mixed", device="cpu") + + with amp.autocast_context_manager(): + with torch.no_grad(): + _ = layer(x) + + loss = layer(x).mean() + loss.backward() + assert loss.grad_fn is not None