Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Dice Distance for Dense Inputs #2359

Merged
merged 12 commits into from
Jun 24, 2024
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,8 @@ if(RAFT_COMPILE_LIBRARY)
src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu
src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu
src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu
Expand Down
82 changes: 79 additions & 3 deletions cpp/include/raft/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ using distance_tag = std::integral_constant<DistanceType, d>;
* - DistanceType::Canberra:
* - DistanceType::CorrelationExpanded:
* - DistanceType::CosineExpanded:
* - DistanceType::DiceExpanded:
* - DistanceType::HammingUnexpanded:
* - DistanceType::HellingerExpanded:
* - DistanceType::JensenShannon:
Expand Down Expand Up @@ -238,6 +239,79 @@ void distance_impl(raft::resources const& handle,
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
}

template <typename DataT, typename AccT, typename OutT, typename FinOpT, typename IdxT = int>
void distance_impl(raft::resources const& handle,
distance_tag<DistanceType::DiceExpanded> distance_type,
const DataT* x,
const DataT* y,
OutT* out,
IdxT m,
IdxT n,
IdxT k,
AccT* workspace,
size_t worksize,
FinOpT fin_op,
bool is_row_major,
DataT) // unused
{
// raft distance support inputs as float/double and output as uint8_t/float/double.
static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))),
"OutT can be uint8_t, float, double,"
"if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT).");

ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

cudaStream_t stream = raft::resource::get_cuda_stream(handle);

DataT* x_norm = workspace;
DataT* y_norm = workspace;
// TODO: Column major case looks to have lower accuracy for X == Y,
// perhaps the use of stridedSummationKernel could be causing this,
// need to investigate and fix.
if (x == y && is_row_major) {
raft::linalg::reduce(x_norm,
x,
k,
std::max(m, n),
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
} else {
y_norm += m;
raft::linalg::reduce(x_norm,
x,
k,
m,
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
raft::linalg::reduce(y_norm,
y,
k,
n,
(AccT)0,
is_row_major,
true,
stream,
false,
raft::identity_op(),
raft::add_op());
}

ops::dice_distance_op<DataT, AccT, IdxT> distance_op{};
pairwise_matrix_dispatch<decltype(distance_op), DataT, AccT, OutT, FinOpT, IdxT>(
distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major);
}

template <typename DataT, typename AccT, typename OutT, typename FinOpT, typename IdxT = int>
void distance_impl(raft::resources const& handle,
distance_tag<DistanceType::HammingUnexpanded> distance_type,
Expand Down Expand Up @@ -794,9 +868,11 @@ template <raft::distance::DistanceType distanceType,
typename Index_ = int>
size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, Index_ k)
{
size_t worksize = 0;
constexpr bool is_allocated = (distanceType <= raft::distance::DistanceType::CosineExpanded) ||
(distanceType == raft::distance::DistanceType::CorrelationExpanded);
size_t worksize = 0;
constexpr bool is_allocated =
(distanceType <= raft::distance::DistanceType::CosineExpanded) ||
(distanceType == raft::distance::DistanceType::CorrelationExpanded) ||
(distanceType == raft::distance::DistanceType::DiceExpanded);
constexpr int numOfBuffers =
(distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1;

Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/distance/detail/distance_ops/all_ops.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,7 @@
#include <raft/distance/detail/distance_ops/canberra.cuh>
#include <raft/distance/detail/distance_ops/correlation.cuh>
#include <raft/distance/detail/distance_ops/cosine.cuh>
#include <raft/distance/detail/distance_ops/dice.cuh>
#include <raft/distance/detail/distance_ops/hamming.cuh>
#include <raft/distance/detail/distance_ops/hellinger.cuh>
#include <raft/distance/detail/distance_ops/jensen_shannon.cuh>
Expand Down
85 changes: 85 additions & 0 deletions cpp/include/raft/distance/detail/distance_ops/dice.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/util/cuda_dev_essentials.cuh> // DI

namespace raft::distance::detail::ops {

// Epilogue operator for CUTLASS based kernel
template <typename DataT, typename AccT>
struct dice_cutlass_op {
__device__ dice_cutlass_op() noexcept {}
__device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept
{
return static_cast<AccT>(1.0) - static_cast<AccT>(2 * accVal / (aNorm + bNorm));
}
__device__ AccT operator()(DataT aData) const noexcept { return aData; }
};

/**
* @brief the expanded dice distance matrix calculation
*
* It computes the following equation:
*
* d(x, y) = 1 - 2*(x ⋅ y) / ( Σ(x) + Σ(y) )
*/
template <typename DataType, typename AccType, typename IdxType>
struct dice_distance_op {
using DataT = DataType;
using AccT = AccType;
using IdxT = IdxType;

// Load norms of input data
static constexpr bool use_norms = true;
// Whether the core function requires so many instructions that it makes sense
// to reduce loop unrolling, etc. We do this to keep compile times in check.
static constexpr bool expensive_inner_loop = false;

// Size of shared memory. This is normally decided by the kernel policy, but
// some ops such as correlation_distance_op use more.
template <typename Policy>
static constexpr size_t shared_mem_size()
{
return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT));
}

DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; };

template <typename Policy>
DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh],
DataT* regxn,
DataT* regyn,
IdxT gridStrideX,
IdxT gridStrideY) const
{
#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < Policy::AccColsPerTh; ++j) {
acc[i][j] = 1.0 - (2 * acc[i][j] / (regxn[i] + regyn[j]));
}
}
}

