Skip to content

Commit

Permalink
Add more checks and test
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderKalistratov committed Sep 13, 2024
1 parent d26851e commit 088beb5
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 14 deletions.
50 changes: 44 additions & 6 deletions dpnp/backend/extensions/sycl_ext/histogram_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,20 @@
//*****************************************************************************

#include <algorithm>
#include <limits>
#include <string>
#include <unordered_map>
#include <vector>

#include "dpctl4pybind11.hpp"
#include "utils/type_dispatch.hpp"
#include <pybind11/pybind11.h>

#include "histogram_common.hpp"

namespace dpctl_td_ns = dpctl::tensor::type_dispatch;
using dpctl::tensor::usm_ndarray;
using dpctl_td_ns::typenum_t;

namespace histogram
{
Expand All @@ -45,9 +49,8 @@ void validate(const usm_ndarray &sample,
{
auto exec_q = sample.get_queue();
using array_ptr = const usm_ndarray *;
using array_list = std::vector<array_ptr>;

array_list arrays{&sample, &bins, &histogram};
std::vector<array_ptr> arrays{&sample, &bins, &histogram};
std::unordered_map<array_ptr, std::string> names = {
{arrays[0], "sample"}, {arrays[1], "bins"}, {arrays[2], "histogram"}};

Expand Down Expand Up @@ -94,10 +97,21 @@ void validate(const usm_ndarray &sample,
std::to_string(histogram.get_ndim()) + "d");
}

if (weights_ptr && weights_ptr->get_ndim() != 1) {
throw py::value_error(get_name(weights_ptr) +
" parameter must be 1d. Actual " +
std::to_string(weights_ptr->get_ndim()) + "d");
if (weights_ptr) {
if (weights_ptr->get_ndim() != 1) {
throw py::value_error(
get_name(weights_ptr) + " parameter must be 1d. Actual " +
std::to_string(weights_ptr->get_ndim()) + "d");
}

auto sample_size = sample.get_size();
auto weights_size = weights_ptr->get_size();
if (sample.get_size() != weights_ptr->get_size()) {
throw py::value_error(
get_name(&sample) + " size (" + std::to_string(sample_size) +
") and " + get_name(weights_ptr) + " size (" +
std::to_string(weights_size) + ")" + " must match");
}
}

if (sample.get_ndim() > 2) {
Expand Down Expand Up @@ -143,6 +157,30 @@ void validate(const usm_ndarray &sample,
" expected to have size = " + std::to_string(expected_hist_size) +
". Actual " + std::to_string(histogram.get_size()));
}

int64_t max_hist_size = std::numeric_limits<uint32_t>::max() - 1;
if (histogram.get_size() > max_hist_size) {
throw py::value_error(get_name(&histogram) +
" parameter size expected to be less than " +
std::to_string(max_hist_size) + ". Actual " +
std::to_string(histogram.get_size()));
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
auto hist_type = static_cast<typenum_t>(
array_types.typenum_to_lookup_id(histogram.get_typenum()));
if (histogram.get_elemsize() == 8 && hist_type != typenum_t::CFLOAT) {
auto device = exec_q.get_device();
bool _64bit_atomics = device.has(sycl::aspect::atomic64);

if (!_64bit_atomics) {
auto device_name = device.get_info<sycl::info::device::name>();
throw py::value_error(
get_name(&histogram) +
" parameter has 64-bit type, but 64-bit atomics " +
" are not supported for " + device_name);
}
}
}

} // namespace histogram
10 changes: 5 additions & 5 deletions dpnp/backend/extensions/sycl_ext/histogram_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ struct IsNan
static bool isnan(const T &v)
{
if constexpr (std::is_floating_point<T>::value) {
return std::isnan(v);
return sycl::isnan(v);
}

return false;
Expand All @@ -347,7 +347,7 @@ struct IsNan<std::complex<T>>
{
T real1 = std::real(v);
T imag1 = std::imag(v);
return std::isnan(real1) || std::isnan(imag1);
return sycl::isnan(real1) || sycl::isnan(imag1);
}
};

Expand Down Expand Up @@ -418,9 +418,9 @@ void submit_histogram(const T *in,
size_t size,
size_t dims,
uint32_t WorkPI,
HistImpl &hist,
Edges &edges,
Weights &weights,
const HistImpl &hist,
const Edges &edges,
const Weights &weights,
sycl::nd_range<1> nd_range,
sycl::handler &cgh)
{
Expand Down
6 changes: 3 additions & 3 deletions dpnp/dpnp_iface_histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def _result_type(dtype1, dtype2, has_fp64):
return rt


def _aling_dtypes(a_dtype, bins_dtype, ntype, has_fp64):
def _align_dtypes(a_dtype, bins_dtype, ntype, has_fp64):
a_bin_dtype = _result_type(a_dtype, bins_dtype, has_fp64)

supported_types = (dpnp.float32, dpnp.int64, dpnp.uint64, dpnp.complex64)
Expand Down Expand Up @@ -465,7 +465,7 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
queue = a.sycl_queue
has_fp64 = queue.sycl_device.has_aspect_fp64

a_bin_dtype, hist_dtype = _aling_dtypes(
a_bin_dtype, hist_dtype = _align_dtypes(
a.dtype, bin_edges.dtype, ntype, has_fp64
)

Expand All @@ -485,7 +485,7 @@ def histogram(a, bins=10, range=None, density=None, weights=None):
else None
)

n_usm_type = "shared" if usm_type == "host" else usm_type
n_usm_type = "device" if usm_type == "host" else usm_type
n_casted = dpnp.zeros(
bin_edges.size - 1,
dtype=hist_dtype,
Expand Down
13 changes: 13 additions & 0 deletions tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,19 @@ def test_weights_another_sycl_queue(self):
with assert_raises(ValueError):
dpnp.histogram(v, weights=w)

@pytest.mark.parametrize(
"bins_count",
[10, 10**2, 10**3, 10**4, 10**5, 10**6],
)
def test_different_bins_amount(self, bins_count):
v = numpy.linspace(0, bins_count, bins_count, dtype=numpy.float32)
iv = dpnp.array(v)

expected_hist, expected_edges = numpy.histogram(v, bins=bins_count)
result_hist, result_edges = dpnp.histogram(iv, bins=bins_count)
assert_array_equal(result_hist, expected_hist)
assert_allclose(result_edges, expected_edges)


class TestHistogramBinEdges:
@pytest.mark.parametrize(
Expand Down

0 comments on commit 088beb5

Please sign in to comment.