Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit c86f900

Browse files
Merge pull request #173 from SciML/auto-juliaformatter-pr
Automatic JuliaFormatter.jl run
2 parents b73c4b7 + fd0b447 commit c86f900

4 files changed

+85
-82
lines changed

src/SimpleNonlinearSolve.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,7 @@ export SteadyStateDiffEqTerminationMode, SimpleNonlinearSolveTerminationMode,
149149
AbsNormTerminationMode, RelSafeTerminationMode, AbsSafeTerminationMode,
150150
RelSafeBestTerminationMode, AbsSafeBestTerminationMode
151151
# Deprecated API
152-
export NLSolveTerminationMode,
153-
NLSolveSafeTerminationOptions, NLSolveTerminationCondition,
152+
export NLSolveTerminationMode, NLSolveSafeTerminationOptions, NLSolveTerminationCondition,
154153
NLSolveSafeTerminationResult
155154

156155
end # module

src/termination_conditions.jl

Lines changed: 64 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,7 @@ const TERM_DOCS = Dict(
5454
:Rel => doc"``all \left(| \Delta u | \leq reltol \times | u | \right)``.",
5555
:RelNorm => doc"``\| \Delta u \| \leq reltol \times \| \Delta u + u \|``.",
5656
:Abs => doc"``all \left( | \Delta u | \leq abstol \right)``.",
57-
:AbsNorm => doc"``\| \Delta u \| \leq abstol``."
58-
)
57+
:AbsNorm => doc"``\| \Delta u \| \leq abstol``.")
5958

