From dc01bc13bb260c1d574213097a46ebd704c13ccf Mon Sep 17 00:00:00 2001 From: Julio Perez Date: Tue, 25 Jun 2024 15:18:23 -0400 Subject: [PATCH] add in header for coo tfidf --- .../sparse/matrix/detail/preprocessing.cuh | 30 +++++++++++-------- .../raft/sparse/matrix/preprocessing.cuh | 8 +++++ 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/cpp/include/raft/sparse/matrix/detail/preprocessing.cuh b/cpp/include/raft/sparse/matrix/detail/preprocessing.cuh index 755fa7122f..59c9f06bd1 100644 --- a/cpp/include/raft/sparse/matrix/detail/preprocessing.cuh +++ b/cpp/include/raft/sparse/matrix/detail/preprocessing.cuh @@ -181,19 +181,10 @@ void encode_tfidf(raft::resources& handle, raft::device_coordinate_structure_view coo_in, raft::device_vector_view values_out) { - auto rows = coo_in.get_rows(); - auto columns = coo_in.get_columns(); - auto values = coo_in.get_elements(); - auto doc_lengths = raft::make_device_vector(handle, columns.size()); - auto term_counts = raft::make_device_vector(handle, rows.size()); - auto [doc_count, avg_doc_length] = sparse_search_preprocess( - handle, rows, columns, values, doc_lengths.view(), term_counts.view()); - - raft::linalg::map(handle, - values_out, - tfidf(doc_count), - raft::make_const_mdspan(values), - raft::make_const_mdspan(term_counts.view())); + auto rows = coo_in.get_rows(); + auto columns = coo_in.get_columns(); + auto values = coo_in.get_elements(); + encode_tfidf(handle, rows, columns, values, values_out); } template @@ -215,6 +206,19 @@ void encode_tfidf(raft::resources& handle, raft::make_const_mdspan(term_counts.view())); } +template +void encode_bm25(raft::resources& handle, + raft::device_coordinate_structure_view coo_in, + raft::device_vector_view values_out, + float k_param = 1.6f, + float b_param = 0.75f) +{ + auto rows = coo_in.get_rows(); + auto columns = coo_in.get_columns(); + auto values = coo_in.get_elements(); + encode_bm25(handle, rows, columns, values, values_out, k_param, b_param); +} + template void encode_bm25(raft::resources& handle, raft::device_vector_view rows, diff --git a/cpp/include/raft/sparse/matrix/preprocessing.cuh b/cpp/include/raft/sparse/matrix/preprocessing.cuh index 510819c1b5..95898d0403 100644 --- a/cpp/include/raft/sparse/matrix/preprocessing.cuh +++ b/cpp/include/raft/sparse/matrix/preprocessing.cuh @@ -39,6 +39,14 @@ void encode_bm25(raft::resources& handle, handle, rows, columns, values, values_out, k_param, b_param); } +template +void encode_bm25(raft::resources& handle, + raft::device_coordinate_structure_view coo_in, + raft::device_vector_view values_out) +{ + return matrix::detail::encode_bm25(handle, coo_in, values_out); +} + template void encode_tfidf(raft::resources& handle, raft::device_vector_view rows,