Skip to content

Commit

Permalink
Format
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 28, 2023
1 parent cb3723f commit 908a98e
Show file tree
Hide file tree
Showing 9 changed files with 43 additions and 39 deletions.
25 changes: 13 additions & 12 deletions ext/SimpleNonlinearSolveADLinearSolveExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
module SimpleNonlinearSolveADLinearSolveExt

using AbstractDifferentiation, ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve,
using AbstractDifferentiation,
ArrayInterface, DiffEqBase, LinearAlgebra, LinearSolve,
SimpleNonlinearSolve, SciMLBase
import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _result_from_storage, _get_tolerance, @maybeinplace
import SimpleNonlinearSolve: _construct_batched_problem_structure,
_get_storage, _result_from_storage, _get_tolerance, @maybeinplace

const AD = AbstractDifferentiation

Expand All @@ -20,19 +22,18 @@ function SimpleNonlinearSolve.SimpleBatchedNewtonRaphson(; chunk_size = Val{0}()
# TODO: Use `diff_type`. FiniteDiff.jl is currently not available in AD.jl
chunksize = SciMLBase._unwrap_val(chunk_size) == 0 ? nothing : chunk_size
ad = SciMLBase._unwrap_val(autodiff) ?
AD.ForwardDiffBackend(; chunksize) :
AD.FiniteDifferencesBackend()
return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(
ad,
AD.ForwardDiffBackend(; chunksize) :
AD.FiniteDifferencesBackend()
return SimpleBatchedNewtonRaphson{typeof(ad), Nothing, typeof(termination_condition)}(ad,
nothing,
termination_condition)
end

function SciMLBase.__solve(prob::NonlinearProblem,
alg::SimpleBatchedNewtonRaphson;
abstol=nothing,
reltol=nothing,
maxiters=1000,
abstol = nothing,
reltol = nothing,
maxiters = 1000,
kwargs...)
iip = isinplace(prob)
@assert !iip "SimpleBatchedNewtonRaphson currently only supports out-of-place nonlinear problems."
Expand All @@ -57,9 +58,9 @@ function SciMLBase.__solve(prob::NonlinearProblem,
alg,
reconstruct(xₙ),
reconstruct(fₙ);
retcode=ReturnCode.Success)
retcode = ReturnCode.Success)

solve(LinearProblem(𝓙, vec(fₙ); u0=vec(δx)), alg.linsolve; kwargs...)
solve(LinearProblem(𝓙, vec(fₙ); u0 = vec(δx)), alg.linsolve; kwargs...)
xₙ .-= δx

if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
Expand All @@ -83,7 +84,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
alg,
reconstruct(xₙ),
reconstruct(fₙ);
retcode=ReturnCode.MaxIters)
retcode = ReturnCode.MaxIters)
end

end
11 changes: 6 additions & 5 deletions ext/SimpleNonlinearSolveNNlibExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module SimpleNonlinearSolveNNlibExt

using ArrayInterface, DiffEqBase, LinearAlgebra, NNlib, SimpleNonlinearSolve, SciMLBase
import SimpleNonlinearSolve: _construct_batched_problem_structure, _get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace
import SimpleNonlinearSolve: _construct_batched_problem_structure,
_get_storage, _init_𝓙, _result_from_storage, _get_tolerance, @maybeinplace

function __init__()
SimpleNonlinearSolve.NNlibExtLoaded[] = true
Expand All @@ -10,9 +11,9 @@ end

@views function SciMLBase.__solve(prob::NonlinearProblem,
alg::BatchedBroyden;
abstol=nothing,
reltol=nothing,
maxiters=1000,
abstol = nothing,
reltol = nothing,
maxiters = 1000,
kwargs...)
iip = isinplace(prob)

Expand Down Expand Up @@ -74,7 +75,7 @@ end
alg,
reconstruct(xₙ),
reconstruct(fₙ);
retcode=ReturnCode.MaxIters)
retcode = ReturnCode.MaxIters)
end

end
3 changes: 2 additions & 1 deletion src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ abstract type AbstractSimpleNonlinearSolveAlgorithm <: SciMLBase.AbstractNonline
abstract type AbstractBracketingAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
abstract type AbstractNewtonAlgorithm{CS, AD, FDT} <: AbstractSimpleNonlinearSolveAlgorithm end
abstract type AbstractImmutableNonlinearSolver <: AbstractSimpleNonlinearSolveAlgorithm end
abstract type AbstractBatchedNonlinearSolveAlgorithm <: AbstractSimpleNonlinearSolveAlgorithm end
abstract type AbstractBatchedNonlinearSolveAlgorithm <:
AbstractSimpleNonlinearSolveAlgorithm end

include("utils.jl")
include("bisection.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/batched/broyden.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
struct BatchedBroyden{TC <: NLSolveTerminationCondition} <:
AbstractBatchedNonlinearSolveAlgorithm
AbstractBatchedNonlinearSolveAlgorithm
termination_condition::TC
end

Expand Down
24 changes: 12 additions & 12 deletions src/batched/dfsane.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <:
Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <:
AbstractBatchedNonlinearSolveAlgorithm
σₘᵢₙ::T = 1.0f-10
σₘₐₓ::T = 1.0f+10
Expand All @@ -10,17 +10,17 @@
nₑₓₚ::Int = 2
ηₛ::F = (f₍ₙₒᵣₘ₎₁, n, xₙ, fₙ) -> f₍ₙₒᵣₘ₎₁ ./ n .^ 2
termination_condition::TC = NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol=nothing,
reltol=nothing)
abstol = nothing,
reltol = nothing)
max_inner_iterations::Int = 1000
end

