Skip to content

Commit

Permalink
fix order
Browse files Browse the repository at this point in the history
  • Loading branch information
aamijar committed Jun 17, 2024
1 parent 4b64c37 commit a1ecfcb
Showing 1 changed file with 43 additions and 43 deletions.
86 changes: 43 additions & 43 deletions cpp/include/raft/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ void distance_impl(raft::resources const& handle,

template <typename DataT, typename AccT, typename OutT, typename FinOpT, typename IdxT = int>
void distance_impl(raft::resources const& handle,
distance_tag<DistanceType::DiceExpanded> distance_type,
distance_tag<DistanceType::CosineExpanded> distance_type,
const DataT* x,
const DataT* y,
OutT* out,
Expand Down Expand Up @@ -224,51 +224,24 @@ void distance_impl(raft::resources const& handle,
// perhaps the use of stridedSummationKernel could be causing this,
// need to investigate and fix.
if (x == y && is_row_major) {
raft::linalg::reduce(x_norm,
x,
k,
std::max(m, n),
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
raft::linalg::rowNorm(
x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
} else {
y_norm += m;
raft::linalg::reduce(x_norm,
x,
k,
m,
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
raft::linalg::reduce(y_norm,
y,
k,
n,
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
raft::linalg::rowNorm(
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
raft::linalg::rowNorm(
y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
}

ops::dice_distance_op<DataT, AccT, IdxT> distance_op{};
ops::cosine_distance_op<DataT, AccT, IdxT> distance_op{};
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
}

template <typename DataT, typename AccT, typename OutT, typename FinOpT, typename IdxT = int>
void distance_impl(raft::resources const& handle,
distance_tag<DistanceType::CosineExpanded> distance_type,
distance_tag<DistanceType::DiceExpanded> distance_type,
const DataT* x,
const DataT* y,
OutT* out,
Expand Down Expand Up @@ -297,17 +270,44 @@ void distance_impl(raft::resources const& handle,
// perhaps the use of stridedSummationKernel could be causing this,
// need to investigate and fix.
if (x == y && is_row_major) {
raft::linalg::rowNorm(
x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
raft::linalg::reduce(x_norm,
x,
k,
std::max(m, n),
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
} else {
y_norm += m;
raft::linalg::rowNorm(
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
raft::linalg::rowNorm(
y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
raft::linalg::reduce(x_norm,
x,
k,
m,
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
raft::linalg::reduce(y_norm,
y,
k,
n,
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
}

ops::cosine_distance_op<DataT, AccT, IdxT> distance_op{};
ops::dice_distance_op<DataT, AccT, IdxT> distance_op{};
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
}
Expand Down

0 comments on commit a1ecfcb

Please sign in to comment.