Skip to content

Commit

Permalink
basic test for separable KRILL
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Jul 11, 2024
1 parent 9d01500 commit 0d1d2cf
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 24 deletions.
16 changes: 16 additions & 0 deletions test/drivers/test_krillx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using RandBLAS::DenseDist;
using RandBLAS::SparseDist;
using RandBLAS::RNGState;
using RandLAPACK::linops::RegExplicitSymLinOp;
using RandLAPACK::linops::SEKLO;
using RandLAPACK_Testing::polynomial_decay_psd;


Expand Down Expand Up @@ -173,3 +174,18 @@ TEST_F(TestKrillx, test_krill_separable_rpchol) {
run_krill_separable(1, G_linop, k);
}
}

TEST_F(TestKrillx, test_krill_separable_squared_exp_kernel) {
using T = double;
T mu_min = 1e-2;
vector<T> mus {mu_min, mu_min*10, mu_min*100};
for (uint32_t key = 0; key < 5; ++key) {
//auto G = polynomial_decay_psd(m, 1e12, (T) decay, key);
//RegExplicitSymLinOp G_linop(m, G.data(), m, mus);
vector<T> X0 = RandLAPACK_Testing::random_gaussian_mat<T>(5, m, key);
SEKLO G_linop(m, X0.data(), 5, 3.0, mus);
int64_t k = 128;
run_krill_separable(0, G_linop, k);
run_krill_separable(1, G_linop, k);
}
}
28 changes: 4 additions & 24 deletions test/misc/test_pdkernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <RandBLAS.hh>
#include "../RandLAPACK/RandBLAS/test/comparison.hh"
#include "../moremats.hh"

#include <math.h>
#include <gtest/gtest.h>
Expand All @@ -13,27 +14,6 @@ using RandBLAS::DenseDist;
using blas::Layout;
using std::vector;

template <typename T>
vector<T> random_gaussian_mat(int64_t m, int64_t n, uint32_t seed) {
RandBLAS::DenseDist D(m, n);
RNGState state(seed);
vector<T> mat(m*n);
RandBLAS::fill_dense(D, mat.data(), state);
return mat;
}

template <typename T, typename RNG>
RNGState<RNG> left_multiply_by_orthmat(int64_t m, int64_t n, std::vector<T> &A, RNGState<RNG> state) {
using std::vector;
vector<T> U(m * m, 0.0);
RandBLAS::DenseDist DU(m, m);
auto out_state = RandBLAS::fill_dense(DU, U.data(), state).second;
vector<T> tau(m, 0.0);
lapack::geqrf(m, m, U.data(), m, tau.data());
lapack::ormqr(blas::Side::Left, blas::Op::NoTrans, m, n, m, U.data(), m, tau.data(), A.data(), m);
return out_state;
}

class TestPDK_SquaredExponential : public ::testing::Test {
protected:

Expand All @@ -49,7 +29,7 @@ class TestPDK_SquaredExponential : public ::testing::Test {
void run_same_blockimpl_vs_entrywise(int64_t d, int64_t n, T bandwidth, uint32_t seed) {
vector<T> K_blockimpl(n*n, 0.0);
vector<T> K_entrywise(n*n, 0.0);
vector<T> X = random_gaussian_mat<T>(d, n, seed);
vector<T> X = RandLAPACK_Testing::random_gaussian_mat<T>(d, n, seed);
vector<T> squared_norms(n, 0.0);
T* X_ = X.data();
for (int64_t i = 0; i < n; ++i) {
Expand Down Expand Up @@ -79,7 +59,7 @@ class TestPDK_SquaredExponential : public ::testing::Test {
*/
template <typename T>
void run_all_same_column(int64_t d, int64_t n, uint32_t seed) {
vector<T> c = random_gaussian_mat<T>(d, 1, seed);
vector<T> c = RandLAPACK_Testing::random_gaussian_mat<T>(d, 1, seed);
vector<T> X(d*n, 0.0);
T* _X = X.data();
T* _c = c.data();
Expand Down Expand Up @@ -112,7 +92,7 @@ class TestPDK_SquaredExponential : public ::testing::Test {
for (int64_t i = 0; i < n; ++i)
X[i+i*n] = 1.0;
RNGState state(seed);
left_multiply_by_orthmat(n, n, X, state);
RandLAPACK_Testing::left_multiply_by_orthmat(n, n, X, state);
vector<T> squarednorms(n, 1.0);
vector<T> K(n*n, 0.0);
RandLAPACK::squared_exp_kernel_submatrix(
Expand Down
22 changes: 22 additions & 0 deletions test/moremats.hh
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,26 @@ vector<T> polynomial_decay_psd(int64_t m, T cond_num, T exponent, uint32_t seed)
return G;
}

template <typename T>
vector<T> random_gaussian_mat(int64_t m, int64_t n, uint32_t seed) {
RandBLAS::DenseDist D(m, n);
RNGState state(seed);
vector<T> mat(m*n);
RandBLAS::fill_dense(D, mat.data(), state);
return mat;
}

template <typename T, typename RNG>
RNGState<RNG> left_multiply_by_orthmat(int64_t m, int64_t n, std::vector<T> &A, RNGState<RNG> state) {
using std::vector;
vector<T> U(m * m, 0.0);
RandBLAS::DenseDist DU(m, m);
auto out_state = RandBLAS::fill_dense(DU, U.data(), state).second;
vector<T> tau(m, 0.0);
lapack::geqrf(m, m, U.data(), m, tau.data());
lapack::ormqr(blas::Side::Left, blas::Op::NoTrans, m, n, m, U.data(), m, tau.data(), A.data(), m);
return out_state;
}


}

0 comments on commit 0d1d2cf

Please sign in to comment.