@@ -54,8 +54,7 @@ const TERM_DOCS = Dict(
54
54
:Rel => doc " ``all \l eft(| \D elta u | \l eq reltol \t imes | u | \r ight)``." ,
55
55
:RelNorm => doc " ``\| \D elta u \| \l eq reltol \t imes \| \D elta u + u \| ``." ,
56
56
:Abs => doc " ``all \l eft( | \D elta u | \l eq abstol \r ight)``." ,
57
- :AbsNorm => doc " ``\| \D elta u \| \l eq abstol``."
58
- )
57
+ :AbsNorm => doc " ``\| \D elta u \| \l eq abstol``." )
59
58
60
59
const __TERM_INTERNALNORM_DOCS = """
61
60
where `internalnorm` is the norm to use for the termination condition. Special handling is
@@ -148,9 +147,10 @@ for norm_type in (:Rel, :Abs), safety in (:Safe, :SafeBest)
148
147
function $ (struct_name)(f:: F = nothing ; protective_threshold = nothing ,
149
148
patience_steps = 100 , patience_objective_multiplier = 3 ,
150
149
min_max_factor = 1.3 , max_stalled_steps = nothing ) where {F}
151
- return new{__norm_type (f), typeof (max_stalled_steps), F,
152
- typeof (protective_threshold), typeof (patience_objective_multiplier),
153
- typeof (min_max_factor)}(f, protective_threshold, patience_steps,
150
+ return new{__norm_type (f), typeof (max_stalled_steps),
151
+ F, typeof (protective_threshold),
152
+ typeof (patience_objective_multiplier), typeof (min_max_factor)}(
153
+ f, protective_threshold, patience_steps,
154
154
patience_objective_multiplier, min_max_factor, max_stalled_steps)
155
155
end
156
156
end
@@ -164,8 +164,8 @@ for norm_type in (:Rel, :Abs), safety in (:Safe, :SafeBest)
164
164
end
165
165
end
166
166
167
- @concrete mutable struct NonlinearTerminationModeCache{dep_retcode,
168
- M <: AbstractNonlinearTerminationMode ,
167
+ @concrete mutable struct NonlinearTerminationModeCache{
168
+ dep_retcode, M <: AbstractNonlinearTerminationMode ,
169
169
R <: Union{NonlinearSafeTerminationReturnCode.T, ReturnCode.T} }
170
170
u
171
171
retcode:: R
209
209
function SciMLBase. init (du:: Union{AbstractArray{T}, T} , u:: Union{AbstractArray{T}, T} ,
210
210
mode:: AbstractNonlinearTerminationMode , saved_value_prototype... ;
211
211
use_deprecated_retcodes:: Val{D} = Val (true ), # Remove in v8, warn in v7
212
- abstol = nothing , reltol = nothing , kwargs... ) where {D, T <: Number }
212
+ abstol = nothing ,
213
+ reltol = nothing , kwargs... ) where {D, T <: Number }
213
214
abstol = _get_tolerance (abstol, T)
214
215
reltol = _get_tolerance (reltol, T)
215
216
TT = typeof (abstol)
@@ -229,7 +230,8 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
229
230
Vector {TT} (undef, mode. max_stalled_steps)
230
231
best_value = initial_objective
231
232
max_stalled_steps = mode. max_stalled_steps
232
- if ArrayInterface. can_setindex (u_) && ! (u_ isa Number) &&
233
+ if ArrayInterface. can_setindex (u_) &&
234
+ ! (u_ isa Number) &&
233
235
step_norm_trace != = nothing
234
236
u_diff_cache = similar (u_)
235
237
else
@@ -249,22 +251,23 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
249
251
250
252
retcode = ifelse (D, NonlinearSafeTerminationReturnCode. Default, ReturnCode. Default)
251
253
252
- return NonlinearTerminationModeCache {D} (u_, retcode, abstol, reltol, best_value, mode,
253
- initial_objective, objectives_trace, 0 , saved_value_prototype, u0_norm,
254
+ return NonlinearTerminationModeCache {D} (
255
+ u_, retcode, abstol, reltol, best_value, mode, initial_objective,
256
+ objectives_trace, 0 , saved_value_prototype, u0_norm,
254
257
step_norm_trace, max_stalled_steps, u_diff_cache)
255
258
end
256
259
257
- function SciMLBase. reinit! (cache :: NonlinearTerminationModeCache{dep_retcode} , du,
258
- u, saved_value_prototype... ; abstol = nothing , reltol = nothing ,
259
- kwargs... ) where {dep_retcode}
260
+ function SciMLBase. reinit! (
261
+ cache :: NonlinearTerminationModeCache{dep_retcode} , du, u, saved_value_prototype... ;
262
+ abstol = nothing , reltol = nothing , kwargs... ) where {dep_retcode}
260
263
T = eltype (cache. abstol)
261
264
length (saved_value_prototype) != 0 && (cache. saved_values = saved_value_prototype)
262
265
263
266
u_ = cache. mode isa AbstractSafeBestNonlinearTerminationMode ?
264
267
(ArrayInterface. can_setindex (u) ? copy (u) : u) : nothing
265
268
cache. u = u_
266
- cache. retcode = ifelse (dep_retcode, NonlinearSafeTerminationReturnCode . Default,
267
- ReturnCode. Default)
269
+ cache. retcode = ifelse (
270
+ dep_retcode, NonlinearSafeTerminationReturnCode . Default, ReturnCode. Default)
268
271
269
272
cache. abstol = _get_tolerance (abstol, T)
270
273
cache. reltol = _get_tolerance (reltol, T)
293
296
294
297
# This dispatch is needed based on how Terminating Callback works!
295
298
# This intentially drops the `abstol` and `reltol` arguments
296
- function (cache:: NonlinearTerminationModeCache )(integrator :: SciMLBase.AbstractODEIntegrator ,
297
- abstol:: Number , reltol:: Number , min_t)
299
+ function (cache:: NonlinearTerminationModeCache )(
300
+ integrator :: SciMLBase.AbstractODEIntegrator , abstol:: Number , reltol:: Number , min_t)
298
301
retval = cache (cache. mode, get_du (integrator), integrator. u, integrator. uprev)
299
302
(min_t === nothing || integrator. t ≥ min_t) && return retval
300
303
return false
@@ -303,8 +306,8 @@ function (cache::NonlinearTerminationModeCache)(du, u, uprev, args...)
303
306
return cache (cache. mode, du, u, uprev, args... )
304
307
end
305
308
306
- function (cache:: NonlinearTerminationModeCache )(mode :: AbstractNonlinearTerminationMode , du,
307
- u, uprev, args... )
309
+ function (cache:: NonlinearTerminationModeCache )(
310
+ mode :: AbstractNonlinearTerminationMode , du, u, uprev, args... )
308
311
return check_convergence (mode, du, u, uprev, cache. abstol, cache. reltol)
309
312
end
310
313
@@ -322,15 +325,17 @@ function (cache::NonlinearTerminationModeCache{dep_retcode})(
322
325
323
326
# Protective Break
324
327
if isinf (objective) || isnan (objective)
325
- cache. retcode = ifelse (dep_retcode,
326
- NonlinearSafeTerminationReturnCode. ProtectiveTermination, ReturnCode. Unstable)
328
+ cache. retcode = ifelse (
329
+ dep_retcode, NonlinearSafeTerminationReturnCode. ProtectiveTermination,
330
+ ReturnCode. Unstable)
327
331
return true
328
332
end
329
333
# # By default we turn this off since it has the potential for false positives
330
334
if cache. mode. protective_threshold != = nothing &&
331
335
(objective > cache. initial_objective * cache. mode. protective_threshold * length (du))
332
- cache. retcode = ifelse (dep_retcode,
333
- NonlinearSafeTerminationReturnCode. ProtectiveTermination, ReturnCode. Unstable)
336
+ cache. retcode = ifelse (
337
+ dep_retcode, NonlinearSafeTerminationReturnCode. ProtectiveTermination,
338
+ ReturnCode. Unstable)
334
339
return true
335
340
end
336
341
@@ -346,8 +351,8 @@ function (cache::NonlinearTerminationModeCache{dep_retcode})(
346
351
347
352
# Main Termination Condition
348
353
if objective ≤ criteria
349
- cache. retcode = ifelse (dep_retcode,
350
- NonlinearSafeTerminationReturnCode. Success, ReturnCode. Success)
354
+ cache. retcode = ifelse (
355
+ dep_retcode, NonlinearSafeTerminationReturnCode. Success, ReturnCode. Success)
351
356
return true
352
357
end
353
358
@@ -364,8 +369,8 @@ function (cache::NonlinearTerminationModeCache{dep_retcode})(
364
369
min_obj, max_obj = extrema (cache. objectives_trace)
365
370
end
366
371
if min_obj < cache. mode. min_max_factor * max_obj
367
- cache. retcode = ifelse (dep_retcode,
368
- NonlinearSafeTerminationReturnCode. PatienceTermination,
372
+ cache. retcode = ifelse (
373
+ dep_retcode, NonlinearSafeTerminationReturnCode. PatienceTermination,
369
374
ReturnCode. Stalled)
370
375
return true
371
376
end
@@ -391,22 +396,22 @@ function (cache::NonlinearTerminationModeCache{dep_retcode})(
391
396
cache. reltol * (max_step_norm + cache. u0_norm)
392
397
end
393
398
if stalled_step
394
- cache. retcode = ifelse (dep_retcode,
395
- NonlinearSafeTerminationReturnCode. PatienceTermination,
399
+ cache. retcode = ifelse (
400
+ dep_retcode, NonlinearSafeTerminationReturnCode. PatienceTermination,
396
401
ReturnCode. Stalled)
397
402
return true
398
403
end
399
404
end
400
405
end
401
406
402
- cache. retcode = ifelse (dep_retcode,
403
- NonlinearSafeTerminationReturnCode. Failure, ReturnCode. Failure)
407
+ cache. retcode = ifelse (
408
+ dep_retcode, NonlinearSafeTerminationReturnCode. Failure, ReturnCode. Failure)
404
409
return false
405
410
end
406
411
407
412
# Check Convergence
408
- function check_convergence (:: SteadyStateDiffEqTerminationMode , duₙ, uₙ, uₙ₋₁, abstol,
409
- reltol)
413
+ function check_convergence (
414
+ :: SteadyStateDiffEqTerminationMode , duₙ, uₙ, uₙ₋₁, abstol, reltol)
410
415
if __fast_scalar_indexing (duₙ, uₙ)
411
416
return all (@closure (xy-> begin
412
417
x, y = xy
435
440
function check_convergence (:: RelTerminationMode , duₙ, uₙ, uₙ₋₁, abstol, reltol)
436
441
if __fast_scalar_indexing (duₙ, uₙ)
437
442
return all (@closure (xy-> begin
438
- x, y = xy
439
- return abs (x) ≤ reltol * abs (y)
440
- end ), zip (duₙ, uₙ))
443
+ x, y = xy
444
+ return abs (x) ≤ reltol * abs (y)
445
+ end ), zip (duₙ, uₙ))
441
446
else
442
447
return all (@. abs (duₙ) ≤ reltol * abs (uₙ + duₙ))
443
448
end
@@ -454,14 +459,22 @@ end
454
459
function check_convergence (
455
460
mode:: Union {
456
461
RelNormTerminationMode, RelSafeTerminationMode, RelSafeBestTerminationMode},
457
- duₙ, uₙ, uₙ₋₁, abstol, reltol)
462
+ duₙ,
463
+ uₙ,
464
+ uₙ₋₁,
465
+ abstol,
466
+ reltol)
458
467
return __apply_termination_internalnorm (mode. internalnorm, duₙ) ≤
459
468
reltol * __add_and_norm (mode. internalnorm, duₙ, uₙ)
460
469
end
461
470
function check_convergence (
462
- mode:: Union {AbsNormTerminationMode, AbsSafeTerminationMode,
463
- AbsSafeBestTerminationMode},
464
- duₙ, uₙ, uₙ₋₁, abstol, reltol)
471
+ mode:: Union {
472
+ AbsNormTerminationMode, AbsSafeTerminationMode, AbsSafeBestTerminationMode},
473
+ duₙ,
474
+ uₙ,
475
+ uₙ₋₁,
476
+ abstol,
477
+ reltol)
465
478
return __apply_termination_internalnorm (mode. internalnorm, duₙ) ≤ abstol
466
479
end
467
480
@@ -472,13 +485,11 @@ end
472
485
473
486
# Nonlinear Solve Norm (norm(_, 2))
474
487
NONLINEARSOLVE_DEFAULT_NORM (u:: Union{AbstractFloat, Complex} ) = @fastmath abs (u)
475
- function NONLINEARSOLVE_DEFAULT_NORM (f:: F ,
476
- u:: Union{AbstractFloat, Complex} ) where {F}
488
+ function NONLINEARSOLVE_DEFAULT_NORM (f:: F , u:: Union{AbstractFloat, Complex} ) where {F}
477
489
return @fastmath abs (f (u))
478
490
end
479
491
480
- function NONLINEARSOLVE_DEFAULT_NORM (u:: Array {
481
- T}) where {T <: Union{AbstractFloat, Complex} }
492
+ function NONLINEARSOLVE_DEFAULT_NORM (u:: Array{T} ) where {T <: Union{AbstractFloat, Complex} }
482
493
x = zero (T)
483
494
@inbounds @fastmath for ui in u
484
495
x += abs2 (ui)
@@ -501,9 +512,8 @@ function NONLINEARSOLVE_DEFAULT_NORM(u::StaticArray{
501
512
return Base. FastMath. sqrt_fast (real (sum (abs2, u)))
502
513
end
503
514
504
- function NONLINEARSOLVE_DEFAULT_NORM (f:: F ,
505
- u:: StaticArray{<:Tuple, T} ) where {
506
- F, T <: Union{AbstractFloat, Complex} }
515
+ function NONLINEARSOLVE_DEFAULT_NORM (
516
+ f:: F , u:: StaticArray{<:Tuple, T} ) where {F, T <: Union{AbstractFloat, Complex} }
507
517
return Base. FastMath. sqrt_fast (real (sum (abs2 ∘ f, u)))
508
518
end
509
519
@@ -525,9 +535,9 @@ NONLINEARSOLVE_DEFAULT_NORM(f::F, u) where {F} = norm(f.(u))
525
535
@inline function __maximum (op:: F , x, y) where {F}
526
536
if __fast_scalar_indexing (x, y)
527
537
return maximum (@closure ((xᵢyᵢ)-> begin
528
- xᵢ, yᵢ = xᵢyᵢ
529
- return op (xᵢ, yᵢ)
530
- end ), zip (x, y))
538
+ xᵢ, yᵢ = xᵢyᵢ
539
+ return op (xᵢ, yᵢ)
540
+ end ), zip (x, y))
531
541
else
532
542
return mapreduce (@closure ((xᵢ, yᵢ)-> op (xᵢ, yᵢ)), max, x, y)
533
543
end
536
546
@inline function __norm_op (:: typeof (Base. Fix2 (norm, 2 )), op:: F , x, y) where {F}
537
547
if __fast_scalar_indexing (x, y)
538
548
return sqrt (sum (@closure ((xᵢyᵢ)-> begin
539
- xᵢ, yᵢ = xᵢyᵢ
540
- return op (xᵢ, yᵢ)^ 2
541
- end ), zip (x, y)))
549
+ xᵢ, yᵢ = xᵢyᵢ
550
+ return op (xᵢ, yᵢ)^ 2
551
+ end ), zip (x, y)))
542
552
else
543
553
return sqrt (mapreduce (@closure ((xᵢ, yᵢ)-> (op (xᵢ, yᵢ)^ 2 )), + , x, y))
544
554
end
0 commit comments