diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index c6c239cca..1683e8a2f 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -659,6 +659,11 @@ def align_supported_dtypes(self, db): backward_dtypes.add(bfloat16) opinfo.backward_dtypes = tuple(backward_dtypes) + if "has_fp64=0" in str(torch.xpu.get_device_properties(0)): + fp64_dtypes = [ torch.float64, torch.complex128, torch.double, ] + opinfo.dtypesIfXPU = set(filter(lambda x: (x not in fp64_dtypes), list(opinfo.dtypesIfXPU))) + opinfo.backward_dtypes = tuple(filter(lambda x: (x not in fp64_dtypes), list(opinfo.backward_dtypes))) + def __enter__(self): # Monkey patch until we have a fancy way