From cf2326f0a2a95b82db8654de5e04eb43d04cc701 Mon Sep 17 00:00:00 2001 From: xytintel Date: Wed, 10 Jul 2024 06:15:10 +0000 Subject: [PATCH 1/2] add norm and its variants --- src/ATen/native/xpu/ReduceOps.cpp | 121 ++++++++++++++++++++++- src/ATen/native/xpu/XPUFallback.template | 1 - test/xpu/xpu_test_utils.py | 1 + yaml/xpu_functions.yaml | 4 + 4 files changed, 125 insertions(+), 2 deletions(-) diff --git a/src/ATen/native/xpu/ReduceOps.cpp b/src/ATen/native/xpu/ReduceOps.cpp index f9b512723..69a520790 100644 --- a/src/ATen/native/xpu/ReduceOps.cpp +++ b/src/ATen/native/xpu/ReduceOps.cpp @@ -725,7 +725,7 @@ Tensor XPUNativeFunctions::argmax( return out; } -static Tensor amax_amin_meta( +static Tensor& amax_amin_meta( Tensor& result, const char* name, const Tensor& self, @@ -803,4 +803,123 @@ Tensor XPUNativeFunctions::amin( return out; } +static ScalarType get_result_or_self_value_dtype( + const Tensor& self, + const Tensor& result, + const std::optional& dtype) { + if (result.defined()) { + return result.scalar_type(); + } else { + return dtype.value_or(toRealValueType(self.scalar_type())); + } +} + +Tensor& norm_scalaropt_dim_dtype_meta( + const Tensor& self, + const OptionalScalarRef p, + IntArrayRef dim, + bool keepdim, + ScalarType dtype, + Tensor& result) { + TORCH_CHECK( + at::isFloatingType(dtype) || at::isComplexType(dtype), + "norm(): the desired output dtype should be either floating point or complex. " + "Got ", + dtype, + " instead."); + auto out_dtype = get_result_or_self_value_dtype(self, result, dtype); + return resize_reduction(result, self, dim, keepdim, out_dtype); +} + +static void impl_func_norm( + const Tensor& self, + const OptionalScalarRef& opt_p, + IntArrayRef dim, + bool keepdim, + optional opt_dtype, + const Tensor& result) { + // Left this implementation without deprecating it as it is called in a number + // of places in the codebase. We should swap those by linalg_vector_norm + auto p = opt_p.has_value() ? opt_p.get() : Scalar(2.0).to(); + at::linalg_vector_norm_out( + const_cast(result), self, p, dim, keepdim, opt_dtype); +} + +// wrapper_CUDA_norm_ScalarOpt_dim_dtype +Tensor XPUNativeFunctions::norm( + const Tensor& self, + const std::optional& p, + IntArrayRef dim, + bool keepdim, + ScalarType dtype) { + Tensor result; + auto p_ = + (p.has_value() ? at::OptionalScalarRef(&(p.value())) + : at::OptionalScalarRef()); + result = norm_scalaropt_dim_dtype_meta(self, p_, dim, keepdim, dtype, result); + impl_func_norm(self, p_, dim, keepdim, dtype, result); + return result; +} + +Tensor& XPUNativeFunctions::norm_out( + const Tensor& self, + const std::optional& p, + IntArrayRef dim, + bool keepdim, + ScalarType dtype, + Tensor& result) { + auto p_ = + (p.has_value() ? at::OptionalScalarRef(&(p.value())) + : at::OptionalScalarRef()); + result = norm_scalaropt_dim_dtype_meta(self, p_, dim, keepdim, dtype, result); + impl_func_norm(self, p_, dim, keepdim, dtype, result); + return result; +} + +Tensor& norm_scalaropt_dim_meta( + const Tensor& self, + const OptionalScalarRef p, + IntArrayRef dim, + bool keepdim, + Tensor& result) { + TORCH_CHECK( + at::isFloatingType(self.scalar_type()) || + at::isComplexType(self.scalar_type()), + "norm(): input dtype should be either floating point or complex. " + "Got ", + self.scalar_type(), + " instead."); + + auto out_dtype = get_result_or_self_value_dtype(self, result, c10::nullopt); + return resize_reduction(result, self, dim, keepdim, out_dtype); +} + +Tensor XPUNativeFunctions::norm( + const Tensor& self, + const std::optional& p, + IntArrayRef dim, + bool keepdim) { + auto p_ = + (p.has_value() ? at::OptionalScalarRef(&(p.value())) + : at::OptionalScalarRef()); + Tensor result; + result = norm_scalaropt_dim_meta(self, p_, dim, keepdim, result); + impl_func_norm(self, p_, dim, keepdim, c10::nullopt, result); + return result; +} + +Tensor& XPUNativeFunctions::norm_out( + const Tensor& self, + const std::optional& p, + IntArrayRef dim, + bool keepdim, + Tensor& result) { + auto p_ = + (p.has_value() ? at::OptionalScalarRef(&(p.value())) + : at::OptionalScalarRef()); + result = norm_scalaropt_dim_meta(self, p_, dim, keepdim, result); + impl_func_norm(self, p_, dim, keepdim, c10::nullopt, result); + return result; +} + } // namespace at diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index fa7fafe13..ed7ba84e1 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -292,7 +292,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "nansum", "nan_to_num.out", "nextafter.out", - "norm.out", "ormqr", "_pdist_backward", "_pdist_forward", diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index 35c29d96b..b3613adc6 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -157,6 +157,7 @@ "renorm", "lerp", "conj_physical", + "norm", ] diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index 2ecc6790b..bd2c2ea97 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -511,3 +511,7 @@ supported: - ceil - ceil_ - ceil.out + - norm.ScalarOpt_dim_dtype + - norm.dtype_out + - norm.ScalarOpt_dim + - norm.out From 6557d9fff589f8dd6b5685ba92659a3f5c5f5eba Mon Sep 17 00:00:00 2001 From: Feng Yuan Date: Sat, 20 Jul 2024 14:04:00 +0800 Subject: [PATCH 2/2] Remove unnecessary comments --- src/ATen/native/xpu/ReduceOps.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ATen/native/xpu/ReduceOps.cpp b/src/ATen/native/xpu/ReduceOps.cpp index dc0971025..a8fd993ee 100644 --- a/src/ATen/native/xpu/ReduceOps.cpp +++ b/src/ATen/native/xpu/ReduceOps.cpp @@ -912,7 +912,6 @@ static void impl_func_norm( const_cast(result), self, p, dim, keepdim, opt_dtype); } -// wrapper_CUDA_norm_ScalarOpt_dim_dtype Tensor XPUNativeFunctions::norm( const Tensor& self, const std::optional& p,