Skip to content

Commit

Permalink
Add aten::norm and its variants (#556)
Browse files Browse the repository at this point in the history
Task list
- [x] norm.ScalarOpt_dim_dtype
- [x] norm.dtype_out
- [x] norm.ScalarOpt_dim
- [x] norm.out

---------

Co-authored-by: Feng Yuan <[email protected]>
  • Loading branch information
xytintel and fengyuan14 authored Jul 20, 2024
1 parent 2258cb4 commit ab8bff1
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 1 deletion.
118 changes: 118 additions & 0 deletions src/ATen/native/xpu/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,124 @@ Tensor XPUNativeFunctions::amin(
return out;
}

static ScalarType get_result_or_self_value_dtype(
const Tensor& self,
const Tensor& result,
const std::optional<ScalarType>& dtype) {
if (result.defined()) {
return result.scalar_type();
} else {
return dtype.value_or(toRealValueType(self.scalar_type()));
}
}

Tensor& norm_scalaropt_dim_dtype_meta(
const Tensor& self,
const OptionalScalarRef p,
IntArrayRef dim,
bool keepdim,
ScalarType dtype,
Tensor& result) {
TORCH_CHECK(
at::isFloatingType(dtype) || at::isComplexType(dtype),
"norm(): the desired output dtype should be either floating point or complex. "
"Got ",
dtype,
" instead.");
auto out_dtype = get_result_or_self_value_dtype(self, result, dtype);
return resize_reduction(result, self, dim, keepdim, out_dtype);
}

static void impl_func_norm(
const Tensor& self,
const OptionalScalarRef& opt_p,
IntArrayRef dim,
bool keepdim,
optional<ScalarType> opt_dtype,
const Tensor& result) {
// Left this implementation without deprecating it as it is called in a number
// of places in the codebase. We should swap those by linalg_vector_norm
auto p = opt_p.has_value() ? opt_p.get() : Scalar(2.0).to<double>();
at::linalg_vector_norm_out(
const_cast<Tensor&>(result), self, p, dim, keepdim, opt_dtype);
}

Tensor XPUNativeFunctions::norm(
const Tensor& self,
const std::optional<Scalar>& p,
IntArrayRef dim,
bool keepdim,
ScalarType dtype) {
Tensor result;
auto p_ =
(p.has_value() ? at::OptionalScalarRef(&(p.value()))
: at::OptionalScalarRef());
result = norm_scalaropt_dim_dtype_meta(self, p_, dim, keepdim, dtype, result);
impl_func_norm(self, p_, dim, keepdim, dtype, result);
return result;
}

Tensor& XPUNativeFunctions::norm_out(
const Tensor& self,
const std::optional<Scalar>& p,
IntArrayRef dim,
bool keepdim,
ScalarType dtype,
Tensor& result) {
auto p_ =
(p.has_value() ? at::OptionalScalarRef(&(p.value()))
: at::OptionalScalarRef());
result = norm_scalaropt_dim_dtype_meta(self, p_, dim, keepdim, dtype, result);
impl_func_norm(self, p_, dim, keepdim, dtype, result);
return result;
}

Tensor& norm_scalaropt_dim_meta(
const Tensor& self,
const OptionalScalarRef p,
IntArrayRef dim,
bool keepdim,
Tensor& result) {
TORCH_CHECK(
at::isFloatingType(self.scalar_type()) ||
at::isComplexType(self.scalar_type()),
"norm(): input dtype should be either floating point or complex. "
"Got ",
self.scalar_type(),
" instead.");

auto out_dtype = get_result_or_self_value_dtype(self, result, c10::nullopt);
return resize_reduction(result, self, dim, keepdim, out_dtype);
}

Tensor XPUNativeFunctions::norm(
const Tensor& self,
const std::optional<Scalar>& p,
IntArrayRef dim,
bool keepdim) {
auto p_ =
(p.has_value() ? at::OptionalScalarRef(&(p.value()))
: at::OptionalScalarRef());
Tensor result;
result = norm_scalaropt_dim_meta(self, p_, dim, keepdim, result);
impl_func_norm(self, p_, dim, keepdim, c10::nullopt, result);
return result;
}

Tensor& XPUNativeFunctions::norm_out(
const Tensor& self,
const std::optional<Scalar>& p,
IntArrayRef dim,
bool keepdim,
Tensor& result) {
auto p_ =
(p.has_value() ? at::OptionalScalarRef(&(p.value()))
: at::OptionalScalarRef());
result = norm_scalaropt_dim_meta(self, p_, dim, keepdim, result);
impl_func_norm(self, p_, dim, keepdim, c10::nullopt, result);
return result;
}

TensorIterator meta_aminmax(
const Tensor& self,
std::optional<int64_t> dim_opt,
Expand Down
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"nanmedian.dim_values",
"nansum",
"nextafter.out",
"norm.out",
"ormqr",
"_pdist_backward",
"_pdist_forward",
Expand Down
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@
"std_mean",
"var",
"var_mean",
"norm",
"hypot",
"unfold",
"uniform",
Expand Down
4 changes: 4 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -638,4 +638,8 @@ supported:
- ceil
- ceil_
- ceil.out
- norm.ScalarOpt_dim_dtype
- norm.dtype_out
- norm.ScalarOpt_dim
- norm.out
- nan_to_num.out

0 comments on commit ab8bff1

Please sign in to comment.