Skip to content

Commit

Permalink
Fix per Riley's comments
Browse files Browse the repository at this point in the history
  • Loading branch information
TeachRaccooon committed Apr 16, 2024
1 parent 2a0e797 commit 8911b6b
Show file tree
Hide file tree
Showing 9 changed files with 28 additions and 30 deletions.
8 changes: 4 additions & 4 deletions benchmark/bench_CQRRP/CQRRP_pivot_quality.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ static void R_norm_ratio(
std::iota(all_data.J.begin(), all_data.J.end(), 1);
//RandLAPACK::hqrrp(m, n, all_data.A.data(), m, all_data.J.data(), all_data.tau.data(), b_sz, (d_factor - 1) * b_sz, 0, 0, state, (T*) nullptr);
lapack::geqp3(m, n, all_data.A.data(), m, all_data.J.data(), all_data.tau.data());
std::vector<T> R_norms_HQRRP = get_norms<T>(all_data);
std::vector<T> R_norms_HQRRP = get_norms(all_data);
printf("\nDone with HQRRP\n");

// Clear and re-generate data
Expand All @@ -91,7 +91,7 @@ static void R_norm_ratio(
printf("\nStarting CQRRP\n");
// Running CQRRP
CQRRP_blocked.call(m, n, all_data.A.data(), m, d_factor, all_data.tau.data(), all_data.J.data(), state);
std::vector<T> R_norms_CQRRP = get_norms<T>(all_data);
std::vector<T> R_norms_CQRRP = get_norms(all_data);

// Declare a data file
std::fstream file1("data_out/QR_R_norm_ratios_rows_" + std::to_string(m)
Expand Down Expand Up @@ -139,7 +139,7 @@ static void sv_ratio(
lapack::gesdd(Job::NoVec, m, n, all_data.A.data(), m, all_data.S.data(), (T*) nullptr, m, (T*) nullptr, n);

// Clear and re-generate data
data_regen<T>(m_info, all_data, state);
data_regen(m_info, all_data, state);

// Running GEQP3
std::iota(all_data.J.begin(), all_data.J.end(), 1);
Expand All @@ -153,7 +153,7 @@ static void sv_ratio(
file2 << ",\n";

// Clear and re-generate data
data_regen<T>(m_info, all_data, state1);
data_regen(m_info, all_data, state1);

// Running CQRRP
CQRRP_blocked.call(m, n, all_data.A.data(), m, d_factor, all_data.tau.data(), all_data.J.data(), state);
Expand Down
4 changes: 2 additions & 2 deletions benchmark/bench_CQRRP/CQRRP_runtime_breakdown.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ static void data_regen(RandLAPACK::gen::mat_gen_info<T> m_info,
QR_speed_benchmark_data<T> &all_data,
RandBLAS::RNGState<RNG> &state) {

RandLAPACK::gen::mat_gen<double>(m_info, all_data.A.data(), state);
RandLAPACK::gen::mat_gen(m_info, all_data.A.data(), state);
std::fill(all_data.tau.begin(), all_data.tau.end(), 0.0);
std::fill(all_data.J.begin(), all_data.J.end(), 0);
}
Expand Down Expand Up @@ -95,7 +95,7 @@ static std::vector<long> call_all_algs(
state_gen_0 = state;
state_alg_0 = state;
// Clear and re-generate data
data_regen<T>(m_info, all_data, state_gen_0);
data_regen(m_info, all_data, state_gen_0);
}

return inner_timing_best;
Expand Down
10 changes: 4 additions & 6 deletions benchmark/bench_CQRRP/CQRRP_single_precision.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ static std::vector<long> call_all_algs(
CQRRP_blocked.nnz = 2;
CQRRP_blocked.num_threads = 48;
// We are nbot using panel pivoting in performance testing.
int panel_pivoting = 0;

// timing vars
long dur_cqrrp = 0;
long dur_geqrf = 0;
Expand All @@ -87,12 +85,12 @@ static std::vector<long> call_all_algs(
auto start_getrf = high_resolution_clock::now();
lapack::getrf(m, n, all_data_rest.A.data(), m, all_data_rest.J.data());
auto stop_getrf = high_resolution_clock::now();
auto dur_getrf = duration_cast<microseconds>(stop_getrf - start_getrf).count();
dur_getrf = duration_cast<microseconds>(stop_getrf - start_getrf).count();
printf("TOTAL TIME FOR GETRF %ld\n", dur_getrf);
// Update best timing
i == 0 ? t_getrf_best = dur_getrf : (dur_getrf < t_getrf_best) ? t_getrf_best = dur_getrf : NULL;

data_regen<T_rest>(m_info_rest, all_data_rest, state_gen, 0);
data_regen(m_info_rest, all_data_rest, state_gen, 0);
state_gen = state;

// Testing GEQRF
Expand All @@ -105,7 +103,7 @@ static std::vector<long> call_all_algs(
i == 0 ? t_geqrf_best = dur_geqrf : (dur_geqrf < t_geqrf_best) ? t_geqrf_best = dur_geqrf : NULL;

// Clear and re-generate data
data_regen<T_rest>(m_info_rest, all_data_rest, state_gen, 0);
data_regen(m_info_rest, all_data_rest, state_gen, 0);
state_gen = state;

// Testing CQRRP - best setup
Expand All @@ -118,7 +116,7 @@ static std::vector<long> call_all_algs(
i == 0 ? t_cqrrp_best = dur_cqrrp : (dur_cqrrp < t_cqrrp_best) ? t_cqrrp_best = dur_cqrrp : NULL;

// Clear and re-generate data
data_regen<T_cqrrp>(m_info_cqrrp, all_data_cqrrp, state_gen, 1);
data_regen(m_info_cqrrp, all_data_cqrrp, state_gen, 1);
state_gen = state;
state_alg = state;
}
Expand Down
10 changes: 5 additions & 5 deletions benchmark/bench_CQRRP/CQRRP_speed_comparisons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ static std::vector<long> call_all_algs(
auto dur_geqp3 = duration_cast<microseconds>(stop_geqp3 - start_geqp3).count();
printf("TOTAL TIME FOR GEQP3 %ld\n", dur_geqp3);

data_regen<T>(m_info, all_data, state_buf, 0);
data_regen(m_info, all_data, state_buf, 0);

// Testing GEQRF
auto start_geqrf = high_resolution_clock::now();
Expand All @@ -114,7 +114,7 @@ static std::vector<long> call_all_algs(
auto state_gen_1 = state_gen_0;
auto state_alg_1 = state_alg_0;
// Clear and re-generate data
data_regen<T>(m_info, all_data, state_gen_0, 0);
data_regen(m_info, all_data, state_gen_0, 0);

// Testing CQRRP - best setup
auto start_cqrrp = high_resolution_clock::now();
Expand All @@ -128,7 +128,7 @@ static std::vector<long> call_all_algs(
auto state_gen_3 = state_gen_1;
auto state_alg_3 = state_alg_1;
// Clear and re-generate data
data_regen<T>(m_info, all_data, state_gen_1, 1);
data_regen(m_info, all_data, state_gen_1, 1);

// Testing HQRRP with GEQRF
auto start_hqrrp_geqrf = high_resolution_clock::now();
Expand All @@ -143,7 +143,7 @@ static std::vector<long> call_all_algs(
auto state_gen_4 = state_gen_3;
auto state_alg_4 = state_alg_3;
// Clear and re-generate data
data_regen<T>(m_info, all_data, state_gen_3, 1);
data_regen(m_info, all_data, state_gen_3, 1);

// Testing HQRRP with Cholqr
auto start_hqrrp_cholqr = high_resolution_clock::now();
Expand All @@ -159,7 +159,7 @@ static std::vector<long> call_all_algs(
state_alg_0 = state_alg_4;
state_buf = state_gen_4;
// Clear and re-generate data
data_regen<T>(m_info, all_data, state_gen_4, 0);
data_regen(m_info, all_data, state_gen_4, 0);
}

printf("CQRRP takes %ld μs\n", t_cqrrp_best);
Expand Down
10 changes: 5 additions & 5 deletions benchmark/bench_CQRRPT/CQRRPT_pivot_quality.cc
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,14 @@ static void R_norm_ratio(

// Running HQRRP
lapack::geqp3(m, n, all_data.A.data(), m, all_data.J.data(), all_data.tau.data());
std::vector<T> R_norms_HQRRP = get_norms<T>(all_data);
std::vector<T> R_norms_HQRRP = get_norms(all_data);

// Clear and re-generate data
data_regen<T>(m_info, all_data, state);
data_regen(m_info, all_data, state);

// Running CQRRP
CQRRPT.call(m, n, all_data.A.data(), m, all_data.R.data(), n, all_data.J.data(), d_factor, state);
std::vector<T> R_norms_CQRRPT = get_norms<T>(all_data);
std::vector<T> R_norms_CQRRPT = get_norms(all_data);

// Declare a data file
std::fstream file1("data_out/QR_R_norm_ratios_rows_" + std::to_string(m)
Expand Down Expand Up @@ -131,7 +131,7 @@ static void sv_ratio(
lapack::gesdd(Job::NoVec, m, n, all_data.A.data(), m, all_data.S.data(), (T*) nullptr, m, (T*) nullptr, n);

// Clear and re-generate data
data_regen<T>(m_info, all_data, state);
data_regen(m_info, all_data, state);

// Running GEQP3
std::iota(all_data.J.begin(), all_data.J.end(), 1);
Expand All @@ -143,7 +143,7 @@ static void sv_ratio(
file2 << ",\n";

// Clear and re-generate data
data_regen<T>(m_info, all_data, state1);
data_regen(m_info, all_data, state1);

// Running CQRRP
CQRRPT.call(m, n, all_data.A.data(), m, all_data.R.data(), n, all_data.J.data(), d_factor, state);
Expand Down
2 changes: 1 addition & 1 deletion benchmark/bench_CQRRPT/CQRRPT_runtime_breakdown.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ static std::vector<long> call_all_algs(
state_alg = state;
state_gen = state;
// Clear and re-generate data
data_regen<T>(m_info, all_data, state_gen);
data_regen(m_info, all_data, state_gen);
}

return inner_timing_best;
Expand Down
10 changes: 5 additions & 5 deletions benchmark/bench_CQRRPT/CQRRPT_speed_comparisons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ static std::vector<long> call_all_algs(
dur_geqp3 = duration_cast<microseconds>(stop_geqp3 - start_geqp3).count();

state_gen = state;
data_regen<T>(m_info, all_data, state_gen);
data_regen(m_info, all_data, state_gen);

// Testing GEQRF
auto start_geqrf = high_resolution_clock::now();
Expand All @@ -106,7 +106,7 @@ static std::vector<long> call_all_algs(
dur_geqrf = duration_cast<microseconds>(stop_geqrf - start_geqrf).count();

state_gen = state;
data_regen<T>(m_info, all_data, state_gen);
data_regen(m_info, all_data, state_gen);

// Testing CQRRPT
auto start_cqrrp = high_resolution_clock::now();
Expand All @@ -116,7 +116,7 @@ static std::vector<long> call_all_algs(

state_gen = state;
state_alg = state;
data_regen<T>(m_info, all_data, state_gen);
data_regen(m_info, all_data, state_gen);

// Testing SCHOLQR3
auto start_scholqr = high_resolution_clock::now();
Expand All @@ -141,7 +141,7 @@ static std::vector<long> call_all_algs(
dur_scholqr = duration_cast<microseconds>(stop_scholqr - start_scholqr).count();

auto state_gen = state;
data_regen<T>(m_info, all_data, state_gen);
data_regen(m_info, all_data, state_gen);

// Testing GEQR + GEQPT
auto start_geqpt = high_resolution_clock::now();
Expand All @@ -164,7 +164,7 @@ static std::vector<long> call_all_algs(
dur_geqpt = duration_cast<microseconds>(stop_geqpt - start_geqpt).count();

state_gen = state;
data_regen<T>(m_info, all_data, state_gen);
data_regen(m_info, all_data, state_gen);

i == 0 ? t_cqrrpt_best = dur_cqrrpt : (dur_cqrrpt < t_cqrrpt_best) ? t_cqrrpt_best = dur_cqrrpt : NULL;
i == 0 ? t_geqpt_best = dur_geqpt : (dur_geqpt < t_geqpt_best) ? t_geqpt_best = dur_geqpt : NULL;
Expand Down
2 changes: 1 addition & 1 deletion benchmark/bench_RBKI/RBKI_runtime_breakdown.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ static void call_all_algs(
file << "\n";

// Clear and re-generate data
data_regen<T>(m_info, all_data, state_gen, 0);
data_regen(m_info, all_data, state_gen, 0);
state_gen = state;
}
}
Expand Down
2 changes: 1 addition & 1 deletion benchmark/bench_RBKI/RBKI_speed_comparisons.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ static void call_all_algs(
std::ofstream file(output_filename, std::ios::app);
file << b_sz << ", " << RBKI.max_krylov_iters << ", " << target_rank << ", " << custom_rank << ", " << residual_err_target << ", " << residual_err_custom << ", " << dur_rbki << ", " << dur_svd << ",\n";
state_gen = state;
data_regen<T>(m_info, all_data, state_gen, 0);
data_regen(m_info, all_data, state_gen, 0);
}
}

Expand Down

0 comments on commit 8911b6b

Please sign in to comment.