Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten::round/round.decimals and thieir variants #647

Merged
merged 6 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions src/ATen/native/xpu/UnaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,76 @@ Tensor& XPUNativeFunctions::ceil_out(const Tensor& self, Tensor& out) {
return out;
}


Tensor XPUNativeFunctions::round(const Tensor& self) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/false)) {
return self.clone();
}
Tensor out;
TensorIterator iter;
iter.build_borrowing_unary_op(out, self);
native::xpu::round_kernel(iter);
return iter.output();
}

Tensor& XPUNativeFunctions::round_(Tensor& self) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/false)) {
return self;
}
TensorIterator iter;
iter.build_borrowing_unary_op(self, self);
native::xpu::round_kernel(iter);
return self;
}

Tensor& XPUNativeFunctions::round_out(const Tensor& self, Tensor& out) {
if (c10::isIntegralType(self.scalar_type(), /*includeBool=*/false)) {
out.copy_(self);
return out;
}
TensorIterator iter;
iter.build_borrowing_unary_op(out, self);
native::xpu::round_kernel(iter);
return out;
}

Tensor XPUNativeFunctions::round(const Tensor& self, int64_t decimals) {
Tensor out;
TensorIterator iter;
iter.build_borrowing_unary_op(out, self);
if (decimals != 0) {
native::xpu::round_decimals_kernel(iter, decimals);
} else {
native::xpu::round_kernel(iter);
}
return iter.output();
}

Tensor& XPUNativeFunctions::round_(Tensor& self, int64_t decimals) {
TensorIterator iter;
iter.build_borrowing_unary_op(self, self);
if (decimals != 0) {
native::xpu::round_decimals_kernel(iter, decimals);
} else {
native::xpu::round_kernel(iter);
}
return self;
}

Tensor& XPUNativeFunctions::round_out(
const Tensor& self,
int64_t decimals,
Tensor& out) {
TensorIterator iter;
iter.build_borrowing_unary_op(out, self);
if (decimals != 0) {
native::xpu::round_decimals_kernel(iter, decimals);
} else {
native::xpu::round_kernel(iter);
}
return out;
}

TensorIterator meta_floor(const Tensor& self, Tensor& out) {
// Note: this is consistent with NumPy
TORCH_CHECK(!self.is_complex(), "floor is not supported for complex inputs");
Expand Down
2 changes: 0 additions & 2 deletions src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"prod",
"prod.int_out",
"put_",
"round.decimals_out",
"round.out",
"rrelu_with_noise",
"__rshift__.Scalar",
"_scaled_dot_product_efficient_attention",
Expand Down
67 changes: 67 additions & 0 deletions src/ATen/native/xpu/sycl/UnaryFractionKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,73 @@ void ceil_kernel(TensorIteratorBase& iter) {
});
}

template <typename scalar_t>
inline scalar_t nearbyint_wrapper(scalar_t a) {
return static_cast<scalar_t>(std::nearbyintf(static_cast<float>(a)));
}

inline double nearbyint_wrapper(double a) {
return std::nearbyint(a);
}

#pragma push
inline c10::complex<float> nearbyint_wrapper(c10::complex<float> a) {
return c10::complex<float>(
std::nearbyintf(static_cast<float>(a.real())),
std::nearbyintf(static_cast<float>(a.imag())));
}

inline c10::complex<double> nearbyint_wrapper(c10::complex<double> a) {
return c10::complex<double>(
std::nearbyint(static_cast<double>(a.real())),
std::nearbyint(static_cast<double>(a.imag())));
}
#pragma pop

template <typename scalar_t>
struct RoundFunctor {
scalar_t operator()(scalar_t a) const {
return nearbyint_wrapper(a);
}
};

template <typename scalar_t>
struct RoundDecimalsFunctor {
scalar_t operator()(scalar_t a) const {
return neg_flag_
? std::nearbyint(a / ten_pow_decimals_) * ten_pow_decimals_
: std::nearbyint(a * ten_pow_decimals_) / ten_pow_decimals_;
}
RoundDecimalsFunctor(scalar_t ten_pow_decimals, bool neg_flag)
: ten_pow_decimals_(ten_pow_decimals), neg_flag_(neg_flag) {}

private:
scalar_t ten_pow_decimals_;
bool neg_flag_;
};

void round_kernel(TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "round_xpu", [&]() {
gpu_kernel(iter, RoundFunctor<scalar_t>());
});
}

void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) {
AT_DISPATCH_FLOATING_TYPES_AND2(
ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "round_xpu", [&]() {
bool neg_flag = false;
scalar_t ten_pow_decimals;
if (decimals < 0) {
decimals = -decimals;
neg_flag = true;
}
ten_pow_decimals = static_cast<scalar_t>(std::pow(10, decimals));
gpu_kernel(
iter, RoundDecimalsFunctor<scalar_t>(ten_pow_decimals, neg_flag));
});
}

template <typename scalar_t>
struct FloorFunctor {
scalar_t operator()(scalar_t a) const {
Expand Down
4 changes: 4 additions & 0 deletions src/ATen/native/xpu/sycl/UnaryFractionKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ void floor_kernel(TensorIteratorBase& iter);

void ceil_kernel(TensorIteratorBase& iter);

void round_kernel(TensorIteratorBase& iter);

void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals);

void frac_kernel(TensorIteratorBase& iter);

} // namespace at::native::xpu
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@
"sigmoid",
"logsigmoid",
"sgn",
"round",
"nn.functional.embedding_bag",
"bucketize",
"searchsorted",
Expand Down
6 changes: 6 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,12 @@ supported:
- ceil
- ceil_
- ceil.out
- round
- round_
- round.out
- round.decimals
- round_.decimals
- round.decimals_out
- histogram.bins_tensor
- histogram.bins_tensor_out
- histogram.bin_ct
Expand Down