Skip to content

Commit

Permalink
Fixed a bug in RSVD.
Browse files Browse the repository at this point in the history
  • Loading branch information
TeachRaccooon committed Jun 3, 2024
1 parent 6e3aa07 commit a634360
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 91 deletions.
28 changes: 14 additions & 14 deletions RandLAPACK/drivers/rl_rsvd.hh
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class RSVDalg {
T tol,
T* &U,
T* &S,
T* &VT,
T* &V,
RandBLAS::RNGState<RNG> &state
) = 0;
};
Expand Down Expand Up @@ -80,7 +80,7 @@ class RSVD : public RSVDalg<T, RNG> {
/// Initially, may not have any space allocated for it.
///
/// @param[in] VT
/// Buffer for the \transpose{V}-factor.
/// Buffer for the V-factor.
/// Initially, may not have any space allocated for it.
///
/// @param[out] U
Expand All @@ -89,8 +89,8 @@ class RSVD : public RSVDalg<T, RNG> {
/// @param[out] S
/// Stores k-by-k factor \Sigma.
///
/// @param[out] VT
/// Stores k-by-n factor \transpose{V}.
/// @param[out] V
/// Stores k-by-n factor V.
///
/// @returns 0 if successful

Expand All @@ -102,7 +102,7 @@ class RSVD : public RSVDalg<T, RNG> {
T tol,
T* &U,
T* &S,
T* &VT,
T* &V,
RandBLAS::RNGState<RNG> &state
) override;

Expand All @@ -122,28 +122,28 @@ int RSVD<T, RNG>::call(
T tol,
T* &U,
T* &S,
T* &VT,
T* &V,
RandBLAS::RNGState<RNG> &state
){
T* Q = nullptr;
T* B = nullptr;
T* BT = nullptr;
// Q and B sizes will be adjusted automatically
this->QB_Obj.call(m, n, A, k, this->block_sz, tol, Q, B, state);
this->QB_Obj.call(m, n, A, k, this->block_sz, tol, Q, BT, state);

T* U_buf = ( T * ) calloc(k * k, sizeof( T ) );
T* UT_buf = ( T * ) calloc(k * k, sizeof( T ) );
// Making sure all vectors are large enough
U = ( T * ) calloc(m * k, sizeof( T ) );
S = ( T * ) calloc(k, sizeof( T ) );
VT = ( T * ) calloc(n * k, sizeof( T ) );
V = ( T * ) calloc(n * k, sizeof( T ) );

// SVD of B
lapack::gesdd(Job::SomeVec, k, n, B, k, S, U_buf, k, VT, k);
lapack::gesdd(Job::SomeVec, n, k, BT, n, S, V, n, UT_buf, k);
// Adjusting U
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, k, k, 1.0, Q, m, U_buf, k, 0.0, U, m);
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, k, k, 1.0, Q, m, UT_buf, k, 0.0, U, m);

free(Q);
free(B);
free(U_buf);
free(BT);
free(UT_buf);
return 0;
}

