Skip to content

Commit

Permalink
finish merge
Browse files Browse the repository at this point in the history
  • Loading branch information
rileyjmurray committed Sep 8, 2024
2 parents 7e8b41e + 237c2f1 commit 37759e9
Show file tree
Hide file tree
Showing 37 changed files with 2,579 additions and 950 deletions.
86 changes: 37 additions & 49 deletions RandLAPACK/comps/rl_orth.hh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Stabilization {
virtual int call(
int64_t m,
int64_t k,
std::vector<T> &Q
T* A
) = 0;
};

Expand All @@ -29,7 +29,7 @@ class CholQRQ : public Stabilization<T> {

CholQRQ(bool c_check, bool verb) {
cond_check = c_check;
verbosity = verb;
verbose = verb;
chol_fail = false;
};

Expand All @@ -44,10 +44,10 @@ class CholQRQ : public Stabilization<T> {
/// @param[in] k
/// The number of columns in the matrix Q.
///
/// @param[in] Q
/// @param[in] A
/// The m-by-k matrix, stored in a column-major format.
///
/// @param[out] Q
/// @param[out] A
/// Overwritten with an orthogonal Q-factor.
///
///
Expand All @@ -56,65 +56,58 @@ class CholQRQ : public Stabilization<T> {
int call(
int64_t m,
int64_t k,
std::vector<T> &Q
T* A
);

public:
bool chol_fail;
bool cond_check;
bool verbosity;

// CholQR-specific
std::vector<T> Q_gram;
std::vector<T> Q_gram_cpy;
std::vector<T> s;
bool verbose;
};

// -----------------------------------------------------------------------------
template <typename T>
int CholQRQ<T>::call(
int64_t m,
int64_t k,
std::vector<T> &Q
T* A
){

T* Q_gram_dat = util::upsize(k * k, this->Q_gram);
T* Q_dat = Q.data();
T* A_gram = ( T * ) calloc( k * k, sizeof( T ) );

// Find normal equation Q'Q - Just the upper triangular portion
blas::syrk(Layout::ColMajor, Uplo::Upper, Op::Trans, k, m, 1.0, Q_dat, m, 0.0, Q_gram_dat, k);
blas::syrk(Layout::ColMajor, Uplo::Upper, Op::Trans, k, m, 1.0, A, m, 0.0, A_gram, k);

// Positive definite cholesky factorization
if (lapack::potrf(Uplo::Upper, k, Q_gram_dat, k)) {
if(this->verbosity) {
if (lapack::potrf(Uplo::Upper, k, A_gram, k)) {
if(this->verbose) {
printf("CHOLESKY QR FAILED\n");
}
this->chol_fail = true; // scheme failure
free(A_gram);
return 1;
}

// Scheme may succeed, but output garbage
if(this->cond_check) {
if(util::cond_num_check(k, k, Q_gram.data(), (this->Q_gram_cpy).data(), (this->s).data(), this->verbosity) > (1 / std::sqrt(std::numeric_limits<T>::epsilon()))){
if(util::cond_num_check(k, k, A_gram, this->verbose) > (1 / std::sqrt(std::numeric_limits<T>::epsilon()))){
free(A_gram);
return 1;
}
}

blas::trsm(Layout::ColMajor, Side::Right, Uplo::Upper, Op::NoTrans, Diag::NonUnit, m, k, 1.0, Q_gram_dat, k, Q_dat, m);
blas::trsm(Layout::ColMajor, Side::Right, Uplo::Upper, Op::NoTrans, Diag::NonUnit, m, k, 1.0, A_gram, k, A, m);
free(A_gram);
return 0;
}





template <typename T>
class HQRQ : public Stabilization<T> {
public:

HQRQ(bool c_check, bool verb) {
cond_check = c_check;
verbosity = verb;
verbose = verb;
};

/// Performs a Householder QR factorization. Outputs the Q-factor only.
Expand Down Expand Up @@ -143,35 +136,34 @@ class HQRQ : public Stabilization<T> {
int call(
int64_t m,
int64_t k,
std::vector<T> &Q
T* A
);

public:
std::vector<T> tau;
bool cond_check;
bool verbosity;
bool verbose;
};

// -----------------------------------------------------------------------------
template <typename T>
int HQRQ<T>::call(
int64_t m,
int64_t n,
std::vector<T> &A
T* A
) {
// Done via regular LAPACK's QR
// tau The vector tau of length min(m,n). The scalar factors of the elementary reflectors (see Further Details).
// tau needs to be a vector of all 2's by default

auto tau = this->tau;
util::upsize(n, tau);
T* tau = ( T * ) calloc( n, sizeof( T ) );

T* A_dat = A.data();
T* tau_dat = tau.data();
if(lapack::geqrf(m, n, A_dat, m, tau_dat))
if(lapack::geqrf(m, n, A, m, tau)) {
free(tau);
return 1; // Failure condition
}

lapack::ungqr(m, n, n, A_dat, m, tau_dat);
lapack::ungqr(m, n, n, A, m, tau);
free(tau);
return 0;
}

Expand All @@ -181,7 +173,7 @@ class PLUL : public Stabilization<T> {

PLUL(bool c_check, bool verb) {
this->cond_check = c_check;
this->verbosity = verb;
this->verbose = verb;
};

/// Performs an unpivoted LU factorization. Outputs the L-factor only.
Expand Down Expand Up @@ -210,13 +202,12 @@ class PLUL : public Stabilization<T> {
int call(
int64_t m,
int64_t k,
std::vector<T> &Q
T* A
);

public:
std::vector<int64_t> ipiv;
bool cond_check;
bool verbosity;
bool verbose;
};


Expand All @@ -225,22 +216,19 @@ template <typename T>
int PLUL<T>::call(
int64_t m,
int64_t n,
std::vector<T> &A
T* A
){
auto ipiv = this->ipiv;
// Not using utility bc vector of int
if(ipiv.size() < (uint64_t)n)
ipiv.resize(n);
int64_t* ipiv = ( int64_t * ) calloc( n, sizeof( int64_t ) );

T* A_dat = A.data();
int64_t* ipiv_dat = ipiv.data();

if(lapack::getrf(m, n, A_dat, m, ipiv_dat))
if(lapack::getrf(m, n, A, m, ipiv)) {
free(ipiv);
return 1; // failure condition
}

util::get_L(m, n, A_dat, 1);
lapack::laswp(n, A_dat, m, 1, n, ipiv_dat, 1);
util::get_L(m, n, A, 1);
lapack::laswp(n, A, m, 1, n, ipiv, 1);

free(ipiv);
return 0;
}

Expand Down
Loading

0 comments on commit 37759e9

Please sign in to comment.