Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

[WIP] Testing out a vmap implementation #141

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SimpleNonlinearSolve"
uuid = "727e6d20-b764-4bd8-a329-72de5adea6c7"
authors = ["SciML"]
version = "1.8.0"
version = "1.8.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
4 changes: 2 additions & 2 deletions src/nlsolve/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,15 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
fx_norm_new = NONLINEARSOLVE_DEFAULT_NORM(fx)^nexp

while k < maxiters
Bool(fx_norm_new ≤ (f_bar + η - γ * α_p^2 * fx_norm)) && break
all(fx_norm_new ≤ (f_bar + η - γ * α_p^2 * fx_norm)) && break

α_tp = α_p^2 * fx_norm / (fx_norm_new + (T(2) * α_p - T(1)) * fx_norm)
@bb @. x_cache = x - α_m * d

fx = __eval_f(prob, fx, x_cache)
fx_norm_new = NONLINEARSOLVE_DEFAULT_NORM(fx)^nexp

Bool(fx_norm_new ≤ (f_bar + η - γ * α_m^2 * fx_norm)) && break
all(fx_norm_new ≤ (f_bar + η - γ * α_m^2 * fx_norm)) && break

α_tm = α_m^2 * fx_norm / (fx_norm_new + (T(2) * α_m - T(1)) * fx_norm)
α_p = clamp(α_tp, τ_min * α_p, τ_max * α_p)
Expand Down
3 changes: 1 addition & 2 deletions src/nlsolve/halley.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
cᵢ = _restructure(cᵢ, cᵢ_)

if i == 1
if iszero(fx)
all(iszero(fx)) &&
return build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
end
else
# Termination Checks
tc_sol = check_termination(tc_cache, fx, x, xo, prob, alg)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...;
@bb δx² = similar(x)

for _ in 1:maxiters
any(iszero, J) && (J = __init_identity_jacobian!!(J))
any(any(iszero, J)) && (J = __init_identity_jacobian!!(J))

@bb @. δx = fprev / J

Expand Down
3 changes: 2 additions & 1 deletion src/nlsolve/raphson.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresPr
fx, dfx = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)

if i == 1
iszero(fx) && build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
all(iszero(fx)) &&
build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
else
# Termination Checks
tc_sol = check_termination(tc_cache, fx, x, xo, prob, alg)
Expand Down
2 changes: 1 addition & 1 deletion src/nlsolve/trustRegion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleTrustRegion, args.
termination_condition)

# Set default trust region radius if not specified by user.
Δₘₐₓ == 0 && (Δₘₐₓ = max(norm_fx, maximum(x) - minimum(x)))
Δₘₐₓ = ifelse(iszero(Δₘₐₓ), max(norm_fx, maximum(x) - minimum(x)), Δₘₐₓ)
if Δ == 0
if _unwrap_val(alg.nlsolve_update_rule)
norm_x = norm(x)
Expand Down
36 changes: 16 additions & 20 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
struct SimpleNonlinearSolveTag end

function ForwardDiff.checktag(::Type{<:ForwardDiff.Tag{<:SimpleNonlinearSolveTag, <:T}},
f::F, x::AbstractArray{T}) where {T, F}
return true
end

"""
__prevfloat_tdir(x, x0, x1)

Expand All @@ -26,9 +19,12 @@
"""
__max_tdir(a, b, x0, x1) = ifelse(x1 > x0, max(a, b), min(a, b))

__standard_tag(::Nothing, x) = ForwardDiff.Tag(SimpleNonlinearSolveTag(), eltype(x))
__standard_tag(tag::ForwardDiff.Tag, _) = tag
__standard_tag(tag, x) = ForwardDiff.Tag(tag, eltype(x))
__standard_tag(::Nothing, f::F, x::AbstractArray{T}) where {F, T} = ForwardDiff.Tag(f, T)
__standard_tag(tag::ForwardDiff.Tag, f::F, x::AbstractArray{T}) where {F, T} = tag
__standard_tag(tag, f::F, x::AbstractArray{T}) where {F, T} = ForwardDiff.Tag(tag, T)

Check warning on line 24 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L23-L24

