Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Support for Cosine distance in IVF-Flat #179

Merged
merged 22 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 97 additions & 18 deletions cpp/src/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <raft/linalg/add.cuh>
#include <raft/linalg/gemm.cuh>
#include <raft/linalg/map.cuh>
#include <raft/linalg/matrix_vector.cuh>
#include <raft/linalg/matrix_vector_op.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/normalize.cuh>
Expand Down Expand Up @@ -141,6 +142,53 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
raft::compose_op<raft::cast_op<LabelT>, raft::key_op>());
break;
}
case cuvs::distance::DistanceType::CosineExpanded: {
lowener marked this conversation as resolved.
Show resolved Hide resolved
lowener marked this conversation as resolved.
Show resolved Hide resolved
auto workspace = raft::make_device_mdarray<char, IdxT>(
handle, mr, raft::make_extents<IdxT>((sizeof(int)) * n_rows));

auto minClusterAndDistance = raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
handle, mr, raft::make_extents<IdxT>(n_rows));
raft::KeyValuePair<IdxT, MathT> initial_value(0, std::numeric_limits<MathT>::max());
thrust::fill(raft::resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
initial_value);

auto centroidsNorm =
raft::make_device_mdarray<MathT, IdxT>(handle, mr, raft::make_extents<IdxT>(n_clusters));
raft::linalg::rowNorm<MathT, IdxT>(centroidsNorm.data_handle(),
centers,
dim,
n_clusters,
raft::linalg::L2Norm,
true,
stream,
raft::sqrt_op{});

cuvs::distance::fusedDistanceNNMinReduce<MathT, raft::KeyValuePair<IdxT, MathT>, IdxT>(
minClusterAndDistance.data_handle(),
dataset,
centers,
dataset_norm,
centroidsNorm.data_handle(),
n_rows,
n_clusters,
dim,
(void*)workspace.data_handle(),
false,
false,
true,
params.metric,
0.0f,
stream);
// Copy keys to output labels
thrust::transform(raft::resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + n_rows,
labels,
raft::compose_op<raft::cast_op<LabelT>, raft::key_op>());
break;
}
case cuvs::distance::DistanceType::InnerProduct: {
// TODO: pass buffer
rmm::device_uvector<MathT> distances(n_rows * n_clusters, stream, mr);
Expand Down Expand Up @@ -320,13 +368,14 @@ void calc_centers_and_sizes(const raft::resources& handle,
}

/** Computes the L2 norm of the dataset, converting to MathT if necessary */
template <typename T, typename MathT, typename IdxT, typename MappingOpT>
template <typename T, typename MathT, typename IdxT, typename MappingOpT, typename FinOpT>
void compute_norm(const raft::resources& handle,
MathT* dataset_norm,
const T* dataset,
IdxT dim,
IdxT n_rows,
MappingOpT mapping_op,
FinOpT norm_fin_op,
std::optional<rmm::device_async_resource_ref> mr = std::nullopt)
{
raft::common::nvtx::range<raft::common::nvtx::domain::raft> fun_scope("compute_norm");
Expand All @@ -347,7 +396,7 @@ void compute_norm(const raft::resources& handle,
}

raft::linalg::rowNorm<MathT, IdxT>(
dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream);
dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream, norm_fin_op);
}

