Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten::histc and its variants #739

Merged
merged 13 commits into from
Oct 21, 2024
27 changes: 27 additions & 0 deletions src/ATen/native/xpu/SummaryOps.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include <ATen/native/Resize.h>
#include <ATen/native/xpu/sycl/SummaryOpsKernels.h>
#include <ATen/xpu/XPUNativeFunctions.h>
#include <comm/SYCLContext.h>
Expand All @@ -21,4 +22,30 @@ Tensor XPUNativeFunctions::bincount(
return native::xpu::bincount_kernel(self, weights, minlength);
}

Tensor XPUNativeFunctions::histc(
const Tensor& self,
int64_t nbins,
const Scalar& min,
const Scalar& max) {
if (self.scalar_type() == ScalarType::Half) {
AT_ERROR("HalfTensor is not supported");
}
// See Note [Writing Nondeterministic Operations]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's move deterministic check to kernel level, which is related to kernel algorithm

// Nondeterministic because of atomicAdd usage
globalContext().alertNotDeterministic("_histc_xpu");
return native::xpu::_histc_kernel(self, nbins, min, max);
}

Tensor& XPUNativeFunctions::histc_out(
const Tensor& self,
int64_t bins,
const Scalar& min,
const Scalar& max,
Tensor& result) {
auto ret = histc(self, bins, min, max);
at::native::resize_output(result, ret.sizes());
result.copy_(ret);
return result;
}

} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"hardshrink_backward.grad_input",
"hardshrink.out",
"heaviside.out",
"histc",
"i0.out",
"igammac.out",
"igamma.out",
Expand Down
55 changes: 55 additions & 0 deletions src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,61 @@ void tensor_histogram(

return;
}

template <typename input_t>
Tensor _histc_template(
const Tensor& self,
int64_t nbins,
at::acc_type_device<input_t, kXPU> min,
at::acc_type_device<input_t, kXPU> max) {
if (nbins <= 0) {
AT_ERROR("bins must be > 0");
}
Tensor output = at::zeros(
{nbins},
self.scalar_type(),
std::nullopt /* layout */,
DeviceType::XPU,
std::nullopt /* pin_memory */);
input_t minvalue = min;
input_t maxvalue = max;
if (min == max && self.numel() > 0) {
minvalue = *self.min().cpu().data_ptr<input_t>();
maxvalue = *self.max().cpu().data_ptr<input_t>();
}
if (minvalue == maxvalue) {
minvalue = minvalue - 1;
maxvalue = maxvalue + 1;
}

TORCH_CHECK(
!(std::isinf(minvalue) || std::isinf(maxvalue) || std::isnan(minvalue) ||
std::isnan(maxvalue)),
"range of [",
minvalue,
", ",
maxvalue,
"] is not finite");

TORCH_CHECK(minvalue < maxvalue, "max must be larger than min");

tensor_histogram<input_t, input_t, false>(
output, self, Tensor(), nbins, minvalue, maxvalue);
return output;
}

Tensor _histc_kernel(
const Tensor& self,
int64_t nbins,
const Scalar& min,
const Scalar& max) {
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "_histc_xpu", [&] {
using bounds_t = at::acc_type_device<scalar_t, kXPU>;
return _histc_template<scalar_t>(
self, nbins, min.to<bounds_t>(), max.to<bounds_t>());
});
}

template <typename input_t, typename weights_t>
Tensor bincount_template(
const Tensor& self,
Expand Down
6 changes: 6 additions & 0 deletions src/ATen/native/xpu/sycl/SummaryOpsKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,10 @@ Tensor bincount_kernel(
const Tensor& weights,
int64_t minlength);

Tensor _histc_kernel(
const Tensor& self,
int64_t nbins,
const Scalar& min,
const Scalar& max);

} // namespace at::native::xpu
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@
"argmin",
"conj_physical",
"histogram",
"histc",
"repeat_interleave",
"fmax",
"fmin",
Expand Down
2 changes: 2 additions & 0 deletions yaml/xpu_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,8 @@ supported:
- histogram.bins_tensor_out
- histogram.bin_ct
- histogram.bin_ct_out
- histc
- histc.out
- repeat_interleave.Tensor
- norm.ScalarOpt_dim_dtype
- norm.dtype_out
Expand Down
Loading