Skip to content

Commit 200eb01

Browse files
committed
Finish an initial prototype
1 parent e6e560f commit 200eb01

File tree

4 files changed

+233
-20
lines changed

4 files changed

+233
-20
lines changed

README.md

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ Precompilation can be controlled via `Preferences.jl`
5252
- `PrecompileMIRKNLLS` -- Precompile the MIRK2 - MIRK6 algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above).
5353
- `PrecompileShootingNLLS` -- Precompile the single shooting algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). This is triggered when `OrdinaryDiffEq` is loaded.
5454
- `PrecompileMultipleShootingNLLS` -- Precompile the multiple shooting algorithms for under-determined and over-determined BVPs (default: `true` on Julia Version 1.10 and above). This is triggered when `OrdinaryDiffEq` is loaded.
55-
=======
56-
57-
> > > > > > > 5d53c01 (Add documentation about the solvers)
5855

5956
To set these preferences before loading the package, do the following (replacing `PrecompileShooting` with the preference you want to set, or pass in multiple pairs to set them together):
6057

src/algorithms.jl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ end
6363
"""
6464
MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(),
6565
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm(),
66-
static_auto_nodes::Val = Val(false))
66+
auto_static_nodes::Val = Val(false))
6767
6868
Multiple Shooting method, reduces BVP to an initial value problem and solves the IVP.
6969
Significantly more stable than Single Shooting.
@@ -98,9 +98,13 @@ Significantly more stable than Single Shooting.
9898
- `Function`: Takes the current number of shooting points and returns the next number
9999
of shooting points. For example, if `nshoots = 10` and
100100
`grid_coarsening = n -> n ÷ 2`, then the grid will be coarsened to `[5, 2]`.
101-
- `static_auto_nodes`: Automatically detect the timepoints used in the boundary condition
101+
102+
## Experimental Features
103+
104+
- `auto_static_nodes`: Automatically detect the timepoints used in the boundary condition
102105
and use a faster version of the algorithm! This particular keyword argument should be
103-
considered experimental and should be used with care!
106+
considered experimental and should be used with care! (Note that we ignore
107+
`grid_coarsening` if this is set to `Val(true)`. We plan to support this in the future.)
104108
105109
!!! note
106110
For type-stability, the chunksizes for ForwardDiff ADTypes in `BVPJacobianAlgorithm`
@@ -125,13 +129,23 @@ function update_nshoots(alg::MultipleShooting, nshoots::Int)
125129
alg.grid_coarsening)
126130
end
127131

132+
function __without_static_nodes(ms::MultipleShooting{S}) where {S}
133+
return MultipleShooting{false}(ms.ode_alg, ms.nlsolve, ms.jac_alg, ms.nshoots,
134+
ms.grid_coarsening)
135+
end
136+
128137
function MultipleShooting(nshoots::Int, ode_alg; nlsolve = NewtonRaphson(),
129-
grid_coarsening = true, jac_alg = BVPJacobianAlgorithm(),
130-
static_auto_nodes::Val{S} = Val(false)) where {S}
131-
@assert grid_coarsening isa Bool || grid_coarsening isa Function ||
132-
grid_coarsening isa AbstractVector{<:Integer} ||
133-
grid_coarsening isa NTuple{N, <:Integer} where {N}
134-
@assert S isa Bool
138+
grid_coarsening = missing, jac_alg = BVPJacobianAlgorithm(),
139+
auto_static_nodes::Val{S} = Val(false)) where {S}
140+
@assert S isa Bool "`auto_static_nodes` must be either `Val(true)` or `Val(false)`."
141+
if S
142+
@assert grid_coarsening === missing||(grid_coarsening isa Bool && !grid_coarsening) "`auto_static_nodes` doesn't support grid_coarsening."
143+
else
144+
grid_coarsening === missing && (grid_coarsening = false)
145+
@assert grid_coarsening isa Bool || grid_coarsening isa Function ||
146+
grid_coarsening isa AbstractVector{<:Integer} ||
147+
grid_coarsening isa NTuple{N, <:Integer} where {N}
148+
end
135149
grid_coarsening isa Tuple && (grid_coarsening = Vector(grid_coarsening...))
136150
if grid_coarsening isa AbstractVector
137151
sort!(grid_coarsening; rev = true)

src/solve/multiple_shooting.jl

Lines changed: 168 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,78 @@ function __solve(prob::BVProblem, _alg::MultipleShooting{true}; odesolve_kwargs
22
nlsolve_kwargs = (;), ensemblealg = EnsembleThreads(), verbose = true, kwargs...)
33
# For TwoPointBVPs there is nothing to do. Forward to general multiple shooting
44
prob.problem_type isa TwoPointBVProblem &&
5-
return __solve_internal(prob, _alg; odesolve_kwargs, nlsolve_kwargs, ensemblealg,
6-
verbose, kwargs...)
5+
return __solve_internal(prob, __without_static_nodes(_alg); odesolve_kwargs,
6+
nlsolve_kwargs, ensemblealg, verbose, kwargs...)
7+
8+
ig, T, N, Nig, u0 = __extract_problem_details(prob; dt = 0.1)
9+
10+
if _unwrap_val(ig) && prob.u0 isa AbstractVector
11+
if verbose
12+
@warn "Static Nodes for Multiple-Shooting is not supported when Vector of \
13+
initial guesses are provided. Falling back to using the generic method!"
14+
end
15+
return __solve_internal(prob, __without_static_nodes(_alg); odesolve_kwargs,
16+
nlsolve_kwargs, ensemblealg, verbose, kwargs...)
17+
end
18+
19+
has_initial_guess = _unwrap_val(ig)
20+
21+
bcresid_prototype, resid_size = __get_bcresid_prototype(prob, u0)
22+
iip, bc, u0, u0_size = isinplace(prob), prob.f.bc, deepcopy(u0), size(u0)
723

824
# Extract the time-points used in BC
9-
_prob = ODEProblem{isinplace(prob)}(prob.f, prob.u0, prob.tspan, prob.p)
25+
_prob = ODEProblem{iip}(prob.f, prob.u0, prob.tspan, prob.p)
26+
_fake_ode_sol = __construct_fake_ode_solution(_prob, _alg.ode_alg)
27+
if iip
28+
bc(bcresid_prototype, _fake_ode_sol, prob.p, _fake_ode_sol.sol.t)
29+
else
30+
bc(_fake_ode_sol, prob.p, _fake_ode_sol.sol.t)
31+
end
32+
__finalize_nodes!(_fake_ode_sol)
33+
34+
__alg = concretize_jacobian_algorithm(_alg, prob)
35+
alg = if has_initial_guess && Nig != __alg.nshoots
36+
verbose &&
37+
@warn "Initial guess length != `nshoots + 1`! Adapting to `nshoots = $(Nig)`"
38+
update_nshoots(__alg, Nig)
39+
else
40+
__alg
41+
end
42+
nshoots = alg.nshoots
43+
M = length(bcresid_prototype)
44+
45+
internal_ode_kwargs = (; verbose, kwargs..., odesolve_kwargs..., save_end = true)
46+
47+
function solve_internal_odes!(resid_nodes::T1, us::T2, p::T3, cur_nshoot::Int,
48+
nodes::T4, odecache::C) where {T1, T2, T3, T4, C}
49+
return __multiple_shooting_solve_internal_odes!(resid_nodes, us, cur_nshoot,
50+
odecache, nodes, u0_size, N, ensemblealg)
51+
end
52+
53+
ode_cache_loss_fn = __multiple_shooting_init_odecache(ensemblealg, prob,
54+
alg.ode_alg, u0, nshoots; internal_ode_kwargs...)
55+
56+
nodes = typeof(first(tspan))[]
57+
u_at_nodes = __multiple_shooting_initialize!(nodes, prob, alg, ig, nshoots,
58+
ode_cache_loss_fn; kwargs..., verbose, odesolve_kwargs...,
59+
static_nodes = _fake_ode_sol.nodes)
60+
61+
__solve_nlproblem!(prob.problem_type, alg, bcresid_prototype, u_at_nodes, nodes,
62+
nshoots, M, N, prod(resid_size), solve_internal_odes!, bc, prob, prob.f,
63+
u0_size, u0, ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; verbose,
64+
kwargs..., nlsolve_kwargs...)
65+
66+
if prob.problem_type isa TwoPointBVProblem
67+
diffmode_shooting = __get_non_sparse_ad(alg.jac_alg.diffmode)
68+
else
69+
diffmode_shooting = __get_non_sparse_ad(alg.jac_alg.bc_diffmode)
70+
end
71+
shooting_alg = Shooting(alg.ode_alg, alg.nlsolve,
72+
BVPJacobianAlgorithm(diffmode_shooting))
73+
74+
single_shooting_prob = remake(prob; u0 = reshape(@view(u_at_nodes[1:N]), u0_size))
75+
return __solve(single_shooting_prob, shooting_alg; odesolve_kwargs, nlsolve_kwargs,
76+
verbose, kwargs...)
1077
end
1178

1279
function __solve(prob::BVProblem, _alg::MultipleShooting{false}; kwargs...)
@@ -145,10 +212,70 @@ function __solve_nlproblem!(::TwoPointBVProblem, alg::MultipleShooting, bcresid_
145212
return nothing
146213
end
147214

148-
function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting, bcresid_prototype,
149-
u_at_nodes, nodes, cur_nshoot::Int, M::Int, N::Int, resid_len::Int,
150-
solve_internal_odes!::S, bc::BC, prob, f::F, u0_size, u0, ode_cache_loss_fn,
151-
ensemblealg, internal_ode_kwargs; kwargs...) where {BC, F, S}
215+
function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting{true},
216+
bcresid_prototype, u_at_nodes, nodes, cur_nshoot::Int, M::Int, N::Int,
217+
resid_len::Int, solve_internal_odes!::S, bc::BC, prob, f::F, u0_size, u0,
218+
ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; kwargs...) where {BC, F, S}
219+
if __any_sparse_ad(alg.jac_alg)
220+
J_proto = __generate_sparse_jacobian_prototype(alg, prob.problem_type,
221+
bcresid_prototype, u0, N, cur_nshoot)
222+
end
223+
resid_prototype = vcat(bcresid_prototype, similar(u_at_nodes, cur_nshoot * N))
224+
225+
__resid_nodes = resid_prototype[(end - cur_nshoot * N + 1):end]
226+
resid_nodes = __maybe_allocate_diffcache(__resid_nodes,
227+
pickchunksize((cur_nshoot + 1) * N), alg.jac_alg.bc_diffmode)
228+
229+
loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss!(du, u, p, cur_nshoot,
230+
nodes, prob, solve_internal_odes!, resid_len, N, f, bc, u0_size, prob.tspan,
231+
alg.ode_alg, u0, ode_cache_loss_fn)
232+
233+
# ODE Part
234+
sd_ode = alg.jac_alg.nonbc_diffmode isa AbstractSparseADType ?
235+
__sparsity_detection_alg(J_proto) : NoSparsityDetection()
236+
ode_jac_cache = sparse_jacobian_cache(alg.jac_alg.nonbc_diffmode, sd_ode,
237+
nothing, similar(u_at_nodes, cur_nshoot * N), u_at_nodes)
238+
ode_cache_ode_jac_fn = __multiple_shooting_init_jacobian_odecache(ensemblealg, prob,
239+
ode_jac_cache, alg.jac_alg.nonbc_diffmode, alg.ode_alg, cur_nshoot, u0;
240+
internal_ode_kwargs...)
241+
242+
# BC Part
243+
sd_bc = alg.jac_alg.bc_diffmode isa AbstractSparseADType ?
244+
SymbolicsSparsityDetection() : NoSparsityDetection()
245+
bc_jac_cache = sparse_jacobian_cache(alg.jac_alg.bc_diffmode,
246+
sd_bc, nothing, similar(bcresid_prototype), u_at_nodes)
247+
ode_cache_bc_jac_fn = __multiple_shooting_init_jacobian_odecache(ensemblealg, prob,
248+
bc_jac_cache, alg.jac_alg.bc_diffmode, alg.ode_alg, cur_nshoot, u0;
249+
internal_ode_kwargs...)
250+
251+
jac_prototype = vcat(init_jacobian(bc_jac_cache), init_jacobian(ode_jac_cache))
252+
253+
# Define the functions now
254+
ode_fn = (du, u) -> solve_internal_odes!(du, u, prob.p, cur_nshoot, nodes,
255+
ode_cache_ode_jac_fn)
256+
bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc_static_node!(du, u, prob.p, cur_nshoot, nodes,
257+
prob, solve_internal_odes!, N, f, bc, u0_size, prob.tspan, alg.ode_alg, u0,
258+
ode_cache_bc_jac_fn)
259+
260+
jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian!(J, u, p,
261+
similar(bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache,
262+
ode_fn, bc_fn, alg, N, M)
263+
264+
loss_function! = NonlinearFunction{true}(loss_fn; resid_prototype, jac = jac_fn,
265+
jac_prototype)
266+
267+
# NOTE: u_at_nodes is updated inplace
268+
nlprob = (M != N ? NonlinearLeastSquaresProblem : NonlinearProblem)(loss_function!,
269+
u_at_nodes, prob.p)
270+
__solve(nlprob, alg.nlsolve; kwargs..., alias_u0 = true)
271+
272+
return nothing
273+
end
274+
275+
function __solve_nlproblem!(::StandardBVProblem, alg::MultipleShooting{false},
276+
bcresid_prototype, u_at_nodes, nodes, cur_nshoot::Int, M::Int, N::Int,
277+
resid_len::Int, solve_internal_odes!::S, bc::BC, prob, f::F, u0_size, u0,
278+
ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; kwargs...) where {BC, F, S}
152279
if __any_sparse_ad(alg.jac_alg)
153280
J_proto = __generate_sparse_jacobian_prototype(alg, prob.problem_type,
154281
bcresid_prototype, u0, N, cur_nshoot)
@@ -348,6 +475,29 @@ end
348475
return nothing
349476
end
350477

478+
@views function __multiple_shooting_mpoint_loss_bc_static_node!(resid_bc, us, p,
479+
cur_nshoots::Int, nodes, prob, solve_internal_odes!::S, N, f::F, bc::BC, u0_size,
480+
tspan, ode_alg, u0, ode_cache) where {S, F, BC}
481+
iip = isinplace(prob)
482+
483+
# NOTE: We placed the nodes at the points `bc` is evaluated so we don't need to
484+
# recompute the solution
485+
_ts = nodes
486+
_us = [reshape(us[((i - 1) * prod(u0_size) + 1):(i * prod(u0_size))], u0_size)
487+
for i in eachindex(_ts)]
488+
489+
odeprob = ODEProblem{iip}(f, u0, tspan, p)
490+
total_solution = SciMLBase.build_solution(odeprob, ode_alg, _ts, _us)
491+
492+
if iip
493+
eval_bc_residual!(resid_bc, StandardBVProblem(), bc, total_solution, p)
494+
else
495+
resid_bc .= eval_bc_residual(StandardBVProblem(), bc, total_solution, p)
496+
end
497+
498+
return nothing
499+
end
500+
351501
@views function __multiple_shooting_mpoint_loss!(resid, us, p, cur_nshoots::Int, nodes,
352502
prob, solve_internal_odes!::S, resid_len, N, f::F, bc::BC, u0_size, tspan,
353503
ode_alg, u0, ode_cache) where {S, F, BC}
@@ -390,12 +540,22 @@ end
390540

391541
# No initial guess
392542
@views function __multiple_shooting_initialize!(nodes, prob, alg::MultipleShooting,
393-
::Val{false}, nshoots::Int, odecache_; verbose, kwargs...)
543+
::Val{false}, nshoots::Int, odecache_; verbose, static_nodes = nothing, kwargs...)
394544
@unpack f, u0, tspan, p = prob
395545
@unpack ode_alg = alg
396546

397547
resize!(nodes, nshoots + 1)
398548
nodes .= range(tspan[1], tspan[2]; length = nshoots + 1)
549+
550+
if static_nodes !== nothing
551+
idx = 1
552+
for snode in static_nodes
553+
sidx = searchsortedfirst(nodes[idx:end], snode)
554+
nodes[idx + sidx - 1] = snode
555+
idx = sidx + 1
556+
end
557+
end
558+
399559
N = length(u0)
400560

401561
# Ensures type stability in case the parameters are dual numbers

src/utils.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,45 @@ function __restructure_sol(sol::Vector{<:AbstractArray}, u_size)
249249
end
250250

251251
# TODO: Add dispatch for a ODESolution Type as well
252+
253+
# Fake ODE Solution to capture calls to the solution object
254+
@concrete struct __FakeODESolution2
255+
sol
256+
nodes
257+
end
258+
259+
__FakeODESolutionXXX = __FakeODESolution2
260+
261+
function __construct_fake_ode_solution(prob::ODEProblem, alg)
262+
nodes = Vector{promote_type(typeof(prob.tspan[1]), typeof(prob.tspan[2]))}()
263+
return __FakeODESolutionXXX(SciMLBase.build_solution(prob, alg,
264+
[prob.tspan[1], prob.tspan[2]], [prob.u0, prob.u0]), nodes)
265+
end
266+
267+
function __finalize_nodes!(sol::__FakeODESolutionXXX)
268+
sort!(sol.nodes)
269+
unique!(sol.nodes)
270+
return sol
271+
end
272+
273+
function (s::__FakeODESolutionXXX)(t::T, args...; kwargs...) where {T <: Number}
274+
push!(s.nodes, t)
275+
return s.sol(t, args...; kwargs...)
276+
end
277+
278+
function (s::__FakeODESolutionXXX)(t::T, args...; kwargs...) where {T <: AbstractVector}
279+
append!(s.nodes, t)
280+
return s.sol(t, args...; kwargs...)
281+
end
282+
283+
function Base.getindex(::__FakeODESolutionXXX, args...)
284+
throw(ArgumentError("`static_auto_nodes = Val(true)` doesn't support indexing into \
285+
the solution object. Please rewrite your code to call the \
286+
solution object with the time points you want to evaluate at \
287+
or use `static_auto_nodes = Val(false)`"))
288+
end
289+
290+
function Base.show(io::IO, sol::__FakeODESolutionXXX)
291+
print(io, "ODESolution evaluated @ nodes: $(sol.nodes)")
292+
return
293+
end

0 commit comments

Comments
 (0)