Skip to content

Commit

Permalink
Fix IVF-PQ API to use device_vector_view (#1384)
Browse files Browse the repository at this point in the history
This PR mainly intends to replace `device_matrix_view` for `ivf_pq::extend` to `device_vector_view`.
There are also a few updates to the documentation to reflect the current API.
The order of the arguments in the API is not touched.

Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Victor Lafargue (https://github.com/viclafargue)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1384
  • Loading branch information
lowener authored Apr 13, 2023
1 parent 35c2f1c commit 7c5b638
Show file tree
Hide file tree
Showing 12 changed files with 140 additions and 130 deletions.
23 changes: 10 additions & 13 deletions cpp/include/raft/neighbors/ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,11 @@ auto build(raft::device_resources const& handle,
* // use default search parameters
* ivf_flat::search_params search_params;
* // search K nearest neighbours for each of the N queries
* ivf_flat::search(handle, index, queries, out_inds, out_dists, search_params, k);
* ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists);
* @endcode
*
* @tparam value_t data element type
* @tparam idx_t type of the indices in the source dataset
* @tparam int_t precision / type of integral arguments
*
* @param[in] handle
* @param[in] params configure the index building
Expand Down Expand Up @@ -139,13 +138,11 @@ auto build(raft::device_resources const& handle,
* // use default search parameters
* ivf_flat::search_params search_params;
* // search K nearest neighbours for each of the N queries
* ivf_flat::search(handle, index, queries, out_inds, out_dists, search_params, k);
* ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists);
* @endcode
*
* @tparam value_t data element type
* @tparam idx_t type of the indices in the source dataset
* @tparam int_t precision / type of integral arguments
* @tparam matrix_idx_t matrix indexing type
*
* @param[in] handle
* @param[in] params configure the index building
Expand Down Expand Up @@ -232,15 +229,16 @@ auto extend(raft::device_resources const& handle,
* // train the index from a [N, D] dataset
* auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset);
* // fill the index with the data
* auto index = ivf_flat::extend(handle, index_empty, dataset);
* std::optional<raft::device_vector_view<const idx_t, idx_t>> no_op = std::nullopt;
* auto index = ivf_flat::extend(handle, index_empty, no_op, dataset);
* @endcode
*
* @tparam value_t data element type
* @tparam idx_t type of the indices in the source dataset
*
* @param[in] handle
* @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices optional raft::device_matrix_view to a vector of indices [n_rows].
* @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt`
* here to imply a continuous range `[0...n_rows)`.
* @param[in] orig_index original index
Expand Down Expand Up @@ -314,7 +312,7 @@ void extend(raft::device_resources const& handle,
* index_params.add_data_on_build = false; // don't populate index on build
* index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training
* // train the index from a [N, D] dataset
* auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset);
* auto index_empty = ivf_flat::build(handle, index_params, dataset);
* // fill the index with the data
* std::optional<raft::device_vector_view<const idx_t, idx_t>> no_op = std::nullopt;
* ivf_flat::extend(handle, dataset, no_opt, &index_empty);
Expand All @@ -325,7 +323,7 @@ void extend(raft::device_resources const& handle,
*
* @param[in] handle
* @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()]
* @param[in] new_indices optional raft::device_matrix_view to a vector of indices [n_rows].
* @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows].
* If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt`
* here to imply a continuous range `[0...n_rows)`.
* @param[inout] index pointer to index, to be overwritten in-place
Expand Down Expand Up @@ -422,15 +420,14 @@ void search(raft::device_resources const& handle,
* ivf_flat::search_params search_params;
* // Use the same allocator across multiple searches to reduce the number of
* // cuda memory allocations
* ivf_flat::search(handle, index, queries1, out_inds1, out_dists1, search_params, K);
* ivf_flat::search(handle, index, queries2, out_inds2, out_dists2, search_params, K);
* ivf_flat::search(handle, index, queries3, out_inds3, out_dists3, search_params, K);
* ivf_flat::search(handle, search_params, index, queries1, out_inds1, out_dists1);
* ivf_flat::search(handle, search_params, index, queries2, out_inds2, out_dists2);
* ivf_flat::search(handle, search_params, index, queries3, out_inds3, out_dists3);
* ...
* @endcode
*
* @tparam value_t data element type
* @tparam idx_t type of the indices
* @tparam int_t precision / type of integral arguments
*
* @param[in] handle
* @param[in] params configure the search
Expand Down
8 changes: 4 additions & 4 deletions cpp/include/raft/neighbors/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,15 @@ index<IdxT> build(raft::device_resources const& handle,
*
* @param[in] handle
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* @param[in] new_indices a device vector view to a vector of indices [n_rows].
* If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt`
* here to imply a continuous range `[0...n_rows)`.
* @param[inout] idx
*/
template <typename T, typename IdxT>
index<IdxT> extend(raft::device_resources const& handle,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices,
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices,
const index<IdxT>& idx)
{
ASSERT(new_vectors.extent(1) == idx.dim(),
Expand All @@ -104,15 +104,15 @@ index<IdxT> extend(raft::device_resources const& handle,
*
* @param[in] handle
* @param[in] new_vectors a device matrix view to a row-major matrix [n_rows, idx.dim()]
* @param[in] new_indices a device matrix view to a vector of indices [n_rows].
* @param[in] new_indices a device vector view to a vector of indices [n_rows].
* If the original index is empty (`idx.size() == 0`), you can pass `std::nullopt`
* here to imply a continuous range `[0...n_rows)`.
* @param[inout] idx
*/
template <typename T, typename IdxT>
void extend(raft::device_resources const& handle,
raft::device_matrix_view<const T, IdxT, row_major> new_vectors,
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices,
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices,
index<IdxT>* idx)
{
ASSERT(new_vectors.extent(1) == idx->dim(),
Expand Down
34 changes: 16 additions & 18 deletions cpp/include/raft/neighbors/specializations/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,24 +35,22 @@ namespace raft::neighbors::ivf_pq {
// We define overloads for build and extend with void return type. This is used in the Cython
// wrappers, where exception handling is not compatible with return type that has nontrivial
// constructor.
#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \
extern template auto build(raft::device_resources const&, \
const raft::neighbors::ivf_pq::index_params&, \
raft::device_matrix_view<const T, IdxT, row_major>) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
extern template auto extend( \
raft::device_resources const&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>>, \
const raft::neighbors::ivf_pq::index<IdxT>&) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
extern template void extend( \
raft::device_resources const&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>>, \
raft::neighbors::ivf_pq::index<IdxT>*);
#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \
extern template auto build(raft::device_resources const&, \
const raft::neighbors::ivf_pq::index_params&, \
raft::device_matrix_view<const T, IdxT, row_major>) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
extern template auto extend(raft::device_resources const&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
std::optional<raft::device_vector_view<const IdxT, IdxT>>, \
const raft::neighbors::ivf_pq::index<IdxT>&) \
->raft::neighbors::ivf_pq::index<IdxT>; \
\
extern template void extend(raft::device_resources const&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
std::optional<raft::device_vector_view<const IdxT, IdxT>>, \
raft::neighbors::ivf_pq::index<IdxT>*);

RAFT_DECL_BUILD_EXTEND(float, int64_t)
RAFT_DECL_BUILD_EXTEND(int8_t, int64_t)
Expand Down
40 changes: 20 additions & 20 deletions cpp/include/raft_runtime/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,26 @@ namespace raft::runtime::neighbors::ivf_pq {
// We define overloads for build and extend with void return type. This is used in the Cython
// wrappers, where exception handling is not compatible with return type that has nontrivial
// constructor.
#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \
[[nodiscard]] raft::neighbors::ivf_pq::index<IdxT> build( \
raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset); \
\
void build(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset, \
raft::neighbors::ivf_pq::index<IdxT>* idx); \
\
[[nodiscard]] raft::neighbors::ivf_pq::index<IdxT> extend( \
raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices, \
const raft::neighbors::ivf_pq::index<IdxT>& idx); \
\
void extend(raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices, \
#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \
[[nodiscard]] raft::neighbors::ivf_pq::index<IdxT> build( \
raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset); \
\
void build(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset, \
raft::neighbors::ivf_pq::index<IdxT>* idx); \
\
[[nodiscard]] raft::neighbors::ivf_pq::index<IdxT> extend( \
raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
const raft::neighbors::ivf_pq::index<IdxT>& idx); \
\
void extend(raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
raft::neighbors::ivf_pq::index<IdxT>* idx);

RAFT_DECL_BUILD_EXTEND(float, int64_t);
Expand Down
58 changes: 29 additions & 29 deletions cpp/src/neighbors/ivfpq_build.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,35 @@

namespace raft::runtime::neighbors::ivf_pq {

#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
raft::neighbors::ivf_pq::index<IdxT> build( \
raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset) \
{ \
return raft::neighbors::ivf_pq::build<T, IdxT>(handle, params, dataset); \
} \
void build(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset, \
raft::neighbors::ivf_pq::index<IdxT>* idx) \
{ \
*idx = raft::neighbors::ivf_pq::build<T, IdxT>(handle, params, dataset); \
} \
raft::neighbors::ivf_pq::index<IdxT> extend( \
raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices, \
const raft::neighbors::ivf_pq::index<IdxT>& idx) \
{ \
return raft::neighbors::ivf_pq::extend<T, IdxT>(handle, new_vectors, new_indices, idx); \
} \
void extend(raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_matrix_view<const IdxT, IdxT, row_major>> new_indices, \
raft::neighbors::ivf_pq::index<IdxT>* idx) \
{ \
raft::neighbors::ivf_pq::extend<T, IdxT>(handle, new_vectors, new_indices, idx); \
#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
raft::neighbors::ivf_pq::index<IdxT> build( \
raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset) \
{ \
return raft::neighbors::ivf_pq::build<T, IdxT>(handle, params, dataset); \
} \
void build(raft::device_resources const& handle, \
const raft::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset, \
raft::neighbors::ivf_pq::index<IdxT>* idx) \
{ \
*idx = raft::neighbors::ivf_pq::build<T, IdxT>(handle, params, dataset); \
} \
raft::neighbors::ivf_pq::index<IdxT> extend( \
raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
const raft::neighbors::ivf_pq::index<IdxT>& idx) \
{ \
return raft::neighbors::ivf_pq::extend<T, IdxT>(handle, new_vectors, new_indices, idx); \
} \
void extend(raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
raft::neighbors::ivf_pq::index<IdxT>* idx) \
{ \
raft::neighbors::ivf_pq::extend<T, IdxT>(handle, new_vectors, new_indices, idx); \
}

RAFT_INST_BUILD_EXTEND(float, int64_t);
Expand Down
Loading

0 comments on commit 7c5b638

Please sign in to comment.