Skip to content

Commit

Permalink
Allocation logic in QB fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
TeachRaccooon committed Apr 26, 2024
1 parent 95c0ae3 commit 3382c24
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 55 deletions.
154 changes: 140 additions & 14 deletions RandLAPACK/comps/rl_qb.hh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class QBalg {
int64_t b_sz,
T tol,
T* &Q,
T* &B,
T* &BT,
RandBLAS::RNGState<RNG> &state
) = 0;
};
Expand Down Expand Up @@ -58,9 +58,9 @@ class QB : public QBalg<T, RNG> {
/// or
/// (2) Q has k columns.
/// Each iteration involves sketching A from the right by a sketching
/// matrix with "b_sz" columns.
/// matrix with "block_sz" columns.
///
/// The number of columns in Q increase by "b_sz" at each iteration, unless
/// The number of columns in Q increase by "block_sz" at each iteration, unless
/// that would bring #cols(Q) > k. In that case, the final iteration only
/// adds enough columns to Q so that #cols(Q) == k.
/// The implementation relies on RowSketcher and RangeFinder,
Expand Down Expand Up @@ -96,16 +96,16 @@ class QB : public QBalg<T, RNG> {
///
/// @param[in] Q
/// Buffer for the Q-factor.
/// Q is REQUIRED to be either nullptr or to point to >= m * b_sz * sizeof(T) bytes.
/// We expect Q to be nullptr.
///
/// @param[in] B
/// @param[in] BT
/// Buffer for the B-factor.
/// B is REQUIRED to be either nullptr or to point to >= n * b_sz * sizeof(T) bytes.
/// We expect BT to be nullptr.
///
/// @param[out] Q
/// Has the same number of rows of A, and orthonormal columns.
///
/// @param[out] B
/// @param[out] BT
/// Number of rows in B is equal to number of columns in A (B is returned in a transposed format).
///
/// @return = 0: successful exit
Expand All @@ -119,7 +119,7 @@ class QB : public QBalg<T, RNG> {
int64_t b_sz,
T tol,
T* &Q,
T* &B,
T* &BT,
RandBLAS::RNGState<RNG> &state
) override;

Expand All @@ -140,17 +140,143 @@ int QB<T, RNG>::call(
int64_t b_sz,
T tol,
T* &Q,
T* &B,
T* &BT,
RandBLAS::RNGState<RNG> &state
){
// #cols(Q) & #cols(BT) that are filled at a given iteration.
int64_t curr_sz = 0;
// #cols(Q) & #cols(BT) that will be filled at the end of a given iteration.
int64_t next_sz = 0;
tol = std::max(tol, 100 * std::numeric_limits<T>::epsilon());
T norm_B = 0.0;
T prev_err = 0.0;
T approx_err = 0.0;

int ctr = 0;
while(10 > ctr) {
Q = ( double * ) realloc(Q, ctr * sizeof( double ));
++ctr;
// We require Q, B to be nullptr.
if(Q) free(Q);
if(BT) free(BT);
// Make sure Q, B have space for one iteration
Q = ( T * ) calloc(m * b_sz, sizeof( T ) );
BT = ( T * ) calloc(n * b_sz, sizeof( T ) );
// Allocate buffers
T* QtQi = ( T * ) calloc( b_sz * b_sz, sizeof( T ) );
T* A_cpy = ( T * ) calloc( m * n, sizeof( T ) );
T* Q_i = ( T * ) calloc( m * b_sz, sizeof( T ) );
T* BT_i = ( T * ) calloc( b_sz * n, sizeof( T ) );

// pre-compute nrom
T norm_A = lapack::lange(Norm::Fro, m, n, A, m);

// Copy the initial data to avoid unwanted modification
lapack::lacpy(MatrixType::General, m, n, A, m, A_cpy, m);

while(curr_sz < k) {

// Dynamically changing block size.
b_sz = std::min(b_sz, k - curr_sz);
next_sz = curr_sz + b_sz;

// Allocate more space in Q, B, QtQi buffer if needed.
if (curr_sz != 0) {
Q = ( T * ) realloc(Q, next_sz * m * sizeof( T ));
BT = ( T * ) realloc(BT, next_sz * n * sizeof( T ));
QtQi = ( T * ) realloc(QtQi, next_sz * b_sz * sizeof( T ));
}

// Calling RangeFinder
if(this->RF_Obj.call(m, n, A_cpy, b_sz, Q_i, state)) {
// RF failed
k = curr_sz;
free(A_cpy);
free(QtQi);
free(Q_i);
free(BT_i);
return 6;
}

if(this->orth_check) {
if (util::orthogonality_check(m, b_sz, b_sz, Q_i, this->verbosity)) {
// Lost orthonormality of Q
k = curr_sz;
free(A_cpy);
free(QtQi);
free(Q_i);
free(BT_i);
return 4;
}
}

// No need to reorthogonalize on the 1st pass
if(curr_sz != 0) {
// Q_i = orth(Q_i - Q(Q'Q_i))
blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, curr_sz, b_sz, m, 1.0, Q, m, Q_i, m, 0.0, QtQi, next_sz);
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, b_sz, curr_sz, -1.0, Q, m, QtQi, next_sz, 1.0, Q_i, m);
this->Orth_Obj.call(m, b_sz, Q_i);
}

//B_i' = A' * Q_i'
blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, n, b_sz, m, 1.0, A_cpy, m, Q_i, m, 0.0, BT_i, n);

// Updating B norm estimation
T norm_B_i = lapack::lange(Norm::Fro, n, b_sz, BT_i, n);
norm_B = std::hypot(norm_B, norm_B_i);
// Updating approximation error
prev_err = approx_err;
approx_err = std::sqrt(std::abs(norm_A - norm_B)) * (std::sqrt(norm_A + norm_B) / norm_A);

// Early termination - handling round-off error accumulation
if ((curr_sz > 0) && (approx_err > prev_err)) {
// Early termination - error has grown.
k = curr_sz;
free(A_cpy);
free(QtQi);
free(Q_i);
free(BT_i);
return 2;
}

// Update the matrices Q and B
lapack::lacpy(MatrixType::General, m, b_sz, Q_i, m, &Q[m * curr_sz], m);
lapack::lacpy(MatrixType::General, n, b_sz, BT_i, n, &BT[n * curr_sz], n);

if(this->orth_check) {
if (util::orthogonality_check(m, next_sz, next_sz, Q, this->verbosity)) {
// Lost orthonormality of Q
k = curr_sz;
free(A_cpy);
free(QtQi);
free(Q_i);
free(BT_i);
return 5;
}
}

// Update #cols(Q) & #cols(B)
curr_sz += b_sz;

// Termination criteria
if (approx_err < tol) {
// Reached the required error tol
k = curr_sz;
free(A_cpy);
free(QtQi);
free(Q_i);
free(BT_i);
return 0;
}

// This step is only necessary for the next iteration
// A = A - Q_i * B_i
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, b_sz, -1.0, Q_i, m, BT_i, n, 1.0, A_cpy, m);
}

