Skip to content

Commit

Permalink
Handle Klement
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 4, 2024
1 parent 009dc15 commit ffdc60c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/nlsolve/klement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleKlement, args...;
@bb δx² = similar(x)

for _ in 1:maxiters
any(iszero, J) && (J = __init_identity_jacobian!!(J))
any(any(iszero, J)) && (J = __init_identity_jacobian!!(J))

@bb @. δx = fprev / J

Expand Down
11 changes: 7 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ __max_tdir(a, b, x0, x1) = ifelse(x1 > x0, max(a, b), min(a, b))
__standard_tag(::Nothing, f::F, x::AbstractArray{T}) where {F, T} = ForwardDiff.Tag(f, T)
__standard_tag(tag::ForwardDiff.Tag, f::F, x::AbstractArray{T}) where {F, T} = tag
__standard_tag(tag, f::F, x::AbstractArray{T}) where {F, T} = ForwardDiff.Tag(tag, T)

Check warning on line 24 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L23-L24

Added lines #L23 - L24 were not covered by tests
__standard_tag(::Nothing, f::F, x::T) where {F, T <: Number} = ForwardDiff.Tag(f, T)
__standard_tag(tag::ForwardDiff.Tag, f::F, x::T) where {F, T <: Number} = tag
__standard_tag(tag, f::F, x::T) where {F, T <: Number} = ForwardDiff.Tag(tag, T)

Check warning on line 27 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L26-L27

Added lines #L26 - L27 were not covered by tests

__pick_forwarddiff_chunk(x) = ForwardDiff.Chunk(length(x))
function __pick_forwarddiff_chunk(x::StaticArray)
Expand All @@ -40,7 +43,7 @@ function __get_jacobian_config(ad::AutoForwardDiff{CS}, f::F, x) where {F, CS}
end
function __get_jacobian_config(ad::AutoForwardDiff{CS}, f!::F, y, x) where {F, CS}
ck = (CS === nothing || CS 0) ? __pick_forwarddiff_chunk(x) : ForwardDiff.Chunk{CS}()
tag = __standard_tag(ad.tag, f, x)
tag = __standard_tag(ad.tag, f!, x)
return ForwardDiff.JacobianConfig(f!, y, x, ck, tag)
end

Expand Down Expand Up @@ -76,7 +79,7 @@ function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F,
return y, J
elseif ad isa AutoForwardDiff
res = DiffResults.DiffResult(y, J)
ForwardDiff.jacobian!(res, _f, y, x, cache)
ForwardDiff.jacobian!(res, _f, y, x, cache, Val(false))
return DiffResults.value(res), DiffResults.jacobian(res)
elseif ad isa AutoFiniteDiff
FiniteDiff.finite_difference_jacobian!(J, _f, x, cache)
Expand All @@ -95,10 +98,10 @@ function value_and_jacobian(ad, f::F, y, x::X, p, cache; J = nothing) where {F,
elseif ad isa AutoForwardDiff
if ArrayInterface.can_setindex(x)
res = DiffResults.DiffResult(y, J)
ForwardDiff.jacobian!(res, _f, x, cache)
ForwardDiff.jacobian!(res, _f, x, cache, Val(false))
return DiffResults.value(res), DiffResults.jacobian(res)
else
J_fd = ForwardDiff.jacobian(_f, x, cache)
J_fd = ForwardDiff.jacobian(_f, x, cache, Val(false))
return _f(x), J_fd
end
elseif ad isa AutoFiniteDiff
Expand Down

0 comments on commit ffdc60c

Please sign in to comment.