-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathruntime.mc
356 lines (335 loc) · 10.7 KB
/
runtime.mc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
include "math.mc"
include "sundials/sundials.mc"
include "sundials/ida.mc"
include "arg.mc"
include "option.mc"
include "ext/array-ext.mc"
let _debug = ref false
let debugPrint = lam msg.
if deref _debug then
printError (join ["LOG DEBUG:\t", msg ()]);
printError "\n";
flushStderr ()
else ()
let doLoop : Int -> (Int -> ()) -> ()
= lam n. lam f.
recursive let doLoop = lam i.
if eqi i n then ()
else f i; doLoop (succ i)
in
doLoop 0
type Vec = Tensor[Float]
type Mat = Tensor[Float]
let vget = tensorLinearGetExn
let vset = tensorLinearSetExn
let vlength = lam v. head (tensorShape v)
let viteri = lam f. lam v. tensorIterSlice (lam i. lam v. f i (tensorGetExn v [])) v
let vcreate = lam n. lam f. tensorCreateCArrayFloat [n] (lam idx. f (head idx))
-- let mget = lam m. lam i. lam j. tensorGetExn m [i, j]
let mset = sundialsMatrixDenseSet
let mupdate = sundialsMatrixDenseUpdate
let vecToString = lam name. lam v.
let n = vlength v in
join
(create n
(lam i.
join [name, "[", int2string i, "] = ", float2string (vget v i), "\n"]))
let stateToString = lam y. lam yp.
join [vecToString "y" y, "\n", vecToString "y'" yp]
let usage = lam. strJoin " " ["USAGE:", get argv 0, "interval", "stepsize"]
let errorExit = lam. print (usage ()); print "\n"; exit 1
type DAEInit = ([Float], [Float])
type DAEResf = Array Float -> Array Float -> Tensor[Float] -> ()
type DAEJacVals = [((Int, Int), () -> Float)]
type DAEJacf = Array Float -> Array Float -> (DAEJacVals, DAEJacVals)
type DAEOutf = Array Float -> Array Float -> ()
type Options = {
interval : Float,
stepSize : Float,
rtol : Float,
atol : Float,
outputOnlyLast : Bool,
benchmarkResidual : Option Int,
benchmarkJacobian : Option Int,
debug : Bool,
printStats : Bool,
seed : Int
}
let defaultOptions = {
interval = 20.,
stepSize = 0.1,
rtol = 1e-4,
atol = 1e-6,
outputOnlyLast = false,
debug = false,
dumpInfo = false,
seed = 0
}
let argConfig = [
([("--interval", " ", "<value>")],
"Simulation interval. ",
lam p. { p.options with interval = argToFloatMin p 0. }),
([("--step-size", " ", "<value>")],
"Interval where to output solution. ",
lam p. { p.options with stepSize = argToFloatMin p 0. }),
([("--rtol", " ", "<value>")],
"Relative tolerance. ",
lam p. { p.options with rtol = argToFloatMin p 0. }),
([("--atol", " ", "<value>")],
"Absolute tolerance. ",
lam p. { p.options with atol = argToFloatMin p 0. }),
([("--output-only-last", "", "")],
"Output only the solution after the last time-step. ",
lam p. { p.options with outputOnlyLast = true }),
([("--debug", "", "")],
"Debug runtime. ",
lam p. { p.options with debug = true }),
([("--dump-info", "", "")],
"Dump solver info to stderr. ",
lam p. { p.options with dumpInfo = true }),
([("--seed", " ", "<value>")],
"Random seed. ",
lam p. { p.options with seed = argToIntMin p 0 })
]
let usage = lam prog. join [
"Usage: ", prog, " [OPTION]\n\n",
"Options:\n",
argHelpOptions argConfig,
"\n"
]
let _randState = lam y.
doLoop (arrayLength y) (lam i.
arraySet y i (int2float (randIntU 0 10)))
let _benchUsage = lam prog. join [prog, " INTEGER\n"]
let parseArgs = lam n.
switch argParse defaultOptions argConfig
case ParseOK r then
-- Print menu if not exactly n arguments
if neqi (length r.strings) n then
print (usage (get argv 0));
exit 1
else (r.options, r.strings)
case result then
argPrintError result;
exit 1
end
let daeRuntimeBenchmarkRes : Int -> DAEResf -> ()
= lam n. lam resf.
match parseArgs 1 with (opt, [neval]) then
-- Set seed
randSetSeed opt.seed;
if stringIsInt neval then
let y = arrayCreateFloat n in
let r = vcreate n (lam. 0.) in
let neval = string2int neval in
let sum = ref 0. in
let ws = wallTimeMs () in
doLoop neval (lam.
_randState y;
resf y y r;
doLoop n (lam i.
modref sum (addf (deref sum) (vget r i))));
let wt = subf (wallTimeMs ()) ws in
print (join [
"Executed the residual ",
int2string neval,
" times in ",
float2string wt,
" ms, accumulating the residual value ",
float2string (deref sum),
"\n"
])
else
print (_benchUsage (get argv 0)); exit 1
else
print (_benchUsage (get argv 0)); exit 1
let daeRuntimeBenchmarkJac : Int -> DAEJacf -> DAEJacf -> ()
= lam n. lam jacYf. lam jacYpf.
match parseArgs 1 with (opt, [neval]) then
-- Set seed
randSetSeed opt.seed;
if stringIsInt neval then
let y = arrayCreateFloat n in
let neval = string2int neval in
let sum = ref 0. in
let ws = wallTimeMs () in
doLoop neval
(lam.
_randState y;
let jy = iter (lam f. modref sum (addf (deref sum) (f.1 ()))) in
let fs = jacYf y y in
jy fs.0;
jy fs.1;
let jyp = iter (lam f. modref sum (addf (deref sum) (f.1 ()))) in
let fs = jacYpf y y in
jyp fs.0;
jyp fs.1;
());
let wt = subf (wallTimeMs ()) ws in
print (join [
"Executed the Jacobian ",
int2string neval,
" times in ",
float2string wt,
" ms, accumulating the value ",
float2string (deref sum),
"\n"
])
else
print (_benchUsage (get argv 0)); exit 1
else
print (_benchUsage (get argv 0)); exit 1
let daeRuntimeRun
: Bool -> [Bool] -> DAEInit -> DAEResf -> DAEJacf -> DAEJacf -> DAEOutf -> ()
= lam numjac. lam varids. lam initVals. lam resf. lam jacYf. lam jacYpf. lam outf.
match parseArgs 0 with (opt, _) in
let n = length varids in
modref _debug opt.debug;
let resEvalCount = ref 0 in
let jacEvalCount = ref 0 in
let resTimeCount = ref 0. in
let jacTimeCount = ref 0. in
-- Set seed
randSetSeed opt.seed;
-- Initialize
match initVals with (y0, yp0) in
let tol = idaSSTolerances opt.rtol opt.atol in
let y = vcreate n (get y0) in
let yp = vcreate n (get yp0) in
-- Pre-allocate residual states
-- OPT(oerikss, 2023-10-21): We read faster from these compared to BigArrays
-- (i.e. Tensors)
let ay = arrayCreateFloat n in
let ayp = arrayCreateFloat n in
let resf = lam t. lam y. lam yp. lam r.
-- gather statistics
modref resEvalCount (succ (deref resEvalCount));
-- compute residual
doLoop n (lam i.
arraySet ay i (vget y i);
arraySet ayp i (vget yp i));
let ws = wallTimeMs () in
resf ay ayp r;
let we = wallTimeMs () in
modref resTimeCount (addf (deref resTimeCount) (subf we ws));
()
in
let r = vcreate n (lam. 0.) in
resf y yp r;
debugPrint
(lam. strJoin "\n"
["Initial residual:", stateToString y yp, vecToString "r" r]);
let v = nvectorSerialWrap y in
let vp = nvectorSerialWrap yp in
let m = sundialsMatrixDense n in
let nlsolver = sundialsNonlinearSolverNewtonMake v in
let lsolver =
if numjac then idaDlsSolver (idaDlsDense v m)
else
-- Pre-allocate Jacobian states
-- OPT(oerikss, 2022-10-21): See the comment for the residiual function
let ay = arrayCreateFloat n in
let ayp = arrayCreateFloat n in
let jacf = lam jacargs : IdaJacArgs. lam m : SundialsMatrixDense.
-- gather statistics
modref jacEvalCount (succ (deref jacEvalCount));
-- compute Jacobian
doLoop n (lam i.
arraySet ay i (vget jacargs.y i);
arraySet ayp i (vget jacargs.yp i));
-- let m = sundialsMatrixDenseUnwrap m in
let ws = wallTimeMs () in
let jy =
iter
(lam ijf.
match ijf with ((i, j), f) in
-- m is in column-major format
mset m i j (f ()))
in
let fs = jacYf ay ayp in
jy fs.0;
jy fs.1;
let jyp =
iter
(lam ijf.
match ijf with ((i, j), f) in
-- m is in column-major format
mupdate m i j (addf (mulf jacargs.c (f ()))))
in
let fs = jacYpf ay ayp in
jyp fs.0;
jyp fs.1;
let we = wallTimeMs () in
modref jacTimeCount (addf (deref jacTimeCount) (subf we ws));
()
in
idaDlsSolverJacf jacf (idaDlsDense v m)
in
let varid =
nvectorSerialWrap
(vcreate n
(lam i. if get varids i then idaVarIdDifferential
else idaVarIdAlgebraic))
in
let t0 = negf 1.e-4 in
let s = idaInit {
tol = tol,
nlsolver = nlsolver,
lsolver = lsolver,
resf = resf,
varid = varid,
roots = idaNoRoots,
t = t0,
y = v,
yp = vp
} in
idaCalcICYaYd s { tend = 0., y = v, yp = vp };
resf y yp r;
debugPrint
(lam. strJoin "\n"
["After idaCalcICYaYd:", stateToString y yp, vecToString "r" r]);
idaSetStopTime s opt.interval;
-- Solve
-- Pre-allocate output states
let ay = arrayCreateFloat n in
let ayp = arrayCreateFloat n in
recursive let recur = lam t.
doLoop n (lam i.
arraySet ay i (vget y i);
arraySet ayp i (vget yp i));
(if opt.outputOnlyLast then () else outf ay ayp);
if gtf t opt.interval then ()
else
switch idaSolveNormal s { tend = addf t opt.stepSize, y = v, yp = vp }
case (tend, IdaSuccess _) then recur tend
case (_, IdaStopTimeReached _) then
(if opt.outputOnlyLast then outf ay ayp else ()); ()
case (tend, IdaRootsFound _) then
printError (join ["Roots found at t = ", float2string tend]);
flushStderr ()
case (tend, IdaSolveError _) then
printError (join ["Solver error at t = ", float2string tend]);
flushStderr ()
end
in
recur 0.;
(if opt.dumpInfo then
print (join ["resvals: ", int2string (deref resEvalCount), "\n"]);
print (join ["jacevals: ", int2string (deref jacEvalCount), "\n"]);
print (join ["restime: ", float2string (deref resTimeCount), "\n"]);
print (join ["jactime: ", float2string (deref jacTimeCount), "\n"])
else ());
()
mexpr
-- Hack that allows us to parse this file with dead-code elimination
dprint [
dprint [daeRuntimeBenchmarkRes],
dprint [daeRuntimeBenchmarkJac],
dprint [daeRuntimeRun],
dprint [sin, cos, exp, sqrt],
dprint [pow],
dprint [arrayGet],
dprint [cArray1Set],
dprint [sundialsMatrixDenseSet],
dprint [sundialsMatrixDenseUpdate]
]