Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 908a98e

Browse files
committed
Format
1 parent cb3723f commit 908a98e

File tree

9 files changed

+43
-39
lines changed

9 files changed

+43
-39
lines changed

ext/SimpleNonlinearSolveADLinearSolveExt.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
module SimpleNonlinearSolveADLinearSolveExt
22

3-
using AbstractDifferentiation, ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve,
3+
using AbstractDifferentiation,
4+
ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve,
45
SimpleNonlinearSolve, SciMLBase
5-
import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _result_from_storage, _get_tolerance, @maybeinplace
6+
import SimpleNonlinearSolve: _construct_batched_problem_structure,
7+
_get_storage, _result_from_storage, _get_tolerance, @maybeinplace
68

79
const AD = AbstractDifferentiation
810

@@ -20,19 +22,18 @@ function SimpleNonlinearSolve.SimpleBatchedNewtonRaphson(; chunk_size = Val{0}()
2022
# TODO: Use `diff_type`. FiniteDiff.jl is currently not available in AD.jl
2123
chunksize = SciMLBase._unwrap_val(chunk_size) == 0 ? nothing : chunk_size
2224
ad = SciMLBase._unwrap_val(autodiff) ?
23-
AD.ForwardDiffBackend(; chunksize) :
24-
AD.FiniteDifferencesBackend()
25-
return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(
26-
ad,
25+
AD.ForwardDiffBackend(; chunksize) :
26+
AD.FiniteDifferencesBackend()
27+
return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(ad,
2728
nothing,
2829
termination_condition)
2930
end
3031

3132
function SciMLBase.__solve(prob::NonlinearProblem,
3233
alg::SimpleBatchedNewtonRaphson;
33-
abstol=nothing,
34-
reltol=nothing,
35-
maxiters=1000,
34+
abstol = nothing,
35+
reltol = nothing,
36+
maxiters = 1000,
3637
kwargs...)
3738
iip = isinplace(prob)
3839
@assert !iip "SimpleBatchedNewtonRaphson currently only supports out-of-place nonlinear problems."
@@ -57,9 +58,9 @@ function SciMLBase.__solve(prob::NonlinearProblem,
5758
alg,
5859
reconstruct(xₙ),
5960
reconstruct(fₙ);
60-
retcode=ReturnCode.Success)
61+
retcode = ReturnCode.Success)
6162

62-
solve(LinearProblem(𝓙, vec(fₙ); u0=vec(δx)), alg.linsolve; kwargs...)
63+
solve(LinearProblem(𝓙, vec(fₙ); u0 = vec(δx)), alg.linsolve; kwargs...)
6364
xₙ .-= δx
6465

6566
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
@@ -83,7 +84,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
8384
alg,
8485
reconstruct(xₙ),
8586
reconstruct(fₙ);
86-
retcode=ReturnCode.MaxIters)
87+
retcode = ReturnCode.MaxIters)
8788
end
8889

8990
end

ext/SimpleNonlinearSolveNNlibExt.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
module SimpleNonlinearSolveNNlibExt
22

33
using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase
4-
import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace
4+
import SimpleNonlinearSolve: _construct_batched_problem_structure,
5+
_get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace
56

67
function __init__()
78
SimpleNonlinearSolve.NNlibExtLoaded[] = true
@@ -10,9 +11,9 @@ end
1011

1112
@views function SciMLBase.__solve(prob::NonlinearProblem,
1213
alg::BatchedBroyden;
13-
abstol=nothing,
14-
reltol=nothing,
15-
maxiters=1000,
14+
abstol = nothing,
15+
reltol = nothing,
16+
maxiters = 1000,
1617
kwargs...)
1718
iip = isinplace(prob)
1819

@@ -74,7 +75,7 @@ end
7475
alg,
7576
reconstruct(xₙ),
7677
reconstruct(fₙ);
77-
retcode=ReturnCode.MaxIters)
78+
retcode = ReturnCode.MaxIters)
7879
end
7980

8081
end

src/SimpleNonlinearSolve.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonline
2222
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
2323
abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end
2424
abstract type AbstractImmutableNonlinearSolver <: AbstractSimpleNonlinearSolveAlgorithm end
25-
abstract type AbstractBatchedNonlinearSolveAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
25+
abstract type AbstractBatchedNonlinearSolveAlgorithm <:
26+
AbstractSimpleNonlinearSolveAlgorithm end
2627

2728
include("utils.jl")
2829
include("bisection.jl")

src/batched/broyden.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
struct BatchedBroyden{TC <: NLSolveTerminationCondition} <:
2-
AbstractBatchedNonlinearSolveAlgorithm
2+
AbstractBatchedNonlinearSolveAlgorithm
33
termination_condition::TC
44
end
55

src/batched/dfsane.jl

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <:
1+
Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <:
22
AbstractBatchedNonlinearSolveAlgorithm
33
σₘᵢₙ::T = 1.0f-10
44
σₘₐₓ::T = 1.0f+10
@@ -10,17 +10,17 @@
1010
nₑₓₚ::Int = 2
1111
ηₛ::F = (f₍ₙₒᵣₘ₎₁, n, xₙ, fₙ) -> f₍ₙₒᵣₘ₎₁ ./ n .^ 2
1212
termination_condition::TC = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
13-
abstol=nothing,
14-
reltol=nothing)
13+
abstol = nothing,
14+
reltol = nothing)
1515
max_inner_iterations::Int = 1000
1616
end
1717

