diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index debab066a..ca53831f5 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -638,6 +638,30 @@ def sample_inputs_softmax_variant_nofp64( SampleInput(make_arg(shape), args=dim, kwargs=kwargs) for shape, dim in cases ) +def sample_inputs_like_fns_nofp64(self, device, dtype, requires_grad, **kwargs): + + inputs = [ + ((), {}), + ((S, S), {}), + ((0, S, 0), {}), + ((S,), {'dtype': dtype, 'device': device}), + # Hard-code some dtypes/devices. We want to test cases where the + # (dtype, device) is different from the input's (dtype, device) + # disabled for ARC + # ((S,), {'dtype': torch.double}), + ((S,), {'device': 'cpu'}), + # disabled for ARC + #((S,), {'dtype': torch.double, 'device': 'cpu'}), + ] + if torch.cuda.is_available(): + inputs.append(((S,), {'device': 'cuda'})) + + for shape, kwargs in inputs: + t = make_tensor(shape, dtype=dtype, device=device, + low=None, high=None, + requires_grad=requires_grad) + yield SampleInput(t, **kwargs) + class XPUPatchForImport: def __init__(self, patch_test_case=True) -> None: self.test_package = ( @@ -672,7 +696,6 @@ def __init__(self, patch_test_case=True) -> None: self.cuda_is_bf16_supported = cuda.is_bf16_supported if "has_fp64=0" in str(torch.xpu.get_device_properties(0)): - self.sample_inputs_softmax_variant = common_methods_invocations.sample_inputs_softmax_variant self.index_variable = common_methods_invocations.index_variable self.reference_inputs_cat = common_methods_invocations.reference_inputs_cat @@ -737,13 +760,29 @@ def align_supported_dtypes(self, db): opinfo.dtypesIfXPU = set(filter(lambda x: (x not in fp64_dtypes), list(opinfo.dtypesIfXPU))) opinfo.backward_dtypes = tuple(filter(lambda x: (x not in fp64_dtypes), list(opinfo.backward_dtypes))) + def filter_fp64_sample_input(self, db): + # Only for platform without fp64 support + if "has_fp64=0" in str(torch.xpu.get_device_properties(0)): + for opinfo in db: + if opinfo.name in _xpu_computation_op_list: + if opinfo.variant_test_name == "with_dtype" and \ + opinfo.name in ["log_softmax", "softmax", "nn.functional.softmin", ] and \ + get_wrapped_fn(opinfo.sample_inputs_func) != opinfo.sample_inputs_func and \ + get_wrapped_fn(opinfo.sample_inputs_func).func.__name__ == common_methods_invocations.sample_inputs_softmax_variant.__name__: + opinfo.sample_inputs_func = torch.no_grad()(partial(sample_inputs_softmax_variant_nofp64, with_dtype=True)) + elif opinfo.sample_inputs_func.__name__ == common_methods_invocations.sample_inputs_softmax_variant.__name__: + opinfo.sample_inputs_func = sample_inputs_softmax_variant_nofp64 + elif opinfo.sample_inputs_func.__name__ == common_methods_invocations.sample_inputs_like_fns.__name__: + opinfo.sample_inputs_func = sample_inputs_like_fns_nofp64 + + + def __enter__(self): # Monkey patch until we have a fancy way common_device_type.onlyCUDA = common_device_type.onlyXPU if "has_fp64=0" in str(torch.xpu.get_device_properties(0)): - common_methods_invocations.sample_inputs_softmax_variant = sample_inputs_softmax_variant_nofp64 common_methods_invocations.index_variable = index_variable_nofp64 common_methods_invocations.reference_inputs_cat = reference_inputs_cat_nofp64 @@ -773,6 +812,7 @@ def __init__(self, *args): ]: self.align_supported_dtypes(db) self.align_db_decorators(db) + self.filter_fp64_sample_input(db) self.align_db_decorators(module_db) common_methods_invocations.python_ref_db = [ op @@ -870,7 +910,6 @@ def __exit__(self, exc_type, exc_value, traceback): cuda.is_bf16_supported = self.cuda_is_bf16_supported if "has_fp64=0" in str(torch.xpu.get_device_properties(0)): - common_methods_invocations.sample_inputs_softmax_variant = self.sample_inputs_softmax_variant common_methods_invocations.index_variable = self.index_variable common_methods_invocations.reference_inputs_cat = self.reference_inputs_cat