function SciMLBase.__solve(prob::NonlinearProblem,
alg::SimpleBatchedDFSane,
args...;
abstol=nothing,
reltol=nothing,
maxiters=100,
abstol = nothing,
reltol = nothing,
maxiters = 100,
kwargs...)
iip = isinplace(prob)

Expand Down Expand Up @@ -60,7 +60,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
return fₓ
end

@maybeinplace iip fₙ₋₁ = ff!(f₍ₙₒᵣₘ₎ₙ₋₁, xₙ) xₙ
@maybeinplace iip fₙ₋₁=ff!(f₍ₙₒᵣₘ₎ₙ₋₁, xₙ) xₙ
iip && (fₙ = similar(fₙ₋₁))
= repeat(f₍ₙₒᵣₘ₎ₙ₋₁, M, 1)
= similar(ℋ, 1, N)
Expand All @@ -79,7 +79,7 @@ function SciMLBase.__solve(prob::NonlinearProblem,
fill!(α₋, α₁)
@. xₙ = xₙ₋₁ + α₊ * 𝒹

@maybeinplace iip fₙ = ff!(f₍ₙₒᵣₘ₎ₙ, xₙ)
@maybeinplace iip fₙ=ff!(f₍ₙₒᵣₘ₎ₙ, xₙ)

for _ in 1:(alg.max_inner_iterations)
𝒸 = @.+ η - γ * α₊^2 * f₍ₙₒᵣₘ₎ₙ₋₁
Expand All @@ -90,15 +90,15 @@ function SciMLBase.__solve(prob::NonlinearProblem,
τₘᵢₙ * α₊,
τₘₐₓ * α₊)
@. xₙ = xₙ₋₁ - α₋ * 𝒹
@maybeinplace iip fₙ = ff!(f₍ₙₒᵣₘ₎ₙ, xₙ)
@maybeinplace iip fₙ=ff!(f₍ₙₒᵣₘ₎ₙ, xₙ)

(sum(f₍ₙₒᵣₘ₎ₙ .≤ 𝒸) N ÷ 2) && break

@. α₋ = clamp(α₋^2 * f₍ₙₒᵣₘ₎ₙ₋₁ / (f₍ₙₒᵣₘ₎ₙ + (T(2) * α₋ - T(1)) * f₍ₙₒᵣₘ₎ₙ₋₁),
τₘᵢₙ * α₋,
τₘₐₓ * α₋)
@. xₙ = xₙ₋₁ + α₊ * 𝒹
@maybeinplace iip fₙ = ff!(f₍ₙₒᵣₘ₎ₙ, xₙ)
@maybeinplace iip fₙ=ff!(f₍ₙₒᵣₘ₎ₙ, xₙ)
end

if termination_condition(fₙ, xₙ, xₙ₋₁, atol, rtol)
Expand Down Expand Up @@ -129,12 +129,12 @@ function SciMLBase.__solve(prob::NonlinearProblem,

if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
xₙ = storage.u
@maybeinplace iip fₙ = f(xₙ)
@maybeinplace iip fₙ=f(xₙ)
end

return DiffEqBase.build_solution(prob,
alg,
reconstruct(xₙ),
reconstruct(fₙ);
retcode=ReturnCode.MaxIters)
retcode = ReturnCode.MaxIters)
end
1 change: 1 addition & 0 deletions src/batched/lbroyden.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

2 changes: 1 addition & 1 deletion src/batched/raphson.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
struct SimpleBatchedNewtonRaphson{AD, LS, TC <: NLSolveTerminationCondition} <:
AbstractBatchedNonlinearSolveAlgorithm
AbstractBatchedNonlinearSolveAlgorithm
autodiff::AD
linsolve::LS
termination_condition::TC
Expand Down
4 changes: 2 additions & 2 deletions src/batched/utils.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
macro maybeinplace(iip::Symbol, expr::Expr, u0::Union{Symbol, Nothing}=nothing)
macro maybeinplace(iip::Symbol, expr::Expr, u0::Union{Symbol, Nothing} = nothing)
@assert expr.head == :(=)
x1, x2 = expr.args
@assert x2.head == :call
Expand Down Expand Up @@ -64,7 +64,7 @@ function _result_from_storage(storage::NLSolveSafeTerminationResult, xₙ, fₙ,
return ReturnCode.Success, xₙ, fₙ
else
if mode DiffEqBase.SAFE_BEST_TERMINATION_MODES
@maybeinplace iip fₙ = f(xₙ)
@maybeinplace iip fₙ=f(xₙ)
return ReturnCode.Terminated, storage.u, fₙ
else
return ReturnCode.Terminated, xₙ, fₙ
Expand Down
10 changes: 5 additions & 5 deletions src/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ and static array problems.
"""
struct SimpleNewtonRaphson{CS, AD, FDT} <: AbstractNewtonAlgorithm{CS, AD, FDT} end

function SimpleNewtonRaphson(; batched=false,
function SimpleNewtonRaphson(; batched = false,
chunk_size = Val{0}(),
autodiff = Val{true}(),
diff_type = Val{:forward},
Expand All @@ -46,10 +46,10 @@ function SimpleNewtonRaphson(; batched=false,
if batched
@assert ADLinearSolveExtLoaded[] "Please install and load `LinearSolve.jl` and `AbstractDifferentiation.jl` to use batched Newton-Raphson."
termination_condition = ismissing(termination_condition) ?
NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing) :
termination_condition
NLSolveTerminationCondition(NLSolveTerminationMode.NLSolveDefault;
abstol = nothing,
reltol = nothing) :
termination_condition
return SimpleBatchedNewtonRaphson(; chunk_size,
autodiff,
diff_type,
Expand Down

0 comments on commit 908a98e

Please sign in to comment.