diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh index d53765c373..a2dc9c00d7 100644 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ b/cpp/include/raft/neighbors/detail/nn_descent.cuh @@ -1435,16 +1435,16 @@ void build(raft::resources const& res, GNND nnd(res, build_config); - if(idx.distances().has_value()) { + if (idx.distances().has_value()) { nnd.build(dataset.data_handle(), - dataset.extent(0), - int_graph.data_handle(), - params.return_distances, - idx.distances().value().data_handle()); + dataset.extent(0), + int_graph.data_handle(), + params.return_distances, + idx.distances().value().data_handle()); } else { - RAFT_EXPECTS(false, "Distance view not allocated. Using NN Descent requires return_distances to be true"); + RAFT_EXPECTS( + false, "Distance view not allocated. Using NN Descent requires return_distances to be true"); } - #pragma omp parallel for for (size_t i = 0; i < static_cast(dataset.extent(0)); i++) { diff --git a/cpp/include/raft/neighbors/nn_descent_types.hpp b/cpp/include/raft/neighbors/nn_descent_types.hpp index 4bad85908d..9838047bb5 100644 --- a/cpp/include/raft/neighbors/nn_descent_types.hpp +++ b/cpp/include/raft/neighbors/nn_descent_types.hpp @@ -155,7 +155,8 @@ struct index : ann::index { } /** neighborhood graph distances [size, graph-degree] */ - [[nodiscard]] inline auto distances() noexcept -> std::optional> + [[nodiscard]] inline auto distances() noexcept + -> std::optional> { return distances_view_; } diff --git a/cpp/test/neighbors/ann_nn_descent.cuh b/cpp/test/neighbors/ann_nn_descent.cuh index 953da2d42b..f74cadb415 100644 --- a/cpp/test/neighbors/ann_nn_descent.cuh +++ b/cpp/test/neighbors/ann_nn_descent.cuh @@ -108,18 +108,22 @@ class AnnNNDescentTest : public ::testing::TestWithParam { auto index = nn_descent::build(handle_, index_params, database_host_view); raft::copy( indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); - if(index.distances().has_value()) { - raft::copy( - distances_NNDescent.data(), index.distances().value().data_handle(), queries_size, stream_); + if (index.distances().has_value()) { + raft::copy(distances_NNDescent.data(), + index.distances().value().data_handle(), + queries_size, + stream_); } - + } else { auto index = nn_descent::build(handle_, index_params, database_view); raft::copy( indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_); - if(index.distances().has_value()) { - raft::copy( - distances_NNDescent.data(), index.distances().value().data_handle(), queries_size, stream_); + if (index.distances().has_value()) { + raft::copy(distances_NNDescent.data(), + index.distances().value().data_handle(), + queries_size, + stream_); } }; }