Expand Down
32 changes: 16 additions & 16 deletions benchmark/bench_CQRRP/CQRRP_speed_comparisons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ static void call_all_algs(
printf("\nITERATION %d\n", i);
// Testing GEQRF
auto start_geqp3 = high_resolution_clock::now();
//lapack::geqp3(m, n, all_data.A.data(), m, all_data.J.data(), all_data.tau.data());
lapack::geqp3(m, n, all_data.A.data(), m, all_data.J.data(), all_data.tau.data());
auto stop_geqp3 = high_resolution_clock::now();
dur_geqp3 = duration_cast<microseconds>(stop_geqp3 - start_geqp3).count();
printf("TOTAL TIME FOR GEQP3 %ld\n", dur_geqp3);
Expand All @@ -114,7 +114,7 @@ static void call_all_algs(

// Testing CQRRP - best setup
auto start_cqrrp = high_resolution_clock::now();
CQRRP_blocked.call(m, n, all_data.A.data(), m, d_factor, all_data.tau.data(), all_data.J.data(), state_alg);
//CQRRP_blocked.call(m, n, all_data.A.data(), m, d_factor, all_data.tau.data(), all_data.J.data(), state_alg);
auto stop_cqrrp = high_resolution_clock::now();
dur_cqrrp = duration_cast<microseconds>(stop_cqrrp - start_cqrrp).count();
printf("TOTAL TIME FOR CQRRP %ld\n", dur_cqrrp);
Expand All @@ -123,12 +123,12 @@ static void call_all_algs(
state_gen = state;
state_alg = state;
// Clear and re-generate data
data_regen(m_info, all_data, state_gen, 0);
//data_regen(m_info, all_data, state_gen, 0);

// Testing CQRRP - using QP3
CQRRP_blocked.use_qp3 = true;
auto start_cqrrp_qp3 = high_resolution_clock::now();
CQRRP_blocked.call(m, n, all_data.A.data(), m, d_factor, all_data.tau.data(), all_data.J.data(), state_alg);
//CQRRP_blocked.call(m, n, all_data.A.data(), m, d_factor, all_data.tau.data(), all_data.J.data(), state_alg);
auto stop_cqrrp_qp3 = high_resolution_clock::now();
CQRRP_blocked.use_qp3 = false;
dur_cqrrp_qp3 = duration_cast<microseconds>(stop_cqrrp_qp3 - start_cqrrp_qp3).count();
Expand All @@ -138,11 +138,11 @@ static void call_all_algs(
state_gen = state;
state_alg = state;
// Clear and re-generate data
data_regen(m_info, all_data, state_gen, 1);
//data_regen(m_info, all_data, state_gen, 1);

// Testing HQRRP DEFAULT
auto start_hqrrp = high_resolution_clock::now();
RandLAPACK::hqrrp(m, n, all_data.A.data(), m, all_data.J.data(), all_data.tau.data(), b_sz, (d_factor - 1) * b_sz, panel_pivoting, 0, state_alg, (T*) nullptr);
//RandLAPACK::hqrrp(m, n, all_data.A.data(), m, all_data.J.data(), all_data.tau.data(), b_sz, (d_factor - 1) * b_sz, panel_pivoting, 0, state_alg, (T*) nullptr);
auto stop_hqrrp = high_resolution_clock::now();
dur_hqrrp = duration_cast<microseconds>(stop_hqrrp - start_hqrrp).count();
printf("TOTAL TIME FOR HQRRP %ld\n", dur_hqrrp);
Expand All @@ -151,11 +151,11 @@ static void call_all_algs(
state_gen = state;
state_alg = state;
// Clear and re-generate data
data_regen(m_info, all_data, state_gen, 1);
//data_regen(m_info, all_data, state_gen, 1);

// Testing HQRRP with GEQRF
auto start_hqrrp_geqrf = high_resolution_clock::now();
RandLAPACK::hqrrp(m, n, all_data.A.data(), m, all_data.J.data(), all_data.tau.data(), b_sz, (d_factor - 1) * b_sz, panel_pivoting, 1, state_alg, (T*) nullptr);
//RandLAPACK::hqrrp(m, n, all_data.A.data(), m, all_data.J.data(), all_data.tau.data(), b_sz, (d_factor - 1) * b_sz, panel_pivoting, 1, state_alg, (T*) nullptr);
auto stop_hqrrp_geqrf = high_resolution_clock::now();
dur_hqrrp_geqrf = duration_cast<microseconds>(stop_hqrrp_geqrf - start_hqrrp_geqrf).count();
printf("TOTAL TIME FOR HQRRP WITH GEQRF %ld\n", dur_hqrrp_geqrf);
Expand All @@ -164,11 +164,11 @@ static void call_all_algs(
state_gen = state;
state_alg = state;
// Clear and re-generate data
data_regen(m_info, all_data, state_gen, 1);
//data_regen(m_info, all_data, state_gen, 1);

// Testing HQRRP with CholQR
auto start_hqrrp_cholqr = high_resolution_clock::now();
RandLAPACK::hqrrp(m, n, all_data.A.data(), m, all_data.J.data(), all_data.tau.data(), b_sz, (d_factor - 1) * b_sz, panel_pivoting, 2, state_alg, (T*) nullptr);
//RandLAPACK::hqrrp(m, n, all_data.A.data(), m, all_data.J.data(), all_data.tau.data(), b_sz, (d_factor - 1) * b_sz, panel_pivoting, 2, state_alg, (T*) nullptr);
auto stop_hqrrp_cholqr = high_resolution_clock::now();
dur_hqrrp_cholqr = duration_cast<microseconds>(stop_hqrrp_cholqr - start_hqrrp_cholqr).count();
printf("TOTAL TIME FOR HQRRP WITH CHOLQRQ %ld\n", dur_hqrrp_cholqr);
Expand All @@ -177,7 +177,7 @@ static void call_all_algs(
state_gen = state;
state_alg = state;
// Clear and re-generate data
data_regen(m_info, all_data, state_gen, 0);
//data_regen(m_info, all_data, state_gen, 0);

std::ofstream file(output_filename, std::ios::app);
file << dur_cqrrp << ", " << dur_cqrrp_qp3 << ", " << dur_hqrrp << ", " << dur_hqrrp_geqrf << ", " << dur_hqrrp_geqrf << ", " << dur_geqrf << ", " << dur_geqp3 << ",\n";
Expand All @@ -186,18 +186,18 @@ static void call_all_algs(

int main() {
// Declare parameters
int64_t m = std::pow(2, 16);
int64_t n = std::pow(2, 16);
int64_t m = 10000;
int64_t n = 10000;
double d_factor = 1.25;
int64_t b_sz_start = 256;
int64_t b_sz_end = 2048;
int64_t b_sz_start = 64;
int64_t b_sz_end = 64;
double tol = std::pow(std::numeric_limits<double>::epsilon(), 0.85);
auto state = RandBLAS::RNGState();
auto state_constant = state;
// Timing results
std::vector<long> res;
// Number of algorithm runs. We only record best times.
int64_t numruns = 2;
int64_t numruns = 50;

// Allocate basic workspace
QR_speed_benchmark_data<double> all_data(m, n, tol, d_factor);
Expand Down
67 changes: 22 additions & 45 deletions test/comps/test_qb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ class TestQB : public ::testing::Test
int64_t rank;
std::vector<T> A;
std::vector<T> Q;
std::vector<T> B;
std::vector<T> B_cpy;
std::vector<T> BT;
std::vector<T> BT_cpy;
std::vector<T> A_hat;
std::vector<T> A_k;
std::vector<T> A_cpy;
Expand All @@ -37,7 +37,7 @@ class TestQB : public ::testing::Test

QBTestData(int64_t m, int64_t n, int64_t k) :
A(m * n, 0.0),
B_cpy(k * n, 0.0),
BT_cpy(k * n, 0.0),
A_hat(m * n, 0.0),
A_k(m * n, 0.0),
A_cpy(m * n, 0.0),
Expand Down Expand Up @@ -119,16 +119,16 @@ class TestQB : public ::testing::Test
T* S_dat = all_data.S.data();
T* VT_dat = all_data.VT.data();

T* Q = nullptr;
T* B = nullptr;
T* Q = nullptr;
T* BT = nullptr;

// Regular QB2 call
all_algs.QB.call(m, n, all_data.A.data(), k, block_sz, tol, Q, B, state);
all_algs.QB.call(m, n, all_data.A.data(), k, block_sz, tol, Q, BT, state);

// Reassing pointers because Q, B have been resized
T* Q_dat = Q;
T* B_dat = B;
T* B_cpy_dat = all_data.B_cpy.data();
T* BT_dat = BT;
T* BT_cpy_dat = all_data.BT_cpy.data();

printf("Inner dimension of QB: %-25ld\n", k);

Expand All @@ -137,14 +137,14 @@ class TestQB : public ::testing::Test
// Generate a reference identity
RandLAPACK::util::eye(k, k, Ident);
// Buffer for testing B
blas::copy(k * n, B_dat, 1, B_cpy_dat, 1);
blas::copy(k * n, BT_dat, 1, BT_cpy_dat, 1);

// A_hat = Q * B
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, k, 1.0, Q_dat, m, B_dat, n, 0.0, A_hat_dat, m);
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, k, 1.0, Q_dat, m, BT_dat, n, 0.0, A_hat_dat, m);
// TEST 1: A = A - Q * B = 0
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, k, -1.0, Q_dat, m, B_dat, n, 1.0, A_dat, m);
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, k, -1.0, Q_dat, m, BT_dat, n, 1.0, A_dat, m);
// TEST 2: B - Q'A = 0
//blas::gemm(Layout::ColMajor, Op::Trans, Op::Trans, k, n, m, -1.0, Q_dat, m, A_cpy_2_dat, m, 1.0, B_cpy_dat, n);
//blas::gemm(Layout::ColMajor, Op::Trans, Op::Trans, k, n, m, -1.0, Q_dat, m, A_cpy_2_dat, m, 1.0, BT_cpy_dat, n);
// TEST 3: Q'Q = I
blas::syrk(Layout::ColMajor, Uplo::Upper, Op::Trans, k, m, 1.0, Q_dat, m, -1.0, Ident_dat, k);

Expand All @@ -163,7 +163,7 @@ class TestQB : public ::testing::Test
printf("FRO NORM OF A - QB: %e\n", norm_test_1);
ASSERT_NEAR(norm_test_1, 0, test_tol);
// Test 2 Output
//T norm_test_2 = lapack::lange(Norm::Fro, n, k, B_cpy_dat, n);
//T norm_test_2 = lapack::lange(Norm::Fro, n, k, BT_cpy_dat, n);
//printf("FRO NORM OF B - Q'A: %e\n", norm_test_2);
//ASSERT_NEAR(norm_test_2, 0, test_tol);
// Test 3 Output
Expand All @@ -175,7 +175,7 @@ class TestQB : public ::testing::Test
printf("FRO NORM OF A_k - QB: %e\n", norm_test_4);
ASSERT_NEAR(norm_test_4, 0, test_tol);
free(Q);
free(B);
free(BT);
}

/// k = min(m, n) test for CholQRCP:
Expand All @@ -197,25 +197,25 @@ class TestQB : public ::testing::Test

T* A_dat = all_data.A.data();
T* Q_dat = all_data.Q.data();
T* B_dat = all_data.B.data();
T* BT_dat = all_data.BT.data();
T* A_hat_dat = all_data.A_hat.data();

T* Q = nullptr;
T* B = nullptr;
T* BT = nullptr;

// Regular QB2 call
all_algs.QB.call(m, n, all_data.A.data(), k_est, block_sz, tol, Q, B, state);
all_algs.QB.call(m, n, all_data.A.data(), k_est, block_sz, tol, Q, BT, state);

// Reassing pointers because Q, B have been resized
Q_dat = Q;
B_dat = B;
Q_dat = Q;
BT_dat = BT;

printf("Inner dimension of QB: %ld\n", k_est);

// A_hat = Q * B
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, k_est, 1.0, Q_dat, m, B_dat, n, 0.0, A_hat_dat, m);
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, k_est, 1.0, Q_dat, m, BT_dat, n, 0.0, A_hat_dat, m);
// TEST 1: A = A - Q * B = 0
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, k_est, -1.0, Q_dat, m, B_dat, n, 1.0, A_dat, m);
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, k_est, -1.0, Q_dat, m, BT_dat, n, 1.0, A_dat, m);

