diff --git a/cpp/src/knn/knn.cu b/cpp/src/knn/knn.cu index 4bf9116d28..449336ef20 100644 --- a/cpp/src/knn/knn.cu +++ b/cpp/src/knn/knn.cu @@ -193,6 +193,7 @@ void approx_knn_build_index(raft::handle_t& handle, auto index_view = raft::make_device_matrix_view(index_array, n, D); if (ivf_ft_pams) { + index->nprobe = ivf_ft_pams->nprobe; cuvs::neighbors::ivf_flat::index_params params; params.metric = static_cast(metric); params.metric_arg = metricArg; @@ -201,7 +202,7 @@ void approx_knn_build_index(raft::handle_t& handle, index->ivf_flat = std::make_unique>( cuvs::neighbors::ivf_flat::build(handle, params, index_view)); } else if (ivf_pq_pams) { - index->nprobe = dynamic_cast(params)->nprobe; + index->nprobe = ivf_pq_pams->nprobe; cuvs::neighbors::ivf_pq::index_params params; params.metric = static_cast(metric); params.metric_arg = metricArg; @@ -225,18 +226,20 @@ void approx_knn_search(raft::handle_t& handle, float* query_array, int n) { - auto query_view = - raft::make_device_matrix_view(query_array, n, index->ivf_pq->dim()); auto indices_view = raft::make_device_matrix_view(indices, n, k); auto distances_view = raft::make_device_matrix_view(distances, n, k); if (index->ivf_flat) { + auto query_view = + raft::make_device_matrix_view(query_array, n, index->ivf_flat->dim()); cuvs::neighbors::ivf_flat::search_params params; params.n_probes = index->nprobe; cuvs::neighbors::ivf_flat::search( handle, params, *index->ivf_flat, query_view, indices_view, distances_view); } else if (index->ivf_pq) { + auto query_view = + raft::make_device_matrix_view(query_array, n, index->ivf_pq->dim()); cuvs::neighbors::ivf_pq::search_params params; params.n_probes = index->nprobe;