Skip to content

Commit 59ba531

Browse files
committed
template pivoting in the linear algebra
now we can do `integrator.linalg_do_pivoting=0` to disable pivoting
1 parent a4d7ab4 commit 59ba531

File tree

5 files changed

+73
-30
lines changed

5 files changed

+73
-30
lines changed

integration/BackwardEuler/be_integrator.H

+15-2
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,27 @@ int single_step (BurnT& state, BeT& be, const Real dt)
9696

9797
int ierr_linpack;
9898
IArray1D pivot;
99-
dgefa<int_neqs>(be.jac, pivot, ierr_linpack);
99+
100+
if (integrator_rp::linalg_do_pivoting == 1) {
101+
constexpr bool allow_pivot{true};
102+
dgefa<int_neqs, allow_pivot>(be.jac, pivot, ierr_linpack);
103+
} else {
104+
constexpr bool allow_pivot{false};
105+
dgefa<int_neqs, allow_pivot>(be.jac, pivot, ierr_linpack);
106+
}
100107

101108
if (ierr_linpack != 0) {
102109
ierr = IERR_LU_DECOMPOSITION_ERROR;
103110
break;
104111
}
105112

106-
dgesl<int_neqs>(be.jac, pivot, b);
113+
if (integrator_rp::linalg_do_pivoting == 1) {
114+
constexpr bool allow_pivot{true};
115+
dgesl<int_neqs, allow_pivot>(be.jac, pivot, b);
116+
} else {
117+
constexpr bool allow_pivot{false};
118+
dgesl<int_neqs, allow_pivot>(be.jac, pivot, b);
119+
}
107120

108121
// update our current guess for the solution
109122

integration/VODE/vode_dvjac.H

+7-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,13 @@ void dvjac (IArray1D& pivot, int& IERPJ, BurnT& state, DvodeT& vstate)
173173
RHS::dgefa(vstate.jac);
174174
IER = 0;
175175
#else
176-
dgefa<int_neqs>(vstate.jac, pivot, IER);
176+
if (integrator_rp::linalg_do_pivoting == 1) {
177+
constexpr bool allow_pivot{true};
178+
dgefa<int_neqs, allow_pivot>(vstate.jac, pivot, IER);
179+
} else {
180+
constexpr bool allow_pivot{false};
181+
dgefa<int_neqs, allow_pivot>(vstate.jac, pivot, IER);
182+
}
177183
#endif
178184

179185
if (IER != 0) {

integration/VODE/vode_dvnlsd.H

+7-1
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,13 @@ Real dvnlsd (IArray1D& pivot, int& NFLAG, BurnT& state, DvodeT& vstate)
114114
#ifdef NEW_NETWORK_IMPLEMENTATION
115115
RHS::dgesl(vstate.jac, vstate.y);
116116
#else
117-
dgesl<int_neqs>(vstate.jac, pivot, vstate.y);
117+
if (integrator_rp::linalg_do_pivoting == 1) {
118+
constexpr bool allow_pivot{true};
119+
dgesl<int_neqs, allow_pivot>(vstate.jac, pivot, vstate.y);
120+
} else {
121+
constexpr bool allow_pivot{false};
122+
dgesl<int_neqs, allow_pivot>(vstate.jac, pivot, vstate.y);
123+
}
118124
#endif
119125

120126
if (vstate.RC != 1.0_rt) {

integration/_parameters

+3
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,6 @@ nse_deriv_dt_factor real 0.05
8888

8989
# for NSE update, do we include the weak rate neutrino losses?
9090
nse_include_enu_weak integer 1
91+
92+
# for the linear algebra, do we allow pivoting?
93+
linalg_do_pivoting integer 1

util/linpack.H

+41-26
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#include <ArrayUtilities.H>
88

9-
template <int num_eqs>
9+
template <int num_eqs, bool allow_pivot>
1010
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
1111
void dgesl (RArray2D& a, IArray1D& pivot, RArray1D& b)
1212
{
@@ -17,11 +17,17 @@ void dgesl (RArray2D& a, IArray1D& pivot, RArray1D& b)
1717
// first solve l * y = b
1818
if (nm1 >= 1) {
1919
for (int k = 1; k <= nm1; ++k) {
20-
int l = pivot(k);
21-
Real t = b(l);
22-
if (l != k) {
23-
b(l) = b(k);
24-
b(k) = t;
20+
21+
Real t{};
22+
if constexpr (allow_pivot) {
23+
int l = pivot(k);
24+
t = b(l);
25+
if (l != k) {
26+
b(l) = b(k);
27+
b(k) = t;
28+
}
29+
} else {
30+
t = b(k);
2531
}
2632

2733
for (int j = k+1; j <= num_eqs; ++j) {
@@ -45,7 +51,7 @@ void dgesl (RArray2D& a, IArray1D& pivot, RArray1D& b)
4551

4652

4753

48-
template <int num_eqs>
54+
template <int num_eqs, bool allow_pivot>
4955
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
5056
void dgefa (RArray2D& a, IArray1D& pivot, int& info)
5157
{
@@ -68,24 +74,29 @@ void dgefa (RArray2D& a, IArray1D& pivot, int& info)
6874

6975
// find l = pivot index
7076
int l = k;
71-
Real dmax = std::abs(a(k,k));
72-
for (int i = k+1; i <= num_eqs; ++i) {
73-
if (std::abs(a(i,k)) > dmax) {
74-
l = i;
75-
dmax = std::abs(a(i,k));
77+
78+
if constexpr (allow_pivot) {
79+
Real dmax = std::abs(a(k,k));
80+
for (int i = k+1; i <= num_eqs; ++i) {
81+
if (std::abs(a(i,k)) > dmax) {
82+
l = i;
83+
dmax = std::abs(a(i,k));
84+
}
7685
}
77-
}
7886

79-
pivot(k) = l;
87+
pivot(k) = l;
88+
}
8089

8190
// zero pivot implies this column already triangularized
8291
if (a(l,k) != 0.0e0_rt) {
8392

84-
// interchange if necessary
85-
if (l != k) {
86-
t = a(l,k);
87-
a(l,k) = a(k,k);
88-
a(k,k) = t;
93+
if constexpr (allow_pivot) {
94+
// interchange if necessary
95+
if (l != k) {
96+
t = a(l,k);
97+
a(l,k) = a(k,k);
98+
a(k,k) = t;
99+
}
89100
}
90101

91102
// compute multipliers
@@ -97,26 +108,30 @@ void dgefa (RArray2D& a, IArray1D& pivot, int& info)
97108
// row elimination with column indexing
98109
for (int j = k+1; j <= num_eqs; ++j) {
99110
t = a(l,j);
100-
if (l != k) {
101-
a(l,j) = a(k,j);
102-
a(k,j) = t;
111+
112+
if constexpr (allow_pivot) {
113+
if (l != k) {
114+
a(l,j) = a(k,j);
115+
a(k,j) = t;
116+
}
103117
}
118+
104119
for (int i = k+1; i <= num_eqs; ++i) {
105120
a(i,j) += t * a(i,k);
106121
}
107122
}
108-
}
109-
else {
110123

124+
} else {
111125
info = k;
112-
113126
}
114127

115128
}
116129

117130
}
118131

119-
pivot(num_eqs) = num_eqs;
132+
if constexpr (allow_pivot) {
133+
pivot(num_eqs) = num_eqs;
134+
}
120135

121136
if (a(num_eqs,num_eqs) == 0.0e0_rt) {
122137
info = num_eqs;

0 commit comments

Comments
 (0)