From ece6030a2f2981c874a61af50cd879671a39f496 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 May 2024 18:29:32 -0700 Subject: [PATCH] Remove the static arrays special casing --- Project.toml | 6 +++--- ext/SimpleNonlinearSolveStaticArraysExt.jl | 7 ------- src/SimpleNonlinearSolve.jl | 1 + src/nlsolve/dfsane.jl | 16 ++++++++++------ src/utils.jl | 2 +- test/core/rootfind_tests.jl | 6 +++--- 6 files changed, 18 insertions(+), 20 deletions(-) delete mode 100644 ext/SimpleNonlinearSolveStaticArraysExt.jl diff --git a/Project.toml b/Project.toml index a7aa31f..d46bb15 100644 --- a/Project.toml +++ b/Project.toml @@ -18,19 +18,18 @@ MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore" SimpleNonlinearSolveReverseDiffExt = "ReverseDiff" -SimpleNonlinearSolveStaticArraysExt = "StaticArrays" SimpleNonlinearSolveTrackerExt = "Tracker" SimpleNonlinearSolveZygoteExt = "Zygote" @@ -62,10 +61,11 @@ Reexport = "1.2" ReverseDiff = "1.15" SciMLBase = "2.37.0" SciMLSensitivity = "7.58" +Setfield = "1.1.1" StaticArrays = "1.9" StaticArraysCore = "1.4.2" Test = "1.10" -Tracker = "0.2.32" +Tracker = "0.2.33" Zygote = "0.6.69" julia = "1.10" diff --git a/ext/SimpleNonlinearSolveStaticArraysExt.jl b/ext/SimpleNonlinearSolveStaticArraysExt.jl deleted file mode 100644 index c865084..0000000 --- a/ext/SimpleNonlinearSolveStaticArraysExt.jl +++ /dev/null @@ -1,7 +0,0 @@ -module SimpleNonlinearSolveStaticArraysExt - -using SimpleNonlinearSolve: SimpleNonlinearSolve - -@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:StaticArrays}) = true - -end diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 04a32c9..eba6d99 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -24,6 +24,7 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val + using Setfield: @set! using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size end diff --git a/src/nlsolve/dfsane.jl b/src/nlsolve/dfsane.jl index 7dd1522..aa432e2 100644 --- a/src/nlsolve/dfsane.jl +++ b/src/nlsolve/dfsane.jl @@ -77,12 +77,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args... α_1 = one(T) f_1 = fx_norm - history_f_k = if x isa SArray || - (x isa Number && __is_extension_loaded(Val(:StaticArrays))) - ones(SVector{M, T}) * fx_norm - else - fill(fx_norm, M) - end + history_f_k = x isa SArray ? ones(SVector{M, T}) * fx_norm : + __history_vec(fx_norm, Val(M)) # Generate the cache @bb x_cache = similar(x) @@ -150,6 +146,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args... # Store function value if history_f_k isa SVector history_f_k = Base.setindex(history_f_k, fx_norm_new, mod1(k, M)) + elseif history_f_k isa NTuple + @set! history_f_k[mod1(k, M)] = fx_norm_new else history_f_k[mod1(k, M)] = fx_norm_new end @@ -158,3 +156,9 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args... return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters) end + +@inline @generated function __history_vec(fx_norm, ::Val{M}) where {M} + # Julia can't specialize here + M ≥ 11 && return :(fill(fx_norm, M)) + return :(ntuple(Returns(fx_norm), $(M))) +end diff --git a/src/utils.jl b/src/utils.jl index 0bf027f..81a41f3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -32,7 +32,7 @@ function value_and_jacobian( if isinplace(prob) if cache isa HasAnalyticJacobian - prob.f.jac(J, x, p) + prob.f.jac(J, x, prob.p) f(y, x) else DI.jacobian!(f, y, J, ad, x, cache) diff --git a/test/core/rootfind_tests.jl b/test/core/rootfind_tests.jl index 1ef0757..48d6fcb 100644 --- a/test/core/rootfind_tests.jl +++ b/test/core/rootfind_tests.jl @@ -59,7 +59,7 @@ end end end - @testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, + @testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) probN = NonlinearProblem(quadratic_f, u0, 2.0) @@ -79,7 +79,7 @@ end end end - @testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, + @testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) probN = NonlinearProblem(quadratic_f, u0, 2.0) @@ -104,7 +104,7 @@ end @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) end - @testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, + @testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) probN = NonlinearProblem(quadratic_f, u0, 2.0)