From 3c7f0c729fd581c8099d6abbff5fa6bd8e4402d5 Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Sun, 28 Jul 2024 18:59:42 +0800 Subject: [PATCH 1/3] update skip --- test/xpu/run_test_with_skip.py | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index 719af3ca4..2be563638 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -1261,12 +1261,12 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_RReLU_with_up_down_cuda", # AssertionError: Scalars are not close! "test_RReLU_with_up_down_scalar_cuda", - # lstm: AssertionError: Scalars are not equal! + # rnn fallback to cpu "test_cudnn_weight_format", # NotImplementedError: Could not run 'aten::_indices' with arguments from the 'SparseXPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). "test_EmbeddingBag_sparse_cuda", "test_Embedding_sparse_cuda", - # AssertionError: 'XPU error: device-side assert triggered' not found in ' File "", line 8\n def test_cross_entropy_loss_2d_out_of_bounds_class_index(self):\n ^\nIndentationError: expected an indented block\n' + # not correct assert in LossNLL2d "test_cross_entropy_loss_2d_out_of_bounds_class_index_xpu_float16", "test_cross_entropy_loss_2d_out_of_bounds_class_index_xpu_float32", # AssertionError: MultiheadAttention does not support NestedTensor outside of its fast path. The fast path was not hit because some Tensor argument's device is neither one of cpu, cuda or privateuseone @@ -1289,8 +1289,7 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_rnn_retain_variables_xpu_float64", "test_transformerencoderlayer_xpu_float64", "test_variable_sequence_xpu_float64", - # CPU fallback fails - # AssertionError: Tensor-likes are not close! + # native_group_norm : RuntimeError: Expected X.is_contiguous(memory_format) to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.) "test_GroupNorm_memory_format_xpu", # AssertionError: Scalars are not close! "test_InstanceNorm1d_general_xpu", @@ -1308,8 +1307,7 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_upsamplingBiMode2d_nonsupported_dtypes_antialias_True_num_channels_3_mode_bilinear_uint8_xpu_uint8", "test_upsamplingBiMode2d_nonsupported_dtypes_antialias_True_num_channels_5_mode_bicubic_uint8_xpu_uint8", "test_upsamplingBiMode2d_nonsupported_dtypes_antialias_True_num_channels_5_mode_bilinear_uint8_xpu_uint8", - "test_grid_sample_error_checking", - # Failed: Unexpected success + #upsamplingNearest2d: Failed: Unexpected success "test_upsamplingNearest2d_launch_fail_xpu", # CPU fallback could not cover # NotImplementedError: Could not run 'aten::_thnn_fused_gru_cell' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build pro... @@ -1321,18 +1319,10 @@ def launch_test(test_case, skip_list=None, exe_list=None): # AssertionError: False is not true "test_ctc_loss_cudnn_xpu", # want "xpu" in function name "test_ctc_loss_cudnn_tensor", # want "xpu" in function name - # NotImplementedError: Could not run 'aten::batch_norm_stats' with arguments from the 'CPU' backend. - "test_sync_batchnorm_accuracy_cuda", - # NotImplementedError: Could not run 'aten::batch_norm_backward_elemt' with arguments from the 'CPU' backend. - "test_sync_batchnorm_backward_elemt", # RuntimeError: "smooth_l1_backward_cpu_out" not implemented for 'Half' "test_SmoothL1Loss_no_batch_dim_mean_cuda_half", "test_SmoothL1Loss_no_batch_dim_none_cuda_half", "test_SmoothL1Loss_no_batch_dim_sum_cuda_half", - # RuntimeError: "mse_backward_cpu_out" not implemented for 'Half' - "test_MSELoss_no_batch_dim_mean_cuda_half", - "test_MSELoss_no_batch_dim_none_cuda_half", - "test_MSELoss_no_batch_dim_sum_cuda_half", # RuntimeError: "multilabel_margin_loss_forward_out_frame" not implemented for 'Half' "test_MultiLabelMarginLoss_no_batch_dim_mean_cuda_half", "test_MultiLabelMarginLoss_no_batch_dim_none_cuda_half", From 854bf7f4a7d5f0784e6af0f74e9cb4eaa85bf7d6 Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Sun, 28 Jul 2024 23:12:16 +0800 Subject: [PATCH 2/3] update --- test/xpu/test_nn_xpu.py | 165 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 165 insertions(+) diff --git a/test/xpu/test_nn_xpu.py b/test/xpu/test_nn_xpu.py index b91800473..d273af278 100644 --- a/test/xpu/test_nn_xpu.py +++ b/test/xpu/test_nn_xpu.py @@ -12,6 +12,7 @@ import torch import torch.nn.functional as F from torch import nn +import torch.nn.utils.rnn as rnn_utils from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, @@ -1686,6 +1687,170 @@ def _batch_norm_stats(data, memory_format, mean_axes): _batch_norm_stats(torch.randn(1, 96, 112, 112, 112, dtype=torch.float, device='xpu'), torch.channels_last_3d, (0, 2, 3, 4)) TestNN.test_sync_batchnorm_accuracy_cuda=_test_sync_batchnorm_accuracy_xpu +@parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) +@parametrize_test("mode", ["bilinear", "bicubic"]) +@parametrize_test("antialias", [True, False]) +@parametrize_test("align_corners", [True, False]) +@parametrize_test("num_channels", [3, 5]) +@parametrize_test("output_size", [32, 600]) +@parametrize_test("check_as_unsqueezed_3d_tensor", [True, False]) +@parametrize_test("non_contig", [False, "sliced", "restrided"]) +@parametrize_test("batch_size", [1, 5]) +def _test_upsamplingBiMode2d_consistency( + self, + device, + memory_format, + mode, + antialias, + align_corners, + num_channels, + output_size, + check_as_unsqueezed_3d_tensor, + non_contig, + batch_size, +): + # Check output value consistency between resized_input_uint8 and resized input_float + if torch.device(device).type == "xpu": + raise SkipTest("XPU implementation is not yet supporting uint8") + + torch.manual_seed(0) + + # - input range is set to [30, 220] for bicubic mode, because the bicubic kernel may create + # [intermediate] values outside of the [0, 255] range, which need + # to be clipped in uint8 path, but not in float path. This isn't + # an issue with bilinear kernel. + input_range = (30, 220) if mode == "bicubic" else (0, 256) + input_ui8 = torch.randint(*input_range, size=(batch_size, num_channels, 400, 400), dtype=torch.uint8, device=device) + input_ui8 = input_ui8.contiguous(memory_format=memory_format) + + if non_contig == "sliced": + input_ui8 = input_ui8[:, :, 10:-10, 10:-10] + elif non_contig == "restrided": + input_ui8 = input_ui8[:, :, ::2, ::2] + + if batch_size == 1 and check_as_unsqueezed_3d_tensor: + input_ui8 = input_ui8[0, ...] + input_ui8 = input_ui8[None, ...] + + input_f32 = input_ui8.float() + + output_f32 = F.interpolate( + input_f32, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias + ).round().clip(0, 255) + output_ui8 = F.interpolate( + input_ui8, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=antialias + ) + + if non_contig is False: + self.assertTrue(input_ui8.is_contiguous(memory_format=memory_format)) + + # FIXME if-clause shows the current behaviour which is definitely unexpected. + # Ideally we want to fix it such that both the ui8 and f32 outputs are also channels_last + # See for more details: https://github.com/pytorch/pytorch/pull/100373 + if batch_size == 1 and check_as_unsqueezed_3d_tensor and memory_format == torch.channels_last: + self.assertTrue(output_ui8.is_contiguous()) + self.assertTrue(output_f32.is_contiguous()) + else: + self.assertTrue(output_ui8.is_contiguous(memory_format=memory_format)) + self.assertTrue(output_f32.is_contiguous(memory_format=memory_format)) + + if mode == "bilinear": + torch.testing.assert_close(output_f32, output_ui8.float(), rtol=0, atol=1) + else: + diff = (output_f32 - output_ui8.float()).abs() + self.assertLess(diff.max(), 15) + + threshold = 2 + percent = 3 + self.assertLess((diff > threshold).float().mean(), percent / 100) + + threshold = 5 + percent = 1 + self.assertLess((diff > threshold).float().mean(), percent / 100) + + self.assertLess(diff.mean(), 0.4) +TestNNDeviceType.test_upsamplingBiMode2d_consistency=_test_upsamplingBiMode2d_consistency + +@parametrize_test("memory_format", [torch.contiguous_format, torch.channels_last]) +@parametrize_test("align_corners", [True, False]) +@parametrize_test("input_size, output_size", [(399, 437), (403, 377)]) +def _test_upsamplingBiLinear2d_consistency_interp_size_bug(self, device, memory_format, align_corners, input_size, output_size): + # Non-regression test for https://github.com/pytorch/pytorch/pull/101403 + + if torch.device(device).type == "xpu": + raise SkipTest("XPU implementation is not yet supporting uint8") + + mode = "bilinear" + input_ui8 = torch.randint(0, 256, size=(1, 3, input_size, input_size), dtype=torch.uint8, device=device) + input_ui8 = input_ui8.contiguous(memory_format=memory_format) + input_f32 = input_ui8.float() + + output_f32 = F.interpolate( + input_f32, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=False + ).round().to(torch.uint8) + output_ui8 = F.interpolate( + input_ui8, size=(output_size, output_size), mode=mode, align_corners=align_corners, antialias=False + ) + torch.testing.assert_close(output_f32, output_ui8, atol=1, rtol=0) +TestNNDeviceType.test_upsamplingBiLinear2d_consistency_interp_size_bug=_test_upsamplingBiLinear2d_consistency_interp_size_bug + +def _test_device_mask(self, device): + def is_xpu(packed): + return packed.data.device.type=="xpu" + for enforce_sorted in [True, False]: + padded, lengths = self._padded_sequence('cpu', torch.float) + packed = rnn_utils.pack_padded_sequence( + padded, lengths, enforce_sorted=enforce_sorted) + self.assertFalse(is_xpu(packed)) + packed = packed.to(device) + self.assertTrue(is_xpu(packed)) + unpacked, _ = rnn_utils.pad_packed_sequence(packed) + self.assertTrue(is_xpu(unpacked)) + self.assertEqual(unpacked.dtype, torch.float) +TestNNDeviceType.test_device_mask=_test_device_mask + +def _test_overwrite_module_params_on_conversion_cpu_device(self, device): + # Test that under the current default settings + # (`torch.__future__.get_overwrite_module_params_on_conversion() == False`), + # a view to a module's parameters is not pointing to the same storage as + # its base variable after converting the module to a different device. + m = nn.Linear(20, 10) + mw = m.weight[:] + m.to(device) + with torch.no_grad(): + # Without using `torch.no_grad()`, this will leak CUDA memory. + # (Issue is filed at https://github.com/pytorch/pytorch/issues/21875) + mw[0][0] = 5 + self.assertTrue(mw[0][0].device.type == "cpu") + self.assertTrue(mw._base[0][0].device.type == "xpu") + + try: + torch.__future__.set_overwrite_module_params_on_conversion(True) + + # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`, + # a view to a module's parameters is still pointing to the same storage as + # its base variable after converting the module to a different device. + m = nn.Linear(20, 10) + mw = m.weight[:] + m.to(device) + with torch.no_grad(): + mw[0][0] = 5 + self.assertTrue(mw[0][0] == mw._base[0][0]) + + # Test that if `torch.__future__.get_overwrite_module_params_on_conversion() == True`, + # `cpu_module.to("cuda")` doesn't preserve previous references to + # `cpu_module`'s parameters or gradients. + m = nn.Linear(20, 10) + m.weight.grad = torch.randn(10, 20) + weight_ref = m.weight + weight_grad_ref = m.weight.grad + m.to(device) + self.assertNotEqual(weight_ref.device, m.weight.device) + self.assertNotEqual(weight_grad_ref.device, m.weight.grad.device) + finally: + torch.__future__.set_overwrite_module_params_on_conversion(False) +TestNNDeviceType.test_overwrite_module_params_on_conversion_cpu_device=_test_overwrite_module_params_on_conversion_cpu_device + def _test_ctc_loss_xpu(self, device): batch_size = 16 input_length = 30 From 2c13adfe45977443aa58ec004c422fbc6cedacee Mon Sep 17 00:00:00 2001 From: yuchengliu1 Date: Sun, 28 Jul 2024 23:15:50 +0800 Subject: [PATCH 3/3] update skip list --- test/xpu/run_test_with_skip.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/test/xpu/run_test_with_skip.py b/test/xpu/run_test_with_skip.py index 2be563638..0b6a0dddf 100644 --- a/test/xpu/run_test_with_skip.py +++ b/test/xpu/run_test_with_skip.py @@ -1295,9 +1295,6 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_InstanceNorm1d_general_xpu", "test_InstanceNorm2d_general_xpu", "test_InstanceNorm3d_general_xpu", - # AssertionError: False is not true - "test_device_mask_xpu", - "test_overwrite_module_params_on_conversion_cpu_device_xpu", # AssertionError: RuntimeError not raised "test_upsamplingBiMode2d_nonsupported_dtypes_antialias_False_num_channels_3_mode_bicubic_uint8_xpu_uint8", "test_upsamplingBiMode2d_nonsupported_dtypes_antialias_False_num_channels_3_mode_bilinear_uint8_xpu_uint8", @@ -1327,9 +1324,6 @@ def launch_test(test_case, skip_list=None, exe_list=None): "test_MultiLabelMarginLoss_no_batch_dim_mean_cuda_half", "test_MultiLabelMarginLoss_no_batch_dim_none_cuda_half", "test_MultiLabelMarginLoss_no_batch_dim_sum_cuda_half", - # align CUDA to skip, XPU implementation is not yet supporting uint8 - "test_upsamplingBiMode2d_consistency", - "test_upsamplingBiLinear2d_consistency_interp_size_bug", ) res += launch_test("test_nn_xpu.py", skip_list)