Skip to content

Commit

Permalink
add in header for coo tfidf
Browse files Browse the repository at this point in the history
  • Loading branch information
jperez999 committed Jun 25, 2024
1 parent 2d82aca commit dc01bc1
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
30 changes: 17 additions & 13 deletions cpp/include/raft/sparse/matrix/detail/preprocessing.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,10 @@ void encode_tfidf(raft::resources& handle,
raft::device_coordinate_structure_view<T1, T1, T2> coo_in,
raft::device_vector_view<T2, IdxT> values_out)
{
auto rows = coo_in.get_rows();
auto columns = coo_in.get_columns();
auto values = coo_in.get_elements();
auto doc_lengths = raft::make_device_vector<float, IdxT>(handle, columns.size());
auto term_counts = raft::make_device_vector<float, IdxT>(handle, rows.size());
auto [doc_count, avg_doc_length] = sparse_search_preprocess<int, float>(
handle, rows, columns, values, doc_lengths.view(), term_counts.view());

raft::linalg::map(handle,
values_out,
tfidf(doc_count),
raft::make_const_mdspan(values),
raft::make_const_mdspan(term_counts.view()));
auto rows = coo_in.get_rows();
auto columns = coo_in.get_columns();
auto values = coo_in.get_elements();
encode_tfidf(handle, rows, columns, values, values_out);
}

template <typename T1, typename T2, typename IdxT>
Expand All @@ -215,6 +206,19 @@ void encode_tfidf(raft::resources& handle,
raft::make_const_mdspan(term_counts.view()));
}

template <typename T1, typename T2, typename IdxT>
void encode_bm25(raft::resources& handle,
raft::device_coordinate_structure_view<T1, T1, T2> coo_in,
raft::device_vector_view<T2, IdxT> values_out,
float k_param = 1.6f,
float b_param = 0.75f)
{
auto rows = coo_in.get_rows();
auto columns = coo_in.get_columns();
auto values = coo_in.get_elements();
encode_bm25(handle, rows, columns, values, values_out, k_param, b_param);
}

template <typename T1, typename T2, typename IdxT>
void encode_bm25(raft::resources& handle,
raft::device_vector_view<T1, IdxT> rows,
Expand Down
8 changes: 8 additions & 0 deletions cpp/include/raft/sparse/matrix/preprocessing.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,14 @@ void encode_bm25(raft::resources& handle,
handle, rows, columns, values, values_out, k_param, b_param);
}

template <typename T1, typename T2, typename IdxT>
void encode_bm25(raft::resources& handle,
raft::device_coordinate_structure_view<T1, T1, T2> coo_in,
raft::device_vector_view<T2, IdxT> values_out)
{
return matrix::detail::encode_bm25<T1, T2, IdxT>(handle, coo_in, values_out);
}

template <typename T1, typename T2, typename IdxT>
void encode_tfidf(raft::resources& handle,
raft::device_vector_view<T1, IdxT> rows,
Expand Down

0 comments on commit dc01bc1

Please sign in to comment.