Added lines #L23 - L24 were not covered by tests
__standard_tag(::Nothing, f::F, x::T) where {F, T <: Number} = ForwardDiff.Tag(f, T)
__standard_tag(tag::ForwardDiff.Tag, f::F, x::T) where {F, T <: Number} = tag
__standard_tag(tag, f::F, x::T) where {F, T <: Number} = ForwardDiff.Tag(tag, T)

Check warning on line 27 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L26-L27

Added lines #L26 - L27 were not covered by tests

__pick_forwarddiff_chunk(x) = ForwardDiff.Chunk(length(x))
function __pick_forwarddiff_chunk(x::StaticArray)
Expand All @@ -42,12 +38,12 @@

function __get_jacobian_config(ad::AutoForwardDiff{CS}, f::F, x) where {F, CS}
ck = (CS === nothing || CS ≤ 0) ? __pick_forwarddiff_chunk(x) : ForwardDiff.Chunk{CS}()
tag = __standard_tag(ad.tag, x)
tag = __standard_tag(ad.tag, f, x)
return __forwarddiff_jacobian_config(f, x, ck, tag)
end
function __get_jacobian_config(ad::AutoForwardDiff{CS}, f!::F, y, x) where {F, CS}
ck = (CS === nothing || CS ≤ 0) ? __pick_forwarddiff_chunk(x) : ForwardDiff.Chunk{CS}()
tag = __standard_tag(ad.tag, x)
tag = __standard_tag(ad.tag, f!, x)
return ForwardDiff.JacobianConfig(f!, y, x, ck, tag)
end

Expand Down Expand Up @@ -83,7 +79,7 @@
return y, J
elseif ad isa AutoForwardDiff
res = DiffResults.DiffResult(y, J)
ForwardDiff.jacobian!(res, _f, y, x, cache)
ForwardDiff.jacobian!(res, _f, y, x, cache, Val(false))
return DiffResults.value(res), DiffResults.jacobian(res)
elseif ad isa AutoFiniteDiff
FiniteDiff.finite_difference_jacobian!(J, _f, x, cache)
Expand All @@ -102,10 +98,10 @@
elseif ad isa AutoForwardDiff
if ArrayInterface.can_setindex(x)
res = DiffResults.DiffResult(y, J)
ForwardDiff.jacobian!(res, _f, x, cache)
ForwardDiff.jacobian!(res, _f, x, cache, Val(false))
return DiffResults.value(res), DiffResults.jacobian(res)
else
J_fd = ForwardDiff.jacobian(_f, x, cache)
J_fd = ForwardDiff.jacobian(_f, x, cache, Val(false))
return _f(x), J_fd
end
elseif ad isa AutoFiniteDiff
Expand All @@ -127,12 +123,12 @@
if DiffEqBase.has_jac(f)
return f(x, p), f.jac(x, p)
elseif ad isa AutoForwardDiff
T = typeof(__standard_tag(ad.tag, x))
T = typeof(__standard_tag(ad.tag, f, x))
out = f(ForwardDiff.Dual{T}(x, one(x)), p)
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
elseif ad isa AutoPolyesterForwardDiff
# Just use ForwardDiff
T = typeof(__standard_tag(nothing, x))
T = typeof(__standard_tag(nothing, f, x))
out = f(ForwardDiff.Dual{T}(x, one(x)), p)
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
elseif ad isa AutoFiniteDiff
Expand Down Expand Up @@ -321,19 +317,19 @@
end
function check_termination(tc_cache, fx, x, xo, prob, alg,
::AbstractNonlinearTerminationMode)
tc_cache(fx, x, xo) &&
all(tc_cache(fx, x, xo)) &&
return build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
return nothing
end
function check_termination(tc_cache, fx, x, xo, prob, alg,
::AbstractSafeNonlinearTerminationMode)
tc_cache(fx, x, xo) &&
all(tc_cache(fx, x, xo)) &&
return build_solution(prob, alg, x, fx; retcode = tc_cache.retcode)
return nothing
end
function check_termination(tc_cache, fx, x, xo, prob, alg,
::AbstractSafeBestNonlinearTerminationMode)
if tc_cache(fx, x, xo)
if all(tc_cache(fx, x, xo))
if isinplace(prob)
prob.f(fx, x, prob.p)
else
Expand Down
Loading