From 3382c245a54cddf4f6a298e12297afbc1d65f060 Mon Sep 17 00:00:00 2001 From: TeachRaccooon Date: Fri, 26 Apr 2024 15:05:29 -0700 Subject: [PATCH] Allocation logic in QB fixed --- RandLAPACK/comps/rl_qb.hh | 154 ++++++++++++++++++++++++++++++---- RandLAPACK/drivers/rl_rsvd.hh | 9 +- test/comps/test_qb.cc | 48 ++++------- test/drivers/test_rsvd.cc | 6 +- 4 files changed, 162 insertions(+), 55 deletions(-) diff --git a/RandLAPACK/comps/rl_qb.hh b/RandLAPACK/comps/rl_qb.hh index 8f751ed0..9320ae1e 100644 --- a/RandLAPACK/comps/rl_qb.hh +++ b/RandLAPACK/comps/rl_qb.hh @@ -29,7 +29,7 @@ class QBalg { int64_t b_sz, T tol, T* &Q, - T* &B, + T* &BT, RandBLAS::RNGState &state ) = 0; }; @@ -58,9 +58,9 @@ class QB : public QBalg { /// or /// (2) Q has k columns. /// Each iteration involves sketching A from the right by a sketching - /// matrix with "b_sz" columns. + /// matrix with "block_sz" columns. /// - /// The number of columns in Q increase by "b_sz" at each iteration, unless + /// The number of columns in Q increase by "block_sz" at each iteration, unless /// that would bring #cols(Q) > k. In that case, the final iteration only /// adds enough columns to Q so that #cols(Q) == k. /// The implementation relies on RowSketcher and RangeFinder, @@ -96,16 +96,16 @@ class QB : public QBalg { /// /// @param[in] Q /// Buffer for the Q-factor. - /// Q is REQUIRED to be either nullptr or to point to >= m * b_sz * sizeof(T) bytes. + /// We expect Q to be nullptr. /// - /// @param[in] B + /// @param[in] BT /// Buffer for the B-factor. - /// B is REQUIRED to be either nullptr or to point to >= n * b_sz * sizeof(T) bytes. + /// We expect BT to be nullptr. /// /// @param[out] Q /// Has the same number of rows of A, and orthonormal columns. /// - /// @param[out] B + /// @param[out] BT /// Number of rows in B is equal to number of columns in A (B is returned in a transposed format). /// /// @return = 0: successful exit @@ -119,7 +119,7 @@ class QB : public QBalg { int64_t b_sz, T tol, T* &Q, - T* &B, + T* &BT, RandBLAS::RNGState &state ) override; @@ -140,17 +140,143 @@ int QB::call( int64_t b_sz, T tol, T* &Q, - T* &B, + T* &BT, RandBLAS::RNGState &state ){ + // #cols(Q) & #cols(BT) that are filled at a given iteration. + int64_t curr_sz = 0; + // #cols(Q) & #cols(BT) that will be filled at the end of a given iteration. + int64_t next_sz = 0; + tol = std::max(tol, 100 * std::numeric_limits::epsilon()); + T norm_B = 0.0; + T prev_err = 0.0; + T approx_err = 0.0; - int ctr = 0; - while(10 > ctr) { - Q = ( double * ) realloc(Q, ctr * sizeof( double )); - ++ctr; + // We require Q, B to be nullptr. + if(Q) free(Q); + if(BT) free(BT); + // Make sure Q, B have space for one iteration + Q = ( T * ) calloc(m * b_sz, sizeof( T ) ); + BT = ( T * ) calloc(n * b_sz, sizeof( T ) ); + // Allocate buffers + T* QtQi = ( T * ) calloc( b_sz * b_sz, sizeof( T ) ); + T* A_cpy = ( T * ) calloc( m * n, sizeof( T ) ); + T* Q_i = ( T * ) calloc( m * b_sz, sizeof( T ) ); + T* BT_i = ( T * ) calloc( b_sz * n, sizeof( T ) ); + + // pre-compute nrom + T norm_A = lapack::lange(Norm::Fro, m, n, A, m); + + // Copy the initial data to avoid unwanted modification + lapack::lacpy(MatrixType::General, m, n, A, m, A_cpy, m); + + while(curr_sz < k) { + + // Dynamically changing block size. + b_sz = std::min(b_sz, k - curr_sz); + next_sz = curr_sz + b_sz; + + // Allocate more space in Q, B, QtQi buffer if needed. + if (curr_sz != 0) { + Q = ( T * ) realloc(Q, next_sz * m * sizeof( T )); + BT = ( T * ) realloc(BT, next_sz * n * sizeof( T )); + QtQi = ( T * ) realloc(QtQi, next_sz * b_sz * sizeof( T )); + } + + // Calling RangeFinder + if(this->RF_Obj.call(m, n, A_cpy, b_sz, Q_i, state)) { + // RF failed + k = curr_sz; + free(A_cpy); + free(QtQi); + free(Q_i); + free(BT_i); + return 6; + } + + if(this->orth_check) { + if (util::orthogonality_check(m, b_sz, b_sz, Q_i, this->verbosity)) { + // Lost orthonormality of Q + k = curr_sz; + free(A_cpy); + free(QtQi); + free(Q_i); + free(BT_i); + return 4; + } + } + + // No need to reorthogonalize on the 1st pass + if(curr_sz != 0) { + // Q_i = orth(Q_i - Q(Q'Q_i)) + blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, curr_sz, b_sz, m, 1.0, Q, m, Q_i, m, 0.0, QtQi, next_sz); + blas::gemm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, b_sz, curr_sz, -1.0, Q, m, QtQi, next_sz, 1.0, Q_i, m); + this->Orth_Obj.call(m, b_sz, Q_i); + } + + //B_i' = A' * Q_i' + blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, n, b_sz, m, 1.0, A_cpy, m, Q_i, m, 0.0, BT_i, n); + + // Updating B norm estimation + T norm_B_i = lapack::lange(Norm::Fro, n, b_sz, BT_i, n); + norm_B = std::hypot(norm_B, norm_B_i); + // Updating approximation error + prev_err = approx_err; + approx_err = std::sqrt(std::abs(norm_A - norm_B)) * (std::sqrt(norm_A + norm_B) / norm_A); + + // Early termination - handling round-off error accumulation + if ((curr_sz > 0) && (approx_err > prev_err)) { + // Early termination - error has grown. + k = curr_sz; + free(A_cpy); + free(QtQi); + free(Q_i); + free(BT_i); + return 2; + } + + // Update the matrices Q and B + lapack::lacpy(MatrixType::General, m, b_sz, Q_i, m, &Q[m * curr_sz], m); + lapack::lacpy(MatrixType::General, n, b_sz, BT_i, n, &BT[n * curr_sz], n); + + if(this->orth_check) { + if (util::orthogonality_check(m, next_sz, next_sz, Q, this->verbosity)) { + // Lost orthonormality of Q + k = curr_sz; + free(A_cpy); + free(QtQi); + free(Q_i); + free(BT_i); + return 5; + } + } + + // Update #cols(Q) & #cols(B) + curr_sz += b_sz; + + // Termination criteria + if (approx_err < tol) { + // Reached the required error tol + k = curr_sz; + free(A_cpy); + free(QtQi); + free(Q_i); + free(BT_i); + return 0; + } + + // This step is only necessary for the next iteration + // A = A - Q_i * B_i + blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, b_sz, -1.0, Q_i, m, BT_i, n, 1.0, A_cpy, m); } - return 1; + free(A_cpy); + free(QtQi); + free(Q_i); + free(BT_i); + + // Reached expected rank without achieving the tolerance + return 3; } } // end namespace RandLAPACK diff --git a/RandLAPACK/drivers/rl_rsvd.hh b/RandLAPACK/drivers/rl_rsvd.hh index 26d44118..db86934c 100644 --- a/RandLAPACK/drivers/rl_rsvd.hh +++ b/RandLAPACK/drivers/rl_rsvd.hh @@ -129,11 +129,11 @@ int RSVD::call( std::vector &VT, RandBLAS::RNGState &state ){ - T* Q = NULL; - T* B = NULL; + T* Q = nullptr; + T* B = nullptr; // Q and B sizes will be adjusted automatically this->QB_Obj.call(m, n, A.data(), k, this->block_sz, tol, Q, B, state); -/* + // Making sure all vectors are large enough util::upsize(m * k, U); util::upsize(k * k, this->U_buf); @@ -144,8 +144,7 @@ int RSVD::call( lapack::gesdd(Job::SomeVec, k, n, B, k, S.data(), this->U_buf.data(), k, VT.data(), k); // Adjusting U blas::gemm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, k, k, 1.0, Q, m, this->U_buf.data(), k, 0.0, U.data(), m); -*/ - fprintf( stderr, "Segfault on Free?\n"); + free(Q); free(B); return 0; diff --git a/test/comps/test_qb.cc b/test/comps/test_qb.cc index 686b9468..c8811acd 100644 --- a/test/comps/test_qb.cc +++ b/test/comps/test_qb.cc @@ -105,7 +105,7 @@ class TestQB : public ::testing::Test QBTestData &all_data, alg_type &all_algs, RandBLAS::RNGState &state) { -/* + auto m = all_data.row; auto n = all_data.col; auto k = all_data.rank; @@ -119,25 +119,16 @@ class TestQB : public ::testing::Test T* s_dat = all_data.s.data(); T* S_dat = all_data.S.data(); T* VT_dat = all_data.VT.data(); -*/ - int64_t m = 0; - int64_t n = 0; - int64_t k = 0; - T* A; - T* Q; - T* B; + + T* Q = nullptr; + T* B = nullptr; // Regular QB2 call - all_algs.QB.call(m, n, A, k, block_sz, tol, Q, B, state); + all_algs.QB.call(m, n, all_data.A.data(), k, block_sz, tol, Q, B, state); - printf("%f\n", Q[0]); - free(A); - free(Q); - free(B); -/* // Reassing pointers because Q, B have been resized - T* Q_dat = all_data.Q.data(); - T* B_dat = all_data.B.data(); + T* Q_dat = Q; + T* B_dat = B; T* B_cpy_dat = all_data.B_cpy.data(); printf("Inner dimension of QB: %-25ld\n", k); @@ -184,7 +175,6 @@ class TestQB : public ::testing::Test T norm_test_4 = lapack::lange(Norm::Fro, m, n, A_hat_dat, m); printf("FRO NORM OF A_k - QB: %e\n", norm_test_4); ASSERT_NEAR(norm_test_4, 0, test_tol); -*/ } /// k = min(m, n) test for CholQRCP: @@ -209,12 +199,15 @@ class TestQB : public ::testing::Test T* B_dat = all_data.B.data(); T* A_hat_dat = all_data.A_hat.data(); + T* Q = nullptr; + T* B = nullptr; + // Regular QB2 call - all_algs.QB.call(m, n, all_data.A.data(), k_est, block_sz, tol, all_data.Q.data(), all_data.B.data(), state); + all_algs.QB.call(m, n, all_data.A.data(), k_est, block_sz, tol, Q, B, state); // Reassing pointers because Q, B have been resized - Q_dat = all_data.Q.data(); - B_dat = all_data.B.data(); + Q_dat = Q; + B_dat = B; printf("Inner dimension of QB: %ld\n", k_est); @@ -249,9 +242,9 @@ class TestQB : public ::testing::Test TEST_F(TestQB, Polynomial_Decay_general1) { - int64_t m = 10; - int64_t n = 10; - int64_t k = 5; + int64_t m = 100; + int64_t n = 100; + int64_t k = 50; int64_t p = 2; int64_t passes_per_iteration = 1; int64_t block_sz = 2; @@ -280,14 +273,6 @@ TEST_F(TestQB, Polynomial_Decay_general1) } -TEST_F(TestQB, test_rand) -{ - int* ptr; - rand_fun(ptr); - printf("%d\n", ptr[0]); -} - -/* TEST_F(TestQB, Polynomial_Decay_general2) { int64_t m = 100; @@ -385,7 +370,6 @@ TEST_F(TestQB, Polynomial_Decay_zero_tol2) } -*/ TEST_F(TestQB, random_test) { /* diff --git a/test/drivers/test_rsvd.cc b/test/drivers/test_rsvd.cc index 7d8d3b35..bdd6a1a3 100644 --- a/test/drivers/test_rsvd.cc +++ b/test/drivers/test_rsvd.cc @@ -132,7 +132,7 @@ class TestRSVD : public ::testing::Test // Regular QB2 call all_algs.RSVD.call(m, n, all_data.A, k, tol, all_data.U1, all_data.s1, all_data.VT1, state); -/* + // Construnct A_approx_determ = U1 * S1 * VT1 // Turn vector into diagonal matrix @@ -157,10 +157,9 @@ class TestRSVD : public ::testing::Test T norm_test_4 = lapack::lange(Norm::Fro, m, n, A_approx_determ_dat, m); printf("FRO NORM OF A_k - QB: %e\n", norm_test_4); //ASSERT_NEAR(norm_test_4, 0, std::pow(std::numeric_limits::epsilon(), 0.625)); -*/ } }; -/* + TEST_F(TestRSVD, SimpleTest) { int64_t m = 100; @@ -188,4 +187,3 @@ TEST_F(TestRSVD, SimpleTest) computational_helper(all_data); test_RSVD1_general(tol, all_data, all_algs, state); } -*/ \ No newline at end of file