Skip to content

Commit

Permalink
Rename optional helper function (#1345)
Browse files Browse the repository at this point in the history
Authors:
  - Victor Lafargue (https://github.com/viclafargue)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1345
  • Loading branch information
viclafargue authored Mar 16, 2023
1 parent 9f2a64f commit 8386807
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/common/mdspan.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ cdef device_matrix_view[int8_t, int64_t, row_major] get_dmv_int8(
cdef device_matrix_view[int64_t, int64_t, row_major] get_dmv_int64(
array, check_shape) except *

cdef optional[device_matrix_view[int64_t, int64_t, row_major]] create_optional(
cdef optional[device_matrix_view[int64_t, int64_t, row_major]] make_optional_view_int64( # noqa: E501
device_matrix_view[int64_t, int64_t, row_major]& dmv) except *
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/common/mdspan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -194,5 +194,5 @@ cdef device_matrix_view[int64_t, int64_t, row_major] \


cdef optional[device_matrix_view[int64_t, int64_t, row_major]] \
create_optional(device_matrix_view[int64_t, int64_t, row_major]& dmv) except *: # noqa: E501
make_optional_view_int64(device_matrix_view[int64_t, int64_t, row_major]& dmv) except *: # noqa: E501
return make_optional[device_matrix_view[int64_t, int64_t, row_major]](dmv)
8 changes: 4 additions & 4 deletions python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ from rmm._lib.memory_resource cimport (
cimport pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq as c_ivf_pq
from pylibraft.common.cpp.mdspan cimport device_matrix_view, row_major
from pylibraft.common.mdspan cimport (
create_optional,
get_dmv_float,
get_dmv_int8,
get_dmv_int64,
get_dmv_uint8,
make_optional_view_int64,
)
from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport (
index_params,
Expand Down Expand Up @@ -521,19 +521,19 @@ def extend(Index index, new_vectors, new_indices, handle=None):
with cuda_interruptible():
c_ivf_pq.extend(deref(handle_),
get_dmv_float(vecs_cai, check_shape=True),
create_optional(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501
make_optional_view_int64(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501
index.index)
elif vecs_dt == np.int8:
with cuda_interruptible():
c_ivf_pq.extend(deref(handle_),
get_dmv_int8(vecs_cai, check_shape=True),
create_optional(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501
make_optional_view_int64(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501
index.index)
elif vecs_dt == np.uint8:
with cuda_interruptible():
c_ivf_pq.extend(deref(handle_),
get_dmv_uint8(vecs_cai, check_shape=True),
create_optional(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501
make_optional_view_int64(get_dmv_int64(idx_cai, check_shape=False)), # noqa: E501
index.index)
else:
raise TypeError("query dtype %s not supported" % vecs_dt)
Expand Down

0 comments on commit 8386807

Please sign in to comment.