forked from LupoLab/Luna.jl
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathRK45.jl
410 lines (377 loc) · 13.1 KB
/
RK45.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
module RK45
import Dates
import Logging
import Printf: @sprintf
import Luna.Utils: format_elapsed
#Get Butcher tableau etc from separate file (for convenience of changing if wanted)
include("dopri.jl")
function solve(f!, y0, t, dt, tmax;
rtol=1e-6, atol=1e-10, safety=0.9, max_dt=Inf, min_dt=0, locextrap=true,
norm=weaknorm,
kwargs...)
stepper = Stepper(f!, y0, t, dt,
rtol=rtol, atol=atol, safety=safety, max_dt=max_dt, min_dt=min_dt, locextrap=locextrap, norm=norm)
return solve(stepper, tmax; kwargs...)
end
function solve_precon(f!, linop, y0, t, dt, tmax;
rtol=1e-6, atol=1e-10, safety=0.9, max_dt=Inf, min_dt=0, locextrap=true, norm=weaknorm,
kwargs...)
stepper = PreconStepper(f!, linop, y0, t, dt,
rtol=rtol, atol=atol, safety=safety, max_dt=max_dt, min_dt=min_dt, locextrap=locextrap, norm=norm)
return solve(stepper, tmax; kwargs...)
end
function solve(s, tmax; stepfun=donothing!, output=false, outputN=201,
status_period=1, repeat_limit=10)
if output
yout = Array{eltype(s.y)}(undef, (size(s.y)..., outputN))
tout = range(s.t, stop=tmax, length=outputN)
saved = 1
yout[fill(:, ndims(s.y))..., 1] = s.y
end
steps = 0
repeated = 0
repeated_tot = 0
Logging.@info "Starting propagation"
start = Dates.now()
tic = Dates.now()
while s.tn <= tmax
ok = step!(s)
steps += 1
if Dates.value(Dates.now()-tic) > 1000*status_period
speed = s.tn/(Dates.value(Dates.now()-start)/1000)
eta_in_s = (tmax-s.tn)/(speed)
if eta_in_s > 356400
Logging.@info @sprintf("Progress: %.2f %%, ETA: XX:XX:XX, stepsize %.2e, err %.2f, repeated %d",
s.tn/tmax*100, s.dt, s.err, repeated_tot)
else
eta_in_ms = Dates.Millisecond(ceil(eta_in_s*1000))
etad = Dates.DateTime(Dates.UTInstant(eta_in_ms))
Logging.@info @sprintf("Progress: %.2f %%, ETA: %s, stepsize %.2e, err %.2f, repeated %d",
s.tn/tmax*100, Dates.format(etad, "HH:MM:SS"), s.dt, s.err, repeated_tot)
end
tic = Dates.now()
end
if ok
if output
while (saved<outputN) && tout[saved+1] < s.tn
ti = tout[saved+1]
yout[fill(:, ndims(s.y))..., saved+1] .= interpolate(s, ti)
saved += 1
end
end
stepfun(s.yn, s.tn, s.dtn, t -> interpolate(s, t))
repeated = 0
else
repeated += 1
repeated_tot += 1
if repeated > repeat_limit
error("Reached limit for step repetition ($repeat_limit)")
end
end
end
totaltime = Dates.now()-start
dtstring = format_elapsed(totaltime)
Logging.@info @sprintf("Propagation finished in %s, %d steps",
dtstring, steps)
if output
return collect(tout), yout, steps
else
return nothing
end
end
mutable struct Stepper{T<:AbstractArray, F, nT}
f!::F # RHS function
y::T # Solution at current t
yn::T # Solution at t+dt
yi::T # Interpolant array (see interpolate())
yerr::T # solution error estimate (from embedded RK)
ks::NTuple{7, T} # k values (intermediate solutions for Runge-Kutta method)
t::Float64 # current time (propagation variable)
tn::Float64 # next time
dt::Float64 # time step
dtn::Float64 # time step for next step
rtol::Float64 # relative tolerance on error
atol::Float64 # absolute tolerance on error
safety::Float64 # safety factor for stepsize control
max_dt::Float64 # maximum value for dt (default Inf)
min_dt::Float64 # minimum value for dt (default 0)
locextrap::Bool # true if using local extrapolation
ok::Bool # true if current step was successful
err::Float64 # error metric to be compared to tol
errlast::Float64 # error of the most recent successful step
norm::nT # function to calculate error metric, defaults to RK45.weaknorm
end
function Stepper(f!, y0, t, dt;
rtol=1e-6, atol=1e-10, safety=0.9, max_dt=Inf, min_dt=0,
locextrap=true, norm=weaknorm)
k1 = similar(y0)
f!(k1, y0, t)
ks = (k1, similar(k1), similar(k1), similar(k1), similar(k1), similar(k1), similar(k1))
yerr = similar(y0)
return Stepper(f!, copy(y0), copy(y0), similar(y0), yerr, ks,
float(t), float(t), float(dt), float(dt),
float(rtol), float(atol), float(safety), float(max_dt), float(min_dt),
locextrap, false, 0.0, 0.0, norm)
end
mutable struct PreconStepper{T<:AbstractArray, F, P, nT}
fbar!::F # RHS callable
prop!::P # linear propagator callable
y::T # Solution at current t
yn::T # Solution at t+dt
yi::T # Interpolant array (see interpolate())
yerr::T # solution error estimate (from embedded RK)
ks::NTuple{7, T} # k values (intermediate solutions for Runge-Kutta method)
t::Float64 # current time (propagation variable)
tn::Float64 # next time
dt::Float64 # time step
dtn::Float64 # time step for next step
rtol::Float64 # relative tolerance on error
atol::Float64 # absolute tolerance on error
safety::Float64 # safety factor for stepsize control
max_dt::Float64 # maximum value for dt (default Inf)
min_dt::Float64 # minimum value for dt (default 0)
locextrap::Bool # true if using local extrapolation
ok::Bool # true if current step was successful
err::Float64 # error metric to be compared to tol
errlast::Float64 # error of the most recent successful step
norm::nT # function to calculate error metric, defaults to RK45.weaknorm
end
function PreconStepper(f!, linop, y0, t, dt;
rtol=1e-6, atol=1e-10, safety=0.9, max_dt=Inf, min_dt=0,
locextrap=true, norm=weaknorm)
prop! = make_prop!(linop, y0)
fbar! = make_fbar!(f!, prop!, y0)
k1 = similar(y0)
fbar!(k1, y0, t, t)
ks = (k1, similar(k1), similar(k1), similar(k1), similar(k1), similar(k1), similar(k1))
yerr = similar(y0)
return PreconStepper(fbar!, prop!, copy(y0), copy(y0), similar(y0), yerr, ks,
float(t), float(t), float(dt), float(dt), float(rtol), float(atol), float(safety),
float(max_dt), float(min_dt), locextrap, false, 0.0, 0.0, norm)
end
function step!(s)
evaluate!(s)
if s.locextrap
s.yn .= s.y
for jj = 1:7
b5[jj] == 0 || (s.yn .+= s.dt*b5[jj].*s.ks[jj])
end
end
fill!(s.yerr, 0)
for ii = 1:7
errest[ii] == 0 || (@. s.yerr += s.dt*s.ks[ii]*errest[ii])
end
s.err = s.norm(s.yerr, s.y, s.yn, s.rtol, s.atol)
s.ok = s.err <= 1
stepcontrolPI!(s)
if s.ok
s.tn = s.t + s.dt
s.ks[1] .= s.ks[end]
else
s.yn .= s.y
end
prop!_maybe(s) # propagate to new time to pass correct solution to stepfun
return s.ok
end
function evaluate!(s::Stepper)
# Set new time and stepsize values -- this happens at the beginning because
# the interpolant still requires the old values after the step has finished
s.dt = s.dtn
s.t = s.tn
s.y .= s.yn
for ii = 1:6
s.yn .= s.y
for jj = 1:ii
B[ii][jj] == 0 || (s.yn .+= s.dt*B[ii][jj].*s.ks[jj])
end
s.f!(s.ks[ii+1], s.yn, s.t+nodes[ii]*s.dt)
end
end
function evaluate!(s::PreconStepper)
# Set new time and stepsize values -- this happens at the beginning because
# the interpolant still requires the old values after the step has finished
s.y .= s.yn
s.prop!(s.ks[1], s.t, s.tn)
s.dt = s.dtn
s.t = s.tn
for ii = 1:6
s.yn .= s.y
for jj = 1:ii
B[ii][jj] == 0 || (s.yn .+= s.dt*B[ii][jj].*s.ks[jj])
end
s.fbar!(s.ks[ii+1], s.yn, s.t, s.t+nodes[ii]*s.dt)
end
end
prop!_maybe(s::PreconStepper) = s.prop!(s.yn, s.t, s.tn)
prop!_maybe(s) = nothing
"Interpolate solution, aka dense output."
function interpolate(s::Stepper, ti::Float64)
if ti > s.tn
error("Attempting to extrapolate!")
end
if ti == s.t
return s.y
elseif ti == s.tn
return s.yn
end
σ = (ti - s.t)/s.dt
σp = map(p -> σ^p, range(1, stop=4))
b = sum(σp.*interpC, dims=1)
fill!(s.yi, 0)
for ii = 1:7
s.yi .+= s.ks[ii].*b[ii]
end
return @. s.y + s.dt.*s.yi
end
"Interpolate solution, aka dense output."
function interpolate(s::PreconStepper, ti::Float64)
if ti > s.tn
error("Attempting to extrapolate!")
end
if ti == s.t
return s.y
elseif ti == s.tn
return s.yn
end
σ = (ti - s.t)/s.dt
σp = map(p -> σ^p, range(1, stop=4))
b = sum(σp.*interpC, dims=1)
fill!(s.yi, 0)
for ii = 1:7
s.yi .+= s.ks[ii].*b[ii]
end
out = @. s.y + s.dt.*s.yi
s.prop!(out, s.t, ti)
return out
end
"Make propagator for the case of constant linear operator"
function make_prop!(linop::AbstractArray, y0)
prop! = let linop=linop
function prop!(y, t1, t2, bwd=false)
if bwd
@. y *= exp(linop*(t1-t2))
else
@. y *= exp(linop*(t2-t1))
end
end
end
end
"Make propagator for the case of non-constant linear operator"
function make_prop!(linop!, y0)
linop_int = similar(y0)
lastt2 = [typemin(Float64)]
function prop!(y, t1, t2, bwd=false)
#= linop is always evaluated at later time, even for backward propagation
therefore, linop is often evaluated at the same t2 twice in a row=#
(lastt2[1] != t2) && linop!(linop_int, t2)
lastt2[1] = t2
dt = bwd ? (t1-t2) : (t2-t1)
@. y *= exp(linop_int*dt)
end
return prop!
end
"Make closure for the pre-conditioned RHS function."
function make_fbar!(f!, prop!, y0)
y = similar(y0)
fbar! = let f! = f!, prop! = prop!, y=y
function fbar!(out, ybar, t1, t2)
y .= ybar
prop!(y, t1, t2) # propagate to t2
f!(out, y, t2) # evaluate RHS function
prop!(out, t1, t2, true) # propagate back to t1
end
end
end
"Max-ish norm (from Dane Austin's code, no idea where he got it from)."
function maxnorm(yerr, y, yn, rtol, atol)
maxerr = 0
maxy = 0
for ii in eachindex(yerr)
maxerr = max(maxerr, abs(yerr[ii]))
maxy = max(maxy, max(abs(y[ii]), abs(yn[ii])))
end
return maxerr/(atol + rtol*maxy)
end
"Alternative form of max-ish norm."
function maxnorm_ratio(yerr, y, yn, rtol, atol)
m = 0
for ii in eachindex(yerr)
den = atol + rtol*max(abs(y[ii]), abs(yn[ii]))
m = max(abs(yerr[ii])/den, m)
end
return m
end
"Semi-norm as used in DifferentialEquations.jl, see Hairer, Solving Ordinary Differential
Equations: Nonstiff Problems, eq. (4.11) (p.168 of the second revised edition)."
function normnorm(yerr, y, yn, rtol, atol)
s = 0
for ii in eachindex(yerr)
s += abs2(yerr[ii]/(atol + rtol*max(abs(y[ii]), abs(yn[ii]))))
end
sqrt(s/length(yerr))
end
"'Weak' norm as used in fnfep."
function weaknorm(yerr, y, yn, rtol, atol)
sy = 0
syn = 0
syerr = 0
for ii in eachindex(yerr)
sy += abs2(y[ii])
syn += abs2(yn[ii])
syerr += abs2(yerr[ii])
end
errwt = max(max(sqrt(sy), sqrt(syn)), atol)
return sqrt(syerr)/rtol/errwt
end
"Simple proportional error controller, see e.g. Hairer eq. (4.13)."
function stepcontrolP!(s)
if s.ok
# if error is zero, there is no nonlinearity: increase step size by a lot
s.dtn = s.err == 0 ? 1.5*s.dt : s.dt * min(5, s.safety*(s.err)^(-1/5))
else
if !isfinite(s.err) # check for NaN or Inf
s.dtn = s.dt/2 # if we have one then we're in big trouble so halve the step size
else
s.dtn = s.dt * max(0.1, s.safety*(s.err)^(-1/5))
end
end
steplims!(s)
end
"Proportional-integral error controller, aka Lund stabilisation.
See G. Söderlind and L. Wang, J. Comput. Appl. Math. 185, 225 (2006).
"
function stepcontrolPI!(s)
β1 = 3/5 / 5
β2 = -1/5 / 5
ε = 0.8
if s.ok
s.errlast == 0 && (s.errlast = s.err) # if last error is zero, use current error instead
if s.err == 0
fac = 1.5 # zero error means no nonlinearity: increase step size by a lot
else
fac = s.safety * (ε/s.err)^β1 * (ε/s.errlast)^β2
end
# (0.99 <= fac <= 1.01) && (fac = 1.0)
s.dtn = fac * s.dt
s.errlast = s.err
else
if !isfinite(s.err) # check for NaN or Inf
s.dtn = s.dt/2 # if we have one then we're in big trouble so halve the step size
else
s.dtn = s.dt * max(0.1, s.safety*(s.err)^(-1/5))
end
end
steplims!(s)
end
"Apply user-defined limits on step size."
function steplims!(s)
if s.dtn > s.max_dt
s.dtn = s.max_dt
elseif s.dtn < s.min_dt
s.dtn = s.min_dt
s.ok = true
end
end
function donothing!(y, z, dz, interpolant)
end
end