Skip to content

Commit

Permalink
change to switch
Browse files Browse the repository at this point in the history
  • Loading branch information
jinsolp committed Sep 22, 2024
1 parent fea3dc8 commit 904ab1b
Showing 1 changed file with 112 additions and 98 deletions.
210 changes: 112 additions & 98 deletions cpp/src/hdbscan/detail/reachability.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -166,56 +166,65 @@ void compute_knn(const raft::handle_t& handle,
// pass value_idx through to knn.
rmm::device_uvector<int64_t> int64_indices(k * n_search_items, stream);

if (build_algo == Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN) {
std::vector<value_t*> inputs;
inputs.push_back(const_cast<value_t*>(X));

std::vector<int> sizes;
sizes.push_back(m);

// perform knn
brute_force_knn(handle,
inputs,
sizes,
n,
const_cast<value_t*>(search_items),
n_search_items,
int64_indices.data(),
dists,
k,
true,
true,
metric);
} else { // NN_DESCENT
RAFT_EXPECTS(static_cast<size_t>(k) <= build_params.graph_degree,
"n_neighbors should be smaller than the graph degree computed by nn descent");

auto epilogue = DistancePostProcessSqrt<int64_t, float>{};
build_params.return_distances = true;
auto graph = get_graph_nnd<int64_t, float>(handle, X, m, n, epilogue, build_params);

size_t TPB = 256;
size_t num_blocks = static_cast<size_t>((m + TPB) / TPB);

auto indices_d =
raft::make_device_matrix<int64_t, int64_t>(handle, m, build_params.graph_degree);

raft::copy(
indices_d.data_handle(), graph.graph().data_handle(), m * build_params.graph_degree, stream);

RAFT_EXPECTS(graph.distances().has_value(),
"return_distances for nn descent should be set to true to be used for HDBSCAN");
copy_first_k_cols_shift_zero<float>
<<<num_blocks, TPB, 0, stream>>>(dists,
graph.distances().value().data_handle(),
static_cast<size_t>(k),
build_params.graph_degree,
m);
copy_first_k_cols_shift_self<int64_t><<<num_blocks, TPB, 0, stream>>>(int64_indices.data(),
indices_d.data_handle(),
static_cast<size_t>(k),
build_params.graph_degree,
m);
switch (build_algo) {
case Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN: {
std::vector<value_t*> inputs;
inputs.push_back(const_cast<value_t*>(X));

std::vector<int> sizes;
sizes.push_back(m);

// perform knn
brute_force_knn(handle,
inputs,
sizes,
n,
const_cast<value_t*>(search_items),
n_search_items,
int64_indices.data(),
dists,
k,
true,
true,
metric);
break;
}

case Common::GRAPH_BUILD_ALGO::NN_DESCENT: {
RAFT_EXPECTS(static_cast<size_t>(k) <= build_params.graph_degree,
"n_neighbors should be smaller than the graph degree computed by nn descent");

auto epilogue = DistancePostProcessSqrt<int64_t, float>{};
build_params.return_distances = true;
auto graph = get_graph_nnd<int64_t, float>(handle, X, m, n, epilogue, build_params);

size_t TPB = 256;
size_t num_blocks = static_cast<size_t>((m + TPB) / TPB);

auto indices_d =
raft::make_device_matrix<int64_t, int64_t>(handle, m, build_params.graph_degree);

raft::copy(indices_d.data_handle(),
graph.graph().data_handle(),
m * build_params.graph_degree,
stream);

RAFT_EXPECTS(graph.distances().has_value(),
"return_distances for nn descent should be set to true to be used for HDBSCAN");
copy_first_k_cols_shift_zero<float>
<<<num_blocks, TPB, 0, stream>>>(dists,
graph.distances().value().data_handle(),
static_cast<size_t>(k),
build_params.graph_degree,
m);
copy_first_k_cols_shift_self<int64_t>
<<<num_blocks, TPB, 0, stream>>>(int64_indices.data(),
indices_d.data_handle(),
static_cast<size_t>(k),
build_params.graph_degree,
m);
break;
}
}

// convert from current knn's 64-bit to 32-bit.
Expand Down Expand Up @@ -329,55 +338,60 @@ void mutual_reachability_knn_l2(
// `A type local to a function cannot be used in the template argument of the
// enclosing parent function (and any parent classes) of an extended __device__
// or __host__ __device__ lambda`
switch (build_algo) {
case Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN: {
auto epilogue = ReachabilityPostProcess<value_idx, value_t>{core_dists, alpha};
auto X_view = raft::make_device_matrix_view(X, m, n);
std::vector<raft::device_matrix_view<const value_t, size_t>> index = {X_view};

raft::neighbors::brute_force::knn<value_idx, value_t>(
handle,
index,
X_view,
raft::make_device_matrix_view(out_inds, m, static_cast<size_t>(k)),
raft::make_device_matrix_view(out_dists, m, static_cast<size_t>(k)),
// TODO: expand distance metrics to support more than just L2 distance
// https://github.com/rapidsai/cuml/issues/5301
raft::distance::DistanceType::L2SqrtExpanded,
std::make_optional<float>(2.0f),
std::nullopt,
epilogue);
break;
}

if (build_algo == Common::GRAPH_BUILD_ALGO::BRUTE_FORCE_KNN) {
auto epilogue = ReachabilityPostProcess<value_idx, value_t>{core_dists, alpha};
auto X_view = raft::make_device_matrix_view(X, m, n);
std::vector<raft::device_matrix_view<const value_t, size_t>> index = {X_view};

raft::neighbors::brute_force::knn<value_idx, value_t>(
handle,
index,
X_view,
raft::make_device_matrix_view(out_inds, m, static_cast<size_t>(k)),
raft::make_device_matrix_view(out_dists, m, static_cast<size_t>(k)),
// TODO: expand distance metrics to support more than just L2 distance
// https://github.com/rapidsai/cuml/issues/5301
raft::distance::DistanceType::L2SqrtExpanded,
std::make_optional<float>(2.0f),
std::nullopt,
epilogue);
} else {
RAFT_EXPECTS(static_cast<size_t>(k) <= build_params.graph_degree,
"n_neighbors should be smaller than the graph degree computed by nn descent");

auto epilogue = ReachabilityPostProcessSqrt<value_idx, value_t>(core_dists, alpha);
build_params.return_distances = true;
auto graph = get_graph_nnd<value_idx, value_t>(handle, X, m, n, epilogue, build_params);

auto indices_d =
raft::make_device_matrix<value_idx, value_idx>(handle, m, build_params.graph_degree);

raft::copy(indices_d.data_handle(),
graph.graph().data_handle(),
m * build_params.graph_degree,
handle.get_stream());

RAFT_EXPECTS(graph.distances().has_value(),
"return_distances for nn descent should be set to true to be used for HDBSCAN");

raft::matrix::slice_coordinates coords{static_cast<int64_t>(0),
static_cast<int64_t>(0),
static_cast<int64_t>(m),
static_cast<int64_t>(k)};

auto out_knn_dists_view = raft::make_device_matrix_view(out_dists, m, (size_t)k);
raft::matrix::slice<float, int64_t, raft::row_major>(
handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords);
auto out_knn_indices_view =
raft::make_device_matrix_view<value_idx, int64_t>(out_inds, m, (size_t)k);
raft::matrix::slice<value_idx, int64_t, raft::row_major>(
handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords);
case Common::GRAPH_BUILD_ALGO::NN_DESCENT: {
RAFT_EXPECTS(static_cast<size_t>(k) <= build_params.graph_degree,
"n_neighbors should be smaller than the graph degree computed by nn descent");

auto epilogue = ReachabilityPostProcessSqrt<value_idx, value_t>(core_dists, alpha);
build_params.return_distances = true;
auto graph = get_graph_nnd<value_idx, value_t>(handle, X, m, n, epilogue, build_params);

auto indices_d =
raft::make_device_matrix<value_idx, value_idx>(handle, m, build_params.graph_degree);

raft::copy(indices_d.data_handle(),
graph.graph().data_handle(),
m * build_params.graph_degree,
handle.get_stream());

RAFT_EXPECTS(graph.distances().has_value(),
"return_distances for nn descent should be set to true to be used for HDBSCAN");

raft::matrix::slice_coordinates coords{static_cast<int64_t>(0),
static_cast<int64_t>(0),
static_cast<int64_t>(m),
static_cast<int64_t>(k)};

auto out_knn_dists_view = raft::make_device_matrix_view(out_dists, m, (size_t)k);
raft::matrix::slice<float, int64_t, raft::row_major>(
handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords);
auto out_knn_indices_view =
raft::make_device_matrix_view<value_idx, int64_t>(out_inds, m, (size_t)k);
raft::matrix::slice<value_idx, int64_t, raft::row_major>(
handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords);
break;
}
}
}

Expand Down

0 comments on commit 904ab1b

Please sign in to comment.