Skip to content

Commit

Permalink
Update before QB logic change
Browse files Browse the repository at this point in the history
  • Loading branch information
TeachRaccooon committed Apr 26, 2024
1 parent a0b1f50 commit c433fe4
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 19 deletions.
63 changes: 46 additions & 17 deletions RandLAPACK/comps/rl_qb.hh
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,11 @@ class QB : public QBalg<T, RNG> {
///
/// @param[in] Q
/// Buffer for the Q-factor.
/// Initially, may not have any space allocated for it.
/// Q is REQUIRED to be either nullptr or to point to >= m * b_sz * sizeof(T) bytes.
///
/// @param[in] B
/// Buffer for the B-factor.
/// Initially, may not have any space allocated for it.
/// B is REQUIRED to be either nullptr or to point to >= n * b_sz * sizeof(T) bytes.
///
/// @param[out] Q
/// Has the same number of rows of A, and orthonormal columns.
Expand Down Expand Up @@ -161,20 +161,21 @@ int QB<T, RNG>::call(
T prev_err = 0.0;
T approx_err = 0.0;

fprintf( stderr,"%d\n", m);
fprintf( stderr,"%d\n", n);
fprintf( stderr,"%ld\n", block_sz);

// Make sure Q, B have space for at least one iteration
if(!Q)
Q = ( T * ) realloc(Q, m * block_sz * sizeof( T ) );
if(!B)
B = ( T * ) realloc(B, n * block_sz * sizeof( T ) );

T* A_cpy = ( T * ) calloc( m * n, sizeof( T ) );
T* QtQi = ( T * ) calloc( this->curr_lim * block_sz, sizeof( T ) );
T* Q_i = ( T * ) calloc( m * block_sz, sizeof( T ) );
T* B_i_trans = ( T * ) calloc( block_sz * n, sizeof( T ) );
// Make sure Q, B have space for at least one iteration

if(!Q) {
Q = ( T * ) realloc(Q, m * n * sizeof( T ) );
}
if(!B) {
B = ( T * ) realloc(B, n * n * sizeof( T ) );
}


// pre-compute nrom
T norm_A = lapack::lange(Norm::Fro, m, n, A, m);
// Immediate termination criteria
Expand All @@ -190,23 +191,34 @@ int QB<T, RNG>::call(

// Copy the initial data to avoid unwanted modification TODO #1
lapack::lacpy(MatrixType::General, m, n, A, m, A_cpy, m);

int ctr = 0;
while(k > curr_sz) {
// Dynamically changing block size
block_sz = std::min(block_sz, k - curr_sz);
next_sz = curr_sz + block_sz;

fprintf( stderr,"Next sz %d\n", next_sz);
fprintf( stderr,"this->curr_lim %d\n", this->curr_lim);
// Make sure we have enough space for everything
if(next_sz > this->curr_lim) {
this->curr_lim = std::min(2 * this->curr_lim, k);
Q = ( T * ) realloc(Q, this->curr_lim * m * sizeof( T ));
B = ( T * ) realloc(B, this->curr_lim * n * sizeof( T ));
QtQi = ( T * ) realloc(QtQi, this->curr_lim * block_sz * sizeof( T ));
Q = ( T * ) realloc(Q, (this->curr_lim) * m * sizeof( T ));
B = ( T * ) realloc(B, (this->curr_lim) * n * sizeof( T ));
QtQi = ( T * ) realloc(QtQi, (this->curr_lim) * block_sz * sizeof( T ));
}

fprintf( stderr, "Size Q %ld\n", (this->curr_lim) * m );
fprintf( stderr, "Size B %ld\n", (this->curr_lim) * n );
fprintf( stderr, "Size QtQi %ld\n", (this->curr_lim) * block_sz );

// Calling RangeFinder
if(this->RF_Obj.call(m, n, A_cpy, block_sz, Q_i, state))
if(this->RF_Obj.call(m, n, A_cpy, block_sz, Q_i, state)) {
free(A_cpy);
free(QtQi);
free(Q_i);
free(B_i_trans);
return 6; // RF failed
}

if(this->orth_check) {
if (util::orthogonality_check(m, block_sz, block_sz, Q_i, this->verbosity)) {
Expand Down Expand Up @@ -250,10 +262,24 @@ int QB<T, RNG>::call(
return 2;
}

fprintf( stderr,"before copy\n");

fprintf( stderr,"Curr_sz %d\n", curr_sz);

if (ctr == 1) {
free(A_cpy);
free(QtQi);
free(Q_i);
free(B_i_trans);
return 0;
}

// Update the matrices Q and B
lapack::lacpy(MatrixType::General, m, block_sz, Q_i, m, &Q[m * curr_sz], m);
lapack::lacpy(MatrixType::General, n, block_sz, B_i_trans, n, &B[n * curr_sz], n);

fprintf( stderr,"Copy\n");

if(this->orth_check) {
if (util::orthogonality_check(m, this->curr_lim, next_sz, Q, this->verbosity)) {
// Lost orthonormality of Q
Expand All @@ -266,6 +292,8 @@ int QB<T, RNG>::call(
}
}

fprintf( stderr,"After orth\n\n");

curr_sz += block_sz;
// Termination criteria
if (approx_err < tol) {
Expand All @@ -277,10 +305,11 @@ int QB<T, RNG>::call(
free(B_i_trans);
return 0;
}

// This step is only necessary for the next iteration
// A = A - Q_i * B_i
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, m, n, block_sz, -1.0, Q_i, m, B_i_trans, n, 1.0, A_cpy, m);
++ ctr;
}

free(A_cpy);
Expand Down
4 changes: 2 additions & 2 deletions RandLAPACK/drivers/rl_rsvd.hh
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,11 @@ int RSVD<T, RNG>::call(
lapack::gesdd(Job::SomeVec, k, n, B, k, S.data(), this->U_buf.data(), k, VT.data(), k);
// Adjusting U
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, k, k, 1.0, Q, m, this->U_buf.data(), k, 0.0, U.data(), m);
*/
fprintf( stderr, "Segfault on Free?\n");
free(Q);
free(B);
return 0;
*/
}

} // end namespace RandLAPACK
Expand Down

0 comments on commit c433fe4

Please sign in to comment.