Skip to content

Commit c75486b

Browse files
authored
Use autodiff in templated network rate evaluation (#1614)
Note: this fixes an issue with powerlaw, which previously did not have any derivatives computed
1 parent 3f62e99 commit c75486b

File tree

18 files changed

+1478
-2557
lines changed

18 files changed

+1478
-2557
lines changed

interfaces/rhs_type.H

+7-7
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,13 @@ struct rate_tab_t
9595
}
9696
};
9797

98-
// number_t is currently only used in the screening routines
99-
template <typename number_t = amrex::Real>
98+
// number_t is currently used in the screening routines and templated network
99+
// rate evaluation
100+
template <typename number_t>
100101
struct rhs_state_t
101102
{
102103
amrex::Real rho;
103-
tf_t tf;
104+
tf_t<number_t> tf;
104105
rate_tab_t tab;
105106
#ifdef SCREENING
106107
plasma_state_t<number_t> pstate;
@@ -110,12 +111,11 @@ struct rhs_state_t
110111
amrex::Array1D<amrex::Real, 1, NumSpec> y;
111112
};
112113

114+
template <typename number_t>
113115
struct rate_t
114116
{
115-
amrex::Real fr;
116-
amrex::Real rr;
117-
amrex::Real frdt;
118-
amrex::Real rrdt;
117+
number_t fr;
118+
number_t rr;
119119
};
120120

121121
} // namespace RHS

interfaces/tfactors.H

+73-69
Original file line numberDiff line numberDiff line change
@@ -2,76 +2,80 @@
22
#define TFACTORS_H
33

44
#include <AMReX.H>
5+
#include <AMReX_REAL.H>
6+
#include <microphysics_autodiff.H>
57
#include <cmath>
68

79
using namespace amrex::literals;
810

11+
template <typename number_t>
912
struct tf_t {
10-
amrex::Real temp;
11-
amrex::Real t9;
12-
amrex::Real t92;
13-
amrex::Real t93;
14-
// amrex::Real t94;
15-
amrex::Real t95;
16-
// amrex::Real t96;
17-
amrex::Real t912;
18-
amrex::Real t932;
19-
amrex::Real t952;
20-
amrex::Real t972;
21-
amrex::Real t913;
22-
amrex::Real t923;
23-
amrex::Real t943;
24-
amrex::Real t953;
25-
// amrex::Real t973;
26-
// amrex::Real t9113;
27-
// amrex::Real t914;
28-
// amrex::Real t934;
29-
// amrex::Real t954;
30-
// amrex::Real t974;
31-
// amrex::Real t915;
32-
// amrex::Real t935;
33-
// amrex::Real t945;
34-
// amrex::Real t965;
35-
// amrex::Real t917;
36-
// amrex::Real t927;
37-
// amrex::Real t947;
38-
// amrex::Real t918;
39-
// amrex::Real t938;
40-
// amrex::Real t958;
41-
amrex::Real t9i;
42-
amrex::Real t9i2;
43-
// amrex::Real t9i3;
44-
amrex::Real t9i12;
45-
amrex::Real t9i32;
46-
// amrex::Real t9i52;
47-
// amrex::Real t9i72;
48-
amrex::Real t9i13;
49-
amrex::Real t9i23;
50-
amrex::Real t9i43;
51-
amrex::Real t9i53;
52-
// amrex::Real t9i14;
53-
// amrex::Real t9i34;
54-
// amrex::Real t9i54;
55-
// amrex::Real t9i15;
56-
// amrex::Real t9i35;
57-
// amrex::Real t9i45;
58-
// amrex::Real t9i65;
59-
// amrex::Real t9i17;
60-
// amrex::Real t9i27;
61-
// amrex::Real t9i47;
62-
// amrex::Real t9i18;
63-
// amrex::Real t9i38;
64-
// amrex::Real t9i58;
65-
// amrex::Real t916;
66-
// amrex::Real t976;
67-
// amrex::Real t9i76;
68-
amrex::Real lnt9;
13+
number_t temp;
14+
number_t t9;
15+
number_t t92;
16+
number_t t93;
17+
// number_t t94;
18+
number_t t95;
19+
// number_t t96;
20+
number_t t912;
21+
number_t t932;
22+
number_t t952;
23+
number_t t972;
24+
number_t t913;
25+
number_t t923;
26+
number_t t943;
27+
number_t t953;
28+
// number_t t973;
29+
// number_t t9113;
30+
// number_t t914;
31+
// number_t t934;
32+
// number_t t954;
33+
// number_t t974;
34+
// number_t t915;
35+
// number_t t935;
36+
// number_t t945;
37+
// number_t t965;
38+
// number_t t917;
39+
// number_t t927;
40+
// number_t t947;
41+
// number_t t918;
42+
// number_t t938;
43+
// number_t t958;
44+
number_t t9i;
45+
number_t t9i2;
46+
// number_t t9i3;
47+
number_t t9i12;
48+
number_t t9i32;
49+
// number_t t9i52;
50+
// number_t t9i72;
51+
number_t t9i13;
52+
number_t t9i23;
53+
number_t t9i43;
54+
number_t t9i53;
55+
// number_t t9i14;
56+
// number_t t9i34;
57+
// number_t t9i54;
58+
// number_t t9i15;
59+
// number_t t9i35;
60+
// number_t t9i45;
61+
// number_t t9i65;
62+
// number_t t9i17;
63+
// number_t t9i27;
64+
// number_t t9i47;
65+
// number_t t9i18;
66+
// number_t t9i38;
67+
// number_t t9i58;
68+
// number_t t916;
69+
// number_t t976;
70+
// number_t t9i76;
71+
number_t lnt9;
6972
};
7073

