Skip to content

Commit

Permalink
Finally found the bug. Need cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
TeachRaccooon committed Apr 24, 2024
1 parent 590d5f0 commit 7debf7e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 20 deletions.
32 changes: 22 additions & 10 deletions RandLAPACK/comps/rl_qb.hh
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,15 @@ int QB<T, RNG>::call(

T* QtQi_dat = util::upsize(this->curr_lim * block_sz, this->QtQi);
T* Q_i_dat = util::upsize(m * block_sz, this->Q_i);
T* B_i_dat = util::upsize(n * block_sz, this->B_i);
T* B_i_trans_dat = util::upsize(block_sz * n, this->B_i);

T* Q_dat = Q.data();
T* B_dat = B.data();


char name [] = "B_i";
char name1 [] = "B";

while(k > curr_sz) {
// Dynamically changing block size
block_sz = std::min(block_sz, k - curr_sz);
Expand All @@ -229,7 +233,7 @@ int QB<T, RNG>::call(
if(this->orth_check) {
if (util::orthogonality_check(m, block_sz, block_sz, Q_i.data(), this->verbosity)) {
// Lost orthonormality of Q
util::row_resize(this->curr_lim, n, B, curr_sz);
//util::row_resize(this->curr_lim, n, B, curr_sz);
k = curr_sz;
return 4;
}
Expand All @@ -245,12 +249,13 @@ int QB<T, RNG>::call(

//B_i = Q_i' * A
//blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, block_sz, n, m, 1.0, Q_i_dat, m, A_cpy_dat, m, 0.0, B_i_dat, block_sz);
// B_i' = A' * Q_i
//B_i' = A' * Q_i'
blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, n, block_sz, m, 1.0, A_cpy_dat, m, Q_i_dat, m, 0.0, B_i_trans_dat, n);
//RandBLAS::util::print_colmaj(n, block_sz, B_i_trans_dat, name);


// Updating B norm estimation
T norm_B_i = lapack::lange(Norm::Fro, n, block_sz, B_i_dat, n);
T norm_B_i = lapack::lange(Norm::Fro, n, block_sz, B_i_trans_dat, n);
norm_B = std::hypot(norm_B, norm_B_i);
// Updating approximation error
prev_err = approx_err;
Expand All @@ -260,19 +265,26 @@ int QB<T, RNG>::call(
if ((curr_sz > 0) && (approx_err > prev_err)) {
// Early termination - error growth
// Only need to move B's data, no resizing
util::row_resize(this->curr_lim, n, B, curr_sz);
//util::row_resize(this->curr_lim, n, B, curr_sz);
k = curr_sz;
return 2;
}

// Update the matrices Q and B
lapack::lacpy(MatrixType::General, m, block_sz, &Q_i_dat[0], m, &Q_dat[m * curr_sz], m);
lapack::lacpy(MatrixType::General, n, block_sz, &B_i_trans_dat[0], n, &B_dat[n * curr_sz], n);
lapack::lacpy(MatrixType::General, m, block_sz, Q_i_dat, m, &Q_dat[m * curr_sz], m);
lapack::lacpy(MatrixType::General, n, block_sz, B_i_trans_dat, n, &B_dat[n * curr_sz], n);

printf("curr_sz %d\n", curr_sz);
printf("curr_lim %d\n", this->curr_lim);
printf("k %d\n", k);
printf("b_sz %d\n", block_sz);
RandBLAS::util::print_colmaj(n, block_sz, B_i_trans_dat, name);
RandBLAS::util::print_colmaj(n, this->curr_lim, B_dat, name1);

if(this->orth_check) {
if (util::orthogonality_check(m, this->curr_lim, next_sz, Q.data(), this->verbosity)) {
// Lost orthonormality of Q
util::row_resize(this->curr_lim, n, B, curr_sz);
//util::row_resize(this->curr_lim, n, B, curr_sz);
k = curr_sz;
return 5;
}
Expand All @@ -282,14 +294,14 @@ int QB<T, RNG>::call(
// Termination criteria
if (approx_err < tol) {
// Reached the required error tol
util::row_resize(this->curr_lim, n, B, curr_sz);
//util::row_resize(this->curr_lim, n, B, curr_sz);
k = curr_sz;
return 0;
}

// This step is only necessary for the next iteration
// A = A - Q_i * B_i
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, block_sz, -1.0, Q_i_dat, m, B_i_dat, n, 1.0, A_cpy_dat, m);
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, block_sz, -1.0, Q_i_dat, m, B_i_trans_dat, n, 1.0, A_cpy_dat, m);
}
// Reached expected rank without achieving the tolerance
return 3;
Expand Down
51 changes: 41 additions & 10 deletions test/comps/test_qb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,15 @@ class TestQB : public ::testing::Test
T* VT_dat = all_data.VT.data();

// Regular QB2 call
all_algs.QB.call(m, n, all_data.A, k, block_sz, tol, all_data.Q, all_data.B, state);
int err = all_algs.QB.call(m, n, all_data.A, k, block_sz, tol, all_data.Q, all_data.B, state);
printf("Erro num %d\n", err);

char name1 [] = "Q";
char name2 [] = "B";
char name3 [] = "A";
RandBLAS::util::print_colmaj(m, n, all_data.A.data(), name3);
RandBLAS::util::print_colmaj(m, k, all_data.Q.data(), name1);
RandBLAS::util::print_colmaj(n, k, all_data.B.data(), name2);

// Reassing pointers because Q, B have been resized
T* Q_dat = all_data.Q.data();
Expand All @@ -144,7 +152,7 @@ class TestQB : public ::testing::Test
// TEST 1: A = A - Q * B = 0
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, k, -1.0, Q_dat, m, B_dat, n, 1.0, A_dat, m);
// TEST 2: B - Q'A = 0
blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, k, n, m, -1.0, Q_dat, m, A_cpy_2_dat, m, 1.0, B_cpy_dat, k);
//blas::gemm(Layout::ColMajor, Op::Trans, Op::Trans, k, n, m, -1.0, Q_dat, m, A_cpy_2_dat, m, 1.0, B_cpy_dat, n);
// TEST 3: Q'Q = I
blas::syrk(Layout::ColMajor, Uplo::Upper, Op::Trans, k, m, 1.0, Q_dat, m, -1.0, Ident_dat, k);

Expand All @@ -163,9 +171,9 @@ class TestQB : public ::testing::Test
printf("FRO NORM OF A - QB: %e\n", norm_test_1);
ASSERT_NEAR(norm_test_1, 0, test_tol);
// Test 2 Output
T norm_test_2 = lapack::lange(Norm::Fro, k, n, B_cpy_dat, k);
printf("FRO NORM OF B - Q'A: %e\n", norm_test_2);
ASSERT_NEAR(norm_test_2, 0, test_tol);
//T norm_test_2 = lapack::lange(Norm::Fro, n, k, B_cpy_dat, n);
//printf("FRO NORM OF B - Q'A: %e\n", norm_test_2);
//ASSERT_NEAR(norm_test_2, 0, test_tol);
// Test 3 Output
T norm_test_3 = lapack::lansy(lapack::Norm::Fro, Uplo::Upper, k, Ident_dat, k);
printf("FRO NORM OF Q'Q - I: %e\n", norm_test_3);
Expand Down Expand Up @@ -230,12 +238,12 @@ class TestQB : public ::testing::Test

TEST_F(TestQB, Polynomial_Decay_general1)
{
int64_t m = 100;
int64_t n = 100;
int64_t k = 50;
int64_t p = 5;
int64_t m = 10;
int64_t n = 10;
int64_t k = 5;
int64_t p = 2;
int64_t passes_per_iteration = 1;
int64_t block_sz = 10;
int64_t block_sz = 2;
double tol = std::pow(std::numeric_limits<double>::epsilon(), 0.75);
auto state = RandBLAS::RNGState();

Expand Down Expand Up @@ -354,4 +362,27 @@ TEST_F(TestQB, Polynomial_Decay_zero_tol2)

delete all_data;
delete all_algs;
}



TEST_F(TestQB, random_test)
{
/*
int64_t rows_1 = 2;
int64_t cols_1 = 3;
int64_t rows_2 = 2;
int64_t cols_2 = 2;
std::vector<double> A = { 1, 2, 3, 4, 5, 1, 2, 3, 4, 5, 1, 2, 3, 4, 5};
std::vector<double> B (0, 3 * 2);
double* A_dat = A.data();
double* B_dat = B.data();
blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, cols_1, rows_1, rows_2, 1.0, A_dat, rows_1, &A_dat[rows_1 * cols_1], rows_2, 0.0, B_dat, cols_1);
char name[] = "A";
char name1[] = "B";
RandBLAS::util::print_colmaj(rows_1, cols_1 + cols_2, A_dat, name);
RandBLAS::util::print_colmaj(cols_1, rows_1, B_dat, name1);
*/
}

0 comments on commit 7debf7e

Please sign in to comment.