/**
Expand Down Expand Up @@ -394,7 +443,8 @@ void predict(const raft::resources& handle,
std::is_same_v<T, MathT> ? 0 : max_minibatch_size * dim, stream, mem_res);
bool need_compute_norm =
dataset_norm == nullptr && (params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded);
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded);
rmm::device_uvector<MathT> cur_dataset_norm(
need_compute_norm ? max_minibatch_size : 0, stream, mem_res);
const MathT* dataset_norm_ptr = nullptr;
Expand All @@ -411,8 +461,24 @@ void predict(const raft::resources& handle,

// Compute the norm now if it hasn't been pre-computed.
if (need_compute_norm) {
compute_norm(
handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mem_res);
if (params.metric == cuvs::distance::DistanceType::CosineExpanded)
compute_norm(handle,
cur_dataset_norm.data(),
cur_dataset_ptr,
dim,
minibatch_size,
mapping_op,
raft::sqrt_op{},
mr);
else
compute_norm(handle,
cur_dataset_norm.data(),
cur_dataset_ptr,
dim,
minibatch_size,
mapping_op,
raft::identity_op{},
mr);
dataset_norm_ptr = cur_dataset_norm.data();
} else if (dataset_norm != nullptr) {
dataset_norm_ptr = dataset_norm + offset;
Expand Down Expand Up @@ -904,7 +970,8 @@ auto build_fine_clusters(const raft::resources& handle,
cub::TransformInputIterator<MathT, MappingOpT, const T*> mapping_itr(dataset_mptr, mapping_op);
raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream);
if (params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded) {
thrust::gather(raft::resource::get_thrust_policy(handle),
mc_trainset_ids,
mc_trainset_ids + k,
Expand Down Expand Up @@ -963,7 +1030,8 @@ void build_hierarchical(const raft::resources& handle,
IdxT n_rows,
MathT* cluster_centers,
IdxT n_clusters,
MappingOpT mapping_op)
MappingOpT mapping_op,
const MathT* dataset_norm = nullptr)
{
auto stream = raft::resource::get_cuda_stream(handle);
using LabelT = uint32_t;
Expand All @@ -980,21 +1048,32 @@ void build_hierarchical(const raft::resources& handle,
auto [max_minibatch_size, mem_per_row] =
calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);

// Precompute the L2 norm of the dataset if relevant.
const MathT* dataset_norm = nullptr;
// Precompute the L2 norm of the dataset if relevant and not yet computed.
rmm::device_uvector<MathT> dataset_norm_buf(0, stream, device_memory);
if (params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded) {
if (dataset_norm == nullptr && (params.metric == cuvs::distance::DistanceType::L2Expanded ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded)) {
dataset_norm_buf.resize(n_rows, stream);
for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) {
IdxT minibatch_size = std::min<IdxT>(max_minibatch_size, n_rows - offset);
compute_norm(handle,
dataset_norm_buf.data() + offset,
dataset + dim * offset,
dim,
minibatch_size,
mapping_op,
device_memory);
if (params.metric == cuvs::distance::DistanceType::CosineExpanded)
compute_norm(handle,
dataset_norm_buf.data() + offset,
dataset + dim * offset,
dim,
minibatch_size,
mapping_op,
raft::sqrt_op{},
device_memory);
else
compute_norm(handle,
dataset_norm_buf.data() + offset,
dataset + dim * offset,
dim,
minibatch_size,
mapping_op,
raft::identity_op{},
device_memory);
}
dataset_norm = (const MathT*)dataset_norm_buf.data();
}
Expand Down
47 changes: 28 additions & 19 deletions cpp/src/cluster/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,15 @@ namespace cuvs::cluster::kmeans_balanced {
* @param[out] centroids The generated centroids [dim = n_clusters x n_features]
* @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic
* datatype. If DataT == MathT, this must be the identity.
* @param[in] X_norm (optional) Dataset's row norms [dim = n_samples]
*/
template <typename DataT, typename MathT, typename IndexT, typename MappingOpT = raft::identity_op>
void fit(const raft::resources& handle,
cuvs::cluster::kmeans::balanced_params const& params,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<MathT, IndexT> centroids,
MappingOpT mapping_op = raft::identity_op())
MappingOpT mapping_op = raft::identity_op(),
std::optional<raft::device_vector_view<const MathT, IndexT>> X_norm = std::nullopt)
{
RAFT_EXPECTS(X.extent(1) == centroids.extent(1),
"Number of features in dataset and centroids are different");
Expand All @@ -88,14 +90,16 @@ void fit(const raft::resources& handle,
"The number of centroids must be strictly positive and cannot exceed the number of "
"points in the training dataset.");

cuvs::cluster::kmeans::detail::build_hierarchical(handle,
params,
X.extent(1),
X.data_handle(),
X.extent(0),
centroids.data_handle(),
centroids.extent(0),
mapping_op);
cuvs::cluster::kmeans::detail::build_hierarchical(
handle,
params,
X.extent(1),
X.data_handle(),
X.extent(0),
centroids.data_handle(),
centroids.extent(0),
mapping_op,
X_norm.has_value() ? X_norm.value().data_handle() : nullptr);
}

