From f111bd34381680f7166076ef6dc1962a5669f48f Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 25 May 2024 18:29:32 -0700 Subject: [PATCH] Remove the static arrays special casing --- Project.toml | 10 +++++----- ext/SimpleNonlinearSolveStaticArraysExt.jl | 7 ------- src/SimpleNonlinearSolve.jl | 1 + src/ad.jl | 6 ++++-- src/nlsolve/dfsane.jl | 16 ++++++++++------ src/utils.jl | 2 +- test/core/rootfind_tests.jl | 8 ++++---- 7 files changed, 25 insertions(+), 25 deletions(-) delete mode 100644 ext/SimpleNonlinearSolveStaticArraysExt.jl diff --git a/Project.toml b/Project.toml index a7aa31f..1d9cc13 100644 --- a/Project.toml +++ b/Project.toml @@ -18,19 +18,18 @@ MaybeInplace = "bb5d69b7-63fc-4a16-80bd-7e42200c7bdb" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" [weakdeps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] SimpleNonlinearSolveChainRulesCoreExt = "ChainRulesCore" SimpleNonlinearSolveReverseDiffExt = "ReverseDiff" -SimpleNonlinearSolveStaticArraysExt = "StaticArrays" SimpleNonlinearSolveTrackerExt = "Tracker" SimpleNonlinearSolveZygoteExt = "Zygote" @@ -40,7 +39,7 @@ AllocCheck = "0.1.1" Aqua = "0.8" ArrayInterface = "7.9" CUDA = "5.2" -ChainRulesCore = "1.22" +ChainRulesCore = "1.23" ConcreteStructs = "0.2.3" DiffEqBase = "6.149" DiffResults = "1.1" @@ -59,13 +58,14 @@ PrecompileTools = "1.2" Random = "1.10" ReTestItems = "1.23" Reexport = "1.2" -ReverseDiff = "1.15" +ReverseDiff = "1.15.3" SciMLBase = "2.37.0" SciMLSensitivity = "7.58" +Setfield = "1.1.1" StaticArrays = "1.9" StaticArraysCore = "1.4.2" Test = "1.10" -Tracker = "0.2.32" +Tracker = "0.2.33" Zygote = "0.6.69" julia = "1.10" diff --git a/ext/SimpleNonlinearSolveStaticArraysExt.jl b/ext/SimpleNonlinearSolveStaticArraysExt.jl deleted file mode 100644 index c865084..0000000 --- a/ext/SimpleNonlinearSolveStaticArraysExt.jl +++ /dev/null @@ -1,7 +0,0 @@ -module SimpleNonlinearSolveStaticArraysExt - -using SimpleNonlinearSolve: SimpleNonlinearSolve - -@inline SimpleNonlinearSolve.__is_extension_loaded(::Val{:StaticArrays}) = true - -end diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 04a32c9..eba6d99 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -24,6 +24,7 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm, build_solution, isinplace, _unwrap_val + using Setfield: @set! using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size end diff --git a/src/ad.jl b/src/ad.jl index ffae7cc..bb5afea 100644 --- a/src/ad.jl +++ b/src/ad.jl @@ -109,10 +109,12 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs.. end else # For small problems, nesting ForwardDiff is actually quite fast - _f = Base.Fix2(prob.f, newprob.p) if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) ≥ 50) # TODO: Remove once DI has the value_and_pullback_split defined - _F = @closure (u, p) -> __zygote_compute_nlls_vjp(_f, u, p) + _F = @closure (u, p) -> begin + _f = Base.Fix2(prob.f, p) + return __zygote_compute_nlls_vjp(_f, u, p) + end else _F = @closure (u, p) -> begin _f = Base.Fix2(prob.f, p) diff --git a/src/nlsolve/dfsane.jl b/src/nlsolve/dfsane.jl index 7dd1522..aa432e2 100644 --- a/src/nlsolve/dfsane.jl +++ b/src/nlsolve/dfsane.jl @@ -77,12 +77,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args... α_1 = one(T) f_1 = fx_norm - history_f_k = if x isa SArray || - (x isa Number && __is_extension_loaded(Val(:StaticArrays))) - ones(SVector{M, T}) * fx_norm - else - fill(fx_norm, M) - end + history_f_k = x isa SArray ? ones(SVector{M, T}) * fx_norm : + __history_vec(fx_norm, Val(M)) # Generate the cache @bb x_cache = similar(x) @@ -150,6 +146,8 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args... # Store function value if history_f_k isa SVector history_f_k = Base.setindex(history_f_k, fx_norm_new, mod1(k, M)) + elseif history_f_k isa NTuple + @set! history_f_k[mod1(k, M)] = fx_norm_new else history_f_k[mod1(k, M)] = fx_norm_new end @@ -158,3 +156,9 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleDFSane{M}, args... return build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters) end + +@inline @generated function __history_vec(fx_norm, ::Val{M}) where {M} + # Julia can't specialize here + M ≥ 11 && return :(fill(fx_norm, M)) + return :(ntuple(Returns(fx_norm), $(M))) +end diff --git a/src/utils.jl b/src/utils.jl index 0bf027f..81a41f3 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -32,7 +32,7 @@ function value_and_jacobian( if isinplace(prob) if cache isa HasAnalyticJacobian - prob.f.jac(J, x, p) + prob.f.jac(J, x, prob.p) f(y, x) else DI.jacobian!(f, y, J, ad, x, cache) diff --git a/test/core/rootfind_tests.jl b/test/core/rootfind_tests.jl index 1ef0757..f317058 100644 --- a/test/core/rootfind_tests.jl +++ b/test/core/rootfind_tests.jl @@ -42,7 +42,7 @@ end SimpleTrustRegion, (args...; kwargs...) -> SimpleTrustRegion( args...; nlsolve_update_rule = Val(true), kwargs...)) - @testset "AutoDiff: $(nameof(typeof(autodiff))))" for autodiff in ( + @testset "AutoDiff: $(nameof(typeof(autodiff)))" for autodiff in ( AutoFiniteDiff(), AutoForwardDiff(), AutoPolyesterForwardDiff()) @testset "[OOP] u0: $(typeof(u0))" for u0 in ( [1.0, 1.0], @SVector[1.0, 1.0], 1.0) @@ -59,7 +59,7 @@ end end end - @testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, + @testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) probN = NonlinearProblem(quadratic_f, u0, 2.0) @@ -79,7 +79,7 @@ end end end - @testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, + @testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) probN = NonlinearProblem(quadratic_f, u0, 2.0) @@ -104,7 +104,7 @@ end @test all(abs.(sol.u .* sol.u .- 2) .< 1e-9) end - @testset "Termination condition: $(termination_condition) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, + @testset "Termination condition: $(nameof(typeof(termination_condition))) u0: $(nameof(typeof(u0)))" for termination_condition in TERMINATION_CONDITIONS, u0 in (1.0, [1.0, 1.0], @SVector[1.0, 1.0]) probN = NonlinearProblem(quadratic_f, u0, 2.0)