From ffdf962e8291f262e5abbc68c55b6225899f195e Mon Sep 17 00:00:00 2001 From: "Zhong, Ruijie" Date: Thu, 29 Aug 2024 17:15:09 +0800 Subject: [PATCH] enable hook for sample_inputs_index_put_nofp64 and reference_inputs_cat on op_db instead of context enter --- test/xpu/xpu_test_utils.py | 33 ++++++++++++++++++++------------- 1 file changed, 20 insertions(+), 13 deletions(-) diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index ca53831f5..92435c9b6 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -570,7 +570,7 @@ def convert_dtype(obj, dtype, requires_grad=False): CriterionTest.test_cuda = CriterionTest_test_xpu from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat, S, M -from torch.testing._internal.common_methods_invocations import make_tensor +from torch.testing._internal.common_methods_invocations import make_tensor, mask_not_all_zeros from functools import partial from torch.testing._internal.opinfo.core import SampleInput @@ -604,6 +604,21 @@ def index_variable_nofp64(shape, max_indices, device=torch.device('cpu')): index = torch.rand(*shape, dtype=torch.float32, device=device).mul_(max_indices).floor_().long() return index +def sample_inputs_index_put_nofp64(op_info, device, dtype, requires_grad, **kwargs): + make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad) + + for accumulate in [False, True]: + # Test with indices arg + yield SampleInput( + make_arg((S, S,)), + (index_variable_nofp64(2, S, device=device),), + make_arg((2, S)), + accumulate=accumulate) + + # Test with mask arg + mask = torch.zeros(S, dtype=torch.bool) if accumulate else mask_not_all_zeros((S,)) + yield SampleInput( + make_arg((S, S)), (mask, ), make_arg((S,)), accumulate=accumulate) def sample_inputs_softmax_variant_nofp64( op_info, @@ -695,9 +710,6 @@ def __init__(self, patch_test_case=True) -> None: self.cuda_is_available = cuda.is_available self.cuda_is_bf16_supported = cuda.is_bf16_supported - if "has_fp64=0" in str(torch.xpu.get_device_properties(0)): - self.index_variable = common_methods_invocations.index_variable - self.reference_inputs_cat = common_methods_invocations.reference_inputs_cat def align_db_decorators(self, db): def gen_xpu_wrappers(op_name, wrappers): @@ -774,18 +786,17 @@ def filter_fp64_sample_input(self, db): opinfo.sample_inputs_func = sample_inputs_softmax_variant_nofp64 elif opinfo.sample_inputs_func.__name__ == common_methods_invocations.sample_inputs_like_fns.__name__: opinfo.sample_inputs_func = sample_inputs_like_fns_nofp64 + elif opinfo.sample_inputs_func.__name__ == common_methods_invocations.sample_inputs_index_put.__name__: + opinfo.sample_inputs_func = sample_inputs_index_put_nofp64 - + if opinfo.reference_inputs_func != None and opinfo.reference_inputs_func.__name__ == common_methods_invocations.reference_inputs_cat.__name__: + opinfo.reference_inputs_func = reference_inputs_cat_nofp64 def __enter__(self): # Monkey patch until we have a fancy way common_device_type.onlyCUDA = common_device_type.onlyXPU - if "has_fp64=0" in str(torch.xpu.get_device_properties(0)): - common_methods_invocations.index_variable = index_variable_nofp64 - common_methods_invocations.reference_inputs_cat = reference_inputs_cat_nofp64 - class dtypesIfXPU(common_device_type.dtypes): def __init__(self, *args): super().__init__(*args, device_type="xpu") @@ -909,10 +920,6 @@ def __exit__(self, exc_type, exc_value, traceback): cuda.is_available = self.cuda_is_available cuda.is_bf16_supported = self.cuda_is_bf16_supported - if "has_fp64=0" in str(torch.xpu.get_device_properties(0)): - common_methods_invocations.index_variable = self.index_variable - common_methods_invocations.reference_inputs_cat = self.reference_inputs_cat - # Copy the test cases from generic_base_class to generic_test_class. # It serves to reuse test cases. Regarding some newly added hardware,