Skip to content

Commit 73bcbf0

Browse files
Merge pull request #1944 from oscardssmith/jacobian-for-initialization
use jac for the ShampineCollocationInit nlsolve
2 parents 874db54 + d6c44da commit 73bcbf0

File tree

1 file changed

+49
-8
lines changed

1 file changed

+49
-8
lines changed

src/initialize_dae.jl

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -152,20 +152,32 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
152152

153153
nlequation! = @closure (out, u, p) -> begin
154154
update_coefficients!(M, u, p, t)
155-
#M * (u-u0)/dt - f(u,p,t)
155+
# f(u,p,t) + M * (u0 - u)/dt
156156
tmp = isAD ? PreallocationTools.get_tmp(_tmp, u) : _tmp
157-
@. tmp = (u - u0) / dt
157+
@. tmp = (u0 - u) / dt
158158
mul!(_vec(out), M, _vec(tmp))
159159
f(tmp, u, p, t)
160-
out .-= tmp
160+
out .+= tmp
161161
nothing
162162
end
163163

164-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD)
164+
jac = if isnothing(f.jac)
165+
f.jac
166+
else
167+
@closure (J, u, p) -> begin
168+
# f(u,p,t) + M * (u0 - u)/dt
169+
# df(u,p,t)/du - M/dt
170+
f.jac(J, u, p, t)
171+
J .-= M .* inv(dt)
172+
nothing
173+
end
174+
end
165175

166176
nlfunc = NonlinearFunction(nlequation!;
167-
jac_prototype = f.jac_prototype)
177+
jac_prototype = f.jac_prototype,
178+
jac = jac)
168179
nlprob = NonlinearProblem(nlfunc, integrator.u, p)
180+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD)
169181
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
170182
reltol = integrator.opts.reltol)
171183
integrator.u .= nlsol.u
@@ -227,10 +239,19 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
227239
M * (u - u0) / dt - f(u, p, t)
228240
end
229241

242+
jac = if isnothing(f.jac)
243+
f.jac
244+
else
245+
@closure (u, p) -> begin
246+
return M * (u .- u0) ./ dt .- f.jac(u, p, t)
247+
end
248+
end
249+
230250
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0)
231251

232252
nlfunc = NonlinearFunction(nlequation_oop;
233-
jac_prototype = f.jac_prototype)
253+
jac_prototype = f.jac_prototype,
254+
jac = jac)
234255
nlprob = NonlinearProblem(nlfunc, u0)
235256
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
236257
reltol = integrator.opts.reltol)
@@ -281,10 +302,20 @@ function _initialize_dae!(integrator, prob::DAEProblem,
281302
nothing
282303
end
283304

284-
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD)
305+
jac = if isnothing(f.jac)
306+
f.jac
307+
else
308+
@closure (J, u, p) -> begin
309+
f.jac(J, u, p, inv(dt), t)
310+
nothing
311+
end
312+
end
285313

286-
nlfunc = NonlinearFunction(nlequation!; jac_prototype = f.jac_prototype)
314+
nlfunc = NonlinearFunction(nlequation!;
315+
jac_prototype = f.jac_prototype,
316+
jac = jac)
287317
nlprob = NonlinearProblem(nlfunc, u0, p)
318+
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0, isAD)
288319
nlsol = solve(nlprob, nlsolve; abstol = integrator.opts.abstol,
289320
reltol = integrator.opts.reltol)
290321

@@ -318,6 +349,16 @@ function _initialize_dae!(integrator, prob::DAEProblem,
318349
resid = f(integrator.du, u0, p, t)
319350
integrator.opts.internalnorm(resid, t) <= integrator.opts.abstol && return
320351

352+
jac = if isnothing(f.jac)
353+
f.jac
354+
else
355+
@closure (u, p) -> begin
356+
return f.jac(u, p, inv(dt), t)
357+
end
358+
end
359+
nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype,
360+
jac = jac)
361+
nlprob = NonlinearProblem(nlfunc, u0)
321362
nlsolve = default_nlsolve(alg.nlsolve, isinplace, u0)
322363

323364
nlfunc = NonlinearFunction(nlequation; jac_prototype = f.jac_prototype)

0 commit comments

Comments
 (0)