diff --git a/src/ATen/native/xpu/ReduceOps.cpp b/src/ATen/native/xpu/ReduceOps.cpp index 14e60d6fe..a8fd993ee 100644 --- a/src/ATen/native/xpu/ReduceOps.cpp +++ b/src/ATen/native/xpu/ReduceOps.cpp @@ -870,6 +870,124 @@ Tensor XPUNativeFunctions::amin( return out; } +static ScalarType get_result_or_self_value_dtype( + const Tensor& self, + const Tensor& result, + const std::optional& dtype) { + if (result.defined()) { + return result.scalar_type(); + } else { + return dtype.value_or(toRealValueType(self.scalar_type())); + } +} + +Tensor& norm_scalaropt_dim_dtype_meta( + const Tensor& self, + const OptionalScalarRef p, + IntArrayRef dim, + bool keepdim, + ScalarType dtype, + Tensor& result) { + TORCH_CHECK( + at::isFloatingType(dtype) || at::isComplexType(dtype), + "norm(): the desired output dtype should be either floating point or complex. " + "Got ", + dtype, + " instead."); + auto out_dtype = get_result_or_self_value_dtype(self, result, dtype); + return resize_reduction(result, self, dim, keepdim, out_dtype); +} + +static void impl_func_norm( + const Tensor& self, + const OptionalScalarRef& opt_p, + IntArrayRef dim, + bool keepdim, + optional opt_dtype, + const Tensor& result) { + // Left this implementation without deprecating it as it is called in a number + // of places in the codebase. We should swap those by linalg_vector_norm + auto p = opt_p.has_value() ? opt_p.get() : Scalar(2.0).to(); + at::linalg_vector_norm_out( + const_cast(result), self, p, dim, keepdim, opt_dtype); +} + +Tensor XPUNativeFunctions::norm( + const Tensor& self, + const std::optional& p, + IntArrayRef dim, + bool keepdim, + ScalarType dtype) { + Tensor result; + auto p_ = + (p.has_value() ? at::OptionalScalarRef(&(p.value())) + : at::OptionalScalarRef()); + result = norm_scalaropt_dim_dtype_meta(self, p_, dim, keepdim, dtype, result); + impl_func_norm(self, p_, dim, keepdim, dtype, result); + return result; +} + +Tensor& XPUNativeFunctions::norm_out( + const Tensor& self, + const std::optional& p, + IntArrayRef dim, + bool keepdim, + ScalarType dtype, + Tensor& result) { + auto p_ = + (p.has_value() ? at::OptionalScalarRef(&(p.value())) + : at::OptionalScalarRef()); + result = norm_scalaropt_dim_dtype_meta(self, p_, dim, keepdim, dtype, result); + impl_func_norm(self, p_, dim, keepdim, dtype, result); + return result; +} + +Tensor& norm_scalaropt_dim_meta( + const Tensor& self, + const OptionalScalarRef p, + IntArrayRef dim, + bool keepdim, + Tensor& result) { + TORCH_CHECK( + at::isFloatingType(self.scalar_type()) || + at::isComplexType(self.scalar_type()), + "norm(): input dtype should be either floating point or complex. " + "Got ", + self.scalar_type(), + " instead."); + + auto out_dtype = get_result_or_self_value_dtype(self, result, c10::nullopt); + return resize_reduction(result, self, dim, keepdim, out_dtype); +} + +Tensor XPUNativeFunctions::norm( + const Tensor& self, + const std::optional& p, + IntArrayRef dim, + bool keepdim) { + auto p_ = + (p.has_value() ? at::OptionalScalarRef(&(p.value())) + : at::OptionalScalarRef()); + Tensor result; + result = norm_scalaropt_dim_meta(self, p_, dim, keepdim, result); + impl_func_norm(self, p_, dim, keepdim, c10::nullopt, result); + return result; +} + +Tensor& XPUNativeFunctions::norm_out( + const Tensor& self, + const std::optional& p, + IntArrayRef dim, + bool keepdim, + Tensor& result) { + auto p_ = + (p.has_value() ? at::OptionalScalarRef(&(p.value())) + : at::OptionalScalarRef()); + result = norm_scalaropt_dim_meta(self, p_, dim, keepdim, result); + impl_func_norm(self, p_, dim, keepdim, c10::nullopt, result); + return result; +} + TensorIterator meta_aminmax( const Tensor& self, std::optional dim_opt, diff --git a/src/ATen/native/xpu/XPUFallback.template b/src/ATen/native/xpu/XPUFallback.template index 75274d4cf..8cd4f27bb 100644 --- a/src/ATen/native/xpu/XPUFallback.template +++ b/src/ATen/native/xpu/XPUFallback.template @@ -244,7 +244,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) { "nanmedian.dim_values", "nansum", "nextafter.out", - "norm.out", "ormqr", "_pdist_backward", "_pdist_forward", diff --git a/test/xpu/xpu_test_utils.py b/test/xpu/xpu_test_utils.py index c2f0f15f1..905dae530 100644 --- a/test/xpu/xpu_test_utils.py +++ b/test/xpu/xpu_test_utils.py @@ -141,6 +141,7 @@ "std_mean", "var", "var_mean", + "norm", "hypot", "unfold", "uniform", diff --git a/yaml/xpu_functions.yaml b/yaml/xpu_functions.yaml index b2990fb8a..e30d8814b 100644 --- a/yaml/xpu_functions.yaml +++ b/yaml/xpu_functions.yaml @@ -638,4 +638,8 @@ supported: - ceil - ceil_ - ceil.out + - norm.ScalarOpt_dim_dtype + - norm.dtype_out + - norm.ScalarOpt_dim + - norm.out - nan_to_num.out