Skip to content

Commit

Permalink
Added detailed (maybe too much) time profiling in RBKI, fixed openmp …
Browse files Browse the repository at this point in the history
…bug, fixed norm bug.
  • Loading branch information
TeachRaccooon committed Jan 16, 2024
1 parent dd5f190 commit 963fa0e
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 31 deletions.
95 changes: 76 additions & 19 deletions RandLAPACK/drivers/rl_rbki.hh
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ int RBKI<T, RNG>::call(
high_resolution_clock::time_point qr_t_stop;
high_resolution_clock::time_point gemm_A_t_start;
high_resolution_clock::time_point gemm_A_t_stop;
high_resolution_clock::time_point main_loop_t_start;
high_resolution_clock::time_point main_loop_t_stop;
high_resolution_clock::time_point sketching_t_start;
high_resolution_clock::time_point sketching_t_stop;
high_resolution_clock::time_point r_cpy_t_start;
high_resolution_clock::time_point r_cpy_t_stop;
high_resolution_clock::time_point s_cpy_t_start;
high_resolution_clock::time_point s_cpy_t_stop;
high_resolution_clock::time_point norm_t_start;
high_resolution_clock::time_point norm_t_stop;
high_resolution_clock::time_point total_t_start;
high_resolution_clock::time_point total_t_stop;

Expand All @@ -102,6 +112,11 @@ int RBKI<T, RNG>::call(
long reorth_t_dur = 0;
long qr_t_dur = 0;
long gemm_A_t_dur = 0;
long main_loop_t_dur = 0;
long sketching_t_dur = 0;
long r_cpy_t_dur = 0;
long s_cpy_t_dur = 0;
long norm_t_dur = 0;
long total_t_dur = 0;

if(this -> timing) {
Expand Down Expand Up @@ -157,12 +172,18 @@ int RBKI<T, RNG>::call(
T sq_tol = std::pow(this->tol, 2);
T threshold = std::sqrt(1 - sq_tol) * norm_A;

if(this -> timing)
sketching_t_start = high_resolution_clock::now();

// Generate a dense Gaussian random matrx.
RandBLAS::DenseDist D(n, k);
state = RandBLAS::fill_dense(D, Y_i, state).second;

if(this -> timing)
if(this -> timing) {
sketching_t_stop = high_resolution_clock::now();
sketching_t_dur = duration_cast<microseconds>(sketching_t_stop - sketching_t_start).count();
gemm_A_t_start = high_resolution_clock::now();
}

// [X_ev, ~] = qr(A * Y_i, 0)
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, k, n, 1.0, A, m, Y_i, n, 0.0, X_i, m);
Expand Down Expand Up @@ -197,7 +218,11 @@ int RBKI<T, RNG>::call(
++iter;

// Iterate until in-loop termination criteria is met.

while((iter_ev + iter_od) < max_iters) {
if(this -> timing)
main_loop_t_start = high_resolution_clock::now();

if (iter % 2 != 0) {

if(this -> timing)
Expand All @@ -217,12 +242,13 @@ int RBKI<T, RNG>::call(
if (iter != 1) {
// R_i' = Y_i' * Y_od
blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, k, iter_ev * k, n, 1.0, Y_i, n, Y_od, n, 0.0, R_i, n);

if(this -> timing)
reorth_t_start = high_resolution_clock::now();

// Y_i = Y_i - Y_od * R_i
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, n, k, iter_ev * k, -1.0, Y_od, n, R_i, n, 1.0, Y_i, n);

if(this -> timing)
reorth_t_start = high_resolution_clock::now();

// Reorthogonalization
blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, k, iter_ev * k, n, 1.0, Y_i, n, Y_od, n, 0.0, Y_orth_buf, k);
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::Trans, n, k, iter_ev * k, -1.0, Y_od, n, Y_orth_buf, k, 1.0, Y_i, n);
Expand All @@ -244,15 +270,18 @@ int RBKI<T, RNG>::call(
if(this -> timing) {
qr_t_stop = high_resolution_clock::now();
qr_t_dur += duration_cast<microseconds>(qr_t_stop - qr_t_start).count();
r_cpy_t_start = high_resolution_clock::now();
}

// Copy R_ii over to R's (in transposed format).
#pragma omp parallel for
for(i = 0; i < k; ++i)
blas::copy(i + 1, &Y_i[i * n], 1, &R_ii[i], n);

if(this -> timing)
ungqr_t_start = high_resolution_clock::now();
if(this -> timing) {
r_cpy_t_stop = high_resolution_clock::now();
r_cpy_t_dur += duration_cast<microseconds>(r_cpy_t_stop - r_cpy_t_start).count();
ungqr_t_start = high_resolution_clock::now();
}

// Convert Y_i into an explicit form. It is now stored in Y_odd as it should be.
lapack::ungqr(n, k, k, Y_i, n, tau);
Expand Down Expand Up @@ -286,19 +315,20 @@ int RBKI<T, RNG>::call(

if(this -> timing) {
gemm_A_t_stop = high_resolution_clock::now();
gemm_A_t_dur =+ duration_cast<microseconds>(gemm_A_t_stop - gemm_A_t_start).count();
gemm_A_t_dur += duration_cast<microseconds>(gemm_A_t_stop - gemm_A_t_start).count();
}

// Move the X_i pointer;
Y_i = &Y_i[n * k];

// S_i = X_ev' * X_i
blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, iter_od * k, k, m, 1.0, X_ev, m, X_i, m, 0.0, S_i, n + k);
//X_i = X_i - X_ev * S_i;
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, k, iter_od * k, -1.0, X_ev, m, S_i, n + k, 1.0, X_i, m);


if(this -> timing)
reorth_t_start = high_resolution_clock::now();

//X_i = X_i - X_ev * S_i;
blas::gemm(Layout::ColMajor, Op::NoTrans, Op::NoTrans, m, k, iter_od * k, -1.0, X_ev, m, S_i, n + k, 1.0, X_i, m);

// Reorthogonalization
blas::gemm(Layout::ColMajor, Op::Trans, Op::NoTrans, iter_od * k, k, m, 1.0, X_ev, m, X_i, m, 0.0, X_orth_buf, n + k);
Expand All @@ -320,13 +350,17 @@ int RBKI<T, RNG>::call(
if(this -> timing) {
qr_t_stop = high_resolution_clock::now();
qr_t_dur += duration_cast<microseconds>(qr_t_stop - qr_t_start).count();
s_cpy_t_start = high_resolution_clock::now();
}

// Copy S_ii over to S's space under S_i (offset down by iter_od * k)
lapack::lacpy(MatrixType::Upper, k, k, X_i, m, S_ii, n + k);

if(this -> timing)
ungqr_t_start = high_resolution_clock::now();
if(this -> timing) {
s_cpy_t_stop = high_resolution_clock::now();
s_cpy_t_dur += duration_cast<microseconds>(s_cpy_t_stop - s_cpy_t_start).count();
ungqr_t_start = high_resolution_clock::now();
}

// Convert X_i into an explicit form. It is now stored in X_ev as it should be
lapack::ungqr(m, k, k, X_i, m, tau);
Expand All @@ -349,9 +383,22 @@ int RBKI<T, RNG>::call(
// Advance odd iteration count;
++iter_od;
}

if(this -> timing)
norm_t_start = high_resolution_clock::now();

// This is only changed on odd iters
if (iter % 2 != 0)
norm_R = lapack::lantr(Norm::Fro, Uplo::Upper, Diag::NonUnit, iter_ev * k, iter_ev * k, R, n);

if(this -> timing) {
norm_t_stop = high_resolution_clock::now();
norm_t_dur += duration_cast<microseconds>(norm_t_stop - norm_t_start).count();
main_loop_t_stop = high_resolution_clock::now();
main_loop_t_dur += duration_cast<microseconds>(main_loop_t_stop - main_loop_t_start).count();
}

++iter;
norm_R = lapack::lantr(Norm::Fro, Uplo::Upper, Diag::NonUnit, n, n, R, n);

//norm(R, 'fro') > sqrt(1 - sq_tol) * norm_A
if(norm_R > threshold) {
break;
Expand Down Expand Up @@ -412,9 +459,9 @@ int RBKI<T, RNG>::call(
if(this -> timing) {
total_t_stop = high_resolution_clock::now();
total_t_dur = duration_cast<microseconds>(total_t_stop - total_t_start).count();
long t_rest = total_t_dur - (allocation_t_dur + get_factors_t_dur + ungqr_t_dur + reorth_t_dur + qr_t_dur + gemm_A_t_dur);
this -> times.resize(8);
this -> times = {allocation_t_dur, get_factors_t_dur, ungqr_t_dur, reorth_t_dur, qr_t_dur, gemm_A_t_dur, t_rest, total_t_dur};
long t_rest = total_t_dur - (allocation_t_dur + get_factors_t_dur + ungqr_t_dur + reorth_t_dur + qr_t_dur + gemm_A_t_dur + sketching_t_dur + r_cpy_t_dur + s_cpy_t_dur + norm_t_dur);
this -> times.resize(11);
this -> times = {allocation_t_dur, get_factors_t_dur, ungqr_t_dur, reorth_t_dur, qr_t_dur, gemm_A_t_dur, main_loop_t_dur, sketching_t_dur, r_cpy_t_dur, s_cpy_t_dur, norm_t_dur, t_rest, total_t_dur};

if (this -> verbosity) {
printf("\n\n/------------RBKI TIMING RESULTS BEGIN------------/\n");
Expand All @@ -426,14 +473,24 @@ int RBKI<T, RNG>::call(
printf("Reorthogonalization time: %25ld μs,\n", reorth_t_dur);
printf("QR time: %25ld μs,\n", qr_t_dur);
printf("GEMM A time: %25ld μs,\n", gemm_A_t_dur);
printf("Sketching time: %25ld μs,\n", sketching_t_dur);
printf("R_ii cpy time: %25ld μs,\n", r_cpy_t_dur);
printf("S_ii cpy time: %25ld μs,\n", s_cpy_t_dur);
printf("Norm R time: %25ld μs,\n", norm_t_dur);

printf("\nAllocation takes %22.2f%% of runtime.\n", 100 * ((T) allocation_t_dur / (T) total_t_dur));
printf("\nAllocation takes %22.2f%% of runtime.\n", 100 * ((T) allocation_t_dur / (T) total_t_dur));
printf("Factors takes %22.2f%% of runtime.\n", 100 * ((T) get_factors_t_dur / (T) total_t_dur));
printf("Ungqr takes %22.2f%% of runtime.\n", 100 * ((T) ungqr_t_dur / (T) total_t_dur));
printf("Reorth takes %22.2f%% of runtime.\n", 100 * ((T) reorth_t_dur / (T) total_t_dur));
printf("QR takes %22.2f%% of runtime.\n", 100 * ((T) qr_t_dur / (T) total_t_dur));
printf("GEMM A takes %22.2f%% of runtime.\n", 100 * ((T) gemm_A_t_dur / (T) total_t_dur));
printf("Sketching takes %22.2f%% of runtime.\n", 100 * ((T) sketching_t_dur / (T) total_t_dur));
printf("R_ii cpy takes %22.2f%% of runtime.\n", 100 * ((T) r_cpy_t_dur / (T) total_t_dur));
printf("S_ii cpy takes %22.2f%% of runtime.\n", 100 * ((T) s_cpy_t_dur / (T) total_t_dur));
printf("Norm R takes %22.2f%% of runtime.\n", 100 * ((T) norm_t_dur / (T) total_t_dur));
printf("Rest takes %22.2f%% of runtime.\n", 100 * ((T) t_rest / (T) total_t_dur));

printf("\nMain loop takes %22.2f%% of runtime.\n", 100 * ((T) main_loop_t_dur / (T) total_t_dur));
printf("/-------------RBKI TIMING RESULTS END-------------/\n\n");
}
}
Expand Down
34 changes: 22 additions & 12 deletions benchmark/bench_RBKI/RBKI_speed_comparisons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ static void update_best_time(int iter, long &t_best, long &t_curr, T* S1, T* S2,
blas::copy(k, S1, 1, S2, 1);
}
if (timing)
blas::copy(8, break_out, 1, break_in, 1);
blas::copy(13, break_out, 1, break_in, 1);
}
/*
template <typename T>
Expand Down Expand Up @@ -88,7 +88,7 @@ static void call_all_algs(
T err_rbki;
T err_lan;
int64_t k_lanc = std::min((int64_t) (num_krylov_iters / (T) 2), k);
bool time_subroutines = true;
bool time_subroutines = false;

// Set the threshold for Lanchosz
// Setting up Lanchosz - RBKI with k = 1.
Expand All @@ -112,8 +112,8 @@ static void call_all_algs(
//auto state_alg = state;

// Timing breakdown vectors;
std::vector<long> Lanc_timing_breakdown (8, 0.0);
std::vector<long> RBKI_timing_breakdown (8, 0.0);
std::vector<long> Lanc_timing_breakdown (13, 0.0);
std::vector<long> RBKI_timing_breakdown (13, 0.0);

for (i = 0; i < numruns; ++i) {
printf("Iteration %d start.\n", i);
Expand Down Expand Up @@ -180,14 +180,24 @@ static void call_all_algs(
printf("Reorthogonalization time: %25ld μs,\n", RBKI_timing_breakdown[3]);
printf("QR time: %25ld μs,\n", RBKI_timing_breakdown[4]);
printf("GEMM A time: %25ld μs,\n", RBKI_timing_breakdown[5]);

printf("\nAllocation takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[0] / (T) RBKI_timing_breakdown[7]));
printf("Factors takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[1] / (T) RBKI_timing_breakdown[7]));
printf("Ungqr takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[2] / (T) RBKI_timing_breakdown[7]));
printf("Reorth takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[3] / (T) RBKI_timing_breakdown[7]));
printf("QR takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[4] / (T) RBKI_timing_breakdown[7]));
printf("GEMM A takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[5] / (T) RBKI_timing_breakdown[7]));
printf("Rest takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[6] / (T) RBKI_timing_breakdown[7]));
printf("Sketching time: %25ld μs,\n", RBKI_timing_breakdown[7]);
printf("R_ii cpy time: %25ld μs,\n", RBKI_timing_breakdown[8]);
printf("S_ii cpy time: %25ld μs,\n", RBKI_timing_breakdown[9]);
printf("Norm time: %25ld μs,\n", RBKI_timing_breakdown[10]);

printf("\nAllocation takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[0] / (T) RBKI_timing_breakdown[12]));
printf("Factors takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[1] / (T) RBKI_timing_breakdown[12]));
printf("Ungqr takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[2] / (T) RBKI_timing_breakdown[12]));
printf("Reorth takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[3] / (T) RBKI_timing_breakdown[12]));
printf("QR takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[4] / (T) RBKI_timing_breakdown[12]));
printf("GEMM A takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[5] / (T) RBKI_timing_breakdown[12]));
printf("Sketching takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[7] / (T) RBKI_timing_breakdown[12]));
printf("R_ii cpy takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[8] / (T) RBKI_timing_breakdown[12]));
printf("S_ii cpy takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[9] / (T) RBKI_timing_breakdown[12]));
printf("Norm R takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[10] / (T) RBKI_timing_breakdown[12]));
printf("Rest takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[11] / (T) RBKI_timing_breakdown[12]));

printf("\nMain loop takes %22.2f%% of runtime.\n", 100 * ((T) RBKI_timing_breakdown[6] / (T) RBKI_timing_breakdown[12]));
printf("/-------------RBKI TIMING RESULTS END-------------/\n\n");
}

Expand Down

0 comments on commit 963fa0e

Please sign in to comment.