@@ -152,20 +152,32 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
152
152
153
153
nlequation! = @closure (out, u, p) -> begin
154
154
update_coefficients! (M, u, p, t)
155
- # M * (u-u0)/dt - f(u,p,t)
155
+ # f(u,p,t) + M * (u0 - u)/dt
156
156
tmp = isAD ? PreallocationTools. get_tmp (_tmp, u) : _tmp
157
- @. tmp = (u - u0 ) / dt
157
+ @. tmp = (u0 - u ) / dt
158
158
mul! (_vec (out), M, _vec (tmp))
159
159
f (tmp, u, p, t)
160
- out .- = tmp
160
+ out .+ = tmp
161
161
nothing
162
162
end
163
163
164
- nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, isAD)
164
+ jac = if isnothing (f. jac)
165
+ f. jac
166
+ else
167
+ @closure (J, u, p) -> begin
168
+ # f(u,p,t) + M * (u0 - u)/dt
169
+ # df(u,p,t)/du - M/dt
170
+ f. jac (J, u, p, t)
171
+ J .- = M .* inv (dt)
172
+ nothing
173
+ end
174
+ end
165
175
166
176
nlfunc = NonlinearFunction (nlequation!;
167
- jac_prototype = f. jac_prototype)
177
+ jac_prototype = f. jac_prototype,
178
+ jac = jac)
168
179
nlprob = NonlinearProblem (nlfunc, integrator. u, p)
180
+ nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, isAD)
169
181
nlsol = solve (nlprob, nlsolve; abstol = integrator. opts. abstol,
170
182
reltol = integrator. opts. reltol)
171
183
integrator. u .= nlsol. u
@@ -227,10 +239,19 @@ function _initialize_dae!(integrator, prob::ODEProblem, alg::ShampineCollocation
227
239
M * (u - u0) / dt - f (u, p, t)
228
240
end
229
241
242
+ jac = if isnothing (f. jac)
243
+ f. jac
244
+ else
245
+ @closure (u, p) -> begin
246
+ return M * (u .- u0) ./ dt .- f. jac (u, p, t)
247
+ end
248
+ end
249
+
230
250
nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0)
231
251
232
252
nlfunc = NonlinearFunction (nlequation_oop;
233
- jac_prototype = f. jac_prototype)
253
+ jac_prototype = f. jac_prototype,
254
+ jac = jac)
234
255
nlprob = NonlinearProblem (nlfunc, u0)
235
256
nlsol = solve (nlprob, nlsolve; abstol = integrator. opts. abstol,
236
257
reltol = integrator. opts. reltol)
@@ -281,10 +302,20 @@ function _initialize_dae!(integrator, prob::DAEProblem,
281
302
nothing
282
303
end
283
304
284
- nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, isAD)
305
+ jac = if isnothing (f. jac)
306
+ f. jac
307
+ else
308
+ @closure (J, u, p) -> begin
309
+ f. jac (J, u, p, inv (dt), t)
310
+ nothing
311
+ end
312
+ end
285
313
286
- nlfunc = NonlinearFunction (nlequation!; jac_prototype = f. jac_prototype)
314
+ nlfunc = NonlinearFunction (nlequation!;
315
+ jac_prototype = f. jac_prototype,
316
+ jac = jac)
287
317
nlprob = NonlinearProblem (nlfunc, u0, p)
318
+ nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0, isAD)
288
319
nlsol = solve (nlprob, nlsolve; abstol = integrator. opts. abstol,
289
320
reltol = integrator. opts. reltol)
290
321
@@ -318,6 +349,16 @@ function _initialize_dae!(integrator, prob::DAEProblem,
318
349
resid = f (integrator. du, u0, p, t)
319
350
integrator. opts. internalnorm (resid, t) <= integrator. opts. abstol && return
320
351
352
+ jac = if isnothing (f. jac)
353
+ f. jac
354
+ else
355
+ @closure (u, p) -> begin
356
+ return f. jac (u, p, inv (dt), t)
357
+ end
358
+ end
359
+ nlfunc = NonlinearFunction (nlequation; jac_prototype = f. jac_prototype,
360
+ jac = jac)
361
+ nlprob = NonlinearProblem (nlfunc, u0)
321
362
nlsolve = default_nlsolve (alg. nlsolve, isinplace, u0)
322
363
323
364
nlfunc = NonlinearFunction (nlequation; jac_prototype = f. jac_prototype)
0 commit comments