constexpr dice_cutlass_op<DataT, AccT> get_cutlass_op() const
{
return dice_cutlass_op<DataT, AccT>();
}
};

} // namespace raft::distance::detail::ops
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -120,6 +120,10 @@ instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::dice_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::dice_distance_op, double, double, double, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int);
instantiate_raft_distance_detail_pairwise_matrix_dispatch(
Expand Down
48 changes: 48 additions & 0 deletions cpp/include/raft/distance/distance-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ instantiate_raft_distance_distance(
raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, float, float, float, raft::identity_op, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, double, double, double, raft::identity_op, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int);
instantiate_raft_distance_distance(
Expand Down Expand Up @@ -286,6 +290,10 @@ instantiate_raft_distance_distance(
raft::distance::DistanceType::CosineExpanded, float, float, float, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::CosineExpanded, double, double, double, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, float, float, float, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, double, double, double, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::HammingUnexpanded, float, float, float, int);
instantiate_raft_distance_distance(
Expand Down Expand Up @@ -362,6 +370,10 @@ instantiate_raft_distance_distance(
raft::distance::DistanceType::CosineExpanded, float, float, float, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::CosineExpanded, double, double, double, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, float, float, float, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, double, double, double, int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::HammingUnexpanded, float, float, float, int);
instantiate_raft_distance_distance(
Expand Down Expand Up @@ -429,6 +441,10 @@ instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::CosineExpanded, float, float, float, int);
instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::CosineExpanded, double, double, double, int);
instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::DiceExpanded, float, float, float, int);
instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::DiceExpanded, double, double, double, int);
instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::HammingUnexpanded, float, float, float, int);
instantiate_raft_distance_getWorkspaceSize(
Expand Down Expand Up @@ -547,6 +563,22 @@ instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineE
double,
int,
raft::layout_f_contiguous);
instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_c_contiguous);
instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded,
double,
double,
double,
int,
raft::layout_c_contiguous);
instantiate_raft_distance_getWorkspaceSize(
raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_f_contiguous);
instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded,
double,
double,
double,
int,
raft::layout_f_contiguous);
instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded,
float,
float,
Expand Down Expand Up @@ -822,6 +854,22 @@ instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded,
double,
raft::layout_f_contiguous,
int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_c_contiguous, int);
instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded,
double,
double,
double,
raft::layout_c_contiguous,
int);
instantiate_raft_distance_distance(
raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_f_contiguous, int);
instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded,
double,
double,
double,
raft::layout_f_contiguous,
int);
instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded,
float,
float,
Expand Down
3 changes: 3 additions & 0 deletions cpp/include/raft/distance/distance-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,9 @@ void pairwise_distance(raft::resources const& handle,
case DistanceType::RusselRaoExpanded:
dispatch(std::integral_constant<DistanceType, DistanceType::RusselRaoExpanded>{});
break;
case DistanceType::DiceExpanded:
dispatch(std::integral_constant<DistanceType, DistanceType::DiceExpanded>{});
break;
default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric);
};
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright (c) 2021-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* NOTE: this file is generated by dispatch_00_generate.py
*
* Make changes there and run in this directory:
*
* > python dispatch_00_generate.py
*
*/

#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/distance/detail/distance_ops/all_ops.cuh> // ops::*
#include <raft/distance/detail/pairwise_matrix/dispatch-inl.cuh> // dispatch
#include <raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh>
#include <raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh>
#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \
OpT, DataT, AccT, OutT, FinOpT, IdxT) \
template void raft::distance::detail:: \
pairwise_matrix_dispatch<OpT<DataT, AccT, IdxT>, DataT, AccT, OutT, FinOpT, IdxT>( \
OpT<DataT, AccT, IdxT> distance_op, \
IdxT m, \
IdxT n, \
IdxT k, \
const DataT* x, \
const DataT* y, \
const DataT* x_norm, \
const DataT* y_norm, \
OutT* out, \
FinOpT fin_op, \
cudaStream_t stream, \
bool is_row_major)

instantiate_raft_distance_detail_pairwise_matrix_dispatch(
raft::distance::detail::ops::dice_distance_op, double, double, double, raft::identity_op, int);

#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch
Loading
Loading