Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: strides is now a std::ptrdiff_t #190

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions include/libpy/ndarray_view.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@ slice_impl(const T& view, std::int64_t start, std::int64_t stop, std::int64_t st
}

std::int64_t size = (low >= high) ? 0 : (high - low - 1) / adj_step + 1;
std::int64_t stride = view.strides()[0] * step;
std::ptrdiff_t stride = view.strides()[0] * step;

return {static_cast<std::size_t>(start), {static_cast<std::size_t>(size)}, {stride}};
return {static_cast<std::size_t>(start),
{static_cast<std::size_t>(size)},
{static_cast<std::ptrdiff_t>(stride)}};
}
} // namespace detail

Expand All @@ -66,7 +68,7 @@ class ndarray_view {
friend class ndarray_view;

std::array<std::size_t, ndim> m_shape;
std::array<std::int64_t, ndim> m_strides;
std::array<std::ptrdiff_t, ndim> m_strides;
buffer_type m_buffer;

std::ptrdiff_t pos_to_index(const std::array<std::size_t, ndim>& pos) const {
Expand All @@ -79,7 +81,7 @@ class ndarray_view {

ndarray_view(buffer_type buffer,
const std::array<std::size_t, ndim> shape,
const std::array<std::int64_t, ndim>& strides)
const std::array<std::ptrdiff_t, ndim>& strides)
: m_shape(shape), m_strides(strides), m_buffer(buffer) {}

public:
Expand Down Expand Up @@ -120,7 +122,7 @@ class ndarray_view {
}

std::array<std::size_t, ndim> shape;
std::array<std::int64_t, ndim> strides;
std::array<std::ptrdiff_t, ndim> strides;
for (int ix = 0; ix < buf->ndim; ++ix) {
shape[ix] = static_cast<std::size_t>(buf->shape[ix]);
strides[ix] = static_cast<std::int64_t>(buf->strides[ix]);
Expand Down Expand Up @@ -167,7 +169,7 @@ class ndarray_view {
*/
ndarray_view(T* buffer,
const std::array<std::size_t, ndim> shape,
const std::array<std::int64_t, ndim>& strides)
const std::array<std::ptrdiff_t, ndim>& strides)
: ndarray_view(reinterpret_cast<buffer_type>(buffer), shape, strides) {}

/** Access the element at the given index with bounds checking.
Expand Down Expand Up @@ -228,7 +230,7 @@ class ndarray_view {

/** The number of bytes to go from one element to the next.
*/
const std::array<std::int64_t, ndim>& strides() const {
const std::array<std::ptrdiff_t, ndim>& strides() const {
return m_strides;
}

Expand Down Expand Up @@ -591,7 +593,7 @@ class any_ref_ndarray_view {
friend class any_ref_ndarray_view;

std::array<std::size_t, ndim> m_shape;
std::array<std::int64_t, ndim> m_strides;
std::array<std::ptrdiff_t, ndim> m_strides;
buffer_type m_buffer;
any_vtable m_vtable;

Expand Down Expand Up @@ -697,10 +699,10 @@ class any_ref_ndarray_view {
}

std::array<std::size_t, ndim> shape;
std::array<std::int64_t, ndim> strides;
std::array<std::ptrdiff_t, ndim> strides;
for (int ix = 0; ix < buf->ndim; ++ix) {
shape[ix] = static_cast<std::size_t>(buf->shape[ix]);
strides[ix] = static_cast<std::int64_t>(buf->strides[ix]);
strides[ix] = static_cast<std::ptrdiff_t>(buf->strides[ix]);
}

return {ndarray_view<T, ndim>{static_cast<buffer_type>(buf->buf),
Expand Down Expand Up @@ -746,7 +748,7 @@ class any_ref_ndarray_view {

any_ref_ndarray_view(buffer_type buffer,
const std::array<std::size_t, ndim> shape,
const std::array<std::int64_t, ndim>& strides,
const std::array<std::ptrdiff_t, ndim>& strides,
const py::any_vtable& vtable)
: m_shape(shape), m_strides(strides), m_buffer(buffer), m_vtable(vtable) {}

Expand All @@ -763,7 +765,7 @@ class any_ref_ndarray_view {
template<typename U>
any_ref_ndarray_view(U* buffer,
const std::array<std::size_t, ndim> shape,
const std::array<std::int64_t, ndim>& strides)
const std::array<std::ptrdiff_t, ndim>& strides)
: any_ref_ndarray_view(reinterpret_cast<buffer_type>(buffer),
shape,
strides,
Expand Down Expand Up @@ -827,7 +829,7 @@ class any_ref_ndarray_view {

/** The number of bytes to go from one element to the next.
*/
const std::array<std::int64_t, ndim>& strides() const {
const std::array<std::ptrdiff_t, ndim>& strides() const {
return m_strides;
}

Expand Down Expand Up @@ -1410,7 +1412,7 @@ struct from_object<ndarray_view<T, ndim>> {
}

std::array<std::size_t, ndim> shape{0};
std::array<std::int64_t, ndim> strides{0};
std::array<std::ptrdiff_t, ndim> strides{0};

std::copy_n(PyArray_SHAPE(array), ndim, shape.begin());
std::copy_n(PyArray_STRIDES(array), ndim, strides.begin());
Expand Down