From 0f167375c1b9d7eb392b66c8d8b874a32a43de98 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 24 May 2024 17:20:25 -0700 Subject: [PATCH 01/15] init --- cpp/src/cluster/detail/kmeans_balanced.cuh | 45 ++++++++++++++--- cpp/src/cluster/kmeans_balanced.cuh | 19 ++++--- cpp/src/neighbors/ivf_common.cuh | 19 +++++++ cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh | 49 +++++++++++++++++-- .../neighbors/ivf_flat/ivf_flat_search.cuh | 31 ++++++++++-- 5 files changed, 138 insertions(+), 25 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index a09b17532..0e42a672c 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -33,6 +33,7 @@ #include #include #include +#include #include #include #include @@ -141,12 +142,14 @@ inline std::enable_if_t> predict_core( raft::compose_op, raft::key_op>()); break; } + case cuvs::distance::DistanceType::CosineExpanded: case cuvs::distance::DistanceType::InnerProduct: { // TODO: pass buffer rmm::device_uvector distances(n_rows * n_clusters, stream, mr); MathT alpha = -1.0; - MathT beta = 0.0; + if (params.metric == cuvs::distance::DistanceType::CosineExpanded) alpha = 1.0; + MathT beta = 0.0; raft::linalg::gemm(handle, true, @@ -163,6 +166,14 @@ inline std::enable_if_t> predict_core( distances.data(), n_clusters, stream); + if (dataset_norm) { + raft::linalg::binary_div_skip_zero( + handle, + raft::make_device_matrix_view( + distances.data(), n_rows, n_clusters), + raft::make_device_vector_view(dataset_norm, n_rows), + raft::linalg::Apply::ALONG_COLUMNS); + } auto distances_const_view = raft::make_device_matrix_view( distances.data(), n_rows, n_clusters); @@ -268,7 +279,8 @@ void calc_centers_and_sizes(const raft::resources& handle, const LabelT* labels, bool reset_counters, MappingOpT mapping_op, - rmm::device_async_resource_ref mr) + rmm::device_async_resource_ref mr, + const T* dataset_norm = nullptr) { auto stream = raft::resource::get_cuda_stream(handle); @@ -290,13 +302,31 @@ void calc_centers_and_sizes(const raft::resources& handle, // Apply mapping only when the data and math types are different. if constexpr (std::is_same_v) { - raft::linalg::reduce_rows_by_key( - dataset, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); + raft::linalg::reduce_rows_by_key(dataset, + dim, + labels, + dataset_norm, + nullptr, + n_rows, + dim, + n_clusters, + centers, + stream, + reset_counters); } else { // todo(lsugy): use iterator from KV output of fusedL2NN cub::TransformInputIterator mapping_itr(dataset, mapping_op); - raft::linalg::reduce_rows_by_key( - mapping_itr, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); + raft::linalg::reduce_rows_by_key(mapping_itr, + dim, + labels, + dataset_norm, + nullptr, + n_rows, + dim, + n_clusters, + centers, + stream, + reset_counters); } // Compute weight of each cluster @@ -394,7 +424,8 @@ void predict(const raft::resources& handle, std::is_same_v ? 0 : max_minibatch_size * dim, stream, mem_res); bool need_compute_norm = dataset_norm == nullptr && (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded); + params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || + params.metric == cuvs::distance::DistanceType::CosineExpanded); rmm::device_uvector cur_dataset_norm( need_compute_norm ? max_minibatch_size : 0, stream, mem_res); const MathT* dataset_norm_ptr = nullptr; diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index 040d17b36..f8e43af67 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -335,13 +335,15 @@ template -void calc_centers_and_sizes(const raft::resources& handle, - raft::device_matrix_view X, - raft::device_vector_view labels, - raft::device_matrix_view centroids, - raft::device_vector_view cluster_sizes, - bool reset_counters = true, - MappingOpT mapping_op = raft::identity_op()) +void calc_centers_and_sizes( + const raft::resources& handle, + raft::device_matrix_view X, + raft::device_vector_view labels, + raft::device_matrix_view centroids, + raft::device_vector_view cluster_sizes, + bool reset_counters = true, + MappingOpT mapping_op = raft::identity_op(), + std::optional> X_norm = std::nullopt) { RAFT_EXPECTS(X.extent(0) == labels.extent(0), "Number of rows in dataset and labels are different"); @@ -361,7 +363,8 @@ void calc_centers_and_sizes(const raft::resources& handle, labels.data_handle(), reset_counters, mapping_op, - raft::resource::get_workspace_resource(handle)); + raft::resource::get_workspace_resource(handle), + X_norm.has_value() ? X_norm.value().data_handle() : nullptr); } } // namespace helpers diff --git a/cpp/src/neighbors/ivf_common.cuh b/cpp/src/neighbors/ivf_common.cuh index 60d43bed6..c6e9be11f 100644 --- a/cpp/src/neighbors/ivf_common.cuh +++ b/cpp/src/neighbors/ivf_common.cuh @@ -267,6 +267,25 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk] raft::linalg::unaryOp(out, in, len, raft::cast_op{}, stream); } } break; + case distance::DistanceType::CosineExpanded: { + float factor = (account_for_max_close ? 1.0 : -1.0) * scaling_factor * scaling_factor; + if (factor != 1.0) { + raft::linalg::unaryOp(out, + in, + len, + raft::compose_op(raft::mul_const_op{factor}, + raft::add_const_op{1}, + raft::cast_op{}), + stream); + } else { + raft::linalg::unaryOp( + out, + in, + len, + raft::compose_op(raft::add_const_op{1}, raft::cast_op{}), + stream); + } + } break; default: RAFT_FAIL("Unexpected metric."); } } diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index 811f9c231..98f28ddb9 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -126,7 +126,8 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, IdxT n_rows, uint32_t dim, uint32_t veclen, - IdxT batch_offset = 0) + IdxT batch_offset = 0, + const T* source_vecs_norms = nullptr) { const IdxT i = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x; if (i >= n_rows) { return; } @@ -138,6 +139,7 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, // Record the source vector id in the index list_index[inlist_id] = source_ixs == nullptr ? i + batch_offset : source_ixs[i]; + const auto vec_norm = source_vecs_norms == nullptr ? T(1) : source_vecs_norms[i]; // The data is written in interleaved groups of `index::kGroupSize` vectors using interleaved_group = raft::Pow2; @@ -157,7 +159,7 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, // NB: such `veclen` is selected, that `dim % veclen == 0` for (uint32_t l = 0; l < dim; l += veclen) { for (uint32_t j = 0; j < veclen; j++) { - list_data[l * kIndexGroupSize + ingroup_id + j] = source_vecs[l + j]; + list_data[l * kIndexGroupSize + ingroup_id + j] = source_vecs[l + j] * vec_norm; } } } @@ -230,13 +232,27 @@ void extend(raft::resources const& handle, raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); auto batch_labels_view = raft::make_device_vector_view( new_labels.data_handle() + batch.offset(), batch.size()); + auto batch_vectors_norms = raft::make_device_vector(handle, batch.size()); + std::optional> batch_vectors_norms_opt = std::nullopt; + // Normalize if necessary (Cosine) + if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm(handle, + batch_data_view, + batch_vectors_norms.view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op()); + batch_vectors_norms_opt = + std::make_optional(raft::make_const_mdspan(batch_vectors_norms.view())); + } cuvs::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, batch_data_view, batch_labels_view, centroids_view, list_sizes_view, false, - utils::mapping{}); + utils::mapping{}, + batch_vectors_norms_opt); } } else { raft::stats::histogram(raft::stats::HistTypeAuto, @@ -280,6 +296,18 @@ void extend(raft::resources const& handle, for (const auto& batch : vec_batches) { auto batch_data_view = raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); + auto batch_vectors_norms = raft::make_device_vector(handle, batch.size()); + T* batch_vectors_norms_ptr = nullptr; + // Normalize if necessary (Cosine) + if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::norm(handle, + batch_data_view, + batch_vectors_norms.view(), + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op()); + batch_vectors_norms_ptr = batch_vectors_norms.data_handle(); + } // Kernel to insert the new vectors const dim3 block_dim(256); const dim3 grid_dim(raft::ceildiv(batch.size(), block_dim.x)); @@ -293,7 +321,8 @@ void extend(raft::resources const& handle, batch.size(), dim, index->veclen(), - batch.offset()); + batch.offset(), + batch_vectors_norms_ptr); RAFT_CUDA_TRY(cudaPeekAtLastError()); if (batch.offset() > next_report_offset) { @@ -387,7 +416,17 @@ inline auto build(raft::resources const& handle, index.centers().data_handle(), index.n_lists(), index.dim()); cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = static_cast(index.metric()); + kmeans_params.metric = index.metric(); + if (kmeans_params.metric == cuvs::distance::DistanceType::CosineExpanded) + kmeans_params.metric = cuvs::distance::DistanceType::InnerProduct; + // Normalize if necessary (Cosine) + if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::row_normalize( + handle, + trainset_const_view, + raft::make_device_matrix_view(trainset.data(), n_rows_train, index.dim()), + raft::linalg::NormType::L2Norm); + } cuvs::cluster::kmeans_balanced::fit( handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); } diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index f55e99c18..01ee913b7 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -30,6 +30,7 @@ #include // raft::resources #include // raft::linalg::gemm #include // raft::linalg::norm +#include // raft::linalg::row_normalize #include // raft::linalg::unary_op #include // matrix::detail::select_k @@ -81,6 +82,19 @@ void search_impl(raft::resources const& handle, rmm::device_uvector neighbors_uint32_buf(0, stream, search_mr); size_t float_query_size; + auto compute_metric = index.metric(); + rmm::device_uvector normalized_queries_dev(0, stream, search_mr); + const T* queries_ptr = queries; + + if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) { + normalized_queries_dev.resize(n_queries * index.dim(), stream); + raft::linalg::row_normalize( + handle, + raft::make_device_matrix_view(queries, n_queries, index.dim()), + raft::make_device_matrix_view(normalized_queries_dev.data(), n_queries, index.dim()), + raft::linalg::NormType::L2Norm); + queries_ptr = normalized_queries_dev.data(); + } if constexpr (std::is_integral_v) { float_query_size = n_queries * index.dim(); } else { @@ -90,10 +104,10 @@ void search_impl(raft::resources const& handle, float* converted_queries_ptr = converted_queries_dev.data(); if constexpr (std::is_same_v) { - converted_queries_ptr = const_cast(queries); + converted_queries_ptr = const_cast(queries_ptr); } else { raft::linalg::unaryOp( - converted_queries_ptr, queries, n_queries * index.dim(), utils::mapping{}, stream); + converted_queries_ptr, queries_ptr, n_queries * index.dim(), utils::mapping{}, stream); } float alpha = 1.0f; @@ -122,6 +136,13 @@ void search_impl(raft::resources const& handle, RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); break; } + case cuvs::distance::DistanceType::CosineExpanded: { + compute_metric = cuvs::distance::DistanceType::InnerProduct; + select_min = false; + alpha = 1.0f; + beta = 0.0f; + break; + } default: { alpha = 1.0f; beta = 0.0f; @@ -166,7 +187,7 @@ void search_impl(raft::resources const& handle, nullptr, n_queries, queries_offset, - index.metric(), + compute_metric, n_probes, k, 0, @@ -217,11 +238,11 @@ void search_impl(raft::resources const& handle, ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( index, - queries, + queries_ptr, coarse_indices_dev.data(), n_queries, queries_offset, - index.metric(), + compute_metric, n_probes, k, max_samples, From 7fa5d7652bc016709c8b0ad4408a2d49bb2671dc Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 6 Jun 2024 08:43:28 -0700 Subject: [PATCH 02/15] add dataset_norm to kmeans balanced --- cpp/src/cluster/detail/kmeans_balanced.cuh | 17 ++-- cpp/src/cluster/kmeans_balanced.cuh | 9 +- cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh | 88 +++++++++++------ .../neighbors/ivf_flat/ivf_flat_search.cuh | 2 +- cpp/test/neighbors/ann_ivf_flat.cuh | 94 +++++++++++++++++++ .../ann_ivf_flat/test_float_int64_t.cu | 6 +- 6 files changed, 178 insertions(+), 38 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 0e42a672c..5703cbe28 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -280,7 +280,7 @@ void calc_centers_and_sizes(const raft::resources& handle, bool reset_counters, MappingOpT mapping_op, rmm::device_async_resource_ref mr, - const T* dataset_norm = nullptr) + const MathT* dataset_norm = nullptr) { auto stream = raft::resource::get_cuda_stream(handle); @@ -566,7 +566,8 @@ auto adjust_centers(MathT* centers, MathT threshold, MappingOpT mapping_op, rmm::cuda_stream_view stream, - rmm::device_async_resource_ref device_memory) -> bool + rmm::device_async_resource_ref device_memory, + const MathT* dataset_norm) -> bool { raft::common::nvtx::range fun_scope( "adjust_centers(%zu, %u)", static_cast(n_rows), n_clusters); @@ -601,7 +602,8 @@ auto adjust_centers(MathT* centers, average, ofst, update_count.data(), - mapping_op); + mapping_op, + dataset_norm); adjusted = update_count.value(stream) > 0; // NB: rmm scalar performs the sync return adjusted; @@ -681,7 +683,8 @@ void balancing_em_iters(const raft::resources& handle, balancing_threshold, mapping_op, stream, - device_memory)) { + device_memory, + dataset_norm)) { if (balancing_counter++ >= balancing_pullback) { balancing_counter -= balancing_pullback; n_iters++; @@ -726,7 +729,8 @@ void balancing_em_iters(const raft::resources& handle, cluster_labels, true, mapping_op, - device_memory); + device_memory, + dataset_norm); } } @@ -770,7 +774,8 @@ void build_clusters(const raft::resources& handle, cluster_labels, true, mapping_op, - device_memory); + device_memory, + dataset_norm); // run EM balancing_em_iters(handle, diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index f8e43af67..893d6eb27 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -136,7 +136,8 @@ void predict(const raft::resources& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view labels, - MappingOpT mapping_op = raft::identity_op()) + MappingOpT mapping_op = raft::identity_op(), + std::optional> X_norm = std::nullopt) { RAFT_EXPECTS(X.extent(0) == labels.extent(0), "Number of rows in dataset and labels are different"); @@ -157,7 +158,9 @@ void predict(const raft::resources& handle, X.data_handle(), X.extent(0), labels.data_handle(), - mapping_op); + mapping_op, + raft::resource::get_workspace_resource(handle), + X_norm.has_value() ? X_norm.value().data_handle() : nullptr); } /** @@ -343,7 +346,7 @@ void calc_centers_and_sizes( raft::device_vector_view cluster_sizes, bool reset_counters = true, MappingOpT mapping_op = raft::identity_op(), - std::optional> X_norm = std::nullopt) + std::optional> X_norm = std::nullopt) { RAFT_EXPECTS(X.extent(0) == labels.extent(0), "Number of rows in dataset and labels are different"); diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index 98f28ddb9..40fcfa0d3 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -127,7 +127,7 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, uint32_t dim, uint32_t veclen, IdxT batch_offset = 0, - const T* source_vecs_norms = nullptr) + const float* source_vecs_norms = nullptr) { const IdxT i = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x; if (i >= n_rows) { return; } @@ -159,11 +159,43 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, // NB: such `veclen` is selected, that `dim % veclen == 0` for (uint32_t l = 0; l < dim; l += veclen) { for (uint32_t j = 0; j < veclen; j++) { - list_data[l * kIndexGroupSize + ingroup_id + j] = source_vecs[l + j] * vec_norm; + list_data[l * kIndexGroupSize + ingroup_id + j] = source_vecs[l + j] / vec_norm; } } } +template +void myl2rownorm(raft::resources const& handle, + raft::device_matrix_view in, + raft::device_vector_view out, + raft::linalg::NormType type, + raft::linalg::Apply apply, + Lambda fin_op = raft::identity_op()) +{ + // Created to support two different data types for input and output + // To be replaced with raft::linalg::rowNorm once it supports different data types + auto constexpr row_major = std::is_same_v; + RAFT_EXPECTS(static_cast(out.size()) == in.extent(0), + "Output should be equal to number of rows in Input"); + + raft::linalg::reduce(out.data_handle(), + in.data_handle(), + in.extent(1), + in.extent(0), + (OutType)0, + row_major, + true, + raft::resource::get_cuda_stream(handle), + false, + raft::sq_op(), + raft::add_op(), + fin_op); +} + /** See raft::neighbors::ivf_flat::extend docs */ template void extend(raft::resources const& handle, @@ -187,6 +219,9 @@ void extend(raft::resources const& handle, "You must pass data indices when the index is non-empty."); auto new_labels = raft::make_device_vector(handle, n_rows); + rmm::device_uvector new_vectors_norms(0, stream); + if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) + new_vectors_norms.resize(n_rows, stream); cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.metric = index->metric(); auto orig_centroids_view = @@ -208,12 +243,28 @@ void extend(raft::resources const& handle, raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); auto batch_labels_view = raft::make_device_vector_view( new_labels.data_handle() + batch.offset(), batch.size()); + + // Compute norm if necessary (Cosine) + auto batch_vectors_norms = raft::make_device_vector_view(new_vectors_norms.data() + batch.offset(), batch.size()); + std::optional> batch_vectors_norms_opt = std::nullopt; + if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { + myl2rownorm(handle, + batch_data_view, + batch_vectors_norms, + raft::linalg::NormType::L2Norm, + raft::linalg::Apply::ALONG_ROWS, + raft::sqrt_op()); + batch_vectors_norms_opt = + std::make_optional(raft::make_const_mdspan(batch_vectors_norms)); + } + cuvs::cluster::kmeans_balanced::predict(handle, kmeans_params, batch_data_view, orig_centroids_view, batch_labels_view, - utils::mapping{}); + utils::mapping{}, + batch_vectors_norms_opt); } auto* list_sizes_ptr = index->list_sizes().data_handle(); @@ -232,18 +283,11 @@ void extend(raft::resources const& handle, raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); auto batch_labels_view = raft::make_device_vector_view( new_labels.data_handle() + batch.offset(), batch.size()); - auto batch_vectors_norms = raft::make_device_vector(handle, batch.size()); - std::optional> batch_vectors_norms_opt = std::nullopt; - // Normalize if necessary (Cosine) + auto batch_vectors_norms = raft::make_device_vector_view(new_vectors_norms.data() + batch.offset(), batch.size()); + std::optional> batch_vectors_norms_opt = std::nullopt; if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm(handle, - batch_data_view, - batch_vectors_norms.view(), - raft::linalg::NormType::L2Norm, - raft::linalg::Apply::ALONG_ROWS, - raft::sqrt_op()); batch_vectors_norms_opt = - std::make_optional(raft::make_const_mdspan(batch_vectors_norms.view())); + std::make_optional(raft::make_const_mdspan(batch_vectors_norms)); } cuvs::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, batch_data_view, @@ -296,18 +340,9 @@ void extend(raft::resources const& handle, for (const auto& batch : vec_batches) { auto batch_data_view = raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); - auto batch_vectors_norms = raft::make_device_vector(handle, batch.size()); - T* batch_vectors_norms_ptr = nullptr; - // Normalize if necessary (Cosine) - if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { - raft::linalg::norm(handle, - batch_data_view, - batch_vectors_norms.view(), - raft::linalg::NormType::L2Norm, - raft::linalg::Apply::ALONG_ROWS, - raft::sqrt_op()); - batch_vectors_norms_ptr = batch_vectors_norms.data_handle(); - } + float* batch_vectors_norms_ptr = nullptr; + if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) + batch_vectors_norms_ptr = new_vectors_norms.data() + batch.offset(); // Kernel to insert the new vectors const dim3 block_dim(256); const dim3 grid_dim(raft::ceildiv(batch.size(), block_dim.x)); @@ -417,10 +452,9 @@ inline auto build(raft::resources const& handle, cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.n_iters = params.kmeans_n_iters; kmeans_params.metric = index.metric(); - if (kmeans_params.metric == cuvs::distance::DistanceType::CosineExpanded) - kmeans_params.metric = cuvs::distance::DistanceType::InnerProduct; // Normalize if necessary (Cosine) if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) { + kmeans_params.metric = cuvs::distance::DistanceType::InnerProduct; raft::linalg::row_normalize( handle, trainset_const_view, diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index 01ee913b7..23e0c9bef 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -138,7 +138,7 @@ void search_impl(raft::resources const& handle, } case cuvs::distance::DistanceType::CosineExpanded: { compute_metric = cuvs::distance::DistanceType::InnerProduct; - select_min = false; + //select_min = false; alpha = 1.0f; beta = 0.0f; break; diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 9ce23cc7d..45e388edb 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -20,7 +20,9 @@ #include "naive_knn.cuh" #include +#include #include +#include #include #include @@ -252,6 +254,91 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { } } + void testIVFFlatCosine() + { + size_t queries_size = ps.num_queries * ps.k; + std::vector indices_ivfflat(queries_size); + std::vector indices_naive(queries_size); + std::vector distances_ivfflat(queries_size); + std::vector distances_naive(queries_size); + + { + rmm::device_uvector distances_naive_dev(queries_size, stream_); + rmm::device_uvector indices_naive_dev(queries_size, stream_); + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + + auto indices_out_view = raft::make_device_matrix_view( + indices_naive_dev.data(), ps.num_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_naive_dev.data(), ps.num_queries, ps.k); + auto bfi = cuvs::neighbors::brute_force::build(handle_, database_view, ps.metric); + cuvs::neighbors::brute_force::search(handle_, + bfi, + search_queries_view, + indices_out_view, + dists_out_view); + + raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); + raft::update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + { + // unless something is really wrong with clustering, this could serve as a lower bound on + // recall + double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); + + rmm::device_uvector distances_ivfflat_dev(queries_size, stream_); + rmm::device_uvector indices_ivfflat_dev(queries_size, stream_); + + { + ivf_flat::index_params index_params; + ivf_flat::search_params search_params; + index_params.n_lists = ps.nlist; + index_params.metric = ps.metric; + index_params.adaptive_centers = ps.adaptive_centers; + search_params.n_probes = ps.nprobe; + + index_params.add_data_on_build = true; + index_params.kmeans_trainset_fraction = 0.5; + index_params.metric_arg = 0; + + ivf_flat::index idx(handle_, index_params, ps.dim); + + auto database_view = raft::make_device_matrix_view( + (const DataT*)database.data(), ps.num_db_vecs, ps.dim); + idx = ivf_flat::build(handle_, index_params, database_view); + + auto search_queries_view = raft::make_device_matrix_view( + search_queries.data(), ps.num_queries, ps.dim); + auto indices_out_view = raft::make_device_matrix_view( + indices_ivfflat_dev.data(), ps.num_queries, ps.k); + auto dists_out_view = raft::make_device_matrix_view( + distances_ivfflat_dev.data(), ps.num_queries, ps.k); + ivf_flat::search(handle_, + search_params, + idx, + search_queries_view, + indices_out_view, + dists_out_view); + + raft::update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); + raft::update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); + raft::resource::sync_stream(handle_); + } + ASSERT_TRUE(eval_neighbours(indices_naive, + indices_ivfflat, + distances_naive, + distances_ivfflat, + ps.num_queries, + ps.k, + 0.001, + min_recall)); + } + } + void testPacker() { ivf_flat::index_params index_params; @@ -518,6 +605,13 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { rmm::device_uvector search_queries; }; +const std::vector> inputs_cosine = { + // test various dims (aligned and not aligned to vector sizes) + {2, 20, 1024, 3, 2, 4, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false} +}; + const std::vector> inputs = { // test various dims (aligned and not aligned to vector sizes) {1000, 10000, 1, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, diff --git a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu index 0ce168f5e..331e6bd35 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu @@ -21,7 +21,11 @@ namespace cuvs::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF_float; -TEST_P(AnnIVFFlatTestF_float, AnnIVFFlat) { this->testIVFFlat(); } +TEST_P(AnnIVFFlatTestF_float, AnnIVFFlat) { + this->testIVFFlatCosine(); + this->testIVFFlat(); + this->testPacker(); +} INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_float, ::testing::ValuesIn(inputs)); From 5183b86c8c61652406c66ea4d99d089ecd678aa2 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Thu, 6 Jun 2024 08:49:00 -0700 Subject: [PATCH 03/15] fix kmeans --- cpp/src/cluster/detail/kmeans_balanced.cuh | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 5703cbe28..0a5e2beee 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -148,7 +148,6 @@ inline std::enable_if_t> predict_core( rmm::device_uvector distances(n_rows * n_clusters, stream, mr); MathT alpha = -1.0; - if (params.metric == cuvs::distance::DistanceType::CosineExpanded) alpha = 1.0; MathT beta = 0.0; raft::linalg::gemm(handle, @@ -481,7 +480,8 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL IdxT average, IdxT seed, IdxT* count, - MappingOpT mapping_op) + MappingOpT mapping_op, + const MathT* dataset_norm = nullptr) { IdxT l = threadIdx.y + BlockDimY * static_cast(blockIdx.y); if (l >= n_clusters) return; @@ -508,10 +508,11 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL const MathT wc = min(static_cast(csize), static_cast(kAdjustCentersWeight)); // Weight for the datapoint used to shift the center. const MathT wd = 1.0; + const MathT data_norm = dataset_norm == nullptr ? 1.0 : dataset_norm[i]; for (; j < dim; j += raft::WarpSize) { MathT val = 0; val += wc * centers[j + dim * li]; - val += wd * mapping_op(dataset[j + dim * i]); + val += wd * mapping_op(dataset[j + dim * i]) / data_norm; val /= wc + wd; centers[j + dim * l] = val; } @@ -940,7 +941,8 @@ auto build_fine_clusters(const raft::resources& handle, cub::TransformInputIterator mapping_itr(dataset_mptr, mapping_op); raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream); if (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || + params.metric == cuvs::distance::DistanceType::CosineExpanded) { thrust::gather(raft::resource::get_thrust_policy(handle), mc_trainset_ids, mc_trainset_ids + k, From eebedff91c778e1455b63ad7c795ca49f0400a0a Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 24 Jun 2024 09:16:15 -0700 Subject: [PATCH 04/15] Add cosine for kmeans build --- cpp/src/cluster/detail/kmeans_balanced.cuh | 49 ++++++++----- cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh | 2 +- .../ivf_flat/ivf_flat_build_float_int64_t.cu | 69 ------------------ .../ivf_flat/ivf_flat_build_int8_t_int64_t.cu | 69 ------------------ .../ivf_flat_build_uint8_t_int64_t.cu | 69 ------------------ .../ivf_flat/ivf_flat_extend_float_int64_t.cu | 71 ------------------- .../ivf_flat_extend_int8_t_int64_t.cu | 71 ------------------- .../ivf_flat_extend_uint8_t_int64_t.cu | 71 ------------------- cpp/test/neighbors/ann_ivf_flat.cuh | 9 ++- .../ann_ivf_flat/test_float_int64_t.cu | 7 +- 10 files changed, 46 insertions(+), 441 deletions(-) delete mode 100644 cpp/src/neighbors/ivf_flat/ivf_flat_build_float_int64_t.cu delete mode 100644 cpp/src/neighbors/ivf_flat/ivf_flat_build_int8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/ivf_flat/ivf_flat_build_uint8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/ivf_flat/ivf_flat_extend_float_int64_t.cu delete mode 100644 cpp/src/neighbors/ivf_flat/ivf_flat_extend_int8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/ivf_flat/ivf_flat_extend_uint8_t_int64_t.cu diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 0a5e2beee..c520c8d51 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -349,13 +349,14 @@ void calc_centers_and_sizes(const raft::resources& handle, } /** Computes the L2 norm of the dataset, converting to MathT if necessary */ -template +template void compute_norm(const raft::resources& handle, MathT* dataset_norm, const T* dataset, IdxT dim, IdxT n_rows, MappingOpT mapping_op, + FinOpT norm_fin_op, std::optional mr = std::nullopt) { raft::common::nvtx::range fun_scope("compute_norm"); @@ -376,7 +377,7 @@ void compute_norm(const raft::resources& handle, } raft::linalg::rowNorm( - dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream); + dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream, norm_fin_op); } /** @@ -441,8 +442,12 @@ void predict(const raft::resources& handle, // Compute the norm now if it hasn't been pre-computed. if (need_compute_norm) { - compute_norm( - handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mem_res); + if (params.metric == cuvs::distance::DistanceType::CosineExpanded) + compute_norm( + handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, raft::sqrt_op{}, mr); + else + compute_norm( + handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, raft::identity_op{}, mr); dataset_norm_ptr = cur_dataset_norm.data(); } else if (dataset_norm != nullptr) { dataset_norm_ptr = dataset_norm + offset; @@ -1001,7 +1006,8 @@ void build_hierarchical(const raft::resources& handle, IdxT n_rows, MathT* cluster_centers, IdxT n_clusters, - MappingOpT mapping_op) + MappingOpT mapping_op, + const MathT* dataset_norm = nullptr) { auto stream = raft::resource::get_cuda_stream(handle); using LabelT = uint32_t; @@ -1018,21 +1024,32 @@ void build_hierarchical(const raft::resources& handle, auto [max_minibatch_size, mem_per_row] = calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); - // Precompute the L2 norm of the dataset if relevant. - const MathT* dataset_norm = nullptr; + // Precompute the L2 norm of the dataset if relevant and not yet computed. rmm::device_uvector dataset_norm_buf(0, stream, device_memory); - if (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + if (dataset_norm == nullptr && (params.metric == cuvs::distance::DistanceType::L2Expanded || + params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || + params.metric == cuvs::distance::DistanceType::CosineExpanded)) { dataset_norm_buf.resize(n_rows, stream); for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); - compute_norm(handle, - dataset_norm_buf.data() + offset, - dataset + dim * offset, - dim, - minibatch_size, - mapping_op, - device_memory); + if (params.metric == cuvs::distance::DistanceType::CosineExpanded) + compute_norm(handle, + dataset_norm_buf.data() + offset, + dataset + dim * offset, + dim, + minibatch_size, + mapping_op, + raft::sqrt_op{}, + device_memory); + else + compute_norm(handle, + dataset_norm_buf.data() + offset, + dataset + dim * offset, + dim, + minibatch_size, + mapping_op, + raft::identity_op{}, + device_memory); } dataset_norm = (const MathT*)dataset_norm_buf.data(); } diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index 40fcfa0d3..73bb07f45 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -188,7 +188,7 @@ void myl2rownorm(raft::resources const& handle, in.extent(0), (OutType)0, row_major, - true, + apply == raft::linalg::Apply::ALONG_ROWS, raft::resource::get_cuda_stream(handle), false, raft::sq_op(), diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build_float_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_build_float_int64_t.cu deleted file mode 100644 index 56bb71094..000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build_float_int64_t.cu +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_flat.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_flat.py - * - */ - -#include - -#include "ivf_flat_build.cuh" - -namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_BUILD(T, IdxT) \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index( \ - std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset, \ - cuvs::neighbors::ivf_flat::index& idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \ - } \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index( \ - std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset, \ - cuvs::neighbors::ivf_flat::index& idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \ - } -CUVS_INST_IVF_FLAT_BUILD(float, int64_t); - -#undef CUVS_INST_IVF_FLAT_BUILD - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_build_int8_t_int64_t.cu deleted file mode 100644 index 4803868c0..000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build_int8_t_int64_t.cu +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_flat.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_flat.py - * - */ - -#include - -#include "ivf_flat_build.cuh" - -namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_BUILD(T, IdxT) \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index( \ - std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset, \ - cuvs::neighbors::ivf_flat::index& idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \ - } \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index( \ - std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset, \ - cuvs::neighbors::ivf_flat::index& idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \ - } -CUVS_INST_IVF_FLAT_BUILD(int8_t, int64_t); - -#undef CUVS_INST_IVF_FLAT_BUILD - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_build_uint8_t_int64_t.cu deleted file mode 100644 index e087f94c4..000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build_uint8_t_int64_t.cu +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_flat.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_flat.py - * - */ - -#include - -#include "ivf_flat_build.cuh" - -namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_BUILD(T, IdxT) \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index( \ - std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset, \ - cuvs::neighbors::ivf_flat::index& idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \ - } \ - auto build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index( \ - std::move(cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset))); \ - } \ - \ - void build(raft::resources const& handle, \ - const cuvs::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset, \ - cuvs::neighbors::ivf_flat::index& idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::build(handle, params, dataset, idx); \ - } -CUVS_INST_IVF_FLAT_BUILD(uint8_t, int64_t); - -#undef CUVS_INST_IVF_FLAT_BUILD - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_float_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_extend_float_int64_t.cu deleted file mode 100644 index 2636067bf..000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_float_int64_t.cu +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_flat.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_flat.py - * - */ - -#include - -#include "ivf_flat_build.cuh" - -namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_EXTEND(T, IdxT) \ - auto extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_flat::index& orig_index) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index(std::move( \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_flat::index* idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \ - } \ - auto extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_flat::index& orig_index) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index(std::move( \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_flat::index* idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \ - } -CUVS_INST_IVF_FLAT_EXTEND(float, int64_t); - -#undef CUVS_INST_IVF_FLAT_EXTEND - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_extend_int8_t_int64_t.cu deleted file mode 100644 index 191cb9f39..000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_int8_t_int64_t.cu +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_flat.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_flat.py - * - */ - -#include - -#include "ivf_flat_build.cuh" - -namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_EXTEND(T, IdxT) \ - auto extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_flat::index& orig_index) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index(std::move( \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_flat::index* idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \ - } \ - auto extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_flat::index& orig_index) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index(std::move( \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_flat::index* idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \ - } -CUVS_INST_IVF_FLAT_EXTEND(int8_t, int64_t); - -#undef CUVS_INST_IVF_FLAT_EXTEND - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_extend_uint8_t_int64_t.cu deleted file mode 100644 index 29b7e7b69..000000000 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_extend_uint8_t_int64_t.cu +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is generated by generate_ivf_flat.py - * - * Make changes there and run in this directory: - * - * > python generate_ivf_flat.py - * - */ - -#include - -#include "ivf_flat_build.cuh" - -namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_EXTEND(T, IdxT) \ - auto extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_flat::index& orig_index) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index(std::move( \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_flat::index* idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \ - } \ - auto extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - const cuvs::neighbors::ivf_flat::index& orig_index) \ - ->cuvs::neighbors::ivf_flat::index \ - { \ - return cuvs::neighbors::ivf_flat::index(std::move( \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, orig_index))); \ - } \ - \ - void extend(raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - cuvs::neighbors::ivf_flat::index* idx) \ - { \ - cuvs::neighbors::ivf_flat::detail::extend(handle, new_vectors, new_indices, idx); \ - } -CUVS_INST_IVF_FLAT_EXTEND(uint8_t, int64_t); - -#undef CUVS_INST_IVF_FLAT_EXTEND - -} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 45e388edb..f1bd6b908 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -607,9 +607,12 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { const std::vector> inputs_cosine = { // test various dims (aligned and not aligned to vector sizes) - {2, 20, 1024, 3, 2, 4, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false} + {3, 20, 1024, 3, 3, 4, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 5, 100, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {100, 1000, 5, 32, 40, 124, cuvs::distance::DistanceType::CosineExpanded, true}, + {100, 1000, 8, 64, 40, 124, cuvs::distance::DistanceType::CosineExpanded, true}, + {100, 1000, 500, 16, 10, 50, cuvs::distance::DistanceType::CosineExpanded, false} }; const std::vector> inputs = { diff --git a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu index 331e6bd35..78acd0ee4 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu @@ -22,11 +22,16 @@ namespace cuvs::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF_float; TEST_P(AnnIVFFlatTestF_float, AnnIVFFlat) { - this->testIVFFlatCosine(); this->testIVFFlat(); this->testPacker(); } INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_float, ::testing::ValuesIn(inputs)); +typedef AnnIVFFlatTest AnnIVFFlatTestF_cosine_float; +TEST_P(AnnIVFFlatTestF_cosine_float, AnnIVFFlat) { + this->testIVFFlatCosine(); +} +INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_cosine_float, ::testing::ValuesIn(inputs_cosine)); + } // namespace cuvs::neighbors::ivf_flat From 0edef0906b395cae9db35b3bfdcc7521776de548 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 26 Jun 2024 10:11:43 -0700 Subject: [PATCH 05/15] Fix select_min in Cosine --- cpp/src/cluster/kmeans_balanced.cuh | 9 +++++++-- cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh | 9 --------- cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh | 2 +- 3 files changed, 8 insertions(+), 12 deletions(-) diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index 893d6eb27..07a234d99 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -71,13 +71,15 @@ namespace cuvs::cluster::kmeans_balanced { * @param[out] centroids The generated centroids [dim = n_clusters x n_features] * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic * datatype. If DataT == MathT, this must be the identity. + * @param[in] X_norm (optional) Dataset's row norms [dim = n_samples] */ template void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - MappingOpT mapping_op = raft::identity_op()) + MappingOpT mapping_op = raft::identity_op(), + std::optional> X_norm = std::nullopt) { RAFT_EXPECTS(X.extent(1) == centroids.extent(1), "Number of features in dataset and centroids are different"); @@ -95,7 +97,8 @@ void fit(const raft::resources& handle, X.extent(0), centroids.data_handle(), centroids.extent(0), - mapping_op); + mapping_op, + X_norm.has_value() ? X_norm.value().data_handle() : nullptr); } /** @@ -125,6 +128,7 @@ void fit(const raft::resources& handle, * @param[out] labels The output labels [dim = n_samples] * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic * datatype. If DataT == MathT, this must be the identity. + * @param[in] X_norm (optional) Dataset's row norms [dim = n_samples] */ template (trainset.data(), n_rows_train, index.dim()), - raft::linalg::NormType::L2Norm); - } cuvs::cluster::kmeans_balanced::fit( handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); } diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index 23e0c9bef..38b3f14ab 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -138,7 +138,7 @@ void search_impl(raft::resources const& handle, } case cuvs::distance::DistanceType::CosineExpanded: { compute_metric = cuvs::distance::DistanceType::InnerProduct; - //select_min = false; + select_min = cuvs::distance::is_min_close(compute_metric); alpha = 1.0f; beta = 0.0f; break; From 10d576401cfbcb513993c5aa8ca2f77fbce4494f Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 28 Jun 2024 09:07:59 -0700 Subject: [PATCH 06/15] Update Math type --- cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index 810a9e27d..7cda72a5c 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -117,7 +117,7 @@ auto clone(const raft::resources& res, const index& source) -> index +template RAFT_KERNEL build_index_kernel(const LabelT* labels, const T* source_vecs, const IdxT* source_ixs, @@ -128,6 +128,7 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, uint32_t dim, uint32_t veclen, IdxT batch_offset = 0, + MappingOpT mapping_op = raft::identity_op{}, const float* source_vecs_norms = nullptr) { const IdxT i = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x; @@ -140,7 +141,7 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, // Record the source vector id in the index list_index[inlist_id] = source_ixs == nullptr ? i + batch_offset : source_ixs[i]; - const auto vec_norm = source_vecs_norms == nullptr ? T(1) : source_vecs_norms[i]; + const auto vec_norm = source_vecs_norms == nullptr ? float{1} : source_vecs_norms[i]; // The data is written in interleaved groups of `index::kGroupSize` vectors using interleaved_group = raft::Pow2; @@ -160,7 +161,10 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, // NB: such `veclen` is selected, that `dim % veclen == 0` for (uint32_t l = 0; l < dim; l += veclen) { for (uint32_t j = 0; j < veclen; j++) { - list_data[l * kIndexGroupSize + ingroup_id + j] = source_vecs[l + j] / vec_norm; + if (source_vecs_norms != nullptr) + list_data[l * kIndexGroupSize + ingroup_id + j] = mapping_op(source_vecs[l + j]) / vec_norm; + else + list_data[l * kIndexGroupSize + ingroup_id + j] = source_vecs[l + j]; } } } @@ -360,6 +364,7 @@ void extend(raft::resources const& handle, dim, index->veclen(), batch.offset(), + utils::mapping{}, batch_vectors_norms_ptr); RAFT_CUDA_TRY(cudaPeekAtLastError()); @@ -533,7 +538,9 @@ inline void fill_refinement_index(raft::resources const& handle, list_sizes_ptr, n_queries * n_candidates, refinement_index->dim(), - refinement_index->veclen()); + refinement_index->veclen(), + IdxT(0), + utils::mapping{}); RAFT_CUDA_TRY(cudaPeekAtLastError()); } From 5651e726bdec34baedb0fbb02f80d237dd70b9ad Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Fri, 28 Jun 2024 10:10:57 -0700 Subject: [PATCH 07/15] Add test for int8 --- cpp/test/neighbors/ann_ivf_flat.cuh | 19 ++++++++++++++++--- .../ann_ivf_flat/test_int8_t_int64_t.cu | 5 +++++ 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 03dc8e557..76eb74b92 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -267,17 +267,30 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { rmm::device_uvector indices_naive_dev(queries_size, stream_); auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.num_db_vecs, ps.dim); +auto database_float = raft::make_device_matrix( + handle_, ps.num_db_vecs, ps.dim); auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.num_queries, ps.dim); - +auto search_queries_float = raft::make_device_matrix( + handle_, ps.num_queries, ps.dim); + raft::linalg::map(handle_, + database_float.view(), + [] __device__(DataT val) { return static_cast(val); }, + database_view); + + raft::linalg::map(handle_, + search_queries_float.view(), + [] __device__(DataT val) { return static_cast(val); }, + search_queries_view); auto indices_out_view = raft::make_device_matrix_view( indices_naive_dev.data(), ps.num_queries, ps.k); auto dists_out_view = raft::make_device_matrix_view( distances_naive_dev.data(), ps.num_queries, ps.k); - auto bfi = cuvs::neighbors::brute_force::build(handle_, database_view, ps.metric); + auto bfi = cuvs::neighbors::brute_force::build(handle_, + raft::make_const_mdspan(database_float.view()), ps.metric); cuvs::neighbors::brute_force::search(handle_, bfi, - search_queries_view, + raft::make_const_mdspan(search_queries_float.view()), indices_out_view, dists_out_view, std::nullopt); diff --git a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu index 15935fd88..b08e306df 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu @@ -25,4 +25,9 @@ TEST_P(AnnIVFFlatTestF_int8, AnnIVFFlat) { this->testIVFFlat(); } INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_int8, ::testing::ValuesIn(inputs)); +typedef AnnIVFFlatTest AnnIVFFlatTestF_cosine_int8; +TEST_P(AnnIVFFlatTestF_cosine_int8, AnnIVFFlat) { + this->testIVFFlatCosine(); +} +INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_cosine_int8, ::testing::ValuesIn(inputs_cosine)); } // namespace cuvs::neighbors::ivf_flat From 1da4619d6b4082149e085369fcba3263d2c9d13a Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Sat, 29 Jun 2024 06:47:42 -0700 Subject: [PATCH 08/15] Fix style --- cpp/src/cluster/detail/kmeans_balanced.cuh | 56 +++++++++------ cpp/src/cluster/kmeans_balanced.cuh | 44 ++++++------ cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh | 65 ++++++++++-------- .../neighbors/ivf_flat/ivf_flat_search.cuh | 10 ++- cpp/test/neighbors/ann_ivf_flat.cuh | 68 +++++++++---------- .../ann_ivf_flat/test_float_int64_t.cu | 12 ++-- .../ann_ivf_flat/test_int8_t_int64_t.cu | 8 +-- 7 files changed, 142 insertions(+), 121 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index c520c8d51..3ed947813 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -148,7 +148,7 @@ inline std::enable_if_t> predict_core( rmm::device_uvector distances(n_rows * n_clusters, stream, mr); MathT alpha = -1.0; - MathT beta = 0.0; + MathT beta = 0.0; raft::linalg::gemm(handle, true, @@ -443,11 +443,23 @@ void predict(const raft::resources& handle, // Compute the norm now if it hasn't been pre-computed. if (need_compute_norm) { if (params.metric == cuvs::distance::DistanceType::CosineExpanded) - compute_norm( - handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, raft::sqrt_op{}, mr); + compute_norm(handle, + cur_dataset_norm.data(), + cur_dataset_ptr, + dim, + minibatch_size, + mapping_op, + raft::sqrt_op{}, + mr); else - compute_norm( - handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, raft::identity_op{}, mr); + compute_norm(handle, + cur_dataset_norm.data(), + cur_dataset_ptr, + dim, + minibatch_size, + mapping_op, + raft::identity_op{}, + mr); dataset_norm_ptr = cur_dataset_norm.data(); } else if (dataset_norm != nullptr) { dataset_norm_ptr = dataset_norm + offset; @@ -512,7 +524,7 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL // We dump it for anomalously small clusters, but keep constant otherwise. const MathT wc = min(static_cast(csize), static_cast(kAdjustCentersWeight)); // Weight for the datapoint used to shift the center. - const MathT wd = 1.0; + const MathT wd = 1.0; const MathT data_norm = dataset_norm == nullptr ? 1.0 : dataset_norm[i]; for (; j < dim; j += raft::WarpSize) { MathT val = 0; @@ -1027,29 +1039,29 @@ void build_hierarchical(const raft::resources& handle, // Precompute the L2 norm of the dataset if relevant and not yet computed. rmm::device_uvector dataset_norm_buf(0, stream, device_memory); if (dataset_norm == nullptr && (params.metric == cuvs::distance::DistanceType::L2Expanded || - params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || - params.metric == cuvs::distance::DistanceType::CosineExpanded)) { + params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || + params.metric == cuvs::distance::DistanceType::CosineExpanded)) { dataset_norm_buf.resize(n_rows, stream); for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); if (params.metric == cuvs::distance::DistanceType::CosineExpanded) compute_norm(handle, - dataset_norm_buf.data() + offset, - dataset + dim * offset, - dim, - minibatch_size, - mapping_op, - raft::sqrt_op{}, - device_memory); + dataset_norm_buf.data() + offset, + dataset + dim * offset, + dim, + minibatch_size, + mapping_op, + raft::sqrt_op{}, + device_memory); else compute_norm(handle, - dataset_norm_buf.data() + offset, - dataset + dim * offset, - dim, - minibatch_size, - mapping_op, - raft::identity_op{}, - device_memory); + dataset_norm_buf.data() + offset, + dataset + dim * offset, + dim, + minibatch_size, + mapping_op, + raft::identity_op{}, + device_memory); } dataset_norm = (const MathT*)dataset_norm_buf.data(); } diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index 07a234d99..705734ae3 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -78,7 +78,7 @@ void fit(const raft::resources& handle, cuvs::cluster::kmeans::balanced_params const& params, raft::device_matrix_view X, raft::device_matrix_view centroids, - MappingOpT mapping_op = raft::identity_op(), + MappingOpT mapping_op = raft::identity_op(), std::optional> X_norm = std::nullopt) { RAFT_EXPECTS(X.extent(1) == centroids.extent(1), @@ -90,15 +90,16 @@ void fit(const raft::resources& handle, "The number of centroids must be strictly positive and cannot exceed the number of " "points in the training dataset."); - cuvs::cluster::kmeans::detail::build_hierarchical(handle, - params, - X.extent(1), - X.data_handle(), - X.extent(0), - centroids.data_handle(), - centroids.extent(0), - mapping_op, - X_norm.has_value() ? X_norm.value().data_handle() : nullptr); + cuvs::cluster::kmeans::detail::build_hierarchical( + handle, + params, + X.extent(1), + X.data_handle(), + X.extent(0), + centroids.data_handle(), + centroids.extent(0), + mapping_op, + X_norm.has_value() ? X_norm.value().data_handle() : nullptr); } /** @@ -154,17 +155,18 @@ void predict(const raft::resources& handle, static_cast(std::numeric_limits::max()), "The chosen label type cannot represent all cluster labels"); - cuvs::cluster::kmeans::detail::predict(handle, - params, - centroids.data_handle(), - centroids.extent(0), - X.extent(1), - X.data_handle(), - X.extent(0), - labels.data_handle(), - mapping_op, - raft::resource::get_workspace_resource(handle), - X_norm.has_value() ? X_norm.value().data_handle() : nullptr); + cuvs::cluster::kmeans::detail::predict( + handle, + params, + centroids.data_handle(), + centroids.extent(0), + X.extent(1), + X.data_handle(), + X.extent(0), + labels.data_handle(), + mapping_op, + raft::resource::get_workspace_resource(handle), + X_norm.has_value() ? X_norm.value().data_handle() : nullptr); } /** diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index 7cda72a5c..2e8f37ad2 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -117,7 +117,11 @@ auto clone(const raft::resources& res, const index& source) -> index +template RAFT_KERNEL build_index_kernel(const LabelT* labels, const T* source_vecs, const IdxT* source_ixs, @@ -127,8 +131,8 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, IdxT n_rows, uint32_t dim, uint32_t veclen, - IdxT batch_offset = 0, - MappingOpT mapping_op = raft::identity_op{}, + IdxT batch_offset = 0, + MappingOpT mapping_op = raft::identity_op{}, const float* source_vecs_norms = nullptr) { const IdxT i = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x; @@ -175,30 +179,30 @@ template void myl2rownorm(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_vector_view out, - raft::linalg::NormType type, - raft::linalg::Apply apply, - Lambda fin_op = raft::identity_op()) + raft::device_matrix_view in, + raft::device_vector_view out, + raft::linalg::NormType type, + raft::linalg::Apply apply, + Lambda fin_op = raft::identity_op()) { // Created to support two different data types for input and output // To be replaced with raft::linalg::rowNorm once it supports different data types auto constexpr row_major = std::is_same_v; RAFT_EXPECTS(static_cast(out.size()) == in.extent(0), - "Output should be equal to number of rows in Input"); + "Output should be equal to number of rows in Input"); raft::linalg::reduce(out.data_handle(), - in.data_handle(), - in.extent(1), - in.extent(0), - (OutType)0, - row_major, - apply == raft::linalg::Apply::ALONG_ROWS, - raft::resource::get_cuda_stream(handle), - false, - raft::sq_op(), - raft::add_op(), - fin_op); + in.data_handle(), + in.extent(1), + in.extent(0), + (OutType)0, + row_major, + apply == raft::linalg::Apply::ALONG_ROWS, + raft::resource::get_cuda_stream(handle), + false, + raft::sq_op(), + raft::add_op(), + fin_op); } /** See raft::neighbors::ivf_flat::extend docs */ @@ -225,7 +229,8 @@ void extend(raft::resources const& handle, auto new_labels = raft::make_device_mdarray( handle, raft::resource::get_large_workspace_resource(handle), raft::make_extents(n_rows)); - rmm::device_uvector new_vectors_norms(0, stream, raft::resource::get_large_workspace_resource(handle)); + rmm::device_uvector new_vectors_norms( + 0, stream, raft::resource::get_large_workspace_resource(handle)); if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) new_vectors_norms.resize(n_rows, stream); cuvs::cluster::kmeans::balanced_params kmeans_params; @@ -251,8 +256,10 @@ void extend(raft::resources const& handle, new_labels.data_handle() + batch.offset(), batch.size()); // Compute norm if necessary (Cosine) - auto batch_vectors_norms = raft::make_device_vector_view(new_vectors_norms.data() + batch.offset(), batch.size()); - std::optional> batch_vectors_norms_opt = std::nullopt; + auto batch_vectors_norms = raft::make_device_vector_view( + new_vectors_norms.data() + batch.offset(), batch.size()); + std::optional> batch_vectors_norms_opt = + std::nullopt; if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { myl2rownorm(handle, batch_data_view, @@ -260,8 +267,7 @@ void extend(raft::resources const& handle, raft::linalg::NormType::L2Norm, raft::linalg::Apply::ALONG_ROWS, raft::sqrt_op()); - batch_vectors_norms_opt = - std::make_optional(raft::make_const_mdspan(batch_vectors_norms)); + batch_vectors_norms_opt = std::make_optional(raft::make_const_mdspan(batch_vectors_norms)); } cuvs::cluster::kmeans_balanced::predict(handle, @@ -290,11 +296,12 @@ void extend(raft::resources const& handle, raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); auto batch_labels_view = raft::make_device_vector_view( new_labels.data_handle() + batch.offset(), batch.size()); - auto batch_vectors_norms = raft::make_device_vector_view(new_vectors_norms.data() + batch.offset(), batch.size()); - std::optional> batch_vectors_norms_opt = std::nullopt; + auto batch_vectors_norms = raft::make_device_vector_view( + new_vectors_norms.data() + batch.offset(), batch.size()); + std::optional> batch_vectors_norms_opt = + std::nullopt; if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { - batch_vectors_norms_opt = - std::make_optional(raft::make_const_mdspan(batch_vectors_norms)); + batch_vectors_norms_opt = std::make_optional(raft::make_const_mdspan(batch_vectors_norms)); } cuvs::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, batch_data_view, diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index 43c601221..f6f9c6fdb 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -89,11 +89,15 @@ void search_impl(raft::resources const& handle, if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) { normalized_queries_dev.resize(n_queries * index.dim(), stream); - raft::linalg::row_normalize( // L2Norm + raft::linalg::row_normalize( // L2Norm handle, raft::make_device_matrix_view(queries, n_queries, index.dim()), raft::make_device_matrix_view(normalized_queries_dev.data(), n_queries, index.dim()), - T(0), raft::sq_op(), raft::add_op(), raft::sqrt_op(), T(1e-8)); + T(0), + raft::sq_op(), + raft::add_op(), + raft::sqrt_op(), + T(1e-8)); queries_ptr = normalized_queries_dev.data(); } if constexpr (std::is_integral_v) { @@ -139,7 +143,7 @@ void search_impl(raft::resources const& handle, } case cuvs::distance::DistanceType::CosineExpanded: { compute_metric = cuvs::distance::DistanceType::InnerProduct; - select_min = cuvs::distance::is_min_close(compute_metric); + select_min = cuvs::distance::is_min_close(compute_metric); alpha = 1.0f; beta = 0.0f; break; diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 76eb74b92..0e1428399 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -267,33 +267,34 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { rmm::device_uvector indices_naive_dev(queries_size, stream_); auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.num_db_vecs, ps.dim); -auto database_float = raft::make_device_matrix( - handle_, ps.num_db_vecs, ps.dim); + auto database_float = raft::make_device_matrix(handle_, ps.num_db_vecs, ps.dim); auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.num_queries, ps.dim); -auto search_queries_float = raft::make_device_matrix( - handle_, ps.num_queries, ps.dim); - raft::linalg::map(handle_, - database_float.view(), - [] __device__(DataT val) { return static_cast(val); }, - database_view); - - raft::linalg::map(handle_, - search_queries_float.view(), - [] __device__(DataT val) { return static_cast(val); }, - search_queries_view); - auto indices_out_view = raft::make_device_matrix_view( - indices_naive_dev.data(), ps.num_queries, ps.k); - auto dists_out_view = raft::make_device_matrix_view( - distances_naive_dev.data(), ps.num_queries, ps.k); - auto bfi = cuvs::neighbors::brute_force::build(handle_, - raft::make_const_mdspan(database_float.view()), ps.metric); + auto search_queries_float = + raft::make_device_matrix(handle_, ps.num_queries, ps.dim); + raft::linalg::map( + handle_, + database_float.view(), + [] __device__(DataT val) { return static_cast(val); }, + database_view); + + raft::linalg::map( + handle_, + search_queries_float.view(), + [] __device__(DataT val) { return static_cast(val); }, + search_queries_view); + auto indices_out_view = + raft::make_device_matrix_view(indices_naive_dev.data(), ps.num_queries, ps.k); + auto dists_out_view = + raft::make_device_matrix_view(distances_naive_dev.data(), ps.num_queries, ps.k); + auto bfi = cuvs::neighbors::brute_force::build( + handle_, raft::make_const_mdspan(database_float.view()), ps.metric); cuvs::neighbors::brute_force::search(handle_, - bfi, - raft::make_const_mdspan(search_queries_float.view()), - indices_out_view, - dists_out_view, - std::nullopt); + bfi, + raft::make_const_mdspan(search_queries_float.view()), + indices_out_view, + dists_out_view, + std::nullopt); raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); raft::update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); @@ -331,15 +332,13 @@ auto search_queries_float = raft::make_device_matrix( indices_ivfflat_dev.data(), ps.num_queries, ps.k); auto dists_out_view = raft::make_device_matrix_view( distances_ivfflat_dev.data(), ps.num_queries, ps.k); - ivf_flat::search(handle_, - search_params, - idx, - search_queries_view, - indices_out_view, - dists_out_view); - - raft::update_host(distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); - raft::update_host(indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); + ivf_flat::search( + handle_, search_params, idx, search_queries_view, indices_out_view, dists_out_view); + + raft::update_host( + distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); + raft::update_host( + indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); raft::resource::sync_stream(handle_); } ASSERT_TRUE(eval_neighbours(indices_naive, @@ -626,8 +625,7 @@ const std::vector> inputs_cosine = { {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {100, 1000, 5, 32, 40, 124, cuvs::distance::DistanceType::CosineExpanded, true}, {100, 1000, 8, 64, 40, 124, cuvs::distance::DistanceType::CosineExpanded, true}, - {100, 1000, 500, 16, 10, 50, cuvs::distance::DistanceType::CosineExpanded, false} -}; + {100, 1000, 500, 16, 10, 50, cuvs::distance::DistanceType::CosineExpanded, false}}; const std::vector> inputs = { // test various dims (aligned and not aligned to vector sizes) diff --git a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu index f203d3446..fe5f56587 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu @@ -21,16 +21,14 @@ namespace cuvs::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF_float; -TEST_P(AnnIVFFlatTestF_float, AnnIVFFlat) { - this->testIVFFlat(); -} +TEST_P(AnnIVFFlatTestF_float, AnnIVFFlat) { this->testIVFFlat(); } INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_float, ::testing::ValuesIn(inputs)); typedef AnnIVFFlatTest AnnIVFFlatTestF_cosine_float; -TEST_P(AnnIVFFlatTestF_cosine_float, AnnIVFFlat) { - this->testIVFFlatCosine(); -} -INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_cosine_float, ::testing::ValuesIn(inputs_cosine)); +TEST_P(AnnIVFFlatTestF_cosine_float, AnnIVFFlat) { this->testIVFFlatCosine(); } +INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, + AnnIVFFlatTestF_cosine_float, + ::testing::ValuesIn(inputs_cosine)); } // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu index b08e306df..d95d64f3b 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu @@ -26,8 +26,8 @@ TEST_P(AnnIVFFlatTestF_int8, AnnIVFFlat) { this->testIVFFlat(); } INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_int8, ::testing::ValuesIn(inputs)); typedef AnnIVFFlatTest AnnIVFFlatTestF_cosine_int8; -TEST_P(AnnIVFFlatTestF_cosine_int8, AnnIVFFlat) { - this->testIVFFlatCosine(); -} -INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_cosine_int8, ::testing::ValuesIn(inputs_cosine)); +TEST_P(AnnIVFFlatTestF_cosine_int8, AnnIVFFlat) { this->testIVFFlatCosine(); } +INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, + AnnIVFFlatTestF_cosine_int8, + ::testing::ValuesIn(inputs_cosine)); } // namespace cuvs::neighbors::ivf_flat From 78e6a6f3f3be8c6db0bf4974cd0adac5c88954b9 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Sun, 7 Jul 2024 11:03:03 -0700 Subject: [PATCH 09/15] Fix kmeans for Cosine + Add norm computation for search --- cpp/src/cluster/detail/kmeans_balanced.cuh | 103 ++++--- cpp/src/cluster/kmeans_balanced.cuh | 19 +- cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh | 100 +++---- .../ivf_flat/ivf_flat_interleaved_scan.cuh | 260 ++++++++++++++---- .../neighbors/ivf_flat/ivf_flat_search.cuh | 73 ++--- cpp/src/neighbors/ivf_flat_index.cpp | 1 + cpp/test/neighbors/ann_ivf_flat.cuh | 9 +- 7 files changed, 358 insertions(+), 207 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index 3ed947813..e85fe32c9 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -142,7 +142,56 @@ inline std::enable_if_t> predict_core( raft::compose_op, raft::key_op>()); break; } - case cuvs::distance::DistanceType::CosineExpanded: + case cuvs::distance::DistanceType::CosineExpanded: { + rmm::device_uvector distances(n_rows * n_clusters, stream, mr); + + MathT alpha = -1.0; + MathT beta = 0.0; + + raft::linalg::gemm(handle, + true, + false, + n_clusters, + n_rows, + dim, + &alpha, + centers, + dim, + dataset, + dim, + &beta, + distances.data(), + n_clusters, + stream); + + auto centroidsNorm = + raft::make_device_mdarray(handle, mr, raft::make_extents(n_clusters)); + raft::linalg::rowNorm(centroidsNorm.data_handle(), + centers, + dim, + n_clusters, + raft::linalg::L2Norm, + true, + stream, + raft::sqrt_op{}); + + const auto* index_center_norm_ptr = centroidsNorm.data_handle(); + raft::linalg::map_offset( + handle, + raft::make_device_matrix_view( + distances.data(), n_rows, n_clusters), + [=] __device__(const uint32_t idx, const float dist) { + const auto query = idx / n_clusters; + const auto cluster = idx % n_clusters; + return dist / (dataset_norm[query] * index_center_norm_ptr[cluster]); + }, + raft::make_device_matrix_view(distances.data(), n_rows, n_clusters)); + auto distances_const_view = raft::make_device_matrix_view( + distances.data(), n_rows, n_clusters); + auto labels_view = raft::make_device_vector_view(labels, n_rows); + raft::matrix::argmin(handle, distances_const_view, labels_view); + break; + } case cuvs::distance::DistanceType::InnerProduct: { // TODO: pass buffer rmm::device_uvector distances(n_rows * n_clusters, stream, mr); @@ -278,8 +327,7 @@ void calc_centers_and_sizes(const raft::resources& handle, const LabelT* labels, bool reset_counters, MappingOpT mapping_op, - rmm::device_async_resource_ref mr, - const MathT* dataset_norm = nullptr) + rmm::device_async_resource_ref mr) { auto stream = raft::resource::get_cuda_stream(handle); @@ -301,31 +349,13 @@ void calc_centers_and_sizes(const raft::resources& handle, // Apply mapping only when the data and math types are different. if constexpr (std::is_same_v) { - raft::linalg::reduce_rows_by_key(dataset, - dim, - labels, - dataset_norm, - nullptr, - n_rows, - dim, - n_clusters, - centers, - stream, - reset_counters); + raft::linalg::reduce_rows_by_key( + dataset, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); } else { // todo(lsugy): use iterator from KV output of fusedL2NN cub::TransformInputIterator mapping_itr(dataset, mapping_op); - raft::linalg::reduce_rows_by_key(mapping_itr, - dim, - labels, - dataset_norm, - nullptr, - n_rows, - dim, - n_clusters, - centers, - stream, - reset_counters); + raft::linalg::reduce_rows_by_key( + mapping_itr, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); } // Compute weight of each cluster @@ -497,8 +527,7 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL IdxT average, IdxT seed, IdxT* count, - MappingOpT mapping_op, - const MathT* dataset_norm = nullptr) + MappingOpT mapping_op) { IdxT l = threadIdx.y + BlockDimY * static_cast(blockIdx.y); if (l >= n_clusters) return; @@ -524,12 +553,11 @@ __launch_bounds__((raft::WarpSize * BlockDimY)) RAFT_KERNEL // We dump it for anomalously small clusters, but keep constant otherwise. const MathT wc = min(static_cast(csize), static_cast(kAdjustCentersWeight)); // Weight for the datapoint used to shift the center. - const MathT wd = 1.0; - const MathT data_norm = dataset_norm == nullptr ? 1.0 : dataset_norm[i]; + const MathT wd = 1.0; for (; j < dim; j += raft::WarpSize) { MathT val = 0; val += wc * centers[j + dim * li]; - val += wd * mapping_op(dataset[j + dim * i]) / data_norm; + val += wd * mapping_op(dataset[j + dim * i]); val /= wc + wd; centers[j + dim * l] = val; } @@ -584,8 +612,7 @@ auto adjust_centers(MathT* centers, MathT threshold, MappingOpT mapping_op, rmm::cuda_stream_view stream, - rmm::device_async_resource_ref device_memory, - const MathT* dataset_norm) -> bool + rmm::device_async_resource_ref device_memory) -> bool { raft::common::nvtx::range fun_scope( "adjust_centers(%zu, %u)", static_cast(n_rows), n_clusters); @@ -620,8 +647,7 @@ auto adjust_centers(MathT* centers, average, ofst, update_count.data(), - mapping_op, - dataset_norm); + mapping_op); adjusted = update_count.value(stream) > 0; // NB: rmm scalar performs the sync return adjusted; @@ -701,8 +727,7 @@ void balancing_em_iters(const raft::resources& handle, balancing_threshold, mapping_op, stream, - device_memory, - dataset_norm)) { + device_memory)) { if (balancing_counter++ >= balancing_pullback) { balancing_counter -= balancing_pullback; n_iters++; @@ -747,8 +772,7 @@ void balancing_em_iters(const raft::resources& handle, cluster_labels, true, mapping_op, - device_memory, - dataset_norm); + device_memory); } } @@ -792,8 +816,7 @@ void build_clusters(const raft::resources& handle, cluster_labels, true, mapping_op, - device_memory, - dataset_norm); + device_memory); // run EM balancing_em_iters(handle, diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index 705734ae3..e978c335f 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -345,15 +345,13 @@ template -void calc_centers_and_sizes( - const raft::resources& handle, - raft::device_matrix_view X, - raft::device_vector_view labels, - raft::device_matrix_view centroids, - raft::device_vector_view cluster_sizes, - bool reset_counters = true, - MappingOpT mapping_op = raft::identity_op(), - std::optional> X_norm = std::nullopt) +void calc_centers_and_sizes(const raft::resources& handle, + raft::device_matrix_view X, + raft::device_vector_view labels, + raft::device_matrix_view centroids, + raft::device_vector_view cluster_sizes, + bool reset_counters = true, + MappingOpT mapping_op = raft::identity_op()) { RAFT_EXPECTS(X.extent(0) == labels.extent(0), "Number of rows in dataset and labels are different"); @@ -373,8 +371,7 @@ void calc_centers_and_sizes( labels.data_handle(), reset_counters, mapping_op, - raft::resource::get_workspace_resource(handle), - X_norm.has_value() ? X_norm.value().data_handle() : nullptr); + raft::resource::get_workspace_resource(handle)); } } // namespace helpers diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index 2e8f37ad2..4a946faaf 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -117,11 +117,7 @@ auto clone(const raft::resources& res, const index& source) -> index +template RAFT_KERNEL build_index_kernel(const LabelT* labels, const T* source_vecs, const IdxT* source_ixs, @@ -131,9 +127,7 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, IdxT n_rows, uint32_t dim, uint32_t veclen, - IdxT batch_offset = 0, - MappingOpT mapping_op = raft::identity_op{}, - const float* source_vecs_norms = nullptr) + IdxT batch_offset = 0) { const IdxT i = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x; if (i >= n_rows) { return; } @@ -145,7 +139,6 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, // Record the source vector id in the index list_index[inlist_id] = source_ixs == nullptr ? i + batch_offset : source_ixs[i]; - const auto vec_norm = source_vecs_norms == nullptr ? float{1} : source_vecs_norms[i]; // The data is written in interleaved groups of `index::kGroupSize` vectors using interleaved_group = raft::Pow2; @@ -165,10 +158,7 @@ RAFT_KERNEL build_index_kernel(const LabelT* labels, // NB: such `veclen` is selected, that `dim % veclen == 0` for (uint32_t l = 0; l < dim; l += veclen) { for (uint32_t j = 0; j < veclen; j++) { - if (source_vecs_norms != nullptr) - list_data[l * kIndexGroupSize + ingroup_id + j] = mapping_op(source_vecs[l + j]) / vec_norm; - else - list_data[l * kIndexGroupSize + ingroup_id + j] = source_vecs[l + j]; + list_data[l * kIndexGroupSize + ingroup_id + j] = source_vecs[l + j]; } } } @@ -204,7 +194,6 @@ void myl2rownorm(raft::resources const& handle, raft::add_op(), fin_op); } - /** See raft::neighbors::ivf_flat::extend docs */ template void extend(raft::resources const& handle, @@ -229,10 +218,6 @@ void extend(raft::resources const& handle, auto new_labels = raft::make_device_mdarray( handle, raft::resource::get_large_workspace_resource(handle), raft::make_extents(n_rows)); - rmm::device_uvector new_vectors_norms( - 0, stream, raft::resource::get_large_workspace_resource(handle)); - if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) - new_vectors_norms.resize(n_rows, stream); cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.metric = index->metric(); auto orig_centroids_view = @@ -255,28 +240,12 @@ void extend(raft::resources const& handle, auto batch_labels_view = raft::make_device_vector_view( new_labels.data_handle() + batch.offset(), batch.size()); - // Compute norm if necessary (Cosine) - auto batch_vectors_norms = raft::make_device_vector_view( - new_vectors_norms.data() + batch.offset(), batch.size()); - std::optional> batch_vectors_norms_opt = - std::nullopt; - if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { - myl2rownorm(handle, - batch_data_view, - batch_vectors_norms, - raft::linalg::NormType::L2Norm, - raft::linalg::Apply::ALONG_ROWS, - raft::sqrt_op()); - batch_vectors_norms_opt = std::make_optional(raft::make_const_mdspan(batch_vectors_norms)); - } - cuvs::cluster::kmeans_balanced::predict(handle, kmeans_params, batch_data_view, orig_centroids_view, batch_labels_view, - utils::mapping{}, - batch_vectors_norms_opt); + utils::mapping{}); } auto* list_sizes_ptr = index->list_sizes().data_handle(); @@ -296,21 +265,13 @@ void extend(raft::resources const& handle, raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); auto batch_labels_view = raft::make_device_vector_view( new_labels.data_handle() + batch.offset(), batch.size()); - auto batch_vectors_norms = raft::make_device_vector_view( - new_vectors_norms.data() + batch.offset(), batch.size()); - std::optional> batch_vectors_norms_opt = - std::nullopt; - if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { - batch_vectors_norms_opt = std::make_optional(raft::make_const_mdspan(batch_vectors_norms)); - } cuvs::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, batch_data_view, batch_labels_view, centroids_view, list_sizes_view, false, - utils::mapping{}, - batch_vectors_norms_opt); + utils::mapping{}); } } else { raft::stats::histogram(raft::stats::HistTypeAuto, @@ -354,9 +315,6 @@ void extend(raft::resources const& handle, for (const auto& batch : vec_batches) { auto batch_data_view = raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); - float* batch_vectors_norms_ptr = nullptr; - if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) - batch_vectors_norms_ptr = new_vectors_norms.data() + batch.offset(); // Kernel to insert the new vectors const dim3 block_dim(256); const dim3 grid_dim(raft::ceildiv(batch.size(), block_dim.x)); @@ -370,9 +328,7 @@ void extend(raft::resources const& handle, batch.size(), dim, index->veclen(), - batch.offset(), - utils::mapping{}, - batch_vectors_norms_ptr); + batch.offset()); RAFT_CUDA_TRY(cudaPeekAtLastError()); if (batch.offset() > next_report_offset) { @@ -388,6 +344,37 @@ void extend(raft::resources const& handle, if (!index->center_norms().has_value()) { index->allocate_center_norms(handle); if (index->center_norms().has_value()) { + if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::rowNorm(index->center_norms()->data_handle(), + index->centers().data_handle(), + dim, + n_lists, + raft::linalg::L2Norm, + true, + stream, + raft::sqrt_op{}); + } else { + raft::linalg::rowNorm(index->center_norms()->data_handle(), + index->centers().data_handle(), + dim, + n_lists, + raft::linalg::L2Norm, + true, + stream); + } + RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); + } + } else if (index->center_norms().has_value() && index->adaptive_centers()) { + if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) { + raft::linalg::rowNorm(index->center_norms()->data_handle(), + index->centers().data_handle(), + dim, + n_lists, + raft::linalg::L2Norm, + true, + stream, + raft::sqrt_op{}); + } else { raft::linalg::rowNorm(index->center_norms()->data_handle(), index->centers().data_handle(), dim, @@ -395,16 +382,7 @@ void extend(raft::resources const& handle, raft::linalg::L2Norm, true, stream); - RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); } - } else if (index->center_norms().has_value() && index->adaptive_centers()) { - raft::linalg::rowNorm(index->center_norms()->data_handle(), - index->centers().data_handle(), - dim, - n_lists, - raft::linalg::L2Norm, - true, - stream); RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); } } @@ -545,9 +523,7 @@ inline void fill_refinement_index(raft::resources const& handle, list_sizes_ptr, n_queries * n_candidates, refinement_index->dim(), - refinement_index->veclen(), - IdxT(0), - utils::mapping{}); + refinement_index->veclen()); RAFT_CUDA_TRY(cudaPeekAtLastError()); } diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh index ce29a7e7c..8af54413f 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh @@ -104,13 +104,16 @@ __device__ inline void copy_vectorized(T* out, const T* in, uint32_t n) * @tparam AccT type of the accumulated value (an optimization for 8bit values to be loaded as 32bit * values) */ -template +template struct loadAndComputeDist { Lambda compute_dist; AccT& dist; + AccT& norm_query; + AccT& norm_data; - __device__ __forceinline__ loadAndComputeDist(AccT& dist, Lambda op) - : dist(dist), compute_dist(op) + __device__ __forceinline__ + loadAndComputeDist(AccT& dist, Lambda op, AccT& norm_query, AccT& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -134,6 +137,10 @@ struct loadAndComputeDist { #pragma unroll for (int k = 0; k < Veclen; ++k) { compute_dist(dist, queryRegs[k], encV[k]); + if constexpr (ComputeNorm) { + norm_query += queryRegs[k] * queryRegs[k]; + norm_data += encV[k] * encV[k]; + } } } } @@ -163,7 +170,12 @@ struct loadAndComputeDist { const int d = (i * kUnroll + j) * Veclen; #pragma unroll for (int k = 0; k < Veclen; ++k) { - compute_dist(dist, raft::shfl(queryReg, d + k, raft::WarpSize), encV[k]); + T q = raft::shfl(queryReg, d + k, raft::WarpSize); + compute_dist(dist, q, encV[k]); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += encV[k] * encV[k]; + } } } } @@ -184,20 +196,28 @@ struct loadAndComputeDist { raft::ldg(enc, data + loadDataIdx); #pragma unroll for (int k = 0; k < Veclen; k++) { - compute_dist(dist, raft::shfl(queryReg, d + k, raft::WarpSize), enc[k]); + T q = raft::shfl(queryReg, d + k, raft::WarpSize); + compute_dist(dist, q, enc[k]); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += enc[k] * enc[k]; + } } } } }; // This handles uint8_t 8, 16 Veclens -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; uint32_t& dist; + uint32_t& norm_query; + uint32_t& norm_data; - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + __device__ __forceinline__ + loadAndComputeDist(uint32_t& dist, Lambda op, uint32_t& norm_query, uint32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -220,6 +240,10 @@ struct loadAndComputeDist { #pragma unroll for (int k = 0; k < veclen_int; k++) { compute_dist(dist, queryRegs[k], encV[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs[k], queryRegs[k], norm_query); + norm_data = raft::dp4a(encV[k], encV[k], norm_data); + } } } } @@ -244,7 +268,12 @@ struct loadAndComputeDist { const int d = (i * kUnroll + j) * veclen_int; #pragma unroll for (int k = 0; k < veclen_int; ++k) { - compute_dist(dist, raft::shfl(queryReg, d + k, raft::WarpSize), encV[k]); + uint32_t q = raft::shfl(queryReg, d + k, raft::WarpSize); + compute_dist(dist, q, encV[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(encV[k], encV[k], norm_data); + } } } } @@ -267,6 +296,10 @@ struct loadAndComputeDist { for (int k = 0; k < veclen_int; k++) { uint32_t q = raft::shfl(queryReg, (d / 4) + k, raft::WarpSize); compute_dist(dist, q, enc[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc[k], enc[k], norm_data); + } } } } @@ -274,13 +307,16 @@ struct loadAndComputeDist { // Keep this specialized uint8 Veclen = 4, because compiler is generating suboptimal code while // using above common template of int2/int4 -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; uint32_t& dist; + uint32_t& norm_query; + uint32_t& norm_data; - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + __device__ __forceinline__ + loadAndComputeDist(uint32_t& dist, Lambda op, uint32_t& norm_query, uint32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -294,6 +330,10 @@ struct loadAndComputeDist { uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; compute_dist(dist, queryRegs, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs, queryRegs, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } } } __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, @@ -313,6 +353,10 @@ struct loadAndComputeDist { uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); compute_dist(dist, q, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } } } } @@ -330,17 +374,24 @@ struct loadAndComputeDist { uint32_t enc = reinterpret_cast(data)[lane_id]; uint32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); compute_dist(dist, q, enc); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc, enc, norm_data); + } } } }; -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; uint32_t& dist; + uint32_t& norm_query; + uint32_t& norm_data; - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + __device__ __forceinline__ + loadAndComputeDist(uint32_t& dist, Lambda op, uint32_t& norm_query, uint32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -354,6 +405,10 @@ struct loadAndComputeDist { uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; compute_dist(dist, queryRegs, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs, queryRegs, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } } } @@ -374,6 +429,10 @@ struct loadAndComputeDist { uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); compute_dist(dist, q, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } } } } @@ -391,17 +450,24 @@ struct loadAndComputeDist { uint32_t enc = reinterpret_cast(data)[lane_id]; uint32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); compute_dist(dist, q, enc); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc, enc, norm_data); + } } } }; -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; uint32_t& dist; + uint32_t& norm_query; + uint32_t& norm_data; - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + __device__ __forceinline__ + loadAndComputeDist(uint32_t& dist, Lambda op, uint32_t& norm_query, uint32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -415,6 +481,10 @@ struct loadAndComputeDist { uint32_t encV = data[loadIndex + j * kIndexGroupSize]; uint32_t queryRegs = query_shared[shmemIndex + j]; compute_dist(dist, queryRegs, encV); + if constexpr (ComputeNorm) { + norm_query += queryRegs * queryRegs; + norm_data += encV * encV; + } } } @@ -434,6 +504,10 @@ struct loadAndComputeDist { uint32_t encV = data[lane_id + j * kIndexGroupSize]; uint32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); compute_dist(dist, q, encV); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += encV * encV; + } } } } @@ -451,18 +525,25 @@ struct loadAndComputeDist { uint32_t enc = data[lane_id]; uint32_t q = raft::shfl(queryReg, d, raft::WarpSize); compute_dist(dist, q, enc); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += enc * enc; + } } } }; // This device function is for int8 veclens 4, 8 and 16 -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; int32_t& dist; + int32_t& norm_query; + int32_t& norm_data; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + __device__ __forceinline__ + loadAndComputeDist(int32_t& dist, Lambda op, int32_t& norm_query, int32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -485,6 +566,10 @@ struct loadAndComputeDist { #pragma unroll for (int k = 0; k < veclen_int; k++) { compute_dist(dist, queryRegs[k], encV[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs[k], queryRegs[k], norm_query); + norm_data = raft::dp4a(encV[k], encV[k], norm_data); + } } } } @@ -513,6 +598,10 @@ struct loadAndComputeDist { for (int k = 0; k < veclen_int; ++k) { int32_t q = raft::shfl(queryReg, d + k, raft::WarpSize); compute_dist(dist, q, encV[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(encV[k], encV[k], norm_data); + } } } } @@ -531,17 +620,24 @@ struct loadAndComputeDist { for (int k = 0; k < veclen_int; k++) { int32_t q = raft::shfl(queryReg, (d / 4) + k, raft::WarpSize); // Here 4 is for 1 - int; compute_dist(dist, q, enc[k]); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc[k], enc[k], norm_data); + } } } } }; -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + int32_t& norm_query; + int32_t& norm_data; + __device__ __forceinline__ + loadAndComputeDist(int32_t& dist, Lambda op, int32_t& norm_query, int32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, @@ -554,6 +650,10 @@ struct loadAndComputeDist { int32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; int32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; compute_dist(dist, queryRegs, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryRegs, queryRegs, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } } } @@ -574,6 +674,10 @@ struct loadAndComputeDist { int32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; int32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); compute_dist(dist, q, encV); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(queryReg, queryReg, norm_query); + norm_data = raft::dp4a(encV, encV, norm_data); + } } } } @@ -588,16 +692,23 @@ struct loadAndComputeDist { int32_t enc = reinterpret_cast(data + lane_id * veclen)[0]; int32_t q = raft::shfl(queryReg, d / veclen, raft::WarpSize); compute_dist(dist, q, enc); + if constexpr (ComputeNorm) { + norm_query = raft::dp4a(q, q, norm_query); + norm_data = raft::dp4a(enc, enc, norm_data); + } } } }; -template -struct loadAndComputeDist { +template +struct loadAndComputeDist { Lambda compute_dist; int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) + int32_t& norm_query; + int32_t& norm_data; + __device__ __forceinline__ + loadAndComputeDist(int32_t& dist, Lambda op, int32_t& norm_query, int32_t& norm_data) + : dist(dist), compute_dist(op), norm_query(norm_query), norm_data(norm_data) { } @@ -609,6 +720,11 @@ struct loadAndComputeDist { #pragma unroll for (int j = 0; j < kUnroll; ++j) { compute_dist(dist, query_shared[shmemIndex + j], data[loadIndex + j * kIndexGroupSize]); + if constexpr (ComputeNorm) { + norm_query += int32_t{query_shared[shmemIndex + j]} * int32_t{query_shared[shmemIndex + j]}; + norm_data += int32_t{data[loadIndex + j * kIndexGroupSize]} * + int32_t{data[loadIndex + j * kIndexGroupSize]}; + } } } @@ -625,9 +741,12 @@ struct loadAndComputeDist { for (int i = 0; i < raft::WarpSize / stride; ++i, data += stride * kIndexGroupSize) { #pragma unroll for (int j = 0; j < kUnroll; ++j) { - compute_dist(dist, - raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize), - data[lane_id + j * kIndexGroupSize]); + int32_t q = raft::shfl(queryReg, i * kUnroll + j, raft::WarpSize); + compute_dist(dist, q, data[lane_id + j * kIndexGroupSize]); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += data[lane_id + j * kIndexGroupSize] * data[lane_id + j * kIndexGroupSize]; + } } } } @@ -638,7 +757,12 @@ struct loadAndComputeDist { const int loadDim = dimBlocks + lane_id; int32_t queryReg = loadDim < dim ? query[loadDim] : 0; for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - compute_dist(dist, raft::shfl(queryReg, d, raft::WarpSize), data[lane_id]); + int32_t q = raft::shfl(queryReg, d, raft::WarpSize); + compute_dist(dist, q, data[lane_id]); + if constexpr (ComputeNorm) { + norm_query += q * q; + norm_data += int32_t{data[lane_id]} * int32_t{data[lane_id]}; + } } } }; @@ -691,6 +815,7 @@ using block_sort_t = typename flat_block_sort::typ template lc(dist, - compute_dist); + // Process first shm_assisted_dim dimensions (always using shared memory) + loadAndComputeDist lc( + dist, compute_dist, norm_query, norm_dataset); for (int pos = 0; pos < shm_assisted_dim; pos += raft::WarpSize, data += kIndexGroupSize * raft::WarpSize) { lc.runLoadShmemCompute(data, query_shared, lane_id, pos); } - } - if (dim > query_smem_elems) { - // The default path - using shfl ops - for dimensions beyond query_smem_elems - loadAndComputeDist lc(dist, - compute_dist); - for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += raft::WarpSize) { - lc.runLoadShflAndCompute(data, query, pos, lane_id); - } - lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); - } else { - // when shm_assisted_dim == full_warps_along_dim < dim - if (valid) { - loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT> lc(dist, compute_dist); + if (dim > query_smem_elems) { + // The default path - using shfl ops - for dimensions beyond query_smem_elems + loadAndComputeDist lc( + dist, compute_dist, norm_query, norm_dataset); + for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += raft::WarpSize) { + lc.runLoadShflAndCompute(data, query, pos, lane_id); + } + lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); + } else { + // when shm_assisted_dim == full_warps_along_dim < dim + loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT, ComputeNorm> lc( + dist, compute_dist, norm_query, norm_dataset); for (int pos = full_warps_along_dim; pos < dim; pos += Veclen, data += kIndexGroupSize * Veclen) { lc.runLoadShmemCompute(data, query_shared, lane_id, pos); @@ -814,7 +940,13 @@ RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) } // Enqueue one element per thread - const float val = valid ? static_cast(dist) : local_topk_t::queue_t::kDummy; + float val = valid ? static_cast(dist) : local_topk_t::queue_t::kDummy; + + if constexpr (ComputeNorm) { + if (valid) + val = val / (raft::sqrt(static_cast(norm_query)) * + raft::sqrt(static_cast(norm_dataset))); + } if constexpr (kManageLocalTopK) { queue.add(val, sample_offset + vec_id); } else { @@ -864,6 +996,7 @@ uint32_t configure_launch_x(uint32_t numQueries, uint32_t n_probes, int32_t sMem template , + raft::identity_op>({}, {}, std::forward(args)...); + case cuvs::distance::DistanceType::CosineExpanded: + return launch_kernel // cuvs::selection::select_k #include // RAFT_LOG_TRACE #include -#include // raft::resources -#include // raft::linalg::gemm -#include // raft::linalg::norm -#include // raft::linalg::row_normalize -#include // raft::linalg::unary_op -#include // matrix::detail::select_k +#include // raft::resources +#include // raft::linalg::gemm +#include // raft::linalg::norm +#include // raft::linalg::unary_op #include @@ -81,25 +79,10 @@ void search_impl(raft::resources const& handle, // also we might need additional storage for select_k rmm::device_uvector indices_tmp_dev(0, stream, search_mr); rmm::device_uvector neighbors_uint32_buf(0, stream, search_mr); + auto distance_buffer_dev_view = raft::make_device_matrix_view( + distance_buffer_dev.data(), n_queries, index.n_lists()); size_t float_query_size; - auto compute_metric = index.metric(); - rmm::device_uvector normalized_queries_dev(0, stream, search_mr); - const T* queries_ptr = queries; - - if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) { - normalized_queries_dev.resize(n_queries * index.dim(), stream); - raft::linalg::row_normalize( // L2Norm - handle, - raft::make_device_matrix_view(queries, n_queries, index.dim()), - raft::make_device_matrix_view(normalized_queries_dev.data(), n_queries, index.dim()), - T(0), - raft::sq_op(), - raft::add_op(), - raft::sqrt_op(), - T(1e-8)); - queries_ptr = normalized_queries_dev.data(); - } if constexpr (std::is_integral_v) { float_query_size = n_queries * index.dim(); } else { @@ -109,10 +92,10 @@ void search_impl(raft::resources const& handle, float* converted_queries_ptr = converted_queries_dev.data(); if constexpr (std::is_same_v) { - converted_queries_ptr = const_cast(queries_ptr); + converted_queries_ptr = const_cast(queries); } else { raft::linalg::unaryOp( - converted_queries_ptr, queries_ptr, n_queries * index.dim(), utils::mapping{}, stream); + converted_queries_ptr, queries, n_queries * index.dim(), utils::mapping{}, stream); } float alpha = 1.0f; @@ -142,10 +125,17 @@ void search_impl(raft::resources const& handle, break; } case cuvs::distance::DistanceType::CosineExpanded: { - compute_metric = cuvs::distance::DistanceType::InnerProduct; - select_min = cuvs::distance::is_min_close(compute_metric); - alpha = 1.0f; - beta = 0.0f; + raft::linalg::rowNorm(query_norm_dev.data(), + converted_queries_ptr, + static_cast(index.dim()), + static_cast(n_queries), + raft::linalg::L2Norm, + true, + stream, + raft::sqrt_op{}); + select_min = !select_min; + alpha = 1.0f; + beta = 0.0f; break; } default: { @@ -170,12 +160,25 @@ void search_impl(raft::resources const& handle, index.n_lists(), stream); + if (index.metric() == cuvs::distance::DistanceType::CosineExpanded) { + auto n_lists = index.n_lists(); + const auto* q_norm_ptr = query_norm_dev.data(); + const auto* index_center_norm_ptr = index.center_norms()->data_handle(); + raft::linalg::map_offset( + handle, + distance_buffer_dev_view, + [=] __device__(const uint32_t idx, const float dist) { + const auto query = idx / n_lists; + const auto cluster = idx % n_lists; + return dist / (q_norm_ptr[query] * index_center_norm_ptr[cluster]); + }, + raft::make_const_mdspan(distance_buffer_dev_view)); + } RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); cuvs::selection::select_k( handle, - raft::make_device_matrix_view( - distance_buffer_dev.data(), n_queries, index.n_lists()), + raft::make_const_mdspan(distance_buffer_dev_view), std::nullopt, raft::make_device_matrix_view(coarse_distances_dev.data(), n_queries, n_probes), raft::make_device_matrix_view( @@ -194,7 +197,7 @@ void search_impl(raft::resources const& handle, nullptr, n_queries, queries_offset, - compute_metric, + index.metric(), n_probes, k, 0, @@ -245,11 +248,11 @@ void search_impl(raft::resources const& handle, ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( index, - queries_ptr, + queries, coarse_indices_dev.data(), n_queries, queries_offset, - compute_metric, + index.metric(), n_probes, k, max_samples, @@ -411,4 +414,4 @@ void search(raft::resources const& handle, cuvs::neighbors::filtering::none_ivf_sample_filter()); } -} // namespace cuvs::neighbors::ivf_flat::detail +} // namespace cuvs::neighbors::ivf_flat::detail \ No newline at end of file diff --git a/cpp/src/neighbors/ivf_flat_index.cpp b/cpp/src/neighbors/ivf_flat_index.cpp index b249a9c29..6f7d11e50 100644 --- a/cpp/src/neighbors/ivf_flat_index.cpp +++ b/cpp/src/neighbors/ivf_flat_index.cpp @@ -193,6 +193,7 @@ void index::allocate_center_norms(raft::resources const& res) case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2Unexpanded: case cuvs::distance::DistanceType::L2SqrtUnexpanded: + case cuvs::distance::DistanceType::CosineExpanded: center_norms_ = raft::make_device_vector(res, n_lists()); break; default: center_norms_ = std::nullopt; diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 0e1428399..51b068681 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -621,11 +621,14 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { const std::vector> inputs_cosine = { // test various dims (aligned and not aligned to vector sizes) {3, 20, 1024, 3, 3, 4, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 5, 100, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 5, 100, 60, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 8, 16, 60, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {100, 1000, 5, 32, 40, 124, cuvs::distance::DistanceType::CosineExpanded, true}, {100, 1000, 8, 64, 40, 124, cuvs::distance::DistanceType::CosineExpanded, true}, - {100, 1000, 500, 16, 10, 50, cuvs::distance::DistanceType::CosineExpanded, false}}; + {100, 1000, 500, 16, 10, 50, cuvs::distance::DistanceType::CosineExpanded, false}, + {100, 1000, 2056, 16, 10, 50, cuvs::distance::DistanceType::CosineExpanded, false}, + {10, 1000, 1, 16, 40, 124, cuvs::distance::DistanceType::CosineExpanded, false}, + {10, 1000, 2, 16, 40, 124, cuvs::distance::DistanceType::CosineExpanded, true}}; const std::vector> inputs = { // test various dims (aligned and not aligned to vector sizes) From 42f138cd2d5d9c8df0cbcb8ac08cc3d9a63972e2 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Mon, 8 Jul 2024 08:35:45 -0700 Subject: [PATCH 10/15] Change sign and min selection --- cpp/src/neighbors/ivf_common.cuh | 20 +------- .../ivf_flat/ivf_flat_interleaved_scan.cuh | 8 +-- .../neighbors/ivf_flat/ivf_flat_search.cuh | 5 +- cpp/test/neighbors/ann_ivf_flat.cuh | 50 +++++++++++++++---- .../ann_ivf_flat/test_uint8_t_int64_t.cu | 5 ++ 5 files changed, 52 insertions(+), 36 deletions(-) diff --git a/cpp/src/neighbors/ivf_common.cuh b/cpp/src/neighbors/ivf_common.cuh index c6e9be11f..fb73fb8a9 100644 --- a/cpp/src/neighbors/ivf_common.cuh +++ b/cpp/src/neighbors/ivf_common.cuh @@ -254,6 +254,7 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk] raft::linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream); } } break; + case distance::DistanceType::CosineExpanded: case distance::DistanceType::InnerProduct: { float factor = (account_for_max_close ? -1.0 : 1.0) * scaling_factor * scaling_factor; if (factor != 1.0) { @@ -267,25 +268,6 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk] raft::linalg::unaryOp(out, in, len, raft::cast_op{}, stream); } } break; - case distance::DistanceType::CosineExpanded: { - float factor = (account_for_max_close ? 1.0 : -1.0) * scaling_factor * scaling_factor; - if (factor != 1.0) { - raft::linalg::unaryOp(out, - in, - len, - raft::compose_op(raft::mul_const_op{factor}, - raft::add_const_op{1}, - raft::cast_op{}), - stream); - } else { - raft::linalg::unaryOp( - out, - in, - len, - raft::compose_op(raft::add_const_op{1}, raft::cast_op{}), - stream); - } - } break; default: RAFT_FAIL("Unexpected metric."); } } diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh index 8af54413f..b19de5c6d 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh @@ -1196,14 +1196,16 @@ void launch_with_fixed_consts(cuvs::distance::DistanceType metric, Args&&... arg case cuvs::distance::DistanceType::CosineExpanded: return launch_kernel, - raft::identity_op>({}, {}, std::forward(args)...); + inner_prod_dist>( + {}, + raft::compose_op(raft::mul_const_op{-1.0f}, raft::add_const_op{1.0f}), + std::forward(args)...); // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); } diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index 885139f64..b27b09057 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -133,9 +133,8 @@ void search_impl(raft::resources const& handle, true, stream, raft::sqrt_op{}); - select_min = !select_min; - alpha = 1.0f; - beta = 0.0f; + alpha = -1.0f; + beta = 0.0f; break; } default: { diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 51b068681..18e39b8b8 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -322,9 +322,29 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { ivf_flat::index idx(handle_, index_params, ps.dim); - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - idx = ivf_flat::build(handle_, index_params, database_view); + IdxT half_of_data = ps.num_db_vecs / 2; + auto vector_indices = raft::make_device_vector(handle_, ps.num_db_vecs); + thrust::sequence( + raft::resource::get_thrust_policy(handle_), + thrust::device_pointer_cast(vector_indices.data_handle()), + thrust::device_pointer_cast(vector_indices.data_handle() + ps.num_db_vecs)); + raft::resource::sync_stream(handle_); + + auto database_view_half = + raft::make_device_matrix_view(database.data(), half_of_data, ps.dim); + auto database_view_other_half = raft::make_device_matrix_view( + database.data() + half_of_data * ps.dim, IdxT(ps.num_db_vecs) - half_of_data, ps.dim); + + auto new_half_of_data_indices_view = raft::make_device_vector_view( + vector_indices.data_handle() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); + + idx = ivf_flat::build(handle_, index_params, database_view_half); + cuvs::neighbors::ivf_flat::extend( + handle_, + database_view_other_half, + std::make_optional>( + new_half_of_data_indices_view), + &idx); auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.num_queries, ps.dim); @@ -621,14 +641,22 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { const std::vector> inputs_cosine = { // test various dims (aligned and not aligned to vector sizes) {3, 20, 1024, 3, 3, 4, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 5, 100, 60, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 8, 16, 60, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {100, 1000, 5, 32, 40, 124, cuvs::distance::DistanceType::CosineExpanded, true}, - {100, 1000, 8, 64, 40, 124, cuvs::distance::DistanceType::CosineExpanded, true}, - {100, 1000, 500, 16, 10, 50, cuvs::distance::DistanceType::CosineExpanded, false}, - {100, 1000, 2056, 16, 10, 50, cuvs::distance::DistanceType::CosineExpanded, false}, - {10, 1000, 1, 16, 40, 124, cuvs::distance::DistanceType::CosineExpanded, false}, - {10, 1000, 2, 16, 40, 124, cuvs::distance::DistanceType::CosineExpanded, true}}; + {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + + // test dims that do not fit into kernel shared memory limits + {1000, 10000, 2048, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 2049, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 2050, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 2051, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 10000, 2052, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, + {1000, 10000, 2053, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, + {1000, 10000, 2056, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}}; const std::vector> inputs = { // test various dims (aligned and not aligned to vector sizes) diff --git a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu index 42a8dab2e..d5facdf80 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu @@ -25,4 +25,9 @@ TEST_P(AnnIVFFlatTestF_uint8, AnnIVFFlat) { this->testIVFFlat(); } INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_uint8, ::testing::ValuesIn(inputs)); +typedef AnnIVFFlatTest AnnIVFFlatTestF_cosine_uint8; +TEST_P(AnnIVFFlatTestF_cosine_uint8, AnnIVFFlat) { this->testIVFFlatCosine(); } +INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, + AnnIVFFlatTestF_cosine_uint8, + ::testing::ValuesIn(inputs_cosine)); } // namespace cuvs::neighbors::ivf_flat From 042d3a5e9a07e9be66a35a457cfad87df1332e6b Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 9 Jul 2024 01:35:16 -0700 Subject: [PATCH 11/15] Clean up unused code --- cpp/src/cluster/kmeans_balanced.cuh | 1 - cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh | 32 ------------------- .../ivf_flat/ivf_flat_interleaved_scan.cuh | 1 + .../neighbors/ivf_flat/ivf_flat_search.cuh | 2 +- cpp/test/neighbors/ann_ivf_flat.cuh | 4 +-- 5 files changed, 4 insertions(+), 36 deletions(-) diff --git a/cpp/src/cluster/kmeans_balanced.cuh b/cpp/src/cluster/kmeans_balanced.cuh index e978c335f..306989891 100644 --- a/cpp/src/cluster/kmeans_balanced.cuh +++ b/cpp/src/cluster/kmeans_balanced.cuh @@ -337,7 +337,6 @@ void build_clusters(const raft::resources& handle, * centers and sizes using the weighted average principle. * @param[in] mapping_op (optional) Functor to convert from the input datatype to the * arithmetic datatype. If DataT == MathT, this must be the identity. - * @param[in] X_norm (optional) Dataset's row norms [dim = n_samples] */ template -void myl2rownorm(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_vector_view out, - raft::linalg::NormType type, - raft::linalg::Apply apply, - Lambda fin_op = raft::identity_op()) -{ - // Created to support two different data types for input and output - // To be replaced with raft::linalg::rowNorm once it supports different data types - auto constexpr row_major = std::is_same_v; - RAFT_EXPECTS(static_cast(out.size()) == in.extent(0), - "Output should be equal to number of rows in Input"); - - raft::linalg::reduce(out.data_handle(), - in.data_handle(), - in.extent(1), - in.extent(0), - (OutType)0, - row_major, - apply == raft::linalg::Apply::ALONG_ROWS, - raft::resource::get_cuda_stream(handle), - false, - raft::sq_op(), - raft::add_op(), - fin_op); -} /** See raft::neighbors::ivf_flat::extend docs */ template void extend(raft::resources const& handle, @@ -239,7 +208,6 @@ void extend(raft::resources const& handle, raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); auto batch_labels_view = raft::make_device_vector_view( new_labels.data_handle() + batch.offset(), batch.size()); - cuvs::cluster::kmeans_balanced::predict(handle, kmeans_params, batch_data_view, diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh index b19de5c6d..86ef55928 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_interleaved_scan.cuh @@ -1194,6 +1194,7 @@ void launch_with_fixed_consts(cuvs::distance::DistanceType metric, Args&&... arg inner_prod_dist, raft::identity_op>({}, {}, std::forward(args)...); case cuvs::distance::DistanceType::CosineExpanded: + // NB: "Ascending" is reversed because the post-processing step is done after that sort return launch_kernel> { raft::linalg::map( handle_, database_float.view(), - [] __device__(DataT val) { return static_cast(val); }, + cuvs::spatial::knn::detail::utils::mapping{}, database_view); raft::linalg::map( handle_, search_queries_float.view(), - [] __device__(DataT val) { return static_cast(val); }, + cuvs::spatial::knn::detail::utils::mapping{}, search_queries_view); auto indices_out_view = raft::make_device_matrix_view(indices_naive_dev.data(), ps.num_queries, ps.k); From 208eac2202b0d8adf08987313c9fc968bac98eac Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Tue, 9 Jul 2024 06:50:48 -0700 Subject: [PATCH 12/15] Add documentation for cosine ivf flat --- python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx index 03d254995..342ef3cd6 100644 --- a/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx +++ b/python/cuvs/cuvs/neighbors/ivf_flat/ivf_flat.pyx @@ -58,12 +58,14 @@ cdef class IndexParams: metric : str, default = "sqeuclidean" String denoting the metric type. Valid values for metric: ["sqeuclidean", "inner_product", - "euclidean"], where + "euclidean", "cosine"], where - sqeuclidean is the euclidean distance without the square root operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2, - euclidean is the euclidean distance - inner product distance is defined as distance(a, b) = \\sum_i a_i * b_i. + - cosine distance is defined as + distance(a, b) = 1 - \\sum_i a_i * b_i / ( ||a||_2 * ||b||_2). kmeans_n_iters : int, default = 20 The number of iterations searching for kmeans centers during index building. From 4f3b664fa0bb3ba46021ae2f5fda2a9d9b66cddb Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 10 Jul 2024 09:03:17 -0700 Subject: [PATCH 13/15] Address reviews, use fused distance --- cpp/src/cluster/detail/kmeans_balanced.cuh | 73 ++++----- cpp/test/neighbors/ann_ivf_flat.cuh | 155 ++---------------- .../ann_ivf_flat/test_float_int64_t.cu | 6 - .../ann_ivf_flat/test_int8_t_int64_t.cu | 5 - .../ann_ivf_flat/test_uint8_t_int64_t.cu | 5 - cpp/test/neighbors/naive_knn.cuh | 14 +- 6 files changed, 60 insertions(+), 198 deletions(-) diff --git a/cpp/src/cluster/detail/kmeans_balanced.cuh b/cpp/src/cluster/detail/kmeans_balanced.cuh index e85fe32c9..34bb22e85 100644 --- a/cpp/src/cluster/detail/kmeans_balanced.cuh +++ b/cpp/src/cluster/detail/kmeans_balanced.cuh @@ -143,26 +143,16 @@ inline std::enable_if_t> predict_core( break; } case cuvs::distance::DistanceType::CosineExpanded: { - rmm::device_uvector distances(n_rows * n_clusters, stream, mr); - - MathT alpha = -1.0; - MathT beta = 0.0; + auto workspace = raft::make_device_mdarray( + handle, mr, raft::make_extents((sizeof(int)) * n_rows)); - raft::linalg::gemm(handle, - true, - false, - n_clusters, - n_rows, - dim, - &alpha, - centers, - dim, - dataset, - dim, - &beta, - distances.data(), - n_clusters, - stream); + auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( + handle, mr, raft::make_extents(n_rows)); + raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + thrust::fill(raft::resource::get_thrust_policy(handle), + minClusterAndDistance.data_handle(), + minClusterAndDistance.data_handle() + minClusterAndDistance.size(), + initial_value); auto centroidsNorm = raft::make_device_mdarray(handle, mr, raft::make_extents(n_clusters)); @@ -175,21 +165,28 @@ inline std::enable_if_t> predict_core( stream, raft::sqrt_op{}); - const auto* index_center_norm_ptr = centroidsNorm.data_handle(); - raft::linalg::map_offset( - handle, - raft::make_device_matrix_view( - distances.data(), n_rows, n_clusters), - [=] __device__(const uint32_t idx, const float dist) { - const auto query = idx / n_clusters; - const auto cluster = idx % n_clusters; - return dist / (dataset_norm[query] * index_center_norm_ptr[cluster]); - }, - raft::make_device_matrix_view(distances.data(), n_rows, n_clusters)); - auto distances_const_view = raft::make_device_matrix_view( - distances.data(), n_rows, n_clusters); - auto labels_view = raft::make_device_vector_view(labels, n_rows); - raft::matrix::argmin(handle, distances_const_view, labels_view); + cuvs::distance::fusedDistanceNNMinReduce, IdxT>( + minClusterAndDistance.data_handle(), + dataset, + centers, + dataset_norm, + centroidsNorm.data_handle(), + n_rows, + n_clusters, + dim, + (void*)workspace.data_handle(), + false, + false, + true, + params.metric, + 0.0f, + stream); + // Copy keys to output labels + thrust::transform(raft::resource::get_thrust_policy(handle), + minClusterAndDistance.data_handle(), + minClusterAndDistance.data_handle() + n_rows, + labels, + raft::compose_op, raft::key_op>()); break; } case cuvs::distance::DistanceType::InnerProduct: { @@ -214,14 +211,6 @@ inline std::enable_if_t> predict_core( distances.data(), n_clusters, stream); - if (dataset_norm) { - raft::linalg::binary_div_skip_zero( - handle, - raft::make_device_matrix_view( - distances.data(), n_rows, n_clusters), - raft::make_device_vector_view(dataset_norm, n_rows), - raft::linalg::Apply::ALONG_COLUMNS); - } auto distances_const_view = raft::make_device_matrix_view( distances.data(), n_rows, n_clusters); diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 5199dc74c..379730a8a 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -254,124 +254,6 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { } } - void testIVFFlatCosine() - { - size_t queries_size = ps.num_queries * ps.k; - std::vector indices_ivfflat(queries_size); - std::vector indices_naive(queries_size); - std::vector distances_ivfflat(queries_size); - std::vector distances_naive(queries_size); - - { - rmm::device_uvector distances_naive_dev(queries_size, stream_); - rmm::device_uvector indices_naive_dev(queries_size, stream_); - auto database_view = raft::make_device_matrix_view( - (const DataT*)database.data(), ps.num_db_vecs, ps.dim); - auto database_float = raft::make_device_matrix(handle_, ps.num_db_vecs, ps.dim); - auto search_queries_view = raft::make_device_matrix_view( - search_queries.data(), ps.num_queries, ps.dim); - auto search_queries_float = - raft::make_device_matrix(handle_, ps.num_queries, ps.dim); - raft::linalg::map( - handle_, - database_float.view(), - cuvs::spatial::knn::detail::utils::mapping{}, - database_view); - - raft::linalg::map( - handle_, - search_queries_float.view(), - cuvs::spatial::knn::detail::utils::mapping{}, - search_queries_view); - auto indices_out_view = - raft::make_device_matrix_view(indices_naive_dev.data(), ps.num_queries, ps.k); - auto dists_out_view = - raft::make_device_matrix_view(distances_naive_dev.data(), ps.num_queries, ps.k); - auto bfi = cuvs::neighbors::brute_force::build( - handle_, raft::make_const_mdspan(database_float.view()), ps.metric); - cuvs::neighbors::brute_force::search(handle_, - bfi, - raft::make_const_mdspan(search_queries_float.view()), - indices_out_view, - dists_out_view, - std::nullopt); - - raft::update_host(distances_naive.data(), distances_naive_dev.data(), queries_size, stream_); - raft::update_host(indices_naive.data(), indices_naive_dev.data(), queries_size, stream_); - raft::resource::sync_stream(handle_); - } - { - // unless something is really wrong with clustering, this could serve as a lower bound on - // recall - double min_recall = static_cast(ps.nprobe) / static_cast(ps.nlist); - - rmm::device_uvector distances_ivfflat_dev(queries_size, stream_); - rmm::device_uvector indices_ivfflat_dev(queries_size, stream_); - - { - ivf_flat::index_params index_params; - ivf_flat::search_params search_params; - index_params.n_lists = ps.nlist; - index_params.metric = ps.metric; - index_params.adaptive_centers = ps.adaptive_centers; - search_params.n_probes = ps.nprobe; - - index_params.add_data_on_build = true; - index_params.kmeans_trainset_fraction = 0.5; - index_params.metric_arg = 0; - - ivf_flat::index idx(handle_, index_params, ps.dim); - - IdxT half_of_data = ps.num_db_vecs / 2; - auto vector_indices = raft::make_device_vector(handle_, ps.num_db_vecs); - thrust::sequence( - raft::resource::get_thrust_policy(handle_), - thrust::device_pointer_cast(vector_indices.data_handle()), - thrust::device_pointer_cast(vector_indices.data_handle() + ps.num_db_vecs)); - raft::resource::sync_stream(handle_); - - auto database_view_half = - raft::make_device_matrix_view(database.data(), half_of_data, ps.dim); - auto database_view_other_half = raft::make_device_matrix_view( - database.data() + half_of_data * ps.dim, IdxT(ps.num_db_vecs) - half_of_data, ps.dim); - - auto new_half_of_data_indices_view = raft::make_device_vector_view( - vector_indices.data_handle() + half_of_data, IdxT(ps.num_db_vecs) - half_of_data); - - idx = ivf_flat::build(handle_, index_params, database_view_half); - cuvs::neighbors::ivf_flat::extend( - handle_, - database_view_other_half, - std::make_optional>( - new_half_of_data_indices_view), - &idx); - - auto search_queries_view = raft::make_device_matrix_view( - search_queries.data(), ps.num_queries, ps.dim); - auto indices_out_view = raft::make_device_matrix_view( - indices_ivfflat_dev.data(), ps.num_queries, ps.k); - auto dists_out_view = raft::make_device_matrix_view( - distances_ivfflat_dev.data(), ps.num_queries, ps.k); - ivf_flat::search( - handle_, search_params, idx, search_queries_view, indices_out_view, dists_out_view); - - raft::update_host( - distances_ivfflat.data(), distances_ivfflat_dev.data(), queries_size, stream_); - raft::update_host( - indices_ivfflat.data(), indices_ivfflat_dev.data(), queries_size, stream_); - raft::resource::sync_stream(handle_); - } - ASSERT_TRUE(eval_neighbours(indices_naive, - indices_ivfflat, - distances_naive, - distances_ivfflat, - ps.num_queries, - ps.k, - 0.001, - min_recall)); - } - } - void testPacker() { ivf_flat::index_params index_params; @@ -638,45 +520,39 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { rmm::device_uvector search_queries; }; -const std::vector> inputs_cosine = { - // test various dims (aligned and not aligned to vector sizes) - {3, 20, 1024, 3, 3, 4, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, - {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, - {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, - - // test dims that do not fit into kernel shared memory limits - {1000, 10000, 2048, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 2049, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 2050, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 2051, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, - {1000, 10000, 2052, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, - {1000, 10000, 2053, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, - {1000, 10000, 2056, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}}; - const std::vector> inputs = { // test various dims (aligned and not aligned to vector sizes) {1000, 10000, 1, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 2, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 3, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 4, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, false}, + {1000, 10000, 5, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::L2SqrtExpanded, true}, + {1000, 10000, 8, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, // test dims that do not fit into kernel shared memory limits {1000, 10000, 2048, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 2048, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 2049, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 2049, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 2050, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 2050, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 2051, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 2051, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 2052, 16, 40, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 2052, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 2053, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2053, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 2056, 16, 40, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 10000, 2056, 16, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, // various random combinations {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, @@ -686,12 +562,14 @@ const std::vector> inputs = { {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {3, 20, 1024, 3, 3, 4, cuvs::distance::DistanceType::CosineExpanded, false}, // host input data {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false, true}, + {100, 1000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, @@ -710,6 +588,7 @@ const std::vector> inputs = { {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, false}, {1000000, 1024, 32, 10, 256, 256, cuvs::distance::DistanceType::InnerProduct, false}, {98306, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, true}, + {98306, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::CosineExpanded, true}, // test radix_sort for getting the cluster selection {1000, diff --git a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu index fe5f56587..0ce168f5e 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu @@ -25,10 +25,4 @@ TEST_P(AnnIVFFlatTestF_float, AnnIVFFlat) { this->testIVFFlat(); } INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_float, ::testing::ValuesIn(inputs)); -typedef AnnIVFFlatTest AnnIVFFlatTestF_cosine_float; -TEST_P(AnnIVFFlatTestF_cosine_float, AnnIVFFlat) { this->testIVFFlatCosine(); } -INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, - AnnIVFFlatTestF_cosine_float, - ::testing::ValuesIn(inputs_cosine)); - } // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu index d95d64f3b..15935fd88 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu @@ -25,9 +25,4 @@ TEST_P(AnnIVFFlatTestF_int8, AnnIVFFlat) { this->testIVFFlat(); } INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_int8, ::testing::ValuesIn(inputs)); -typedef AnnIVFFlatTest AnnIVFFlatTestF_cosine_int8; -TEST_P(AnnIVFFlatTestF_cosine_int8, AnnIVFFlat) { this->testIVFFlatCosine(); } -INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, - AnnIVFFlatTestF_cosine_int8, - ::testing::ValuesIn(inputs_cosine)); } // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu index d5facdf80..42a8dab2e 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu @@ -25,9 +25,4 @@ TEST_P(AnnIVFFlatTestF_uint8, AnnIVFFlat) { this->testIVFFlat(); } INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, AnnIVFFlatTestF_uint8, ::testing::ValuesIn(inputs)); -typedef AnnIVFFlatTest AnnIVFFlatTestF_cosine_uint8; -TEST_P(AnnIVFFlatTestF_cosine_uint8, AnnIVFFlat) { this->testIVFFlatCosine(); } -INSTANTIATE_TEST_CASE_P(AnnIVFFlatTest, - AnnIVFFlatTestF_cosine_uint8, - ::testing::ValuesIn(inputs_cosine)); } // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/test/neighbors/naive_knn.cuh b/cpp/test/neighbors/naive_knn.cuh index 90547150f..553e667aa 100644 --- a/cpp/test/neighbors/naive_knn.cuh +++ b/cpp/test/neighbors/naive_knn.cuh @@ -41,7 +41,9 @@ RAFT_KERNEL naive_distance_kernel(EvalT* dist, if (midx >= m) return; IdxT grid_size = IdxT(blockDim.y) * IdxT(gridDim.y); for (IdxT nidx = threadIdx.y + blockIdx.y * blockDim.y; nidx < n; nidx += grid_size) { - EvalT acc = EvalT(0); + EvalT acc = EvalT(0); + EvalT normX = EvalT(0); + EvalT normY = EvalT(0); for (IdxT i = 0; i < k; ++i) { IdxT xidx = i + midx * k; IdxT yidx = i + nidx * k; @@ -51,6 +53,11 @@ RAFT_KERNEL naive_distance_kernel(EvalT* dist, case cuvs::distance::DistanceType::InnerProduct: { acc += xv * yv; } break; + case cuvs::distance::DistanceType::CosineExpanded: { + acc += xv * yv; + normX += xv * xv; + normY += yv * yv; + } break; case cuvs::distance::DistanceType::L2SqrtExpanded: case cuvs::distance::DistanceType::L2SqrtUnexpanded: case cuvs::distance::DistanceType::L2Expanded: @@ -66,6 +73,9 @@ RAFT_KERNEL naive_distance_kernel(EvalT* dist, case cuvs::distance::DistanceType::L2SqrtUnexpanded: { acc = raft::sqrt(acc); } break; + case cuvs::distance::DistanceType::CosineExpanded: { + acc = 1 - acc / (raft::sqrt(normX) * raft::sqrt(normY)); + } default: break; } dist[midx * n + nidx] = acc; @@ -118,7 +128,7 @@ void naive_knn(raft::resources const& handle, static_cast(k), dist_topk + offset * k, indices_topk + offset * k, - type != cuvs::distance::DistanceType::InnerProduct, + cuvs::distance::is_min_close(type), mr); } RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); From 4de8399a6698a8ce4341db3a240efe1cc1295ea1 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 7 Aug 2024 09:20:29 -0700 Subject: [PATCH 14/15] Address reviews, add pytest, update docstring --- cpp/include/cuvs/cluster/kmeans.hpp | 1 + cpp/include/cuvs/neighbors/ivf_flat.hpp | 72 +++++++++++++++++++ cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh | 3 +- cpp/test/neighbors/ann_ivf_flat.cuh | 41 ++++++++++- python/cuvs/cuvs/test/test_ivf_flat.py | 3 +- 5 files changed, 116 insertions(+), 4 deletions(-) diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index c6fe4d046..75205fa4f 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -120,6 +120,7 @@ struct params : base_params { * Simple object to specify hyper-parameters to the balanced k-means algorithm. * * The following metrics are currently supported in k-means balanced: + * - CosineExpanded * - InnerProduct * - L2Expanded * - L2SqrtExpanded diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index 918fef5af..44502f942 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -304,6 +304,12 @@ struct index : cuvs::neighbors::index { /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Usage example: * @code{.cpp} * using namespace cuvs::neighbors; @@ -327,6 +333,12 @@ auto build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Usage example: * @code{.cpp} * using namespace cuvs::neighbors; @@ -351,6 +363,12 @@ void build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Usage example: * @code{.cpp} * using namespace cuvs::neighbors; @@ -374,6 +392,12 @@ auto build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Usage example: * @code{.cpp} * using namespace cuvs::neighbors; @@ -398,6 +422,12 @@ void build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Usage example: * @code{.cpp} * using namespace cuvs::neighbors; @@ -421,6 +451,12 @@ auto build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Usage example: * @code{.cpp} * using namespace cuvs::neighbors; @@ -445,6 +481,12 @@ void build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Note, if index_params.add_data_on_build is set to true, the user can set a * stream pool in the input raft::resource with at least one stream to enable kernel and copy * overlapping. @@ -475,6 +517,12 @@ auto build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Note, if index_params.add_data_on_build is set to true, the user can set a * stream pool in the input raft::resource with at least one stream to enable kernel and copy * overlapping. @@ -506,6 +554,12 @@ void build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Note, if index_params.add_data_on_build is set to true, the user can set a * stream pool in the input raft::resource with at least one stream to enable kernel and copy * overlapping. @@ -536,6 +590,12 @@ auto build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Note, if index_params.add_data_on_build is set to true, the user can set a * stream pool in the input raft::resource with at least one stream to enable kernel and copy * overlapping. @@ -567,6 +627,12 @@ void build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Note, if index_params.add_data_on_build is set to true, the user can set a * stream pool in the input raft::resource with at least one stream to enable kernel and copy * overlapping. @@ -597,6 +663,12 @@ auto build(raft::resources const& handle, /** * @brief Build the index from the dataset for efficient search. * + * NB: Currently, the following distance metrics are supported: + * - L2Expanded + * - L2Unexpanded + * - InnerProduct + * - CosineExpanded + * * Note, if index_params.add_data_on_build is set to true, the user can set a * stream pool in the input raft::resource with at least one stream to enable kernel and copy * overlapping. diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index 0f81c00b8..fb110d810 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -406,7 +406,8 @@ inline auto build(raft::resources const& handle, "unsupported data type"); RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); - + RAFT_EXPECTS(params.metric != cuvs::distance::DistanceType::CosineExpanded || dim > 1, + "Cosine metric requires more than one dim"); index index(handle, params, dim); utils::memzero( index.accum_sorted_sizes().data_handle(), index.accum_sorted_sizes().size(), stream); diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 002f98ef3..e7386db08 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -567,46 +567,75 @@ const std::vector> inputs = { // various random combinations {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false}, - {3, 20, 1024, 3, 3, 4, cuvs::distance::DistanceType::CosineExpanded, false}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, // host input data {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false, true}, - {100, 1000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, // // host input data with prefetching for kernel copy overlapping {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true, true}, {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::InnerProduct, true}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, {10000, 131072, 8, 10, 50, 1024, cuvs::distance::DistanceType::InnerProduct, true}, + {10000, 131072, 8, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, true}, {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, 10000, 4096, 20, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false}, // test splitting the big query batches (> max gridDim.y) into smaller batches {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, false}, + {100000, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::CosineExpanded, false}, {1000000, 1024, 32, 10, 256, 256, cuvs::distance::DistanceType::InnerProduct, false}, + {1000000, 1024, 32, 10, 256, 256, cuvs::distance::DistanceType::CosineExpanded, false}, {98306, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::InnerProduct, true}, {98306, 1024, 32, 10, 64, 64, cuvs::distance::DistanceType::CosineExpanded, true}, @@ -627,6 +656,14 @@ const std::vector> inputs = { raft::matrix::detail::select::warpsort::kMaxCapacity * 4, cuvs::distance::DistanceType::InnerProduct, false}, + {1000, + 10000, + 16, + 10, + raft::matrix::detail::select::warpsort::kMaxCapacity * 4, + raft::matrix::detail::select::warpsort::kMaxCapacity * 4, + cuvs::distance::DistanceType::CosineExpanded, + false}, // The following two test cases should show very similar recall. // num_queries, num_db_vecs, dim, k, nprobe, nlist, metric, adaptive_centers diff --git a/python/cuvs/cuvs/test/test_ivf_flat.py b/python/cuvs/cuvs/test/test_ivf_flat.py index bb50d3573..9dd4097dc 100644 --- a/python/cuvs/cuvs/test/test_ivf_flat.py +++ b/python/cuvs/cuvs/test/test_ivf_flat.py @@ -92,6 +92,7 @@ def run_ivf_flat_build_search_test( skl_metric = { "sqeuclidean": "sqeuclidean", "inner_product": "cosine", + "cosine": "cosine", "euclidean": "euclidean", }[metric] nn_skl = NearestNeighbors( @@ -107,7 +108,7 @@ def run_ivf_flat_build_search_test( @pytest.mark.parametrize("inplace", [True, False]) @pytest.mark.parametrize("dtype", [np.float32]) @pytest.mark.parametrize( - "metric", ["sqeuclidean", "inner_product", "euclidean"] + "metric", ["sqeuclidean", "inner_product", "euclidean", "cosine"] ) def test_ivf_flat(inplace, dtype, metric): run_ivf_flat_build_search_test( From 1a0a2b0cca7959ec0457181826d270f6bc179e35 Mon Sep 17 00:00:00 2001 From: Mickael Ide Date: Wed, 7 Aug 2024 09:37:12 -0700 Subject: [PATCH 15/15] Fix style --- cpp/test/neighbors/ann_ivf_flat.cuh | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index e7386db08..17ec84097 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -583,19 +583,19 @@ const std::vector> inputs = { // host input data {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, - {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, - {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {1000, 10000, 16, 10, 50, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, - {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {1000, 10000, 16, 10, 70, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::L2Expanded, false, true}, - {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {100, 10000, 16, 10, 20, 512, cuvs::distance::DistanceType::CosineExpanded, false, true}, {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, - {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {20, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, - {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {1000, 100000, 16, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::L2Expanded, false, true}, - {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, + {10000, 131072, 8, 10, 20, 1024, cuvs::distance::DistanceType::CosineExpanded, false, true}, // // host input data with prefetching for kernel copy overlapping {1000, 10000, 16, 10, 40, 1024, cuvs::distance::DistanceType::L2Expanded, false, true, true},