T norm_test_1 = lapack::lange(Norm::Fro, m, n, A_dat, m);
T test_tol = std::pow(std::numeric_limits<T>::epsilon(), 0.75);
Expand All @@ -231,7 +231,7 @@ class TestQB : public ::testing::Test
EXPECT_TRUE(norm_test_1 <= (tol * norm_A));
}
free(Q);
free(B);
free(BT);
}
};

Expand Down Expand Up @@ -363,26 +363,3 @@ TEST_F(TestQB, Polynomial_Decay_zero_tol2)
delete all_data;
delete all_algs;
}


TEST_F(TestQB, random_test)
{
/*
int64_t rows_1 = 2;
int64_t cols_1 = 3;
int64_t rows_2 = 2;
int64_t cols_2 = 2;
std::vector<double> A = { 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5};
std::vector<double> B (0, 3 * 2);
double* A_dat = A.data();
double* B_dat = B.data();
blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, cols_1, rows_1, rows_2, 1.0, A_dat, rows_1, &A_dat[rows_1 * cols_1], rows_2, 0.0, B_dat, cols_1);
char name[] = "A";
char name1[] = "B";
RandBLAS::util::print_colmaj(rows_1, cols_1 + cols_2, A_dat, name);
RandBLAS::util::print_colmaj(cols_1, rows_1, B_dat, name1);
*/

}
Loading

0 comments on commit a634360

Please sign in to comment.