@@ -2,11 +2,78 @@ function __solve(prob::BVProblem, _alg::MultipleShooting{true}; odesolve_kwargs
2
2
nlsolve_kwargs = (;), ensemblealg = EnsembleThreads (), verbose = true , kwargs... )
3
3
# For TwoPointBVPs there is nothing to do. Forward to general multiple shooting
4
4
prob. problem_type isa TwoPointBVProblem &&
5
- return __solve_internal (prob, _alg; odesolve_kwargs, nlsolve_kwargs, ensemblealg,
6
- verbose, kwargs... )
5
+ return __solve_internal (prob, __without_static_nodes (_alg); odesolve_kwargs,
6
+ nlsolve_kwargs, ensemblealg, verbose, kwargs... )
7
+
8
+ ig, T, N, Nig, u0 = __extract_problem_details (prob; dt = 0.1 )
9
+
10
+ if _unwrap_val (ig) && prob. u0 isa AbstractVector
11
+ if verbose
12
+ @warn " Static Nodes for Multiple-Shooting is not supported when Vector of \
13
+ initial guesses are provided. Falling back to using the generic method!"
14
+ end
15
+ return __solve_internal (prob, __without_static_nodes (_alg); odesolve_kwargs,
16
+ nlsolve_kwargs, ensemblealg, verbose, kwargs... )
17
+ end
18
+
19
+ has_initial_guess = _unwrap_val (ig)
20
+
21
+ bcresid_prototype, resid_size = __get_bcresid_prototype (prob, u0)
22
+ iip, bc, u0, u0_size = isinplace (prob), prob. f. bc, deepcopy (u0), size (u0)
7
23
8
24
# Extract the time-points used in BC
9
- _prob = ODEProblem {isinplace(prob)} (prob. f, prob. u0, prob. tspan, prob. p)
25
+ _prob = ODEProblem {iip} (prob. f, prob. u0, prob. tspan, prob. p)
26
+ _fake_ode_sol = __construct_fake_ode_solution (_prob, _alg. ode_alg)
27
+ if iip
28
+ bc (bcresid_prototype, _fake_ode_sol, prob. p, _fake_ode_sol. sol. t)
29
+ else
30
+ bc (_fake_ode_sol, prob. p, _fake_ode_sol. sol. t)
31
+ end
32
+ __finalize_nodes! (_fake_ode_sol)
33
+
34
+ __alg = concretize_jacobian_algorithm (_alg, prob)
35
+ alg = if has_initial_guess && Nig != __alg. nshoots
36
+ verbose &&
37
+ @warn " Initial guess length != `nshoots + 1`! Adapting to `nshoots = $(Nig) `"
38
+ update_nshoots (__alg, Nig)
39
+ else
40
+ __alg
41
+ end
42
+ nshoots = alg. nshoots
43
+ M = length (bcresid_prototype)
44
+
45
+ internal_ode_kwargs = (; verbose, kwargs... , odesolve_kwargs... , save_end = true )
46
+
47
+ function solve_internal_odes! (resid_nodes:: T1 , us:: T2 , p:: T3 , cur_nshoot:: Int ,
48
+ nodes:: T4 , odecache:: C ) where {T1, T2, T3, T4, C}
49
+ return __multiple_shooting_solve_internal_odes! (resid_nodes, us, cur_nshoot,
50
+ odecache, nodes, u0_size, N, ensemblealg)
51
+ end
52
+
53
+ ode_cache_loss_fn = __multiple_shooting_init_odecache (ensemblealg, prob,
54
+ alg. ode_alg, u0, nshoots; internal_ode_kwargs... )
55
+
56
+ nodes = typeof (first (tspan))[]
57
+ u_at_nodes = __multiple_shooting_initialize! (nodes, prob, alg, ig, nshoots,
58
+ ode_cache_loss_fn; kwargs... , verbose, odesolve_kwargs... ,
59
+ static_nodes = _fake_ode_sol. nodes)
60
+
61
+ __solve_nlproblem! (prob. problem_type, alg, bcresid_prototype, u_at_nodes, nodes,
62
+ nshoots, M, N, prod (resid_size), solve_internal_odes!, bc, prob, prob. f,
63
+ u0_size, u0, ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; verbose,
64
+ kwargs... , nlsolve_kwargs... )
65
+
66
+ if prob. problem_type isa TwoPointBVProblem
67
+ diffmode_shooting = __get_non_sparse_ad (alg. jac_alg. diffmode)
68
+ else
69
+ diffmode_shooting = __get_non_sparse_ad (alg. jac_alg. bc_diffmode)
70
+ end
71
+ shooting_alg = Shooting (alg. ode_alg, alg. nlsolve,
72
+ BVPJacobianAlgorithm (diffmode_shooting))
73
+
74
+ single_shooting_prob = remake (prob; u0 = reshape (@view (u_at_nodes[1 : N]), u0_size))
75
+ return __solve (single_shooting_prob, shooting_alg; odesolve_kwargs, nlsolve_kwargs,
76
+ verbose, kwargs... )
10
77
end
11
78
12
79
function __solve (prob:: BVProblem , _alg:: MultipleShooting{false} ; kwargs... )
@@ -145,10 +212,70 @@ function __solve_nlproblem!(::TwoPointBVProblem, alg::MultipleShooting, bcresid_
145
212
return nothing
146
213
end
147
214
148
- function __solve_nlproblem! (:: StandardBVProblem , alg:: MultipleShooting , bcresid_prototype,
149
- u_at_nodes, nodes, cur_nshoot:: Int , M:: Int , N:: Int , resid_len:: Int ,
150
- solve_internal_odes!:: S , bc:: BC , prob, f:: F , u0_size, u0, ode_cache_loss_fn,
151
- ensemblealg, internal_ode_kwargs; kwargs... ) where {BC, F, S}
215
+ function __solve_nlproblem! (:: StandardBVProblem , alg:: MultipleShooting{true} ,
216
+ bcresid_prototype, u_at_nodes, nodes, cur_nshoot:: Int , M:: Int , N:: Int ,
217
+ resid_len:: Int , solve_internal_odes!:: S , bc:: BC , prob, f:: F , u0_size, u0,
218
+ ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; kwargs... ) where {BC, F, S}
219
+ if __any_sparse_ad (alg. jac_alg)
220
+ J_proto = __generate_sparse_jacobian_prototype (alg, prob. problem_type,
221
+ bcresid_prototype, u0, N, cur_nshoot)
222
+ end
223
+ resid_prototype = vcat (bcresid_prototype, similar (u_at_nodes, cur_nshoot * N))
224
+
225
+ __resid_nodes = resid_prototype[(end - cur_nshoot * N + 1 ): end ]
226
+ resid_nodes = __maybe_allocate_diffcache (__resid_nodes,
227
+ pickchunksize ((cur_nshoot + 1 ) * N), alg. jac_alg. bc_diffmode)
228
+
229
+ loss_fn = (du, u, p) -> __multiple_shooting_mpoint_loss! (du, u, p, cur_nshoot,
230
+ nodes, prob, solve_internal_odes!, resid_len, N, f, bc, u0_size, prob. tspan,
231
+ alg. ode_alg, u0, ode_cache_loss_fn)
232
+
233
+ # ODE Part
234
+ sd_ode = alg. jac_alg. nonbc_diffmode isa AbstractSparseADType ?
235
+ __sparsity_detection_alg (J_proto) : NoSparsityDetection ()
236
+ ode_jac_cache = sparse_jacobian_cache (alg. jac_alg. nonbc_diffmode, sd_ode,
237
+ nothing , similar (u_at_nodes, cur_nshoot * N), u_at_nodes)
238
+ ode_cache_ode_jac_fn = __multiple_shooting_init_jacobian_odecache (ensemblealg, prob,
239
+ ode_jac_cache, alg. jac_alg. nonbc_diffmode, alg. ode_alg, cur_nshoot, u0;
240
+ internal_ode_kwargs... )
241
+
242
+ # BC Part
243
+ sd_bc = alg. jac_alg. bc_diffmode isa AbstractSparseADType ?
244
+ SymbolicsSparsityDetection () : NoSparsityDetection ()
245
+ bc_jac_cache = sparse_jacobian_cache (alg. jac_alg. bc_diffmode,
246
+ sd_bc, nothing , similar (bcresid_prototype), u_at_nodes)
247
+ ode_cache_bc_jac_fn = __multiple_shooting_init_jacobian_odecache (ensemblealg, prob,
248
+ bc_jac_cache, alg. jac_alg. bc_diffmode, alg. ode_alg, cur_nshoot, u0;
249
+ internal_ode_kwargs... )
250
+
251
+ jac_prototype = vcat (init_jacobian (bc_jac_cache), init_jacobian (ode_jac_cache))
252
+
253
+ # Define the functions now
254
+ ode_fn = (du, u) -> solve_internal_odes! (du, u, prob. p, cur_nshoot, nodes,
255
+ ode_cache_ode_jac_fn)
256
+ bc_fn = (du, u) -> __multiple_shooting_mpoint_loss_bc_static_node! (du, u, prob. p, cur_nshoot, nodes,
257
+ prob, solve_internal_odes!, N, f, bc, u0_size, prob. tspan, alg. ode_alg, u0,
258
+ ode_cache_bc_jac_fn)
259
+
260
+ jac_fn = (J, u, p) -> __multiple_shooting_mpoint_jacobian! (J, u, p,
261
+ similar (bcresid_prototype), resid_nodes, ode_jac_cache, bc_jac_cache,
262
+ ode_fn, bc_fn, alg, N, M)
263
+
264
+ loss_function! = NonlinearFunction {true} (loss_fn; resid_prototype, jac = jac_fn,
265
+ jac_prototype)
266
+
267
+ # NOTE: u_at_nodes is updated inplace
268
+ nlprob = (M != N ? NonlinearLeastSquaresProblem : NonlinearProblem)(loss_function!,
269
+ u_at_nodes, prob. p)
270
+ __solve (nlprob, alg. nlsolve; kwargs... , alias_u0 = true )
271
+
272
+ return nothing
273
+ end
274
+
275
+ function __solve_nlproblem! (:: StandardBVProblem , alg:: MultipleShooting{false} ,
276
+ bcresid_prototype, u_at_nodes, nodes, cur_nshoot:: Int , M:: Int , N:: Int ,
277
+ resid_len:: Int , solve_internal_odes!:: S , bc:: BC , prob, f:: F , u0_size, u0,
278
+ ode_cache_loss_fn, ensemblealg, internal_ode_kwargs; kwargs... ) where {BC, F, S}
152
279
if __any_sparse_ad (alg. jac_alg)
153
280
J_proto = __generate_sparse_jacobian_prototype (alg, prob. problem_type,
154
281
bcresid_prototype, u0, N, cur_nshoot)
348
475
return nothing
349
476
end
350
477
478
+ @views function __multiple_shooting_mpoint_loss_bc_static_node! (resid_bc, us, p,
479
+ cur_nshoots:: Int , nodes, prob, solve_internal_odes!:: S , N, f:: F , bc:: BC , u0_size,
480
+ tspan, ode_alg, u0, ode_cache) where {S, F, BC}
481
+ iip = isinplace (prob)
482
+
483
+ # NOTE: We placed the nodes at the points `bc` is evaluated so we don't need to
484
+ # recompute the solution
485
+ _ts = nodes
486
+ _us = [reshape (us[((i - 1 ) * prod (u0_size) + 1 ): (i * prod (u0_size))], u0_size)
487
+ for i in eachindex (_ts)]
488
+
489
+ odeprob = ODEProblem {iip} (f, u0, tspan, p)
490
+ total_solution = SciMLBase. build_solution (odeprob, ode_alg, _ts, _us)
491
+
492
+ if iip
493
+ eval_bc_residual! (resid_bc, StandardBVProblem (), bc, total_solution, p)
494
+ else
495
+ resid_bc .= eval_bc_residual (StandardBVProblem (), bc, total_solution, p)
496
+ end
497
+
498
+ return nothing
499
+ end
500
+
351
501
@views function __multiple_shooting_mpoint_loss! (resid, us, p, cur_nshoots:: Int , nodes,
352
502
prob, solve_internal_odes!:: S , resid_len, N, f:: F , bc:: BC , u0_size, tspan,
353
503
ode_alg, u0, ode_cache) where {S, F, BC}
@@ -390,12 +540,22 @@ end
390
540
391
541
# No initial guess
392
542
@views function __multiple_shooting_initialize! (nodes, prob, alg:: MultipleShooting ,
393
- :: Val{false} , nshoots:: Int , odecache_; verbose, kwargs... )
543
+ :: Val{false} , nshoots:: Int , odecache_; verbose, static_nodes = nothing , kwargs... )
394
544
@unpack f, u0, tspan, p = prob
395
545
@unpack ode_alg = alg
396
546
397
547
resize! (nodes, nshoots + 1 )
398
548
nodes .= range (tspan[1 ], tspan[2 ]; length = nshoots + 1 )
549
+
550
+ if static_nodes != = nothing
551
+ idx = 1
552
+ for snode in static_nodes
553
+ sidx = searchsortedfirst (nodes[idx: end ], snode)
554
+ nodes[idx + sidx - 1 ] = snode
555
+ idx = sidx + 1
556
+ end
557
+ end
558
+
399
559
N = length (u0)
400
560
401
561
# Ensures type stability in case the parameters are dual numbers
0 commit comments