1818
function SciMLBase.__solve(prob::NonlinearProblem,
1919
alg::SimpleBatchedDFSane,
2020
args...;
21-
abstol=nothing,
22-
reltol=nothing,
23-
maxiters=100,
21+
abstol = nothing,
22+
reltol = nothing,
23+
maxiters = 100,
2424
kwargs...)
2525
iip = isinplace(prob)
2626

@@ -60,7 +60,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
6060
return fₓ
6161
end
6262

63-
@maybeinplace iip fₙ₋₁ = ff!(f₍ₙₒᵣₘ₎ₙ₋₁, xₙ) xₙ
63+
@maybeinplace iip fₙ₋₁=ff!(f₍ₙₒᵣₘ₎ₙ₋₁, xₙ) xₙ
6464
iip && (fₙ = similar(fₙ₋₁))
6565
= repeat(f₍ₙₒᵣₘ₎ₙ₋₁, M, 1)
6666
= similar(ℋ, 1, N)
@@ -79,7 +79,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
7979
fill!(α₋, α₁)
8080
@. xₙ = xₙ₋₁ + α₊ * 𝒹
8181

82-
@maybeinplace iip fₙ = ff!(f₍ₙₒᵣₘ₎ₙ, xₙ)
82+
@maybeinplace iip fₙ=ff!(f₍ₙₒᵣₘ₎ₙ, xₙ)
8383

8484
for _ in 1:(alg.max_inner_iterations)
8585
𝒸 = @.+ η - γ * α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁
@@ -90,15 +90,15 @@ function SciMLBase.__solve(prob::NonlinearProblem,
9090
τₘᵢₙ * α₊,
9191
τₘₐₓ * α₊)
9292
@. xₙ = xₙ₋₁ - α₋ * 𝒹
93-
@maybeinplace iip fₙ = ff!(f₍ₙₒᵣₘ₎ₙ, xₙ)
93+
@maybeinplace iip fₙ=ff!(f₍ₙₒᵣₘ₎ₙ, xₙ)
9494

9595
(sum(f₍ₙₒᵣₘ₎ₙ .≤ 𝒸) N ÷ 2) && break
9696

9797
@. α₋ = clamp(α₋^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₋ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
9898
τₘᵢₙ * α₋,
9999
τₘₐₓ * α₋)
100100
@. xₙ = xₙ₋₁ + α₊ * 𝒹
101-
@maybeinplace iip fₙ = ff!(f₍ₙₒᵣₘ₎ₙ, xₙ)
101+
@maybeinplace iip fₙ=ff!(f₍ₙₒᵣₘ₎ₙ, xₙ)
102102
end
103103

104104
if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
@@ -129,12 +129,12 @@ function SciMLBase.__solve(prob::NonlinearProblem,
129129

130130
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
131131
xₙ = storage.u
132-
@maybeinplace iip fₙ = f(xₙ)
132+
@maybeinplace iip fₙ=f(xₙ)
133133
end
134134

135135
return DiffEqBase.build_solution(prob,
136136
alg,
137137
reconstruct(xₙ),
138138
reconstruct(fₙ);
139-
retcode=ReturnCode.MaxIters)
139+
retcode = ReturnCode.MaxIters)
140140
end

src/batched/lbroyden.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

src/batched/raphson.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
struct SimpleBatchedNewtonRaphson{AD, LS, TC <: NLSolveTerminationCondition} <:
2-
AbstractBatchedNonlinearSolveAlgorithm
2+
AbstractBatchedNonlinearSolveAlgorithm
33
autodiff::AD
44
linsolve::LS
55
termination_condition::TC

src/batched/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
macro maybeinplace(iip::Symbol, expr::Expr, u0::Union{Symbol, Nothing}=nothing)
1+
macro maybeinplace(iip::Symbol, expr::Expr, u0::Union{Symbol, Nothing} = nothing)
22
@assert expr.head == :(=)
33
x1, x2 = expr.args
44
@assert x2.head == :call
@@ -64,7 +64,7 @@ function _result_from_storage(storage::NLSolveSafeTerminationResult, xₙ, fₙ,
6464
return ReturnCode.Success, xₙ, fₙ
6565
else
6666
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
67-
@maybeinplace iip fₙ = f(xₙ)
67+
@maybeinplace iip fₙ=f(xₙ)
6868
return ReturnCode.Terminated, storage.u, fₙ
6969
else
7070
return ReturnCode.Terminated, xₙ, fₙ

src/raphson.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ and static array problems.
3535
"""
3636
struct SimpleNewtonRaphson{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT} end
3737

38-
function SimpleNewtonRaphson(; batched=false,
38+
function SimpleNewtonRaphson(; batched = false,
3939
chunk_size = Val{0}(),
4040
autodiff = Val{true}(),
4141
diff_type = Val{:forward},
@@ -46,10 +46,10 @@ function SimpleNewtonRaphson(; batched=false,
4646
if batched
4747
@assert ADLinearSolveExtLoaded[] "Please install and load `LinearSolve.jl` and `AbstractDifferentiation.jl` to use batched Newton-Raphson."
4848
termination_condition = ismissing(termination_condition) ?
49-
NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
50-
abstol = nothing,
51-
reltol = nothing) :
52-
termination_condition
49+
NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
50+
abstol = nothing,
51+
reltol = nothing) :
52+
termination_condition
5353
return SimpleBatchedNewtonRaphson(; chunk_size,
5454
autodiff,
5555
diff_type,

0 commit comments

Comments
 (0)