Skip to content

Commit 088beb5

Browse files
Add more checks and test
1 parent d26851e commit 088beb5

File tree

4 files changed

+65
-14
lines changed

4 files changed

+65
-14
lines changed

dpnp/backend/extensions/sycl_ext/histogram_common.cpp

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,20 @@
2424
//*****************************************************************************
2525

2626
#include <algorithm>
27+
#include <limits>
2728
#include <string>
2829
#include <unordered_map>
2930
#include <vector>
3031

3132
#include "dpctl4pybind11.hpp"
33+
#include "utils/type_dispatch.hpp"
3234
#include <pybind11/pybind11.h>
3335

3436
#include "histogram_common.hpp"
3537

38+
namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
3639
using dpctl::tensor::usm_ndarray;
40+
using dpctl_td_ns::typenum_t;
3741

3842
namespace histogram
3943
{
@@ -45,9 +49,8 @@ void validate(const usm_ndarray &sample,
4549
{
4650
auto exec_q = sample.get_queue();
4751
using array_ptr = const usm_ndarray *;
48-
using array_list = std::vector<array_ptr>;
4952

50-
array_list arrays{&sample, &bins, &histogram};
53+
std::vector<array_ptr> arrays{&sample, &bins, &histogram};
5154
std::unordered_map<array_ptr, std::string> names = {
5255
{arrays[0], "sample"}, {arrays[1], "bins"}, {arrays[2], "histogram"}};
5356

@@ -94,10 +97,21 @@ void validate(const usm_ndarray &sample,
9497
std::to_string(histogram.get_ndim()) + "d");
9598
}
9699

97-
if (weights_ptr && weights_ptr->get_ndim() != 1) {
98-
throw py::value_error(get_name(weights_ptr) +
99-
" parameter must be 1d. Actual " +
100-
std::to_string(weights_ptr->get_ndim()) + "d");
100+
if (weights_ptr) {
101+
if (weights_ptr->get_ndim() != 1) {
102+
throw py::value_error(
103+
get_name(weights_ptr) + " parameter must be 1d. Actual " +
104+
std::to_string(weights_ptr->get_ndim()) + "d");
105+
}
106+
107+
auto sample_size = sample.get_size();
108+
auto weights_size = weights_ptr->get_size();
109+
if (sample.get_size() != weights_ptr->get_size()) {
110+
throw py::value_error(
111+
get_name(&sample) + " size (" + std::to_string(sample_size) +
112+
") and " + get_name(weights_ptr) + " size (" +
113+
std::to_string(weights_size) + ")" + " must match");
114+
}
101115
}
102116

103117
if (sample.get_ndim() > 2) {
@@ -143,6 +157,30 @@ void validate(const usm_ndarray &sample,
143157
" expected to have size = " + std::to_string(expected_hist_size) +
144158
". Actual " + std::to_string(histogram.get_size()));
145159
}
160+
161+
int64_t max_hist_size = std::numeric_limits<uint32_t>::max() - 1;
162+
if (histogram.get_size() > max_hist_size) {
163+
throw py::value_error(get_name(&histogram) +
164+
" parameter size expected to be less than " +
165+
std::to_string(max_hist_size) + ". Actual " +
166+
std::to_string(histogram.get_size()));
167+
}
168+
169+
auto array_types = dpctl_td_ns::usm_ndarray_types();
170+
auto hist_type = static_cast<typenum_t>(
171+
array_types.typenum_to_lookup_id(histogram.get_typenum()));
172+
if (histogram.get_elemsize() == 8 && hist_type != typenum_t::CFLOAT) {
173+
auto device = exec_q.get_device();
174+
bool _64bit_atomics = device.has(sycl::aspect::atomic64);
175+
176+
if (!_64bit_atomics) {
177+
auto device_name = device.get_info<sycl::info::device::name>();
178+
throw py::value_error(
179+
get_name(&histogram) +
180+
" parameter has 64-bit type, but 64-bit atomics " +
181+
" are not supported for " + device_name);
182+
}
183+
}
146184
}
147185

148186
} // namespace histogram

dpnp/backend/extensions/sycl_ext/histogram_common.hpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ struct IsNan
333333
static bool isnan(const T &v)
334334
{
335335
if constexpr (std::is_floating_point<T>::value) {
336-
return std::isnan(v);
336+
return sycl::isnan(v);
337337
}
338338

339339
return false;
@@ -347,7 +347,7 @@ struct IsNan<std::complex<T>>
347347
{
348348
T real1 = std::real(v);
349349
T imag1 = std::imag(v);
350-
return std::isnan(real1) || std::isnan(imag1);
350+
return sycl::isnan(real1) || sycl::isnan(imag1);
351351
}
352352
};
353353

@@ -418,9 +418,9 @@ void submit_histogram(const T *in,
418418
size_t size,
419419
size_t dims,
420420
uint32_t WorkPI,
421-
HistImpl &hist,
422-
Edges &edges,
423-
Weights &weights,
421+
const HistImpl &hist,
422+
const Edges &edges,
423+
const Weights &weights,
424424
sycl::nd_range<1> nd_range,
425425
sycl::handler &cgh)
426426
{

dpnp/dpnp_iface_histograms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def _result_type(dtype1, dtype2, has_fp64):
329329
return rt
330330

331331

332-
def _aling_dtypes(a_dtype, bins_dtype, ntype, has_fp64):
332+
def _align_dtypes(a_dtype, bins_dtype, ntype, has_fp64):
333333
a_bin_dtype = _result_type(a_dtype, bins_dtype, has_fp64)
334334

335335
supported_types = (dpnp.float32, dpnp.int64, dpnp.uint64, dpnp.complex64)
@@ -465,7 +465,7 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
465465
queue = a.sycl_queue
466466
has_fp64 = queue.sycl_device.has_aspect_fp64
467467

468-
a_bin_dtype, hist_dtype = _aling_dtypes(
468+
a_bin_dtype, hist_dtype = _align_dtypes(
469469
a.dtype, bin_edges.dtype, ntype, has_fp64
470470
)
471471

@@ -485,7 +485,7 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
485485
else None
486486
)
487487

488-
n_usm_type = "shared" if usm_type == "host" else usm_type
488+
n_usm_type = "device" if usm_type == "host" else usm_type
489489
n_casted = dpnp.zeros(
490490
bin_edges.size - 1,
491491
dtype=hist_dtype,

tests/test_histogram.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,19 @@ def test_weights_another_sycl_queue(self):
492492
with assert_raises(ValueError):
493493
dpnp.histogram(v, weights=w)
494494

495+
@pytest.mark.parametrize(
496+
"bins_count",
497+
[10, 10**2, 10**3, 10**4, 10**5, 10**6],
498+
)
499+
def test_different_bins_amount(self, bins_count):
500+
v = numpy.linspace(0, bins_count, bins_count, dtype=numpy.float32)
501+
iv = dpnp.array(v)
502+
503+
expected_hist, expected_edges = numpy.histogram(v, bins=bins_count)
504+
result_hist, result_edges = dpnp.histogram(iv, bins=bins_count)
505+
assert_array_equal(result_hist, expected_hist)
506+
assert_allclose(result_edges, expected_edges)
507+
495508

496509
class TestHistogramBinEdges:
497510
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)