74+
template <typename number_t>
7175
AMREX_GPU_HOST_DEVICE inline
72-
tf_t get_tfactors(amrex::Real temp)
76+
tf_t<number_t> get_tfactors(number_t temp)
7377
{
74-
tf_t tf;
78+
tf_t<number_t> tf;
7579

7680
tf.temp = temp;
7781

@@ -83,38 +87,38 @@ tf_t get_tfactors(amrex::Real temp)
8387
tf.t95 = tf.t92*tf.t93;
8488
// tf.t96 = tf.t9*tf.t95;
8589

86-
tf.t912 = std::sqrt(tf.t9);
90+
tf.t912 = admath::sqrt(tf.t9);
8791
tf.t932 = tf.t9*tf.t912;
8892
tf.t952 = tf.t9*tf.t932;
8993
// tf.t972 = tf.t9*tf.t952;
9094
tf.t972 = tf.t92*tf.t932;
9195

92-
tf.t913 = std::cbrt(tf.t9);
96+
tf.t913 = admath::cbrt(tf.t9);
9397
tf.t923 = tf.t913*tf.t913;
9498
tf.t943 = tf.t9*tf.t913;
9599
tf.t953 = tf.t9*tf.t923;
96100
// tf.t973 = tf.t953*tf.t923;
97101
// tf.t9113 = tf.t973*tf.t943;
98102

99-
// tf.t914 = std::pow(tf.t9, 0.25e0_rt);
103+
// tf.t914 = admath::pow(tf.t9, 0.25e0_rt);
100104
// tf.t934 = tf.t914*tf.t914*tf.t914;
101105
// tf.t954 = tf.t9*tf.t914;
102106
// tf.t974 = tf.t9*tf.t934;
103107

104-
// tf.t915 = std::pow(tf.t9, 0.2_rt);
108+
// tf.t915 = admath::pow(tf.t9, 0.2_rt);
105109
// tf.t935 = tf.t915*tf.t915*tf.t915;
106110
// tf.t945 = tf.t915 * tf.t935;
107111
// tf.t965 = tf.t9 * tf.t915;
108112

109-
// tf.t916 = std::pow(tf.t9, 1.0_rt/6.0_rt);
113+
// tf.t916 = admath::pow(tf.t9, 1.0_rt/6.0_rt);
110114
// tf.t976 = tf.t9 * tf.t916;
111115
// tf.t9i76 = 1.0e0_rt/tf.t976;
112116

113-
// tf.t917 = std::pow(tf.t9, 1.0_rt/7.0_rt);
117+
// tf.t917 = admath::pow(tf.t9, 1.0_rt/7.0_rt);
114118
// tf.t927 = tf.t917*tf.t917;
115119
// tf.t947 = tf.t927*tf.t927;
116120

117-
// tf.t918 = std::sqrt(tf.t914);
121+
// tf.t918 = admath::sqrt(tf.t914);
118122
// tf.t938 = tf.t918*tf.t918*tf.t918;
119123
// tf.t958 = tf.t938*tf.t918*tf.t918;
120124

@@ -149,7 +153,7 @@ tf_t get_tfactors(amrex::Real temp)
149153
// tf.t9i38 = tf.t9i18*tf.t9i18*tf.t9i18;
150154
// tf.t9i58 = tf.t9i38*tf.t9i18*tf.t9i18;
151155

152-
tf.lnt9 = std::log(tf.t9);
156+
tf.lnt9 = admath::log(tf.t9);
153157

154158
return tf;
155159
}

0 commit comments

Comments
 (0)