Skip to content

Commit

Permalink
Exposing kernel gram APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
cjnolet committed Nov 15, 2024
1 parent 557c2aa commit 37e7140
Show file tree
Hide file tree
Showing 17 changed files with 913 additions and 368 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ if(BUILD_SHARED_LIBS)
src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int64_t.cu
src/distance/detail/fused_distance_nn.cu
src/distance/distance.cu
src/distance/kernel_gram.cu
src/distance/pairwise_distance.cu
src/neighbors/brute_force.cu
src/neighbors/cagra_build_float.cu
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
* Copyright (c) 2022-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,10 +16,295 @@

#pragma once

#include "gram_matrix.hpp"
#include <cublas.h>
#include <cuvs/distance/distance.hpp>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/mdspan.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>

namespace cuvs::distance::kernels::detail {
namespace cuvs::distance::kernels {

template <typename math_t>
using dense_input_matrix_view_t = raft::device_matrix_view<const math_t, int, raft::layout_stride>;
template <typename math_t>
using dense_output_matrix_view_t = raft::device_matrix_view<math_t, int, raft::layout_stride>;
template <typename math_t>
using csr_input_matrix_view_t = raft::device_csr_matrix_view<const math_t, int, int, int>;

/**
* Base class for general Gram matrices
* A Gram matrix is the Hermitian matrix of inner probucts G_ik = <x_i, x_k>
* Here, the inner product is evaluated for all elements from vectors sets X1,
* and X2.
*
* To be more precise, on exit the output buffer will store:
* - if is_row_major == true: out[j+k*n1] = <x1_j, x2_k>,
* - if is_row_major == false: out[j*n2 + k] = <x1_j, x2_k>,
* where x1_j is the j-th vector from the x1 set and x2_k is the k-th vector
* from the x2 set.
*/
template <typename math_t>
class GramMatrixBase {
protected:
cublasHandle_t cublas_handle;
bool legacy_interface;

public:
GramMatrixBase() : legacy_interface(false){};
[[deprecated]] GramMatrixBase(cublasHandle_t cublas_handle)
: cublas_handle(cublas_handle), legacy_interface(true){};

virtual ~GramMatrixBase(){};

/** Convenience function to evaluate the Gram matrix for two vector sets.
* Vector sets are provided in Matrix format
*
* @param [in] handle raft handle
* @param [in] x1 dense device matrix view, size [n1*n_cols]
* @param [in] x2 dense device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param norm_x1 optional L2-norm of x1's rows for computation within RBF.
* @param norm_x2 optional L2-norm of x2's rows for computation within RBF.
*/
void operator()(raft::resources const& handle,
dense_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
math_t* norm_x1 = nullptr,
math_t* norm_x2 = nullptr);

/** Convenience function to evaluate the Gram matrix for two vector sets.
* Vector sets are provided in Matrix format
*
* @param [in] handle raft handle
* @param [in] x1 csr device matrix view, size [n1*n_cols]
* @param [in] x2 dense device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param norm_x1 optional L2-norm of x1's rows for computation within RBF.
* @param norm_x2 optional L2-norm of x2's rows for computation within RBF.
*/
void operator()(raft::resources const& handle,
csr_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
math_t* norm_x1 = nullptr,
math_t* norm_x2 = nullptr);

/** Convenience function to evaluate the Gram matrix for two vector sets.
* Vector sets are provided in Matrix format
*
* @param [in] handle raft handle
* @param [in] x1 csr device matrix view, size [n1*n_cols]
* @param [in] x2 csr device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param norm_x1 optional L2-norm of x1's rows for computation within RBF.
* @param norm_x2 optional L2-norm of x2's rows for computation within RBF.
*/
void operator()(raft::resources const& handle,
csr_input_matrix_view_t<math_t> x1,
csr_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
math_t* norm_x1 = nullptr,
math_t* norm_x2 = nullptr);

// unfortunately, 'evaluate' cannot be templatized as it needs to be virtual

/** Evaluate the Gram matrix for two vector sets using simple dot product.
*
* @param [in] handle raft handle
* @param [in] x1 dense device matrix view, size [n1*n_cols]
* @param [in] x2 dense device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param norm_x1 unused.
* @param norm_x2 unused.
*/
virtual void evaluate(raft::resources const& handle,
dense_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
math_t* norm_x1,
math_t* norm_x2);

/** Evaluate the Gram matrix for two vector sets using simple dot product.
*
* @param [in] handle raft handle
* @param [in] x1 csr device matrix view, size [n1*n_cols]
* @param [in] x2 dense device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param norm_x1 unused.
* @param norm_x2 unused.
*/
virtual void evaluate(raft::resources const& handle,
csr_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
math_t* norm_x1,
math_t* norm_x2);

/** Evaluate the Gram matrix for two vector sets using simple dot product.
*
* @param [in] handle raft handle
* @param [in] x1 csr device matrix view, size [n1*n_cols]
* @param [in] x2 csr device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
* @param norm_x1 unused.
* @param norm_x2 unused.
*/
virtual void evaluate(raft::resources const& handle,
csr_input_matrix_view_t<math_t> x1,
csr_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out,
math_t* norm_x1,
math_t* norm_x2);

/** Evaluate the Gram matrix for two vector sets using simple dot product.
*
* @param [in] x1 device array of vectors, size [n1*n_cols]
* @param [in] n1 number vectors in x1
* @param [in] n_cols number of columns (features) in x1 and x2
* @param [in] x2 device array of vectors, size [n2*n_cols]
* @param [in] n2 number vectors in x2
* @param [out] out device buffer to store the Gram matrix, size [n1*n2]
* @param [in] is_row_major whether the input and output matrices are in row
* major format
* @param [in] stream cuda stream
* @param ld1 leading dimension of x1 (usually it is n1)
* @param ld2 leading dimension of x2 (usually it is n2)
* @param ld_out leading dimension of out (usually it is n1)
*/
[[deprecated]] virtual void evaluate(const math_t* x1,
int n1,
int n_cols,
const math_t* x2,
int n2,
math_t* out,
bool is_row_major,
cudaStream_t stream,
int ld1,
int ld2,
int ld_out);

/** Convenience function to evaluate the Gram matrix for two vector sets.
*
* @param [in] x1 device array of vectors, size [n1*n_cols]
* @param [in] n1 number vectors in x1
* @param [in] n_cols number of columns (features) in x1 and x2
* @param [in] x2 device array of vectors, size [n2*n_cols]
* @param [in] n2 number vectors in x2
* @param [out] out device buffer to store the Gram matrix, size [n1*n2]
* @param [in] is_row_major whether the input and output matrices are in row
* major format
* @param [in] stream cuda stream
* @param ld1 leading dimension of x1
* @param ld2 leading dimension of x2
* @param ld_out leading dimension of out
*/
[[deprecated]] void operator()(const math_t* x1,
int n1,
int n_cols,
const math_t* x2,
int n2,
math_t* out,
bool is_row_major,
cudaStream_t stream,
int ld1 = 0,
int ld2 = 0,
int ld_out = 0);

protected:
/** Calculates the Gram matrix using simple dot product between vector sets.
*
* out = x1 * x2
*
* Can be used as a building block for more complex kernel functions.
*
* @param [in] x1 device array of vectors, size [n1*n_cols]
* @param [in] n1 number vectors in x1
* @param [in] n_cols number of columns (features) in x1 and x2
* @param [in] x2 device array of vectors, size [n2*n_cols]
* @param [in] n2 number vectors in x2
* @param [out] out device buffer to store the Gram matrix, size [n1*n2]
* @param [in] is_row_major whether the input and output matrices are in row
* major format
* @param [in] stream cuda stream
* @param ld1 leading dimension of x1
* @param ld2 leading dimension of x2
* @param ld_out leading dimension of out
*/
[[deprecated]] void linear(const math_t* x1,
int n1,
int n_cols,
const math_t* x2,
int n2,
math_t* out,
bool is_row_major,
cudaStream_t stream,
int ld1,
int ld2,
int ld_out);

protected:
bool get_is_row_major(dense_output_matrix_view_t<math_t> matrix);
bool get_is_row_major(dense_input_matrix_view_t<math_t> matrix);
bool get_is_col_major(dense_output_matrix_view_t<math_t> matrix);
bool get_is_col_major(dense_input_matrix_view_t<math_t> matrix);

/** Calculates the Gram matrix using simple dot product between vector sets.
*
* out = x1 * x2
*
* Can be used as a building block for more complex kernel functions.
*
* @param [in] handle raft handle
* @param [in] x1 dense device matrix view, size [n1*n_cols]
* @param [in] x2 dense device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
*/
void linear(raft::resources const& handle,
dense_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out);

/** Calculates the Gram matrix using simple dot product between vector sets.
*
* out = x1 * x2
*
* Can be used as a building block for more complex kernel functions.
*
* @param [in] handle raft handle
* @param [in] x1 csr device matrix view, size [n1*n_cols]
* @param [in] x2 dense device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
*/
void linear(raft::resources const& handle,
csr_input_matrix_view_t<math_t> x1,
dense_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out);

/** Calculates the Gram matrix using simple dot product between vector sets.
*
* out = x1 * x2
*
* Can be used as a building block for more complex kernel functions.
*
* @param [in] handle raft handle
* @param [in] x1 csr device matrix view, size [n1*n_cols]
* @param [in] x2 csr device matrix view, size [n2*n_cols]
* @param [out] out dense device matrix view for the Gram matrix, size [n1*n2]
*/
void linear(raft::resources const& handle,
csr_input_matrix_view_t<math_t> x1,
csr_input_matrix_view_t<math_t> x2,
dense_output_matrix_view_t<math_t> out);
};

template <typename math_t>
class KernelFactory {
public:
static GramMatrixBase<math_t>* create(KernelParams params);
[[deprecated]] static GramMatrixBase<math_t>* create(KernelParams params, cublasHandle_t handle);
};

/**
* Create a kernel matrix using polynomial kernel function.
Expand Down Expand Up @@ -377,5 +662,4 @@ class RBFKernel : public GramMatrixBase<math_t> {
int ld2,
int ld_out);
};

}; // end namespace cuvs::distance::kernels::detail
}; // end namespace cuvs::distance::kernels
6 changes: 3 additions & 3 deletions cpp/src/distance/detail/kernels/gram_matrix.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/

#include "../../distance.cuh"
#include "gram_matrix.hpp"
#include <cuvs/distance/distance.hpp>
#include <cuvs/distance/grammian.hpp>
#include <raft/core/device_csr_matrix.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resources.hpp>
Expand All @@ -25,7 +25,7 @@
#include <raft/sparse/distance/distance.cuh>
#include <raft/sparse/linalg/spmm.hpp>

namespace cuvs::distance::kernels::detail {
namespace cuvs::distance::kernels {

/**
* Base class for general Gram matrices
Expand Down Expand Up @@ -475,4 +475,4 @@ void GramMatrixBase<math_t>::linear(raft::resources const& handle,
}
}

}; // end namespace cuvs::distance::kernels::detail
}; // namespace cuvs::distance::kernels
Loading

0 comments on commit 37e7140

Please sign in to comment.