Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
TeachRaccooon committed May 6, 2024
1 parent 7b7488d commit 7a4bc7d
Showing 1 changed file with 102 additions and 79 deletions.
181 changes: 102 additions & 79 deletions benchmark/bench_CQRRP/CQRRP_linear_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ template <typename T>
struct QR_solver_benchmark_data {
int64_t row;
int64_t col;
float tolerance;
T tol_d;
float tol_s;
T sampling_factor;
std::vector<T> A;
std::vector<float> A_single;
Expand All @@ -28,12 +29,11 @@ struct QR_solver_benchmark_data {
std::vector<T> x;
std::vector<float> work;
std::vector<T> work_double;
std::vector<T> work_double2;
std::vector<T> x_solution;
std::vector<T> b_constant;
std::vector<T> ATA;
std::vector<float> R_buffer;

QR_solver_benchmark_data(int64_t m, int64_t n, float tol, T d_factor) :
QR_solver_benchmark_data(int64_t m, int64_t n, T tolerance_d, float tolerance_s, T d_factor) :
A(m * n, 0.0),
A_single(m * n, 0.0),
tau(n, 0.0),
Expand All @@ -42,14 +42,14 @@ struct QR_solver_benchmark_data {
x(n, 0),
work(m, 0),
work_double(m, 0),
work_double2(m, 0),
x_solution(n, 0),
b_constant(m, 0),
ATA(n * n, 0.0)
R_buffer(n * n, 0)
{
row = m;
col = n;
tolerance = tol;
tol_d = tolerance_d;
tol_s = tolerance_s;
sampling_factor = d_factor;
}
};
Expand Down Expand Up @@ -81,26 +81,27 @@ static long ICQRRP_refinement(RandLAPACK::gen::mat_gen_info<T> m_info,
RandBLAS::RNGState<RNG> state) {
auto m = all_data.row;
auto n = all_data.col;
auto tol = all_data.tolerance;
auto tol_d = all_data.tol_d;
auto tol_s = all_data.tol_s;
auto d_factor = all_data.sampling_factor;

// Set up a single-precision ICQRRP
RandLAPACK::CQRRP_blocked<float, r123::Philox4x32> CQRRP_blocked(false, tol, b_sz);
RandLAPACK::CQRRP_blocked<float, r123::Philox4x32> CQRRP_blocked(false, tol_d, b_sz);
CQRRP_blocked.nnz = 2;
CQRRP_blocked.num_threads = 8;

T* A = all_data.A.data(); // double
float* A_single = all_data.A_single.data(); // single
float* tau = all_data.tau.data(); // single
int64_t* J = all_data.J.data(); // single
T* b = all_data.b.data(); // double
T* x = all_data.x.data(); // double
float* work = all_data.work.data(); // single
T* work_double = all_data.work_double.data(); // double
int64_t* J = all_data.J.data();
float* R_buf = all_data.R_buffer.data(); // single

long dur_cqrrp_refine = 0;

auto start_cqrrp_refine = high_resolution_clock::now();
long dur_icqrrp_refine = 0;
auto start_icqrrp_refine = high_resolution_clock::now();
// Cast the input matrix A down to single precision.
// Copy columns in parallel.
#pragma omp parallel for
Expand All @@ -109,6 +110,9 @@ static long ICQRRP_refinement(RandLAPACK::gen::mat_gen_info<T> m_info,

// Call single-precision ICQRRP
CQRRP_blocked.call(m, n, A_single, m, d_factor, tau, J, state);
// Address Pivoting by permuting columns of A
RandLAPACK::util::col_swap(m, n, n, A, m, all_data.J);

// Vector all_data.x has already been initialized to all 0;
// repeat below until some tolerance threshold is reached:
int ctr = 0;
Expand All @@ -117,40 +121,41 @@ static long ICQRRP_refinement(RandLAPACK::gen::mat_gen_info<T> m_info,
// Need to make sure that work = b before this.
// After this, r will be stored in work (m by 1).
std::transform(b, &b[m], work, [](double d) { return static_cast<float>(d); });
blas::gemv(Layout::ColMajor, Op::NoTrans, m, n, n, A, m, x, 1, -1.0, work, 1);
// 2. Solve Qy = Pr for y in single precision.
// Since Q' = inv(Q); y = Q'Pr.
blas::gemv(Layout::ColMajor, Op::NoTrans, m, n, -1.0, A, m, x, 1, 1.0, work, 1);
// 2. Solve Qy = r for y in single precision.
// Since Q' = inv(Q); y = Q'r.
// After this, y will be stored in work (m by 1).
lapack::laswp(1, work, m, 1, n, J, 1);
//lapack::laswp(1, work, m, 1, n, J, 1);
lapack::ormqr(Side::Left, Op::Trans, m, 1, n, A_single, m, tau, work, m);
// 3. Solve Rz = y for z in single precision.
// A_single stores single-precision R.
// After this, z will be stored in work (n by 1).
blas::trmv(Layout::ColMajor, Uplo::Upper, Op::NoTrans, Diag::NonUnit, n, A_single, m, work, 1);
blas::trsv(Layout::ColMajor, Uplo::Upper, Op::NoTrans, Diag::NonUnit, n, A_single, m, work, 1);
// 4. Check if ||z|| <= tol.
T nrm = blas::nrm2(n, work, 1);
if (nrm <= tol)
printf("Tol is %.20e\n", tol_s);
printf("Nrm is %.20e\n\n", nrm);
if (nrm <= tol_s)
break;
printf("%.20e\n", nrm);
// 5. Transform work into a double-precision array
std::transform(work, &work[n], work_double, [](float f) { return static_cast<double>(f); });
// 6. x = x + z.
blas::axpy(n, 1.0, x, 1, work, 1);
++ctr;
blas::axpy(n, 1.0, work, 1, x, 1);
++ctr;
}
auto stop_cqrrp_refine = high_resolution_clock::now();
dur_cqrrp_refine = duration_cast<microseconds>(stop_cqrrp_refine - start_cqrrp_refine).count();
auto stop_icqrrp_refine = high_resolution_clock::now();
dur_icqrrp_refine = duration_cast<microseconds>(stop_icqrrp_refine - start_icqrrp_refine).count();

return dur_cqrrp_refine;
return dur_icqrrp_refine;
}

template <typename T>
static long GEQRF_refinement(RandLAPACK::gen::mat_gen_info<T> m_info,
QR_solver_benchmark_data<T> &all_data) {
auto m = all_data.row;
auto n = all_data.col;
auto tol = all_data.tolerance;
auto d_factor = all_data.sampling_factor;
auto tol_s = all_data.tol_s;
auto tol_d = all_data.tol_d;

T* A = all_data.A.data(); // double
float* A_single = all_data.A_single.data(); // single
Expand All @@ -160,18 +165,12 @@ static long GEQRF_refinement(RandLAPACK::gen::mat_gen_info<T> m_info,
float* work = all_data.work.data(); // single
T* work_double = all_data.work_double.data(); // double

long dur_geqrf_refine = 0;
T residual_inf_nrm = 0;
T solution_inf_nrm = 0;
T A_inf_nrm = lapack::lange(Norm::Fro, m, n, A, m);

long dur_geqrf_refine = 0;
auto start_geqrf_refine = high_resolution_clock::now();


char name [] = "A";
char name1 [] = "b";
char name2 [] = "work";
char name3 [] = "x";
RandBLAS::util::print_colmaj(m, n, A, name);
RandBLAS::util::print_colmaj(m, 1, b, name1);

// Cast the input matrix A down to single precision.
// Copy columns in parallel.
#pragma omp parallel for
Expand All @@ -183,33 +182,40 @@ static long GEQRF_refinement(RandLAPACK::gen::mat_gen_info<T> m_info,
// Vector all_data.x has already been initialized to all 0;
// repeat below until some tolerance threshold is reached:
int ctr = 0;
while (ctr < 10) {
while (ctr < 5000) {
// 1. Solve r = b - Ax for r in double precision.
// Need to make sure that work = b before this.
// After this, r will be stored in work (m by 1).
std::transform(b, &b[m], work, [](double d) { return static_cast<float>(d); });
blas::gemv(Layout::ColMajor, Op::NoTrans, m, n, n, A, m, x, 1, -1.0, work, 1);
RandBLAS::util::print_colmaj(m, 1, work, name2);
blas::gemv(Layout::ColMajor, Op::NoTrans, m, n, -1.0, A, m, x, 1, 1.0, work, 1);
residual_inf_nrm = lapack::lange(Norm::Inf, m, 1, work, m);
// 2. Solve Qy = r for y in single precision.
// Since Q' = inv(Q); y = Q'r.
// After this, y will be stored in work (m by 1).
lapack::ormqr(Side::Left, Op::Trans, m, 1, n, A_single, m, tau, work, m);
RandBLAS::util::print_colmaj(m, 1, work, name2);
// 3. Solve Rz = y for z in single precision.
// A_single stores single-precision R.
// After this, z will be stored in work (n by 1).
blas::trmv(Layout::ColMajor, Uplo::Upper, Op::NoTrans, Diag::NonUnit, n, A_single, m, work, 1);
RandBLAS::util::print_colmaj(m, 1, work, name2);
// 4. Check if ||z|| <= tol.
T nrm = blas::nrm2(n, work, 1);
if (nrm <= tol)
break;
printf("%.20e\n", nrm);
// 5. Transform work into a double-precision array
blas::trsv(Layout::ColMajor, Uplo::Upper, Op::NoTrans, Diag::NonUnit, n, A_single, m, work, 1);
// 4. Transform work into a double-precision array
std::transform(work, &work[n], work_double, [](float f) { return static_cast<double>(f); });
// 6. x = x + z.
blas::axpy(n, -1.0, work, 1, x, 1);
RandBLAS::util::print_colmaj(n, 1, x, name3);
// 5. x = x + z.
blas::axpy(n, 1.0, work_double, 1, x, 1);

printf("Iteration %d\n", ctr);
printf("Ratio %.20e\n", blas::nrm2(n, work_double, 1) / blas::nrm2(n, x, 1));
printf("Tol %.20e\n\n", tol_d);

if(blas::nrm2(n, work_double, 1) / blas::nrm2(n, x, 1) < tol_d)
break;
/*
// Check termination criteria
solution_inf_nrm = lapack::lange(Norm::Fro, n, 1, x, n);
printf("residual_inf_nrm is %.20e\n", residual_inf_nrm);
printf("expr is %.20e\n\n", tol_d * std::sqrt(n) * solution_inf_nrm * A_inf_nrm);
if(residual_inf_nrm < tol_d * std::sqrt(n) * solution_inf_nrm * A_inf_nrm)
break;
*/
++ctr;
}
auto stop_geqrf_refine = high_resolution_clock::now();
Expand All @@ -222,7 +228,6 @@ template <typename T>
static T forward_error(RandLAPACK::gen::mat_gen_info<T> m_info,
QR_solver_benchmark_data<T> &all_data) {
auto n = all_data.col;
T* A = all_data.A.data(); // double
T* x_solution = all_data.x_solution.data();
T* x = all_data.x.data();

Expand All @@ -238,22 +243,18 @@ static T backward_error(RandLAPACK::gen::mat_gen_info<T> m_info,
auto n = all_data.col;

T* A = all_data.A.data();
T* ATA = all_data.ATA.data();
T* x = all_data.x.data();
T* b = all_data.b.data();
T* work_double = all_data.work_double.data();
T* work_double2 = all_data.work_double2.data();

// ||A'Ax - A'Ab||
blas::syrk(Layout::ColMajor, Uplo::Upper, Op::Trans, n, m, 1.0, A, m, 0.0, ATA, n);
blas::gemv(Layout::ColMajor, Op::NoTrans, n, n, -1.0, ATA, n, x, 1, 0.0, work_double, 1);
blas::gemv(Layout::ColMajor, Op::NoTrans, n, n, -1.0, ATA, n, b, 1, 0.0, work_double2, 1);
blas::axpy(n, -1.0, work_double, 1, work_double2, 1);
T nrm_numerator = blas::nrm2(n, work_double2, 1);

// ||Ax-b||||A||_F
blas::gemv(Layout::ColMajor, Op::NoTrans, m, n, -1.0, A, m, x, 1, -1.0, b, 1);
T nrm_denominator = blas::nrm2(m, b, 1) * lapack::lange(Norm::Fro, m, n, A, m);
// Compute Ax - b
blas::gemv(Layout::ColMajor, Op::NoTrans, m, n, 1.0, A, m, x, 1, -1.0, b, 1);
T nrm1 = blas::nrm2(m, b, 1);
// Compute ||A'Ax - A'Ab||
blas::gemv(Layout::ColMajor, Op::Trans, m, n, 1.0, A, m, b, 1, 0.0, work_double, 1);
T nrm_numerator = blas::nrm2(n, work_double, 1);
// Compute ||Ax-b||||A||_F
T nrm_denominator = nrm1 * lapack::lange(Norm::Fro, m, n, A, m);

return (nrm_numerator / nrm_denominator);
}
Expand All @@ -269,8 +270,6 @@ static void call_all_algs(

auto m = all_data.row;
auto n = all_data.col;
auto tol = all_data.tolerance;
auto d_factor = all_data.sampling_factor;

// timing vars
long dur_gels = 0;
Expand All @@ -286,6 +285,13 @@ static void call_all_algs(
// Making sure the states are unchanged
auto state_gen = state;

char name [] = "A";
char name1 [] = "x";
char name2 [] = "b";

//RandBLAS::util::print_colmaj(m, n, all_data.A.data(), name);
//RandBLAS::util::print_colmaj(m, 1, all_data.b.data(), name2);

for (int i = 0; i < numruns; ++i) {
printf("ITERATION %d\n", i);
// Testing GELS
Expand All @@ -295,25 +301,41 @@ static void call_all_algs(
dur_gels = duration_cast<microseconds>(stop_gels - start_gels).count();
printf("TOTAL TIME FOR GELS %ld\n", dur_gels);

// b now stores the solution vector.
// For some reason, the original solution is actually bad.
blas::copy(n, all_data.b.data(), 1, all_data.x_solution.data(), 1);

data_regen(m_info, all_data, state_gen, 0);
state_gen = state;

// Testing ICQRRP + iterative refinement
// dur_cqrrp_refine = ICQRRP_refinement( m_info, all_data, b_sz, state);
dur_cqrrp_refine = ICQRRP_refinement( m_info, all_data, b_sz, state);
printf("TOTAL TIME FOR ICQRRP + refinement %ld\n", dur_cqrrp_refine);
// MATRIX A HAS BEEN PIVOTED, BE CAREFUL
RandLAPACK::gen::mat_gen(m_info, all_data.A.data(), state);

//RandBLAS::util::print_colmaj(m, n, all_data.A.data(), name);
//RandBLAS::util::print_colmaj(n, 1, all_data.x.data(), name1);

// backward_err_geqrf = backward_error(m_info, all_data);
// forward_err_geqrf = forward_error(m_info, all_data);
backward_err_geqrf = backward_error(m_info, all_data);
forward_err_geqrf = forward_error(m_info, all_data);
printf("F_err_GEQRF: %e\n", forward_err_geqrf);
printf("B_err_GEQRF: %e\n", backward_err_geqrf);

// data_regen(m_info, all_data, state_gen, 1);
data_regen(m_info, all_data, state_gen, 1);
state_gen = state;

// Testing GEQRF + iterative refinement
dur_geqrf_refine = GEQRF_refinement( m_info, all_data);
printf("TOTAL TIME FOR GEQRF + refinement %ld\n", dur_cqrrp_refine);

//RandBLAS::util::print_colmaj(m, n, all_data.A.data(), name);
//RandBLAS::util::print_colmaj(n, 1, all_data.x.data(), name1);

backward_err_cqrrp = backward_error(m_info, all_data);
forward_err_cqrrp = forward_error(m_info, all_data);
printf("F_err_CQRRP: %e\n", forward_err_cqrrp);
printf("B_err_CQRRP: %e\n", backward_err_cqrrp);

data_regen(m_info, all_data, state_gen, 1);
state_gen = state;
Expand All @@ -333,13 +355,14 @@ int main(int argc, char *argv[]) {
}

// Declare parameters
int64_t m = 0;
int64_t n = 0;
double d_factor = 1.25;
int64_t b_sz_start = 8;
int64_t b_sz_end = 8;
double tol = std::pow(std::numeric_limits<double>::epsilon(), 0.85);
auto state = RandBLAS::RNGState();
int64_t m = 0;
int64_t n = 0;
double d_factor = 1.25;
int64_t b_sz_start = 8;
int64_t b_sz_end = 8;
double tol_d = std::pow(std::numeric_limits<double>::epsilon(), 0.85);
float tol_s = std::pow(std::numeric_limits<float>::epsilon(), 0.95);
auto state = RandBLAS::RNGState();
auto state_constant = state;
// Number of algorithm runs. We only record best times.
int64_t numruns = 1;
Expand All @@ -360,7 +383,7 @@ int main(int argc, char *argv[]) {
n = m_info.cols;

// Allocate basic workspace.
QR_solver_benchmark_data<double> all_data(m, n, tol, d_factor);
QR_solver_benchmark_data<double> all_data(m, n, tol_d, tol_s, d_factor);

// Copy A, x, b over
lapack::lacpy(MatrixType::General, m, n, Axb_buf, m, all_data.A.data(), m);
Expand Down

0 comments on commit 7a4bc7d

Please sign in to comment.