Skip to content

Commit

Permalink
reduce code duplication in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Jul 11, 2024
1 parent 99f97b9 commit 9d01500
Showing 1 changed file with 35 additions and 76 deletions.
111 changes: 35 additions & 76 deletions test/misc/test_pdkernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,63 +180,15 @@ class TestPDK_SEKLO : public ::testing::Test {
virtual void TearDown() {};

template <typename T>
void run_no_reg(int64_t m, int64_t d, T bandwidth, uint32_t seed) {
RNGState state_x(seed);
DenseDist D(d, m);
vector<T> X_vec(d*m);
T* X = X_vec.data();
RandBLAS::fill_dense(D, X, state_x);
vector<T> regs{};
RandLAPACK::linops::SEKLO K(m, X, d, bandwidth, regs);

vector<T> eye(m * m, 0.0);
vector<T> sq_colnorms(m, 0.0);
for (int64_t i = 0; i < m; ++i) {
eye[i + m*i] = 1.0;
sq_colnorms[i] = std::pow(blas::nrm2(d, X + i*d, 1), 2);
}
vector<T> K_out_expect(m * m, 0.0);

// (alpha, beta) = (0.25, 0.0)
T alpha = 0.25;
RandLAPACK::squared_exp_kernel_submatrix(
d, m, X, sq_colnorms.data(), m, m, K_out_expect.data(), 0, 0, bandwidth
);
blas::scal(m * m, alpha, K_out_expect.data(), 1);
vector<T> K_out_actual1(m * m, 1.0);
K(blas::Layout::ColMajor, m, alpha, eye.data(), m, 0.0, K_out_actual1.data(), m);

T atol = d * std::numeric_limits<T>::epsilon() * (1.0 + std::pow(bandwidth, -2));
test::comparison::matrices_approx_equal(
blas::Layout::ColMajor, blas::Op::NoTrans, m, m, K_out_actual1.data(), m,
K_out_expect.data(), m, __PRETTY_FUNCTION__, __FILE__, __LINE__, atol, atol
);

// Expected output when (alpha, beta) = (0.25, 0.3)
T beta = 0.3;
for (int i = 0; i < m*m; ++i)
K_out_expect[i] += beta;
vector<T> K_out_actual2(m * m, 1.0);
K(blas::Layout::ColMajor, m, alpha, eye.data(), m, beta, K_out_actual2.data(), m);

test::comparison::matrices_approx_equal(
blas::Layout::ColMajor, blas::Op::NoTrans, m, m, K_out_actual2.data(), m,
K_out_expect.data(), m, __PRETTY_FUNCTION__, __FILE__, __LINE__, atol, atol
);
return;
}

template <typename T>
void run_with_reg(T reg, int64_t m, int64_t d, uint32_t seed) {
T bandwidth = 1.1;
void run(T bandwidth, T reg, int64_t m, int64_t d, uint32_t seed, bool use_reg = true) {
RNGState state_x(seed);
DenseDist D(d, m);
vector<T> X_vec(d*m);
T* X = X_vec.data();
RandBLAS::fill_dense(D, X, state_x);
vector<T> regs(1,reg);
RandLAPACK::linops::SEKLO K(m, X, d, bandwidth, regs);
K.set_eval_includes_reg(true);
K.set_eval_includes_reg(use_reg);

vector<T> eye(m * m, 0.0);
vector<T> sq_colnorms(m, 0.0);
Expand All @@ -252,8 +204,9 @@ class TestPDK_SEKLO : public ::testing::Test {
d, m, X, sq_colnorms.data(), m, m, K_out_expect.data(), 0, 0, bandwidth
);
blas::scal(m * m, alpha, K_out_expect.data(), 1);
for (int i = 0; i < m; ++i) {
K_out_expect[i + i*m] += alpha * reg;
if (use_reg) {
for (int i = 0; i < m; ++i)
K_out_expect[i + i*m] += alpha * reg;
}
vector<T> K_out_actual1(m * m, 1.0);
K(blas::Layout::ColMajor, m, alpha, eye.data(), m, 0.0, K_out_actual1.data(), m);
Expand All @@ -280,50 +233,56 @@ class TestPDK_SEKLO : public ::testing::Test {

};

TEST_F(TestPDK_SEKLO, no_reg_apply_to_eye_m100_d3) {
TEST_F(TestPDK_SEKLO, apply_to_eye_m100_d3) {
double mu = 0.123;
for (uint32_t i = 77; i < 80; ++i) {
run_no_reg(100, 3, 1.0, i);
run_no_reg(100, 3, 2.0, i);
run_no_reg(100, 3, 2.345678, i);
run(1.0, mu, 100, 3, i, false);
run(2.0, mu, 100, 3, i, false);
run(2.345678, mu, 100, 3, i, false);
}
}

TEST_F(TestPDK_SEKLO, no_reg_apply_to_eye_m256_d4) {
TEST_F(TestPDK_SEKLO, apply_to_eye_m256_d4) {
double mu = 0.123;
for (uint32_t i = 77; i < 80; ++i) {
run_no_reg(256, 4, 1.0, i);
run_no_reg(256, 4, 2.0, i);
run_no_reg(256, 4, 2.345678, i);
run(1.0, mu, 256, 4, i, false);
run(2.0, mu, 256, 4, i, false);
run(2.345678, mu, 256, 4, i, false);
}
}

TEST_F(TestPDK_SEKLO, no_reg_apply_to_eye_m999_d7) {
TEST_F(TestPDK_SEKLO, apply_to_eye_m999_d7) {
double mu = 0.123;
for (uint32_t i = 77; i < 80; ++i) {
run_no_reg(999, 7, 1.0, i);
run_no_reg(999, 7, 2.0, i);
run_no_reg(999, 7, 2.345678, i);
run(1.0, mu, 999, 7, i, false);
run(2.0, mu, 999, 7, i, false);
run(2.345678, mu, 999, 7, i, false);
}
}

TEST_F(TestPDK_SEKLO, yes_reg_apply_to_eye_m100_d3) {
TEST_F(TestPDK_SEKLO, reg_apply_to_eye_m100_d3) {
double bandwidth = 1.1;
for (uint32_t i = 77; i < 80; ++i) {
run_with_reg(0.1, 100, 3, i);
run_with_reg(1.0, 100, 3, i);
run_with_reg(7.654321, 100, 3, i);
run(bandwidth, 0.1, 100, 3, i);
run(bandwidth, 1.0, 100, 3, i);
run(bandwidth, 7.654321, 100, 3, i);
}
}

TEST_F(TestPDK_SEKLO, yes_reg_apply_to_eye_m256_d4) {
TEST_F(TestPDK_SEKLO, reg_apply_to_eye_m256_d4) {
double bandwidth = 1.1;
for (uint32_t i = 77; i < 80; ++i) {
run_with_reg(0.1, 256, 4, i);
run_with_reg(1.0, 256, 4, i);
run_with_reg(7.654321, 256, 4, i);
run(bandwidth, 0.1, 256, 4, i);
run(bandwidth, 1.0, 256, 4, i);
run(bandwidth, 7.654321, 256, 4, i);
}
}

TEST_F(TestPDK_SEKLO, yes_reg_apply_to_eye_m257_d5) {
TEST_F(TestPDK_SEKLO, reg_apply_to_eye_m257_d5) {
double bandwidth = 1.1;
for (uint32_t i = 77; i < 80; ++i) {
run_with_reg(0.1, 257, 5, i);
run_with_reg(1.0, 257, 5, i);
run_with_reg(7.654321, 257, 5, i);
run(bandwidth, 0.1, 257, 5, i);
run(bandwidth, 1.0, 257, 5, i);
run(bandwidth, 7.654321, 257, 5, i);
}
}

0 comments on commit 9d01500

Please sign in to comment.