Skip to content

Commit

Permalink
Add aten::histc and its variants (#739)
Browse files Browse the repository at this point in the history
Co-authored-by: Yutao Xu <[email protected]>
  • Loading branch information
yucai-intel and xytintel authored Oct 21, 2024
1 parent d507290 commit 46177ff
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 10 deletions.
28 changes: 27 additions & 1 deletion src/ATen/native/xpu/SummaryOps.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
#include <ATen/native/Resize.h>
#include <ATen/native/xpu/sycl/SummaryOpsKernels.h>
#include <comm/SYCLContext.h>

#include <xpu/ATen/ops/bincount_native.h>

namespace at {
namespace native {

Tensor _bincount_xpu(
const Tensor& self,
const c10::optional<Tensor>& weights_opt,
Expand All @@ -21,6 +24,29 @@ Tensor _bincount_xpu(

return native::xpu::bincount_kernel(self, weights, minlength);
}
} // namespace native

Tensor _histc_xpu(
const Tensor& self,
int64_t nbins,
const Scalar& min,
const Scalar& max) {
// See Note [Writing Nondeterministic Operations]
// Nondeterministic because of atomicAdd usage
globalContext().alertNotDeterministic("_histc_xpu");
return native::xpu::_histc_kernel(self, nbins, min, max);
}

Tensor& _histc_out_xpu(
const Tensor& self,
int64_t bins,
const Scalar& min,
const Scalar& max,
Tensor& result) {
auto ret = _histc_xpu(self, bins, min, max);
at::native::resize_output(result, ret.sizes());
result.copy_(ret);
return result;
}

} // namespace native
} // namespace at
1 change: 0 additions & 1 deletion src/ATen/native/xpu/XPUFallback.template
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ TORCH_LIBRARY_IMPL(aten, XPU, m) {
"_fused_moving_avg_obs_fq_helper",
"geqrf",
"heaviside.out",
"histc",
"i0.out",
"igammac.out",
"igamma.out",
Expand Down
58 changes: 58 additions & 0 deletions src/ATen/native/xpu/sycl/SummaryOpsKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,64 @@ void tensor_histogram(

return;
}

template <typename input_t>
Tensor _histc_template(
const Tensor& self,
int64_t nbins,
at::acc_type_device<input_t, kXPU> min,
at::acc_type_device<input_t, kXPU> max) {
if (nbins <= 0) {
AT_ERROR("bins must be > 0");
}
Tensor output = at::zeros(
{nbins},
self.scalar_type(),
std::nullopt /* layout */,
DeviceType::XPU,
std::nullopt /* pin_memory */);
input_t minvalue = min;
input_t maxvalue = max;
if (min == max && self.numel() > 0) {
minvalue = *self.min().cpu().const_data_ptr<input_t>();
maxvalue = *self.max().cpu().const_data_ptr<input_t>();
}
if (minvalue == maxvalue) {
minvalue = minvalue - 1;
maxvalue = maxvalue + 1;
}

TORCH_CHECK(
!(std::isinf(minvalue) || std::isinf(maxvalue) || std::isnan(minvalue) ||
std::isnan(maxvalue)),
"range of [",
minvalue,
", ",
maxvalue,
"] is not finite");

TORCH_CHECK(minvalue < maxvalue, "max must be larger than min");

tensor_histogram<input_t, input_t, false>(
output, self, Tensor(), nbins, minvalue, maxvalue);
return output;
}

Tensor _histc_kernel(
const Tensor& self,
int64_t nbins,
const Scalar& min,
const Scalar& max) {
if (self.scalar_type() == ScalarType::Half) {
AT_ERROR("HalfTensor is not supported");
}
return AT_DISPATCH_ALL_TYPES(self.scalar_type(), "_histc_xpu", [&] {
using bounds_t = at::acc_type_device<scalar_t, kXPU>;
return _histc_template<scalar_t>(
self, nbins, min.to<bounds_t>(), max.to<bounds_t>());
});
}

template <typename input_t, typename weights_t>
Tensor bincount_template(
const Tensor& self,
Expand Down
6 changes: 6 additions & 0 deletions src/ATen/native/xpu/sycl/SummaryOpsKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,10 @@ namespace at::native::xpu {
TORCH_XPU_API Tensor
bincount_kernel(const Tensor& self, const Tensor& weights, int64_t minlength);

TORCH_XPU_API Tensor _histc_kernel(
const Tensor& self,
int64_t nbins,
const Scalar& min,
const Scalar& max);

} // namespace at::native::xpu
2 changes: 1 addition & 1 deletion test/xpu/skip_list_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
"test_out_geqrf_xpu_float32",
"test_out_narrow_copy_xpu_float32",
"test_out_ormqr_xpu_float32",
"test_out_histc_xpu_float32",

# XFAIL of CUDA, XPU got unexpected success
"test_python_ref__refs_div_no_rounding_mode_xpu_complex32",
Expand Down Expand Up @@ -2522,7 +2523,6 @@
"test_nondeterministic_alert_ReplicationPad3d_xpu",
"test_nondeterministic_alert_grid_sample_2d_xpu",
"test_nondeterministic_alert_grid_sample_3d_xpu",
"test_nondeterministic_alert_histc_xpu",
"test_nondeterministic_alert_interpolate_bicubic_xpu",
"test_nondeterministic_alert_interpolate_bilinear_xpu",
"test_nondeterministic_alert_interpolate_trilinear_xpu",
Expand Down
1 change: 1 addition & 0 deletions test/xpu/xpu_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@
"angle",
"conj_physical",
"histogram",
"histc",
"repeat_interleave",
"fmax",
"fmin",
Expand Down
14 changes: 7 additions & 7 deletions yaml/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5311,14 +5311,14 @@
dispatch:
XPU: replication_pad3d_backward_xpu

# - func: histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!)
# dispatch:
# XPU: histogram_histc_out
- func: histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
XPU: _histc_out_xpu

# - func: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor
# variants: method, function
# dispatch:
# XPU: histogram_histc
- func: histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor
variants: method, function
dispatch:
XPU: _histc_xpu

- func: histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)
dispatch:
Expand Down

0 comments on commit 46177ff

Please sign in to comment.