diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index caf978456c03d..eb64e03a559a9 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -87,7 +87,7 @@ def _compare_version(package: str, op, version) -> bool: _POPTORCH_AVAILABLE = _module_available("poptorch") _RICH_AVAILABLE = _module_available("rich") _TORCH_CPU_AMP_AVAILABLE = _compare_version( - "torch", operator.ge, "1.10.0dev20210501" + "torch", operator.ge, "1.10.dev20210902" ) # todo: swap to 1.10.0 once released _TORCH_BFLOAT_AVAILABLE = _compare_version( "torch", operator.ge, "1.10.0.dev20210902" diff --git a/tests/models/test_amp.py b/tests/models/test_amp.py index 8632b3b9eb6f6..c56cfb11f0164 100644 --- a/tests/models/test_amp.py +++ b/tests/models/test_amp.py @@ -184,10 +184,13 @@ def test_amp_gpu_ddp_slurm_managed(tmpdir): @pytest.mark.skipif(torch.cuda.is_available(), reason="test is restricted only on CPU") -def test_cpu_model_with_amp(tmpdir): - """Make sure model trains on CPU.""" - with pytest.raises(MisconfigurationException, match="AMP is only available on GPU"): - Trainer(precision=16) +@RunIf(max_torch="1.9") +@pytest.mark.parametrize("precision", [16, "bf16"]) +def test_cpu_model_with_amp(tmpdir, precision): + """Make sure exception is thrown on CPU when precision 16 is enabled on PyTorch 1.9 and lower.""" + + with pytest.raises(MisconfigurationException, match="AMP is only available on GPU for PyTorch 1.9"): + Trainer(precision=precision) @mock.patch("pytorch_lightning.plugins.precision.apex_amp.ApexMixedPrecisionPlugin.backward")