Skip to content

Commit

Permalink
fix styling
Browse files Browse the repository at this point in the history
  • Loading branch information
jinsolp committed Jun 13, 2024
1 parent 3ce8cdf commit 4ec3ab0
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 15 deletions.
14 changes: 7 additions & 7 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1435,16 +1435,16 @@ void build(raft::resources const& res,
GNND<const T, int> nnd(res, build_config);
if(idx.distances().has_value()) {
if (idx.distances().has_value()) {
nnd.build(dataset.data_handle(),
dataset.extent(0),
int_graph.data_handle(),
params.return_distances,
idx.distances().value().data_handle());
dataset.extent(0),
int_graph.data_handle(),
params.return_distances,
idx.distances().value().data_handle());
} else {
RAFT_EXPECTS(false, "Distance view not allocated. Using NN Descent requires return_distances to be true");
RAFT_EXPECTS(
false, "Distance view not allocated. Using NN Descent requires return_distances to be true");
}
#pragma omp parallel for
for (size_t i = 0; i < static_cast<size_t>(dataset.extent(0)); i++) {
Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/neighbors/nn_descent_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ struct index : ann::index {
}

/** neighborhood graph distances [size, graph-degree] */
[[nodiscard]] inline auto distances() noexcept -> std::optional<device_matrix_view<T, int64_t, row_major>>
[[nodiscard]] inline auto distances() noexcept
-> std::optional<device_matrix_view<T, int64_t, row_major>>
{
return distances_view_;
}
Expand Down
18 changes: 11 additions & 7 deletions cpp/test/neighbors/ann_nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,22 @@ class AnnNNDescentTest : public ::testing::TestWithParam<AnnNNDescentInputs> {
auto index = nn_descent::build<DataT, IdxT>(handle_, index_params, database_host_view);
raft::copy(
indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
if(index.distances().has_value()) {
raft::copy(
distances_NNDescent.data(), index.distances().value().data_handle(), queries_size, stream_);
if (index.distances().has_value()) {
raft::copy(distances_NNDescent.data(),
index.distances().value().data_handle(),
queries_size,
stream_);
}

} else {
auto index = nn_descent::build<DataT, IdxT>(handle_, index_params, database_view);
raft::copy(
indices_NNDescent.data(), index.graph().data_handle(), queries_size, stream_);
if(index.distances().has_value()) {
raft::copy(
distances_NNDescent.data(), index.distances().value().data_handle(), queries_size, stream_);
if (index.distances().has_value()) {
raft::copy(distances_NNDescent.data(),
index.distances().value().data_handle(),
queries_size,
stream_);
}
};
}
Expand Down

0 comments on commit 4ec3ab0

Please sign in to comment.