Skip to content

Commit

Permalink
fix ivf-fla
Browse files Browse the repository at this point in the history
  • Loading branch information
benfred committed Sep 27, 2024
1 parent dc1f7a6 commit fdd18c5
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions cpp/src/knn/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ void approx_knn_build_index(raft::handle_t& handle,
auto index_view = raft::make_device_matrix_view<const float, int64_t>(index_array, n, D);

if (ivf_ft_pams) {
index->nprobe = ivf_ft_pams->nprobe;
cuvs::neighbors::ivf_flat::index_params params;
params.metric = static_cast<cuvs::distance::DistanceType>(metric);
params.metric_arg = metricArg;
Expand All @@ -201,7 +202,7 @@ void approx_knn_build_index(raft::handle_t& handle,
index->ivf_flat = std::make_unique<cuvs::neighbors::ivf_flat::index<float, int64_t>>(
cuvs::neighbors::ivf_flat::build(handle, params, index_view));
} else if (ivf_pq_pams) {
index->nprobe = dynamic_cast<const IVFParam*>(params)->nprobe;
index->nprobe = ivf_pq_pams->nprobe;
cuvs::neighbors::ivf_pq::index_params params;
params.metric = static_cast<cuvs::distance::DistanceType>(metric);
params.metric_arg = metricArg;
Expand All @@ -225,18 +226,20 @@ void approx_knn_search(raft::handle_t& handle,
float* query_array,
int n)
{
auto query_view =
raft::make_device_matrix_view<const float, int64_t>(query_array, n, index->ivf_pq->dim());
auto indices_view = raft::make_device_matrix_view<int64_t, int64_t>(indices, n, k);
auto distances_view = raft::make_device_matrix_view<float, int64_t>(distances, n, k);

if (index->ivf_flat) {
auto query_view =
raft::make_device_matrix_view<const float, int64_t>(query_array, n, index->ivf_flat->dim());
cuvs::neighbors::ivf_flat::search_params params;
params.n_probes = index->nprobe;

cuvs::neighbors::ivf_flat::search(
handle, params, *index->ivf_flat, query_view, indices_view, distances_view);
} else if (index->ivf_pq) {
auto query_view =
raft::make_device_matrix_view<const float, int64_t>(query_array, n, index->ivf_pq->dim());
cuvs::neighbors::ivf_pq::search_params params;
params.n_probes = index->nprobe;

Expand Down

0 comments on commit fdd18c5

Please sign in to comment.