From a1ecfcb98403cfc754de73ca9eb9927ff2b61525 Mon Sep 17 00:00:00 2001 From: aamijar Date: Mon, 17 Jun 2024 16:55:29 +0000 Subject: [PATCH] fix order --- cpp/include/raft/distance/detail/distance.cuh | 86 +++++++++---------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index a329ff2ce4..b708360074 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -195,7 +195,7 @@ void distance_impl(raft::resources const& handle, template void distance_impl(raft::resources const& handle, - distance_tag distance_type, + distance_tag distance_type, const DataT* x, const DataT* y, OutT* out, @@ -224,51 +224,24 @@ void distance_impl(raft::resources const& handle, // perhaps the use of stridedSummationKernel could be causing this, // need to investigate and fix. if (x == y && is_row_major) { - raft::linalg::reduce(x_norm, - x, - k, - std::max(m, n), - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); + raft::linalg::rowNorm( + x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } else { y_norm += m; - raft::linalg::reduce(x_norm, - x, - k, - m, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - raft::linalg::reduce(y_norm, - y, - k, - n, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); + raft::linalg::rowNorm( + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + raft::linalg::rowNorm( + y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } - ops::dice_distance_op distance_op{}; + ops::cosine_distance_op distance_op{}; pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template void distance_impl(raft::resources const& handle, - distance_tag distance_type, + distance_tag distance_type, const DataT* x, const DataT* y, OutT* out, @@ -297,17 +270,44 @@ void distance_impl(raft::resources const& handle, // perhaps the use of stridedSummationKernel could be causing this, // need to investigate and fix. if (x == y && is_row_major) { - raft::linalg::rowNorm( - x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + raft::linalg::reduce(x_norm, + x, + k, + std::max(m, n), + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); } else { y_norm += m; - raft::linalg::rowNorm( - x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); - raft::linalg::rowNorm( - y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + raft::linalg::reduce(x_norm, + x, + k, + m, + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); + raft::linalg::reduce(y_norm, + y, + k, + n, + (AccT)0, + is_row_major, + true, + stream, + false, + raft::identity_op(), + raft::add_op()); } - ops::cosine_distance_op distance_op{}; + ops::dice_distance_op distance_op{}; pairwise_matrix_dispatch( distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); }