return 1;
free(A_cpy);
free(QtQi);
free(Q_i);
free(BT_i);

// Reached expected rank without achieving the tolerance
return 3;
}

} // end namespace RandLAPACK
Expand Down
9 changes: 4 additions & 5 deletions RandLAPACK/drivers/rl_rsvd.hh
Original file line number Diff line number Diff line change
Expand Up @@ -129,11 +129,11 @@ int RSVD<T, RNG>::call(
std::vector<T> &VT,
RandBLAS::RNGState<RNG> &state
){
T* Q = NULL;
T* B = NULL;
T* Q = nullptr;
T* B = nullptr;
// Q and B sizes will be adjusted automatically
this->QB_Obj.call(m, n, A.data(), k, this->block_sz, tol, Q, B, state);
/*

// Making sure all vectors are large enough
util::upsize(m * k, U);
util::upsize(k * k, this->U_buf);
Expand All @@ -144,8 +144,7 @@ int RSVD<T, RNG>::call(
lapack::gesdd(Job::SomeVec, k, n, B, k, S.data(), this->U_buf.data(), k, VT.data(), k);
// Adjusting U
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, k, k, 1.0, Q, m, this->U_buf.data(), k, 0.0, U.data(), m);
*/
fprintf( stderr, "Segfault on Free?\n");

free(Q);
free(B);
return 0;
Expand Down
48 changes: 16 additions & 32 deletions test/comps/test_qb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class TestQB : public ::testing::Test
QBTestData<T> &all_data,
alg_type &all_algs,
RandBLAS::RNGState<RNG> &state) {
/*

auto m = all_data.row;
auto n = all_data.col;
auto k = all_data.rank;
Expand All @@ -119,25 +119,16 @@ class TestQB : public ::testing::Test
T* s_dat = all_data.s.data();
T* S_dat = all_data.S.data();
T* VT_dat = all_data.VT.data();
*/
int64_t m = 0;
int64_t n = 0;
int64_t k = 0;
T* A;
T* Q;
T* B;

T* Q = nullptr;
T* B = nullptr;

// Regular QB2 call
all_algs.QB.call(m, n, A, k, block_sz, tol, Q, B, state);
all_algs.QB.call(m, n, all_data.A.data(), k, block_sz, tol, Q, B, state);

printf("%f\n", Q[0]);
free(A);
free(Q);
free(B);
/*
// Reassing pointers because Q, B have been resized
T* Q_dat = all_data.Q.data();
T* B_dat = all_data.B.data();
T* Q_dat = Q;
T* B_dat = B;
T* B_cpy_dat = all_data.B_cpy.data();

printf("Inner dimension of QB: %-25ld\n", k);
Expand Down Expand Up @@ -184,7 +175,6 @@ class TestQB : public ::testing::Test
T norm_test_4 = lapack::lange(Norm::Fro, m, n, A_hat_dat, m);
printf("FRO NORM OF A_k - QB: %e\n", norm_test_4);
ASSERT_NEAR(norm_test_4, 0, test_tol);
*/
}

