Skip to content

Commit

Permalink
Remove unnecessary getArrayMinMax3D function
Browse files Browse the repository at this point in the history
  • Loading branch information
maddyscientist committed Nov 14, 2024
1 parent 1918764 commit 1ddda7c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 23 deletions.
4 changes: 0 additions & 4 deletions include/eigensolve_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,6 @@ namespace quda
// Local enum for the LU axpy block type
enum blockType { PENCIL, LOWER_TRI, UPPER_TRI };

// Local enum for the TRLM3D array min/mx sum
enum extremumType { MIN, MAX };

class EigenSolver
{
using range = std::pair<int, int>;
Expand Down Expand Up @@ -559,7 +556,6 @@ namespace quda
*/
void computeEvals(std::vector<ColorSpinorField> &evecs, std::vector<Complex> &evals, int size = 0) override;

template <extremumType min_max, typename T> T getArrayMinMax3D(const std::vector<T> &array);
};

/**
Expand Down
25 changes: 6 additions & 19 deletions lib/eig_trlm_3d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ namespace quda
while (restart_iter < max_restarts && !converged) {

// Get min step
int step_min = getArrayMinMax3D<MIN>(num_locked_3D);
int step_min = *std::min_element(num_locked_3D.begin(), num_locked_3D.end());
comm_allreduce_min(step_min);

for (int step = step_min; step < n_kr; step++) lanczosStep3D(kSpace, step);
iter += (n_kr - step_min);

Expand Down Expand Up @@ -595,7 +597,9 @@ namespace quda
// Compute spectral radius estimate
std::vector<double> inner_products(ortho_dim_size, 0.0);
blas3d::reDotProduct(inner_products, out, in);
double result = getArrayMinMax3D<MAX>(inner_products);

auto result = *std::max_element(inner_products.begin(), inner_products.end());
comm_allreduce_max(result);
logQuda(QUDA_VERBOSE, "Chebyshev max %e\n", result);

// Increase final result by 10% for safety
Expand Down Expand Up @@ -657,21 +661,4 @@ namespace quda
}
}

template <extremumType min_max, typename T>
T TRLM3D::getArrayMinMax3D(const std::vector<T> &array)
{
T ret_val;
if constexpr (min_max == MIN) {
ret_val = *std::min_element(array.begin(), array.end());
comm_allreduce_min(ret_val);
} else if constexpr (min_max == MAX) {
ret_val = *std::max_element(array.begin(), array.end());
comm_allreduce_max(ret_val);
} else {
errorQuda("Unknown extremumType %d", min_max);
}

return ret_val;
}

} // namespace quda

0 comments on commit 1ddda7c

Please sign in to comment.