Skip to content

Commit

Permalink
Remove the static arrays special casing
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed May 26, 2024
1 parent ae0bf10 commit d4662f5
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 15 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,18 @@ MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore"
SimpleNonlinearSolveReverseDiffExt = "ReverseDiff"
SimpleNonlinearSolveStaticArraysExt = "StaticArrays"
SimpleNonlinearSolveTrackerExt = "Tracker"
SimpleNonlinearSolveZygoteExt = "Zygote"

Expand Down Expand Up @@ -62,6 +61,7 @@ Reexport = "1.2"
ReverseDiff = "1.15"
SciMLBase = "2.37.0"
SciMLSensitivity = "7.58"
Setfield = "1.1.1"
StaticArrays = "1.9"
StaticArraysCore = "1.4.2"
Test = "1.10"
Expand Down
7 changes: 0 additions & 7 deletions ext/SimpleNonlinearSolveStaticArraysExt.jl

This file was deleted.

1 change: 1 addition & 0 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
build_solution, isinplace, _unwrap_val
using Setfield: @set!
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
end

Expand Down
16 changes: 10 additions & 6 deletions src/nlsolve/dfsane.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
α_1 = one(T)
f_1 = fx_norm

history_f_k = if x isa SArray ||
(x isa Number && __is_extension_loaded(Val(:StaticArrays)))
ones(SVector{M, T}) * fx_norm
else
fill(fx_norm, M)
end
history_f_k = x isa SArray ? ones(SVector{M, T}) * fx_norm :
__history_vec(fx_norm, Val(M))

# Generate the cache
@bb x_cache = similar(x)
Expand Down Expand Up @@ -150,6 +146,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...
# Store function value
if history_f_k isa SVector
history_f_k = Base.setindex(history_f_k, fx_norm_new, mod1(k, M))
elseif history_f_k isa NTuple
@set! history_f_k[mod1(k, M)] = fx_norm_new
else
history_f_k[mod1(k, M)] = fx_norm_new
end
Expand All @@ -158,3 +156,9 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args...

return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
end

@inline @generated function __history_vec(fx_norm, ::Val{M}) where {M}

Check warning on line 160 in src/nlsolve/dfsane.jl

View check run for this annotation

Codecov / codecov/patch

src/nlsolve/dfsane.jl#L160

Added line #L160 was not covered by tests
# Julia can't specialize here
M 11 && return :(fill(fx_norm, M))
return :(ntuple(Returns(fx_norm), $(M)))
end

0 comments on commit d4662f5

Please sign in to comment.