6059
const __TERM_INTERNALNORM_DOCS = """
6160
where `internalnorm` is the norm to use for the termination condition. Special handling is
@@ -148,9 +147,10 @@ for norm_type in (:Rel, :Abs), safety in (:Safe, :SafeBest)
148147
function $(struct_name)(f::F = nothing; protective_threshold = nothing,
149148
patience_steps = 100, patience_objective_multiplier = 3,
150149
min_max_factor = 1.3, max_stalled_steps = nothing) where {F}
151-
return new{__norm_type(f), typeof(max_stalled_steps), F,
152-
typeof(protective_threshold), typeof(patience_objective_multiplier),
153-
typeof(min_max_factor)}(f, protective_threshold, patience_steps,
150+
return new{__norm_type(f), typeof(max_stalled_steps),
151+
F, typeof(protective_threshold),
152+
typeof(patience_objective_multiplier), typeof(min_max_factor)}(
153+
f, protective_threshold, patience_steps,
154154
patience_objective_multiplier, min_max_factor, max_stalled_steps)
155155
end
156156
end
@@ -164,8 +164,8 @@ for norm_type in (:Rel, :Abs), safety in (:Safe, :SafeBest)
164164
end
165165
end
166166

167-
@concrete mutable struct NonlinearTerminationModeCache{dep_retcode,
168-
M <: AbstractNonlinearTerminationMode,
167+
@concrete mutable struct NonlinearTerminationModeCache{
168+
dep_retcode, M <: AbstractNonlinearTerminationMode,
169169
R <: Union{NonlinearSafeTerminationReturnCode.T, ReturnCode.T}}
170170
u
171171
retcode::R
@@ -209,7 +209,8 @@ end
209209
function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T}, T},
210210
mode::AbstractNonlinearTerminationMode, saved_value_prototype...;
211211
use_deprecated_retcodes::Val{D} = Val(true), # Remove in v8, warn in v7
212-
abstol = nothing, reltol = nothing, kwargs...) where {D, T <: Number}
212+
abstol = nothing,
213+
reltol = nothing, kwargs...) where {D, T <: Number}
213214
abstol = _get_tolerance(abstol, T)
214215
reltol = _get_tolerance(reltol, T)
215216
TT = typeof(abstol)
@@ -229,7 +230,8 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
229230
Vector{TT}(undef, mode.max_stalled_steps)
230231
best_value = initial_objective
231232
max_stalled_steps = mode.max_stalled_steps
232-
if ArrayInterface.can_setindex(u_) && !(u_ isa Number) &&
233+
if ArrayInterface.can_setindex(u_) &&
234+
!(u_ isa Number) &&
233235
step_norm_trace !== nothing
234236
u_diff_cache = similar(u_)
235237
else
@@ -249,22 +251,23 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
249251

250252
retcode = ifelse(D, NonlinearSafeTerminationReturnCode.Default, ReturnCode.Default)
251253

252-
return NonlinearTerminationModeCache{D}(u_, retcode, abstol, reltol, best_value, mode,
253-
initial_objective, objectives_trace, 0, saved_value_prototype, u0_norm,
254+
return NonlinearTerminationModeCache{D}(
255+
u_, retcode, abstol, reltol, best_value, mode, initial_objective,
256+
objectives_trace, 0, saved_value_prototype, u0_norm,
254257
step_norm_trace, max_stalled_steps, u_diff_cache)
255258
end
256259

257-
function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{dep_retcode}, du,
258-
u, saved_value_prototype...; abstol = nothing, reltol = nothing,
259-
kwargs...) where {dep_retcode}
260+
function SciMLBase.reinit!(
261+
cache::NonlinearTerminationModeCache{dep_retcode}, du, u, saved_value_prototype...;
262+
abstol = nothing, reltol = nothing, kwargs...) where {dep_retcode}
260263
T = eltype(cache.abstol)
261264
length(saved_value_prototype) != 0 && (cache.saved_values = saved_value_prototype)
262265

263266
u_ = cache.mode isa AbstractSafeBestNonlinearTerminationMode ?
264267
(ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
265268
cache.u = u_
266-
cache.retcode = ifelse(dep_retcode, NonlinearSafeTerminationReturnCode.Default,
267-
ReturnCode.Default)
269+
cache.retcode = ifelse(
270+
dep_retcode, NonlinearSafeTerminationReturnCode.Default, ReturnCode.Default)
268271

269272
cache.abstol = _get_tolerance(abstol, T)
270273
cache.reltol = _get_tolerance(reltol, T)
@@ -293,8 +296,8 @@ end
293296

294297
# This dispatch is needed based on how Terminating Callback works!
295298
# This intentially drops the `abstol` and `reltol` arguments
296-
function (cache::NonlinearTerminationModeCache)(integrator::SciMLBase.AbstractODEIntegrator,
297-
abstol::Number, reltol::Number, min_t)
299+
function (cache::NonlinearTerminationModeCache)(
300+
integrator::SciMLBase.AbstractODEIntegrator, abstol::Number, reltol::Number, min_t)
298301
retval = cache(cache.mode, get_du(integrator), integrator.u, integrator.uprev)
299302
(min_t === nothing || integrator.t min_t) && return retval
300303
return false
@@ -303,8 +306,8 @@ function (cache::NonlinearTerminationModeCache)(du, u, uprev, args...)
303306
return cache(cache.mode, du, u, uprev, args...)
304307
end
305308

306-
function (cache::NonlinearTerminationModeCache)(mode::AbstractNonlinearTerminationMode, du,
307-
u, uprev, args...)
309+
function (cache::NonlinearTerminationModeCache)(
310+
mode::AbstractNonlinearTerminationMode, du, u, uprev, args...)
308311
return check_convergence(mode, du, u, uprev, cache.abstol, cache.reltol)
309312
end
310313

@@ -322,15 +325,17 @@ function (cache::NonlinearTerminationModeCache{dep_retcode})(
322325

323326
# Protective Break
324327
if isinf(objective) || isnan(objective)
325-
cache.retcode = ifelse(dep_retcode,
326-
NonlinearSafeTerminationReturnCode.ProtectiveTermination, ReturnCode.Unstable)
328+
cache.retcode = ifelse(
329+
dep_retcode, NonlinearSafeTerminationReturnCode.ProtectiveTermination,
330+
ReturnCode.Unstable)
327331
return true
328332
end
329333
## By default we turn this off since it has the potential for false positives
330334
if cache.mode.protective_threshold !== nothing &&
331335
(objective > cache.initial_objective * cache.mode.protective_threshold * length(du))
332-
cache.retcode = ifelse(dep_retcode,
333-
NonlinearSafeTerminationReturnCode.ProtectiveTermination, ReturnCode.Unstable)
336+
cache.retcode = ifelse(
337+
dep_retcode, NonlinearSafeTerminationReturnCode.ProtectiveTermination,
338+
ReturnCode.Unstable)
334339
return true
335340
end
336341

@@ -346,8 +351,8 @@ function (cache::NonlinearTerminationModeCache{dep_retcode})(
346351

347352
# Main Termination Condition
348353
if objective criteria
349-
cache.retcode = ifelse(dep_retcode,
350-
NonlinearSafeTerminationReturnCode.Success, ReturnCode.Success)
354+
cache.retcode = ifelse(
355+
dep_retcode, NonlinearSafeTerminationReturnCode.Success, ReturnCode.Success)
351356
return true
352357
end
353358

@@ -364,8 +369,8 @@ function (cache::NonlinearTerminationModeCache{dep_retcode})(
364369
min_obj, max_obj = extrema(cache.objectives_trace)
365370
end
366371
if min_obj < cache.mode.min_max_factor * max_obj
367-
cache.retcode = ifelse(dep_retcode,
368-
NonlinearSafeTerminationReturnCode.PatienceTermination,
372+
cache.retcode = ifelse(
373+
dep_retcode, NonlinearSafeTerminationReturnCode.PatienceTermination,
369374
ReturnCode.Stalled)
370375
return true
371376
end
@@ -391,22 +396,22 @@ function (cache::NonlinearTerminationModeCache{dep_retcode})(
391396
cache.reltol * (max_step_norm + cache.u0_norm)
392397
end
393398
if stalled_step
394-
cache.retcode = ifelse(dep_retcode,
395-
NonlinearSafeTerminationReturnCode.PatienceTermination,
399+
cache.retcode = ifelse(
400+
dep_retcode, NonlinearSafeTerminationReturnCode.PatienceTermination,
396401
ReturnCode.Stalled)
397402
return true
398403
end
399404
end
400405
end
401406

402-
cache.retcode = ifelse(dep_retcode,
403-
NonlinearSafeTerminationReturnCode.Failure, ReturnCode.Failure)
407+
cache.retcode = ifelse(
408+
dep_retcode, NonlinearSafeTerminationReturnCode.Failure, ReturnCode.Failure)
404409
return false
405410
end
406411

407412
# Check Convergence
408-
function check_convergence(::SteadyStateDiffEqTerminationMode, duₙ, uₙ, uₙ₋₁, abstol,
409-
reltol)
413+
function check_convergence(
414+
::SteadyStateDiffEqTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol)
410415
if __fast_scalar_indexing(duₙ, uₙ)
411416
return all(@closure(xy->begin
412417
x, y = xy
@@ -435,9 +440,9 @@ end
435440
function check_convergence(::RelTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol)
436441
if __fast_scalar_indexing(duₙ, uₙ)
437442
return all(@closure(xy->begin
438-
x, y = xy
439-
return abs(x) reltol * abs(y)
440-
end), zip(duₙ, uₙ))
443+
x, y = xy
444+
return abs(x) reltol * abs(y)
445+
end), zip(duₙ, uₙ))
441446
else
442447
return all(@. abs(duₙ) reltol * abs(uₙ + duₙ))
443448
end
@@ -454,14 +459,22 @@ end
454459
function check_convergence(
455460
mode::Union{
456461
RelNormTerminationMode, RelSafeTerminationMode, RelSafeBestTerminationMode},
457-
duₙ, uₙ, uₙ₋₁, abstol, reltol)
462+
duₙ,
463+
uₙ,
464+
uₙ₋₁,
465+
abstol,
466+
reltol)
458467
return __apply_termination_internalnorm(mode.internalnorm, duₙ)
459468
reltol * __add_and_norm(mode.internalnorm, duₙ, uₙ)
460469
end
461470
function check_convergence(
462-
mode::Union{AbsNormTerminationMode, AbsSafeTerminationMode,
463-
AbsSafeBestTerminationMode},
464-
duₙ, uₙ, uₙ₋₁, abstol, reltol)
471+
mode::Union{
472+
AbsNormTerminationMode, AbsSafeTerminationMode, AbsSafeBestTerminationMode},
473+
duₙ,
474+
uₙ,
475+
uₙ₋₁,
476+
abstol,
477+
reltol)
465478
return __apply_termination_internalnorm(mode.internalnorm, duₙ) abstol
466479
end
467480

@@ -472,13 +485,11 @@ end
472485

473486
# Nonlinear Solve Norm (norm(_, 2))
474487
NONLINEARSOLVE_DEFAULT_NORM(u::Union{AbstractFloat, Complex}) = @fastmath abs(u)
475-
function NONLINEARSOLVE_DEFAULT_NORM(f::F,
476-
u::Union{AbstractFloat, Complex}) where {F}
488+
function NONLINEARSOLVE_DEFAULT_NORM(f::F, u::Union{AbstractFloat, Complex}) where {F}
477489
return @fastmath abs(f(u))
478490
end
479491

480-
function NONLINEARSOLVE_DEFAULT_NORM(u::Array{
481-
T}) where {T <: Union{AbstractFloat, Complex}}
492+
function NONLINEARSOLVE_DEFAULT_NORM(u::Array{T}) where {T <: Union{AbstractFloat, Complex}}
482493
x = zero(T)
483494
@inbounds @fastmath for ui in u
484495
x += abs2(ui)
@@ -501,9 +512,8 @@ function NONLINEARSOLVE_DEFAULT_NORM(u::StaticArray{
501512
return Base.FastMath.sqrt_fast(real(sum(abs2, u)))
502513
end
503514

504-
function NONLINEARSOLVE_DEFAULT_NORM(f::F,
505-
u::StaticArray{<:Tuple, T}) where {
506-
F, T <: Union{AbstractFloat, Complex}}
515+
function NONLINEARSOLVE_DEFAULT_NORM(
516+
f::F, u::StaticArray{<:Tuple, T}) where {F, T <: Union{AbstractFloat, Complex}}
507517
return Base.FastMath.sqrt_fast(real(sum(abs2 f, u)))
508518
end
509519

@@ -525,9 +535,9 @@ NONLINEARSOLVE_DEFAULT_NORM(f::F, u) where {F} = norm(f.(u))
525535
@inline function __maximum(op::F, x, y) where {F}
526536
if __fast_scalar_indexing(x, y)
527537
return maximum(@closure((xᵢyᵢ)->begin
528-
xᵢ, yᵢ = xᵢyᵢ
529-
return op(xᵢ, yᵢ)
530-
end), zip(x, y))
538+
xᵢ, yᵢ = xᵢyᵢ
539+
return op(xᵢ, yᵢ)
540+
end), zip(x, y))
531541
else
532542
return mapreduce(@closure((xᵢ, yᵢ)->op(xᵢ, yᵢ)), max, x, y)
533543
end
@@ -536,9 +546,9 @@ end
536546
@inline function __norm_op(::typeof(Base.Fix2(norm, 2)), op::F, x, y) where {F}
537547
if __fast_scalar_indexing(x, y)
538548
return sqrt(sum(@closure((xᵢyᵢ)->begin
539-
xᵢ, yᵢ = xᵢyᵢ
540-
return op(xᵢ, yᵢ)^2
541-
end), zip(x, y)))
549+
xᵢ, yᵢ = xᵢyᵢ
550+
return op(xᵢ, yᵢ)^2
551+
end), zip(x, y)))
542552
else
543553
return sqrt(mapreduce(@closure((xᵢ, yᵢ)->(op(xᵢ, yᵢ)^2)), +, x, y))
544554
end

src/termination_conditions_deprecated.jl

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -79,30 +79,28 @@ mutable struct NLSolveSafeTerminationResult{T, uType}
7979
return_code::NLSolveSafeTerminationReturnCode.T
8080
end
8181

82-
function NLSolveSafeTerminationResult(u = nothing; best_objective_value = Inf64,
83-
best_objective_value_iteration = 0,
82+
function NLSolveSafeTerminationResult(
83+
u = nothing; best_objective_value = Inf64, best_objective_value_iteration = 0,
8484
return_code = NLSolveSafeTerminationReturnCode.Failure)
8585
u = u !== nothing ? copy(u) : u
8686
Base.depwarn(
8787
"NLSolveSafeTerminationResult has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!",
8888
:NLSolveSafeTerminationResult)
89-
return NLSolveSafeTerminationResult{typeof(best_objective_value), typeof(u)}(u,
90-
best_objective_value, best_objective_value_iteration, return_code)
89+
return NLSolveSafeTerminationResult{typeof(best_objective_value), typeof(u)}(
90+
u, best_objective_value, best_objective_value_iteration, return_code)
9191
end
9292

9393
const BASIC_TERMINATION_MODES = (NLSolveTerminationMode.SteadyStateDefault,
94-
NLSolveTerminationMode.NLSolveDefault,
95-
NLSolveTerminationMode.Norm, NLSolveTerminationMode.Rel,
96-
NLSolveTerminationMode.RelNorm,
94+
NLSolveTerminationMode.NLSolveDefault, NLSolveTerminationMode.Norm,
95+
NLSolveTerminationMode.Rel, NLSolveTerminationMode.RelNorm,
9796
NLSolveTerminationMode.Abs, NLSolveTerminationMode.AbsNorm)
9897

99-
const SAFE_TERMINATION_MODES = (NLSolveTerminationMode.RelSafe,
100-
NLSolveTerminationMode.RelSafeBest,
101-
NLSolveTerminationMode.AbsSafe,
102-
NLSolveTerminationMode.AbsSafeBest)
98+
const SAFE_TERMINATION_MODES = (
99+
NLSolveTerminationMode.RelSafe, NLSolveTerminationMode.RelSafeBest,
100+
NLSolveTerminationMode.AbsSafe, NLSolveTerminationMode.AbsSafeBest)
103101

104-
const SAFE_BEST_TERMINATION_MODES = (NLSolveTerminationMode.RelSafeBest,
105-
NLSolveTerminationMode.AbsSafeBest)
102+
const SAFE_BEST_TERMINATION_MODES = (
103+
NLSolveTerminationMode.RelSafeBest, NLSolveTerminationMode.AbsSafeBest)
106104

107105
@doc doc"""
108106
NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
@@ -146,8 +144,8 @@ Define the termination criteria for the NonlinearProblem or SteadyStateProblem.
146144
!!! warning
147145
This has been deprecated and will be removed in the next major release. Please use the new dispatch based termination conditions API.
148146
"""
149-
struct NLSolveTerminationCondition{mode, T,
150-
S <: Union{<:NLSolveSafeTerminationOptions, Nothing}}
147+
struct NLSolveTerminationCondition{
148+
mode, T, S <: Union{<:NLSolveSafeTerminationOptions, Nothing}}
151149
abstol::T
152150
reltol::T
153151
safe_termination_options::S
@@ -168,8 +166,7 @@ get_termination_mode(::NLSolveTerminationCondition{mode}) where {mode} = mode
168166
# Don't specify `mode` since the defaults would depend on the package
169167
function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
170168
protective_threshold = 1e3, patience_steps::Int = 30,
171-
patience_objective_multiplier = 3,
172-
min_max_factor = 1.3) where {T}
169+
patience_objective_multiplier = 3, min_max_factor = 1.3) where {T}
173170
Base.depwarn(
174171
"NLSolveTerminationCondition has been deprecated in favor of the new dispatch based termination conditions. Please use the new termination conditions API!",
175172
:NLSolveTerminationCondition)
@@ -184,16 +181,14 @@ function NLSolveTerminationCondition(mode; abstol::T = 1e-8, reltol::T = 1e-6,
184181
end
185182

186183
function (cond::NLSolveTerminationCondition)(storage::Union{
187-
NLSolveSafeTerminationResult,
188-
Nothing
189-
})
184+
NLSolveSafeTerminationResult, Nothing})
190185
mode = get_termination_mode(cond)
191186
# We need both the dispatches to support solvers that don't use the integrator
192187
# interface like SimpleNonlinearSolve
193188
if mode in BASIC_TERMINATION_MODES
194189
function _termination_condition_closure_basic(integrator, abstol, reltol, min_t)
195-
return _termination_condition_closure_basic(get_du(integrator), integrator.u,
196-
integrator.uprev, abstol, reltol)
190+
return _termination_condition_closure_basic(
191+
get_du(integrator), integrator.u, integrator.uprev, abstol, reltol)
197192
end
198193
function _termination_condition_closure_basic(du, u, uprev, abstol, reltol)
199194
return _has_converged(du, u, uprev, cond, abstol, reltol)
@@ -204,8 +199,8 @@ function (cond::NLSolveTerminationCondition)(storage::Union{
204199
nstep::Int = 0
205200

206201
function _termination_condition_closure_safe(integrator, abstol, reltol, min_t)
207-
return _termination_condition_closure_safe(get_du(integrator), integrator.u,
208-
integrator.uprev, abstol, reltol)
202+
return _termination_condition_closure_safe(
203+
get_du(integrator), integrator.u, integrator.uprev, abstol, reltol)
209204
end
210205
@inbounds function _termination_condition_closure_safe(du, u, uprev, abstol, reltol)
211206
aType = typeof(abstol)

src/utils.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,7 @@ function init_termination_cache(
174174
end
175175

176176
function check_termination(tc_cache, fx, x, xo, prob, alg)
177-
return check_termination(
178-
tc_cache, fx, x, xo, prob, alg, get_termination_mode(tc_cache))
177+
return check_termination(tc_cache, fx, x, xo, prob, alg, get_termination_mode(tc_cache))
179178
end
180179
function check_termination(
181180
tc_cache, fx, x, xo, prob, alg, ::AbstractNonlinearTerminationMode)

0 commit comments

Comments
 (0)