/**
Expand Down Expand Up @@ -125,6 +129,7 @@ void fit(const raft::resources& handle,
* @param[out] labels The output labels [dim = n_samples]
* @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic
* datatype. If DataT == MathT, this must be the identity.
* @param[in] X_norm (optional) Dataset's row norms [dim = n_samples]
*/
template <typename DataT,
typename MathT,
Expand All @@ -136,7 +141,8 @@ void predict(const raft::resources& handle,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<const MathT, IndexT> centroids,
raft::device_vector_view<LabelT, IndexT> labels,
MappingOpT mapping_op = raft::identity_op())
MappingOpT mapping_op = raft::identity_op(),
std::optional<raft::device_vector_view<const MathT, IndexT>> X_norm = std::nullopt)
{
RAFT_EXPECTS(X.extent(0) == labels.extent(0),
"Number of rows in dataset and labels are different");
Expand All @@ -149,15 +155,18 @@ void predict(const raft::resources& handle,
static_cast<uint64_t>(std::numeric_limits<LabelT>::max()),
"The chosen label type cannot represent all cluster labels");

cuvs::cluster::kmeans::detail::predict(handle,
params,
centroids.data_handle(),
centroids.extent(0),
X.extent(1),
X.data_handle(),
X.extent(0),
labels.data_handle(),
mapping_op);
cuvs::cluster::kmeans::detail::predict(
handle,
params,
centroids.data_handle(),
centroids.extent(0),
X.extent(1),
X.data_handle(),
X.extent(0),
labels.data_handle(),
mapping_op,
raft::resource::get_workspace_resource(handle),
X_norm.has_value() ? X_norm.value().data_handle() : nullptr);
}

/**
Expand Down
1 change: 1 addition & 0 deletions cpp/src/neighbors/ivf_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ void postprocess_distances(ScoreOutT* out, // [n_queries, topk]
raft::linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream);
}
} break;
case distance::DistanceType::CosineExpanded:
case distance::DistanceType::InnerProduct: {
float factor = (account_for_max_close ? -1.0 : 1.0) * scaling_factor * scaling_factor;
if (factor != 1.0) {
Expand Down
42 changes: 32 additions & 10 deletions cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -312,23 +312,45 @@ void extend(raft::resources const& handle,
if (!index->center_norms().has_value()) {
index->allocate_center_norms(handle);
if (index->center_norms().has_value()) {
if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) {
raft::linalg::rowNorm(index->center_norms()->data_handle(),
index->centers().data_handle(),
dim,
n_lists,
raft::linalg::L2Norm,
true,
stream,
raft::sqrt_op{});
} else {
raft::linalg::rowNorm(index->center_norms()->data_handle(),
index->centers().data_handle(),
dim,
n_lists,
raft::linalg::L2Norm,
true,
stream);
}
RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min<uint32_t>(dim, 20));
}
} else if (index->center_norms().has_value() && index->adaptive_centers()) {
if (index->metric() == cuvs::distance::DistanceType::CosineExpanded) {
raft::linalg::rowNorm(index->center_norms()->data_handle(),
index->centers().data_handle(),
dim,
n_lists,
raft::linalg::L2Norm,
true,
stream,
raft::sqrt_op{});
} else {
raft::linalg::rowNorm(index->center_norms()->data_handle(),
index->centers().data_handle(),
dim,
n_lists,
raft::linalg::L2Norm,
true,
stream);
RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min<uint32_t>(dim, 20));
}
} else if (index->center_norms().has_value() && index->adaptive_centers()) {
raft::linalg::rowNorm(index->center_norms()->data_handle(),
index->centers().data_handle(),
dim,
n_lists,
raft::linalg::L2Norm,
true,
stream);
RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min<uint32_t>(dim, 20));
}
}
Expand Down Expand Up @@ -391,7 +413,7 @@ inline auto build(raft::resources const& handle,
index.centers().data_handle(), index.n_lists(), index.dim());
cuvs::cluster::kmeans::balanced_params kmeans_params;
kmeans_params.n_iters = params.kmeans_n_iters;
kmeans_params.metric = static_cast<cuvs::distance::DistanceType>(index.metric());
kmeans_params.metric = index.metric();
cuvs::cluster::kmeans_balanced::fit(
handle, kmeans_params, trainset_const_view, centers_view, utils::mapping<float>{});
}
Expand Down
Loading
Loading