Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable hook for sample_inputs_index_put_nofp64 and reference_inputs_c… #846

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def convert_dtype(obj, dtype, requires_grad=False):
CriterionTest.test_cuda = CriterionTest_test_xpu

from torch.testing._internal.common_methods_invocations import sample_inputs_cat_concat, S, M
from torch.testing._internal.common_methods_invocations import make_tensor
from torch.testing._internal.common_methods_invocations import make_tensor, mask_not_all_zeros
from functools import partial
from torch.testing._internal.opinfo.core import SampleInput

Expand Down Expand Up @@ -604,6 +604,21 @@ def index_variable_nofp64(shape, max_indices, device=torch.device('cpu')):
index = torch.rand(*shape, dtype=torch.float32, device=device).mul_(max_indices).floor_().long()
return index

def sample_inputs_index_put_nofp64(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)

for accumulate in [False, True]:
# Test with indices arg
yield SampleInput(
make_arg((S, S,)),
(index_variable_nofp64(2, S, device=device),),
make_arg((2, S)),
accumulate=accumulate)

# Test with mask arg
mask = torch.zeros(S, dtype=torch.bool) if accumulate else mask_not_all_zeros((S,))
yield SampleInput(
make_arg((S, S)), (mask, ), make_arg((S,)), accumulate=accumulate)

def sample_inputs_softmax_variant_nofp64(
op_info,
Expand Down Expand Up @@ -695,9 +710,6 @@ def __init__(self, patch_test_case=True) -> None:
self.cuda_is_available = cuda.is_available
self.cuda_is_bf16_supported = cuda.is_bf16_supported

if "has_fp64=0" in str(torch.xpu.get_device_properties(0)):
self.index_variable = common_methods_invocations.index_variable
self.reference_inputs_cat = common_methods_invocations.reference_inputs_cat

def align_db_decorators(self, db):
def gen_xpu_wrappers(op_name, wrappers):
Expand Down Expand Up @@ -774,18 +786,17 @@ def filter_fp64_sample_input(self, db):
opinfo.sample_inputs_func = sample_inputs_softmax_variant_nofp64
elif opinfo.sample_inputs_func.__name__ == common_methods_invocations.sample_inputs_like_fns.__name__:
opinfo.sample_inputs_func = sample_inputs_like_fns_nofp64
elif opinfo.sample_inputs_func.__name__ == common_methods_invocations.sample_inputs_index_put.__name__:
opinfo.sample_inputs_func = sample_inputs_index_put_nofp64


if opinfo.reference_inputs_func != None and opinfo.reference_inputs_func.__name__ == common_methods_invocations.reference_inputs_cat.__name__:
opinfo.reference_inputs_func = reference_inputs_cat_nofp64

def __enter__(self):
# Monkey patch until we have a fancy way

common_device_type.onlyCUDA = common_device_type.onlyXPU

if "has_fp64=0" in str(torch.xpu.get_device_properties(0)):
common_methods_invocations.index_variable = index_variable_nofp64
common_methods_invocations.reference_inputs_cat = reference_inputs_cat_nofp64

class dtypesIfXPU(common_device_type.dtypes):
def __init__(self, *args):
super().__init__(*args, device_type="xpu")
Expand Down Expand Up @@ -909,10 +920,6 @@ def __exit__(self, exc_type, exc_value, traceback):
cuda.is_available = self.cuda_is_available
cuda.is_bf16_supported = self.cuda_is_bf16_supported

if "has_fp64=0" in str(torch.xpu.get_device_properties(0)):
common_methods_invocations.index_variable = self.index_variable
common_methods_invocations.reference_inputs_cat = self.reference_inputs_cat


# Copy the test cases from generic_base_class to generic_test_class.
# It serves to reuse test cases. Regarding some newly added hardware,
Expand Down