Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 009dc15

Browse files
committed
Remove custom tag
1 parent ecbad27 commit 009dc15

File tree

2 files changed

+9
-15
lines changed

2 files changed

+9
-15
lines changed

src/nlsolve/raphson.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ function SciMLBase.__solve(prob::Union{NonlinearProblem, NonlinearLeastSquaresPr
3939
fx, dfx = value_and_jacobian(autodiff, prob.f, fx, x, prob.p, jac_cache; J)
4040

4141
if i == 1
42-
iszero(fx) && build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
42+
all(iszero(fx)) &&
43+
build_solution(prob, alg, x, fx; retcode = ReturnCode.Success)
4344
else
4445
# Termination Checks
4546
tc_sol = check_termination(tc_cache, fx, x, xo, prob, alg)

src/utils.jl

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,3 @@
1-
struct SimpleNonlinearSolveTag end
2-
3-
function ForwardDiff.checktag(::Type{<:ForwardDiff.Tag{<:SimpleNonlinearSolveTag, <:T}},
4-
f::F, x::AbstractArray{T}) where {T, F}
5-
return true
6-
end
7-
81
"""
92
__prevfloat_tdir(x, x0, x1)
103
@@ -26,9 +19,9 @@ Return the maximum of `a` and `b` if `x1 > x0`, otherwise return the minimum.
2619
"""
2720
__max_tdir(a, b, x0, x1) = ifelse(x1 > x0, max(a, b), min(a, b))
2821

29-
__standard_tag(::Nothing, x) = ForwardDiff.Tag(SimpleNonlinearSolveTag(), eltype(x))
30-
__standard_tag(tag::ForwardDiff.Tag, _) = tag
31-
__standard_tag(tag, x) = ForwardDiff.Tag(tag, eltype(x))
22+
__standard_tag(::Nothing, f::F, x::AbstractArray{T}) where {F, T} = ForwardDiff.Tag(f, T)
23+
__standard_tag(tag::ForwardDiff.Tag, f::F, x::AbstractArray{T}) where {F, T} = tag
24+
__standard_tag(tag, f::F, x::AbstractArray{T}) where {F, T} = ForwardDiff.Tag(tag, T)
3225

3326
__pick_forwarddiff_chunk(x) = ForwardDiff.Chunk(length(x))
3427
function __pick_forwarddiff_chunk(x::StaticArray)
@@ -42,12 +35,12 @@ end
4235

4336
function __get_jacobian_config(ad::AutoForwardDiff{CS}, f::F, x) where {F, CS}
4437
ck = (CS === nothing || CS 0) ? __pick_forwarddiff_chunk(x) : ForwardDiff.Chunk{CS}()
45-
tag = __standard_tag(ad.tag, x)
38+
tag = __standard_tag(ad.tag, f, x)
4639
return __forwarddiff_jacobian_config(f, x, ck, tag)
4740
end
4841
function __get_jacobian_config(ad::AutoForwardDiff{CS}, f!::F, y, x) where {F, CS}
4942
ck = (CS === nothing || CS 0) ? __pick_forwarddiff_chunk(x) : ForwardDiff.Chunk{CS}()
50-
tag = __standard_tag(ad.tag, x)
43+
tag = __standard_tag(ad.tag, f, x)
5144
return ForwardDiff.JacobianConfig(f!, y, x, ck, tag)
5245
end
5346

@@ -127,12 +120,12 @@ function value_and_jacobian(ad, f::F, y, x::Number, p, cache; J = nothing) where
127120
if DiffEqBase.has_jac(f)
128121
return f(x, p), f.jac(x, p)
129122
elseif ad isa AutoForwardDiff
130-
T = typeof(__standard_tag(ad.tag, x))
123+
T = typeof(__standard_tag(ad.tag, f, x))
131124
out = f(ForwardDiff.Dual{T}(x, one(x)), p)
132125
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
133126
elseif ad isa AutoPolyesterForwardDiff
134127
# Just use ForwardDiff
135-
T = typeof(__standard_tag(nothing, x))
128+
T = typeof(__standard_tag(nothing, f, x))
136129
out = f(ForwardDiff.Dual{T}(x, one(x)), p)
137130
return ForwardDiff.value(out), ForwardDiff.extract_derivative(T, out)
138131
elseif ad isa AutoFiniteDiff

0 commit comments

Comments
 (0)