From 904ab1b79050869ca931d7c0c500c9c5f812c3de Mon Sep 17 00:00:00 2001 From: soleee99 Date: Sun, 22 Sep 2024 03:28:57 +0000 Subject: [PATCH] change to switch --- cpp/src/hdbscan/detail/reachability.cuh | 210 +++++++++++++----------- 1 file changed, 112 insertions(+), 98 deletions(-) diff --git a/cpp/src/hdbscan/detail/reachability.cuh b/cpp/src/hdbscan/detail/reachability.cuh index c5910a6213..9538536723 100644 --- a/cpp/src/hdbscan/detail/reachability.cuh +++ b/cpp/src/hdbscan/detail/reachability.cuh @@ -166,56 +166,65 @@ void compute_knn(const raft::handle_t& handle, // pass value_idx through to knn. rmm::device_uvector int64_indices(k * n_search_items, stream); - if (build_algo == Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN) { - std::vector inputs; - inputs.push_back(const_cast(X)); - - std::vector sizes; - sizes.push_back(m); - - // perform knn - brute_force_knn(handle, - inputs, - sizes, - n, - const_cast(search_items), - n_search_items, - int64_indices.data(), - dists, - k, - true, - true, - metric); - } else { // NN_DESCENT - RAFT_EXPECTS(static_cast(k) <= build_params.graph_degree, - "n_neighbors should be smaller than the graph degree computed by nn descent"); - - auto epilogue = DistancePostProcessSqrt{}; - build_params.return_distances = true; - auto graph = get_graph_nnd(handle, X, m, n, epilogue, build_params); - - size_t TPB = 256; - size_t num_blocks = static_cast((m + TPB) / TPB); - - auto indices_d = - raft::make_device_matrix(handle, m, build_params.graph_degree); - - raft::copy( - indices_d.data_handle(), graph.graph().data_handle(), m * build_params.graph_degree, stream); - - RAFT_EXPECTS(graph.distances().has_value(), - "return_distances for nn descent should be set to true to be used for HDBSCAN"); - copy_first_k_cols_shift_zero - <<>>(dists, - graph.distances().value().data_handle(), - static_cast(k), - build_params.graph_degree, - m); - copy_first_k_cols_shift_self<<>>(int64_indices.data(), - indices_d.data_handle(), - static_cast(k), - build_params.graph_degree, - m); + switch (build_algo) { + case Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN: { + std::vector inputs; + inputs.push_back(const_cast(X)); + + std::vector sizes; + sizes.push_back(m); + + // perform knn + brute_force_knn(handle, + inputs, + sizes, + n, + const_cast(search_items), + n_search_items, + int64_indices.data(), + dists, + k, + true, + true, + metric); + break; + } + + case Common::GRAPH_BUILD_ALGO::NN_DESCENT: { + RAFT_EXPECTS(static_cast(k) <= build_params.graph_degree, + "n_neighbors should be smaller than the graph degree computed by nn descent"); + + auto epilogue = DistancePostProcessSqrt{}; + build_params.return_distances = true; + auto graph = get_graph_nnd(handle, X, m, n, epilogue, build_params); + + size_t TPB = 256; + size_t num_blocks = static_cast((m + TPB) / TPB); + + auto indices_d = + raft::make_device_matrix(handle, m, build_params.graph_degree); + + raft::copy(indices_d.data_handle(), + graph.graph().data_handle(), + m * build_params.graph_degree, + stream); + + RAFT_EXPECTS(graph.distances().has_value(), + "return_distances for nn descent should be set to true to be used for HDBSCAN"); + copy_first_k_cols_shift_zero + <<>>(dists, + graph.distances().value().data_handle(), + static_cast(k), + build_params.graph_degree, + m); + copy_first_k_cols_shift_self + <<>>(int64_indices.data(), + indices_d.data_handle(), + static_cast(k), + build_params.graph_degree, + m); + break; + } } // convert from current knn's 64-bit to 32-bit. @@ -329,55 +338,60 @@ void mutual_reachability_knn_l2( // `A type local to a function cannot be used in the template argument of the // enclosing parent function (and any parent classes) of an extended __device__ // or __host__ __device__ lambda` + switch (build_algo) { + case Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN: { + auto epilogue = ReachabilityPostProcess{core_dists, alpha}; + auto X_view = raft::make_device_matrix_view(X, m, n); + std::vector> index = {X_view}; + + raft::neighbors::brute_force::knn( + handle, + index, + X_view, + raft::make_device_matrix_view(out_inds, m, static_cast(k)), + raft::make_device_matrix_view(out_dists, m, static_cast(k)), + // TODO: expand distance metrics to support more than just L2 distance + // https://github.com/rapidsai/cuml/issues/5301 + raft::distance::DistanceType::L2SqrtExpanded, + std::make_optional(2.0f), + std::nullopt, + epilogue); + break; + } - if (build_algo == Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN) { - auto epilogue = ReachabilityPostProcess{core_dists, alpha}; - auto X_view = raft::make_device_matrix_view(X, m, n); - std::vector> index = {X_view}; - - raft::neighbors::brute_force::knn( - handle, - index, - X_view, - raft::make_device_matrix_view(out_inds, m, static_cast(k)), - raft::make_device_matrix_view(out_dists, m, static_cast(k)), - // TODO: expand distance metrics to support more than just L2 distance - // https://github.com/rapidsai/cuml/issues/5301 - raft::distance::DistanceType::L2SqrtExpanded, - std::make_optional(2.0f), - std::nullopt, - epilogue); - } else { - RAFT_EXPECTS(static_cast(k) <= build_params.graph_degree, - "n_neighbors should be smaller than the graph degree computed by nn descent"); - - auto epilogue = ReachabilityPostProcessSqrt(core_dists, alpha); - build_params.return_distances = true; - auto graph = get_graph_nnd(handle, X, m, n, epilogue, build_params); - - auto indices_d = - raft::make_device_matrix(handle, m, build_params.graph_degree); - - raft::copy(indices_d.data_handle(), - graph.graph().data_handle(), - m * build_params.graph_degree, - handle.get_stream()); - - RAFT_EXPECTS(graph.distances().has_value(), - "return_distances for nn descent should be set to true to be used for HDBSCAN"); - - raft::matrix::slice_coordinates coords{static_cast(0), - static_cast(0), - static_cast(m), - static_cast(k)}; - - auto out_knn_dists_view = raft::make_device_matrix_view(out_dists, m, (size_t)k); - raft::matrix::slice( - handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords); - auto out_knn_indices_view = - raft::make_device_matrix_view(out_inds, m, (size_t)k); - raft::matrix::slice( - handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords); + case Common::GRAPH_BUILD_ALGO::NN_DESCENT: { + RAFT_EXPECTS(static_cast(k) <= build_params.graph_degree, + "n_neighbors should be smaller than the graph degree computed by nn descent"); + + auto epilogue = ReachabilityPostProcessSqrt(core_dists, alpha); + build_params.return_distances = true; + auto graph = get_graph_nnd(handle, X, m, n, epilogue, build_params); + + auto indices_d = + raft::make_device_matrix(handle, m, build_params.graph_degree); + + raft::copy(indices_d.data_handle(), + graph.graph().data_handle(), + m * build_params.graph_degree, + handle.get_stream()); + + RAFT_EXPECTS(graph.distances().has_value(), + "return_distances for nn descent should be set to true to be used for HDBSCAN"); + + raft::matrix::slice_coordinates coords{static_cast(0), + static_cast(0), + static_cast(m), + static_cast(k)}; + + auto out_knn_dists_view = raft::make_device_matrix_view(out_dists, m, (size_t)k); + raft::matrix::slice( + handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords); + auto out_knn_indices_view = + raft::make_device_matrix_view(out_inds, m, (size_t)k); + raft::matrix::slice( + handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords); + break; + } } }