From 24918dbe6c5154f149194d2797762069654003c3 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 9 Aug 2024 08:51:24 +0000 Subject: [PATCH 1/7] add histc --- src/ATen/native/xpu/SummaryOps.cpp | 27 +++++++++ src/ATen/native/xpu/XPUFallback.template | 1 - .../native/xpu/sycl/SummaryOpsKernels.cpp | 55 +++++++++++++++++++ src/ATen/native/xpu/sycl/SummaryOpsKernels.h | 6 ++ test/xpu/xpu_test_utils.py | 2 + yaml/xpu_functions.yaml | 4 ++ 6 files changed, 94 insertions(+), 1 deletion(-) diff --git a/src/ATen/native/xpu/SummaryOps.cpp b/src/ATen/native/xpu/SummaryOps.cpp index cf4cf3f27..221a6050a 100644 --- a/src/ATen/native/xpu/SummaryOps.cpp +++ b/src/ATen/native/xpu/SummaryOps.cpp @@ -1,3 +1,4 @@ +#include #include #include #include @@ -21,4 +22,30 @@ Tensor XPUNativeFunctions::bincount( return native::xpu::bincount_kernel(self, weights, minlength); } +Tensor XPUNativeFunctions::histc( + const Tensor& self, + int64_t nbins, + const Scalar& min, + const Scalar& max) { + if (self.scalar_type() == ScalarType::Half) { + AT_ERROR("HalfTensor is not supported"); + } + // See Note [Writing Nondeterministic Operations] + // Nondeterministic because of atomicAdd usage + globalContext().alertNotDeterministic("_histc_xpu"); + return native::xpu::_histc_kernel(self, nbins, min, max); +} + +Tensor& XPUNativeFunctions::histc_out( + const Tensor& self, + int64_t bins, + const Scalar& min, + const Scalar& max, + Tensor& result) { + auto ret = histc(self, bins, min, max); + at::native::resize_output(result, ret.sizes()); + result.copy_(ret); + return result; +} + } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 5b2d6e5ff..bb9586e91 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -191,7 +191,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "hardshrink_backward.grad_input", "hardshrink.out", "heaviside.out", - "histc", "i0.out", "igammac.out", "igamma.out", diff --git a/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp b/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp index 265170821..fb50da5bf 100644 --- a/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp @@ -185,6 +185,61 @@ void tensor_histogram( return; } + +template +Tensor _histc_template( + const Tensor& self, + int64_t nbins, + at::acc_type_device min, + at::acc_type_device max) { + if (nbins <= 0) { + AT_ERROR("bins must be > 0"); + } + Tensor output = at::zeros( + {nbins}, + self.scalar_type(), + std::nullopt /* layout */, + DeviceType::XPU, + std::nullopt /* pin_memory */); + input_t minvalue = min; + input_t maxvalue = max; + if (min == max && self.numel() > 0) { + minvalue = *self.min().cpu().data_ptr(); + maxvalue = *self.max().cpu().data_ptr(); + } + if (minvalue == maxvalue) { + minvalue = minvalue - 1; + maxvalue = maxvalue + 1; + } + + TORCH_CHECK( + !(std::isinf(minvalue) || std::isinf(maxvalue) || std::isnan(minvalue) || + std::isnan(maxvalue)), + "range of [", + minvalue, + ", ", + maxvalue, + "] is not finite"); + + TORCH_CHECK(minvalue < maxvalue, "max must be larger than min"); + + tensor_histogram( + output, self, Tensor(), nbins, minvalue, maxvalue); + return output; +} + +Tensor _histc_kernel( + const Tensor& self, + int64_t nbins, + const Scalar& min, + const Scalar& max) { + return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "_histc_xpu", [&] { + using bounds_t = at::acc_type_device; + return _histc_template( + self, nbins, min.to(), max.to()); + }); +} + template Tensor bincount_template( const Tensor& self, diff --git a/src/ATen/native/xpu/sycl/SummaryOpsKernels.h b/src/ATen/native/xpu/sycl/SummaryOpsKernels.h index becf3ac54..87b3abff2 100644 --- a/src/ATen/native/xpu/sycl/SummaryOpsKernels.h +++ b/src/ATen/native/xpu/sycl/SummaryOpsKernels.h @@ -8,4 +8,10 @@ Tensor bincount_kernel( const Tensor& weights, int64_t minlength); +Tensor _histc_kernel( + const Tensor& self, + int64_t nbins, + const Scalar& min, + const Scalar& max); + } // namespace at::native::xpu diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 483af50f7..ad0d61475 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -207,6 +207,7 @@ "_native_batch_norm_legit", "_batch_norm_with_update", "bincount", + "histc", "cross", "renorm", "digamma", @@ -222,6 +223,7 @@ "argmin", "conj_physical", "histogram", + "histc", "repeat_interleave", "fmax", "fmin", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index db3df0667..aa7eb29a6 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -553,6 +553,8 @@ supported: - upsample_bicubic2d - upsample_bicubic2d.out - bincount + - histc + - histc.out - _embedding_bag - _embedding_bag_forward_only - _embedding_bag_backward @@ -703,6 +705,8 @@ supported: - histogram.bins_tensor_out - histogram.bin_ct - histogram.bin_ct_out + - histc + - histc.out - repeat_interleave.Tensor - norm.ScalarOpt_dim_dtype - norm.dtype_out From 5773c06d850b949592b678623a92b5bd6c63b4c2 Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 9 Aug 2024 08:53:52 +0000 Subject: [PATCH 2/7] revise --- yaml/xpu_functions.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index aa7eb29a6..e6082c9be 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -553,8 +553,6 @@ supported: - upsample_bicubic2d - upsample_bicubic2d.out - bincount - - histc - - histc.out - _embedding_bag - _embedding_bag_forward_only - _embedding_bag_backward From 4505481624c3a8fd829f8a28e312a0cf4b0aa2de Mon Sep 17 00:00:00 2001 From: yucai Date: Fri, 9 Aug 2024 09:16:09 +0000 Subject: [PATCH 3/7] revise --- test/xpu/xpu_test_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index ad0d61475..c0fbe0128 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -207,7 +207,6 @@ "_native_batch_norm_legit", "_batch_norm_with_update", "bincount", - "histc", "cross", "renorm", "digamma", From da07b9768077ad803dd8100e77d556a5ed49d9e7 Mon Sep 17 00:00:00 2001 From: yucai Date: Tue, 27 Aug 2024 02:58:23 +0000 Subject: [PATCH 4/7] revise --- src/ATen/native/xpu/SummaryOps.cpp | 3 --- src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp | 3 +++ 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/ATen/native/xpu/SummaryOps.cpp b/src/ATen/native/xpu/SummaryOps.cpp index 221a6050a..5cf398484 100644 --- a/src/ATen/native/xpu/SummaryOps.cpp +++ b/src/ATen/native/xpu/SummaryOps.cpp @@ -27,9 +27,6 @@ Tensor XPUNativeFunctions::histc( int64_t nbins, const Scalar& min, const Scalar& max) { - if (self.scalar_type() == ScalarType::Half) { - AT_ERROR("HalfTensor is not supported"); - } // See Note [Writing Nondeterministic Operations] // Nondeterministic because of atomicAdd usage globalContext().alertNotDeterministic("_histc_xpu"); diff --git a/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp b/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp index fb50da5bf..dd20683ee 100644 --- a/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp @@ -233,6 +233,9 @@ Tensor _histc_kernel( int64_t nbins, const Scalar& min, const Scalar& max) { + if (self.scalar_type() == ScalarType::Half) { + AT_ERROR("HalfTensor is not supported"); + } return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "_histc_xpu", [&] { using bounds_t = at::acc_type_device; return _histc_template( From bf340c1782ed29957b330fc4997d8119173c58e1 Mon Sep 17 00:00:00 2001 From: yucai Date: Tue, 8 Oct 2024 07:17:07 +0000 Subject: [PATCH 5/7] revise yaml --- src/ATen/native/xpu/sycl/SummaryOpsKernels.h | 2 +- yaml/xpu_functions.yaml | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/ATen/native/xpu/sycl/SummaryOpsKernels.h b/src/ATen/native/xpu/sycl/SummaryOpsKernels.h index af30fc093..d01e6bf00 100644 --- a/src/ATen/native/xpu/sycl/SummaryOpsKernels.h +++ b/src/ATen/native/xpu/sycl/SummaryOpsKernels.h @@ -6,7 +6,7 @@ namespace at::native::xpu { TORCH_XPU_API Tensor bincount_kernel(const Tensor& self, const Tensor& weights, int64_t minlength); -Tensor _histc_kernel( +TORCH_XPU_API Tensor _histc_kernel( const Tensor& self, int64_t nbins, const Scalar& min, diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 1f9221ba8..8d92fa60e 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -706,8 +706,6 @@ supported: - histogram.bins_tensor_out - histogram.bin_ct - histogram.bin_ct_out - - histc - - histc.out - repeat_interleave.Tensor - norm.ScalarOpt_dim_dtype - norm.dtype_out From 178b614cb2f55f0b5906ebbb837a1e3b2ea8101f Mon Sep 17 00:00:00 2001 From: xytintel Date: Sun, 20 Oct 2024 14:44:43 +0000 Subject: [PATCH 6/7] rebase code --- src/ATen/native/xpu/LayerNorm.cpp | 3 ++- src/ATen/native/xpu/SummaryOps.cpp | 10 ++++++---- src/ATen/native/xpu/UpSample.h | 4 ++-- src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp | 4 ++-- test/xpu/skip_list_common.py | 2 +- test/xpu/test_linalg_xpu.py | 6 ++++++ yaml/native/native_functions.yaml | 14 +++++++------- 7 files changed, 26 insertions(+), 17 deletions(-) diff --git a/src/ATen/native/xpu/LayerNorm.cpp b/src/ATen/native/xpu/LayerNorm.cpp index 0addcd718..65f828633 100644 --- a/src/ATen/native/xpu/LayerNorm.cpp +++ b/src/ATen/native/xpu/LayerNorm.cpp @@ -69,7 +69,8 @@ ::std::tuple layer_norm_xpu( for (const auto idx : c10::irange(axis)) { stat_shape.push_back(input_shape[idx]); } - for (const auto C10_UNUSED idx : c10::irange(axis, input.dim())) { + for (const auto idx : c10::irange(axis, input.dim())) { + (void)idx; stat_shape.push_back(1); } diff --git a/src/ATen/native/xpu/SummaryOps.cpp b/src/ATen/native/xpu/SummaryOps.cpp index b0d800062..953004227 100644 --- a/src/ATen/native/xpu/SummaryOps.cpp +++ b/src/ATen/native/xpu/SummaryOps.cpp @@ -3,8 +3,10 @@ #include #include + namespace at { namespace native { + Tensor _bincount_xpu( const Tensor& self, const c10::optional& weights_opt, @@ -22,9 +24,8 @@ Tensor _bincount_xpu( return native::xpu::bincount_kernel(self, weights, minlength); } -} // namespace native -Tensor XPUNativeFunctions::histc( +Tensor _histc_xpu( const Tensor& self, int64_t nbins, const Scalar& min, @@ -35,16 +36,17 @@ Tensor XPUNativeFunctions::histc( return native::xpu::_histc_kernel(self, nbins, min, max); } -Tensor& XPUNativeFunctions::histc_out( +Tensor& _histc_out_xpu( const Tensor& self, int64_t bins, const Scalar& min, const Scalar& max, Tensor& result) { - auto ret = histc(self, bins, min, max); + auto ret = _histc_xpu(self, bins, min, max); at::native::resize_output(result, ret.sizes()); result.copy_(ret); return result; } +} // namespace native } // namespace at diff --git a/src/ATen/native/xpu/UpSample.h b/src/ATen/native/xpu/UpSample.h index 44e9f5829..91050659d 100644 --- a/src/ATen/native/xpu/UpSample.h +++ b/src/ATen/native/xpu/UpSample.h @@ -12,7 +12,7 @@ namespace at::native::xpu { -inline C10_UNUSED std::array upsample_2d_common_check( +inline std::array upsample_2d_common_check( IntArrayRef input_size, IntArrayRef output_size) { TORCH_CHECK( @@ -228,7 +228,7 @@ static scalar_t upsample_get_value_bounded( return data[batch][channel][access_y][access_x]; } -static C10_UNUSED std::array upsample_1d_common_check( +inline std::array upsample_1d_common_check( IntArrayRef input_size, IntArrayRef output_size) { TORCH_CHECK( diff --git a/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp b/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp index ffb6e5bfc..8603c3b60 100644 --- a/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp +++ b/src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp @@ -206,8 +206,8 @@ Tensor _histc_template( input_t minvalue = min; input_t maxvalue = max; if (min == max && self.numel() > 0) { - minvalue = *self.min().cpu().data_ptr(); - maxvalue = *self.max().cpu().data_ptr(); + minvalue = *self.min().cpu().const_data_ptr(); + maxvalue = *self.max().cpu().const_data_ptr(); } if (minvalue == maxvalue) { minvalue = minvalue - 1; diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index ce1c07fb6..84120b9bf 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -98,6 +98,7 @@ "test_out_geqrf_xpu_float32", "test_out_narrow_copy_xpu_float32", "test_out_ormqr_xpu_float32", + "test_out_histc_cuda_float32", # XFAIL of CUDA, XPU got unexpected success "test_python_ref__refs_div_no_rounding_mode_xpu_complex32", @@ -2523,7 +2524,6 @@ "test_nondeterministic_alert_ReplicationPad3d_xpu", "test_nondeterministic_alert_grid_sample_2d_xpu", "test_nondeterministic_alert_grid_sample_3d_xpu", - "test_nondeterministic_alert_histc_xpu", "test_nondeterministic_alert_interpolate_bicubic_xpu", "test_nondeterministic_alert_interpolate_bilinear_xpu", "test_nondeterministic_alert_interpolate_trilinear_xpu", diff --git a/test/xpu/test_linalg_xpu.py b/test/xpu/test_linalg_xpu.py index 80b30b54e..3c1c8aed7 100644 --- a/test/xpu/test_linalg_xpu.py +++ b/test/xpu/test_linalg_xpu.py @@ -213,6 +213,11 @@ def matmul_small_brute_force_3d_Nd(self, device, dtype): y = make_arg(size_y, noncontiguous=nctg_y) self.check_single_matmul(x, y) +@setBlasBackendsToDefaultFinally +@unittest.skip("xpu not support ck blas library") +def ck_blas_library(self): + pass + with XPUPatchForImport(False): from test_linalg import TestLinalg @@ -227,6 +232,7 @@ def matmul_small_brute_force_3d_Nd(self, device, dtype): TestLinalg.test_matmul_small_brute_force_1d_Nd=matmul_small_brute_force_1d_Nd TestLinalg.test_matmul_small_brute_force_2d_Nd=matmul_small_brute_force_2d_Nd TestLinalg.test_matmul_small_brute_force_3d_Nd=matmul_small_brute_force_3d_Nd +TestLinalg.test_ck_blas_library = ck_blas_library TestLinalg._default_dtype_check_enabled = True instantiate_device_type_tests(TestLinalg, globals(), only_for=("xpu"), allow_xpu=True) diff --git a/yaml/native/native_functions.yaml b/yaml/native/native_functions.yaml index b169c6b59..3f913b26f 100644 --- a/yaml/native/native_functions.yaml +++ b/yaml/native/native_functions.yaml @@ -5297,14 +5297,14 @@ dispatch: XPU: replication_pad3d_backward_xpu -# - func: histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!) -# dispatch: -# XPU: histogram_histc_out +- func: histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + XPU: _histc_out_xpu -# - func: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor -# variants: method, function -# dispatch: -# XPU: histogram_histc +- func: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor + variants: method, function + dispatch: + XPU: _histc_xpu - func: histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) dispatch: From d55c914a0d9f4daff61f7511b6c39c38b48fbcf9 Mon Sep 17 00:00:00 2001 From: Yutao Xu Date: Mon, 21 Oct 2024 09:37:34 +0800 Subject: [PATCH 7/7] Update skip_list_common.py --- test/xpu/skip_list_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/xpu/skip_list_common.py b/test/xpu/skip_list_common.py index 6655b9b29..4be10452e 100644 --- a/test/xpu/skip_list_common.py +++ b/test/xpu/skip_list_common.py @@ -97,7 +97,7 @@ "test_out_geqrf_xpu_float32", "test_out_narrow_copy_xpu_float32", "test_out_ormqr_xpu_float32", - "test_out_histc_cuda_float32", + "test_out_histc_xpu_float32", # XFAIL of CUDA, XPU got unexpected success "test_python_ref__refs_div_no_rounding_mode_xpu_complex32",