Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove reinterpret_cast in rhs.H #1435

Merged
merged 9 commits into from
Mar 15, 2024
26 changes: 12 additions & 14 deletions networks/rhs.H
Original file line number Diff line number Diff line change
Expand Up @@ -273,19 +273,17 @@ constexpr int is_jacobian_term_used ()
}

AMREX_GPU_HOST_DEVICE AMREX_INLINE
void dgesl (RArray2D& a1, RArray1D& b1)
void dgesl (const RArray2D& a, RArray1D& b)
{
auto const& a = reinterpret_cast<amrex::Array2D<amrex::Real, 0, INT_NEQS-1, 0, INT_NEQS-1>&>(a1);
auto& b = reinterpret_cast<amrex::Array1D<amrex::Real, 0, INT_NEQS-1>&>(b1);

// solve a * x = b
// first solve l * y = b
constexpr_for<0, INT_NEQS-1>([&] (auto n1)
constexpr_for<1, INT_NEQS>([&] (auto n1)
{
constexpr int k = n1;

amrex::Real t = b(k);
constexpr_for<k+1, INT_NEQS>([&] (auto n2)
constexpr_for<k+1, INT_NEQS+1>([&] (auto n2)
{
constexpr int j = n2;

Expand All @@ -294,47 +292,47 @@ void dgesl (RArray2D& a1, RArray1D& b1)
});

// now solve u * x = y
constexpr_for<0, INT_NEQS>([&] (auto kb)
constexpr_for<1, INT_NEQS+1>([&] (auto kb)
{
constexpr int k = INT_NEQS - kb - 1;
constexpr int k = INT_NEQS + 1 - kb;

b(k) = b(k) / a(k,k);
amrex::Real t = -b(k);

constexpr_for<0, k>([&] (auto j)
constexpr_for<1, k>([&] (auto j)
{
b(j) += t * a(j,k);
});
});
}

AMREX_GPU_HOST_DEVICE AMREX_INLINE
void dgefa (RArray2D& a1)
void dgefa (RArray2D& a)
{
auto& a = reinterpret_cast<amrex::Array2D<amrex::Real, 0, INT_NEQS-1, 0, INT_NEQS-1>&>(a1);

// LU factorization in-place without pivoting.

constexpr_for<0, INT_NEQS-1>([&] (auto n1)
constexpr_for<1, INT_NEQS>([&] (auto n1)
{
[[maybe_unused]] constexpr int k = n1;

// compute multipliers

amrex::Real t = -1.0_rt / a(k,k);
constexpr_for<k+1, INT_NEQS>([&] (auto n2)
constexpr_for<k+1, INT_NEQS+1>([&] (auto n2)
{
[[maybe_unused]] constexpr int j = n2;

a(j,k) *= t;
});

// row elimination with column indexing
constexpr_for<k+1, INT_NEQS>([&] (auto n2)
constexpr_for<k+1, INT_NEQS+1>([&] (auto n2)
{
[[maybe_unused]] constexpr int j = n2;

t = a(k,j);
constexpr_for<k+1, INT_NEQS>([&] (auto n3)
constexpr_for<k+1, INT_NEQS+1>([&] (auto n3)
{
[[maybe_unused]] constexpr int i = n3;

Expand Down