Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable hook to avoid fp64 issues on ARC for softmax and zeros_like #840

Merged
merged 1 commit into from
Aug 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 42 additions & 3 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,30 @@ def sample_inputs_softmax_variant_nofp64(
SampleInput(make_arg(shape), args=dim, kwargs=kwargs) for shape, dim in cases
)

def sample_inputs_like_fns_nofp64(self, device, dtype, requires_grad, **kwargs):

inputs = [
((), {}),
((S, S), {}),
((0, S, 0), {}),
((S,), {'dtype': dtype, 'device': device}),
# Hard-code some dtypes/devices. We want to test cases where the
# (dtype, device) is different from the input's (dtype, device)
# disabled for ARC
# ((S,), {'dtype': torch.double}),
((S,), {'device': 'cpu'}),
# disabled for ARC
#((S,), {'dtype': torch.double, 'device': 'cpu'}),
]
if torch.cuda.is_available():
inputs.append(((S,), {'device': 'cuda'}))

for shape, kwargs in inputs:
t = make_tensor(shape, dtype=dtype, device=device,
low=None, high=None,
requires_grad=requires_grad)
yield SampleInput(t, **kwargs)

class XPUPatchForImport:
def __init__(self, patch_test_case=True) -> None:
self.test_package = (
Expand Down Expand Up @@ -672,7 +696,6 @@ def __init__(self, patch_test_case=True) -> None:
self.cuda_is_bf16_supported = cuda.is_bf16_supported

if "has_fp64=0" in str(torch.xpu.get_device_properties(0)):
self.sample_inputs_softmax_variant = common_methods_invocations.sample_inputs_softmax_variant
self.index_variable = common_methods_invocations.index_variable
self.reference_inputs_cat = common_methods_invocations.reference_inputs_cat

Expand Down Expand Up @@ -737,13 +760,29 @@ def align_supported_dtypes(self, db):
opinfo.dtypesIfXPU = set(filter(lambda x: (x not in fp64_dtypes), list(opinfo.dtypesIfXPU)))
opinfo.backward_dtypes = tuple(filter(lambda x: (x not in fp64_dtypes), list(opinfo.backward_dtypes)))

def filter_fp64_sample_input(self, db):
# Only for platform without fp64 support
if "has_fp64=0" in str(torch.xpu.get_device_properties(0)):
for opinfo in db:
if opinfo.name in _xpu_computation_op_list:
if opinfo.variant_test_name == "with_dtype" and \
opinfo.name in ["log_softmax", "softmax", "nn.functional.softmin", ] and \
get_wrapped_fn(opinfo.sample_inputs_func) != opinfo.sample_inputs_func and \
get_wrapped_fn(opinfo.sample_inputs_func).func.__name__ == common_methods_invocations.sample_inputs_softmax_variant.__name__:
opinfo.sample_inputs_func = torch.no_grad()(partial(sample_inputs_softmax_variant_nofp64, with_dtype=True))
elif opinfo.sample_inputs_func.__name__ == common_methods_invocations.sample_inputs_softmax_variant.__name__:
opinfo.sample_inputs_func = sample_inputs_softmax_variant_nofp64
elif opinfo.sample_inputs_func.__name__ == common_methods_invocations.sample_inputs_like_fns.__name__:
opinfo.sample_inputs_func = sample_inputs_like_fns_nofp64



def __enter__(self):
# Monkey patch until we have a fancy way

common_device_type.onlyCUDA = common_device_type.onlyXPU

if "has_fp64=0" in str(torch.xpu.get_device_properties(0)):
common_methods_invocations.sample_inputs_softmax_variant = sample_inputs_softmax_variant_nofp64
common_methods_invocations.index_variable = index_variable_nofp64
common_methods_invocations.reference_inputs_cat = reference_inputs_cat_nofp64

Expand Down Expand Up @@ -773,6 +812,7 @@ def __init__(self, *args):
]:
self.align_supported_dtypes(db)
self.align_db_decorators(db)
self.filter_fp64_sample_input(db)
self.align_db_decorators(module_db)
common_methods_invocations.python_ref_db = [
op
Expand Down Expand Up @@ -870,7 +910,6 @@ def __exit__(self, exc_type, exc_value, traceback):
cuda.is_bf16_supported = self.cuda_is_bf16_supported

if "has_fp64=0" in str(torch.xpu.get_device_properties(0)):
common_methods_invocations.sample_inputs_softmax_variant = self.sample_inputs_softmax_variant
common_methods_invocations.index_variable = self.index_variable
common_methods_invocations.reference_inputs_cat = self.reference_inputs_cat

Expand Down