diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index c99003758e..f21729a695 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -118,31 +118,42 @@ inline void launcher(const raft::handle_t& handle, // TODO: use nndescent from cuvs RAFT_EXPECTS(static_cast(n_neighbors) <= params->nn_descent_params.graph_degree, "n_neighbors should be smaller than the graph degree computed by nn descent"); + RAFT_EXPECTS(params->nn_descent_params.return_distances, + "return_distances for nn descent should be set to true to be used for UMAP"); auto graph = get_graph_nnd(handle, inputsA, params); - auto indices_d = raft::make_device_matrix( - handle, inputsA.n, params->nn_descent_params.graph_degree); - - raft::copy(indices_d.data_handle(), - graph.graph().data_handle(), - inputsA.n * params->nn_descent_params.graph_degree, - stream); - + // `graph.graph()` is a host array (n x graph_degree). + // Slice and copy to a temporary host array (n x n_neighbors), then copy + // that to the output device array `out.knn_indices` (n x n_neighbors). + // TODO: force graph_degree = n_neighbors so the temporary host array and + // slice isn't necessary. + auto temp_indices_h = raft::make_host_matrix(inputsA.n, n_neighbors); + size_t graph_degree = params->nn_descent_params.graph_degree; +#pragma omp parallel for + for (size_t i = 0; i < static_cast(inputsA.n); i++) { + for (int j = 0; j < n_neighbors; j++) { + auto target = temp_indices_h.data_handle(); + auto source = graph.graph().data_handle(); + target[i * n_neighbors + j] = source[i * graph_degree + j]; + } + } + raft::copy(handle, + raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors), + temp_indices_h.view()); + + // `graph.distances()` is a device array (n x graph_degree). + // Slice and copy to the output device array `out.knn_dists` (n x n_neighbors). + // TODO: force graph_degree = n_neighbors so this slice isn't necessary. raft::matrix::slice_coordinates coords{static_cast(0), static_cast(0), static_cast(inputsA.n), static_cast(n_neighbors)}; - - RAFT_EXPECTS(graph.distances().has_value(), - "return_distances for nn descent should be set to true to be used for UMAP"); - auto out_knn_dists_view = raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors); raft::matrix::slice( - handle, raft::make_const_mdspan(graph.distances().value()), out_knn_dists_view, coords); - auto out_knn_indices_view = - raft::make_device_matrix_view(out.knn_indices, inputsA.n, n_neighbors); - raft::matrix::slice( - handle, raft::make_const_mdspan(indices_d.view()), out_knn_indices_view, coords); + handle, + raft::make_const_mdspan(graph.distances().value()), + raft::make_device_matrix_view(out.knn_dists, inputsA.n, n_neighbors), + coords); } }