/// k = min(m, n) test for CholQRCP:
Expand All @@ -209,12 +199,15 @@ class TestQB : public ::testing::Test
T* B_dat = all_data.B.data();
T* A_hat_dat = all_data.A_hat.data();

T* Q = nullptr;
T* B = nullptr;

// Regular QB2 call
all_algs.QB.call(m, n, all_data.A.data(), k_est, block_sz, tol, all_data.Q.data(), all_data.B.data(), state);
all_algs.QB.call(m, n, all_data.A.data(), k_est, block_sz, tol, Q, B, state);

// Reassing pointers because Q, B have been resized
Q_dat = all_data.Q.data();
B_dat = all_data.B.data();
Q_dat = Q;
B_dat = B;

printf("Inner dimension of QB: %ld\n", k_est);

Expand Down Expand Up @@ -249,9 +242,9 @@ class TestQB : public ::testing::Test

TEST_F(TestQB, Polynomial_Decay_general1)
{
int64_t m = 10;
int64_t n = 10;
int64_t k = 5;
int64_t m = 100;
int64_t n = 100;
int64_t k = 50;
int64_t p = 2;
int64_t passes_per_iteration = 1;
int64_t block_sz = 2;
Expand Down Expand Up @@ -280,14 +273,6 @@ TEST_F(TestQB, Polynomial_Decay_general1)
}


TEST_F(TestQB, test_rand)
{
int* ptr;
rand_fun(ptr);
printf("%d\n", ptr[0]);
}

/*
TEST_F(TestQB, Polynomial_Decay_general2)
{
int64_t m = 100;
Expand Down Expand Up @@ -385,7 +370,6 @@ TEST_F(TestQB, Polynomial_Decay_zero_tol2)
}


*/
TEST_F(TestQB, random_test)
{
/*
Expand Down
6 changes: 2 additions & 4 deletions test/drivers/test_rsvd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class TestRSVD : public ::testing::Test

// Regular QB2 call
all_algs.RSVD.call(m, n, all_data.A, k, tol, all_data.U1, all_data.s1, all_data.VT1, state);
/*

// Construnct A_approx_determ = U1 * S1 * VT1

// Turn vector into diagonal matrix
Expand All @@ -157,10 +157,9 @@ class TestRSVD : public ::testing::Test
T norm_test_4 = lapack::lange(Norm::Fro, m, n, A_approx_determ_dat, m);
printf("FRO NORM OF A_k - QB: %e\n", norm_test_4);
//ASSERT_NEAR(norm_test_4, 0, std::pow(std::numeric_limits<T>::epsilon(), 0.625));
*/
}
};
/*

TEST_F(TestRSVD, SimpleTest)
{
int64_t m = 100;
Expand Down Expand Up @@ -188,4 +187,3 @@ TEST_F(TestRSVD, SimpleTest)
computational_helper(all_data);
test_RSVD1_general(tol, all_data, all_algs, state);
}
*/

0 comments on commit 3382c24

Please sign in to comment.