From aedccfb4bd54d62a2e245e91c685adc2affe3cf5 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 30 Jun 2023 15:15:54 -0400 Subject: [PATCH] Add Inplace tests --- Project.toml | 2 +- src/SimpleNonlinearSolve.jl | 4 +-- src/batched/dfsane.jl | 2 +- src/batched/lbroyden.jl | 1 - test/inplace.jl | 52 +++++++++++++++++++++++++++++++++++++ test/runtests.jl | 7 ++--- 6 files changed, 59 insertions(+), 9 deletions(-) delete mode 100644 src/batched/lbroyden.jl create mode 100644 test/inplace.jl diff --git a/Project.toml b/Project.toml index f1f472d..258f009 100644 --- a/Project.toml +++ b/Project.toml @@ -21,8 +21,8 @@ LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" [extensions] -SimpleNonlinearSolveNNlibExt = "NNlib" SimpleNonlinearSolveADLinearSolveExt = ["AbstractDifferentiation", "LinearSolve"] +SimpleNonlinearSolveNNlibExt = "NNlib" [compat] AbstractDifferentiation = "0.5" diff --git a/src/SimpleNonlinearSolve.jl b/src/SimpleNonlinearSolve.jl index 23b332f..cd48556 100644 --- a/src/SimpleNonlinearSolve.jl +++ b/src/SimpleNonlinearSolve.jl @@ -45,7 +45,6 @@ include("batched/utils.jl") include("batched/raphson.jl") include("batched/dfsane.jl") include("batched/broyden.jl") -include("batched/lbroyden.jl") import PrecompileTools @@ -79,7 +78,6 @@ end # DiffEq styled algorithms export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement, Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld -export BatchedBroyden, SimpleBatchedNewtonRaphson, SimpleBatchedDFSane, - BatchedLBroyden +export BatchedBroyden, SimpleBatchedNewtonRaphson, SimpleBatchedDFSane end # module diff --git a/src/batched/dfsane.jl b/src/batched/dfsane.jl index 09fc37f..a394517 100644 --- a/src/batched/dfsane.jl +++ b/src/batched/dfsane.jl @@ -1,5 +1,5 @@ Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <: - AbstractBatchedNonlinearSolveAlgorithm + AbstractBatchedNonlinearSolveAlgorithm σₘᵢₙ::T = 1.0f-10 σₘₐₓ::T = 1.0f+10 σ₁::T = 1.0f0 diff --git a/src/batched/lbroyden.jl b/src/batched/lbroyden.jl deleted file mode 100644 index 8b13789..0000000 --- a/src/batched/lbroyden.jl +++ /dev/null @@ -1 +0,0 @@ - diff --git a/test/inplace.jl b/test/inplace.jl new file mode 100644 index 0000000..4c43a1d --- /dev/null +++ b/test/inplace.jl @@ -0,0 +1,52 @@ +using SimpleNonlinearSolve, StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test, + NNlib, AbstractDifferentiation, LinearSolve + +# Supported Solvers: BatchedBroyden, SimpleBatchedDFSane +function f!(du::AbstractArray{<:Number, N}, + u::AbstractArray{<:Number, N}, + p::AbstractVector) where {N} + u_ = reshape(u, :, size(u, N)) + du .= reshape(sum(abs2, u_; dims = 1) .- reshape(p, 1, :), + ntuple(_ -> 1, N - 1)..., + size(u, N)) + return du +end + +function f!(du::AbstractMatrix, u::AbstractMatrix, p::AbstractVector) + du .= sum(abs2, u; dims = 1) .- reshape(p, 1, :) + return du +end + +function f!(du::AbstractVector, u::AbstractVector, p::AbstractVector) + du .= sum(abs2, u) .- p + return du +end + +@testset "Solver: $(nameof(typeof(solver)))" for solver in (Broyden(batched = true), + SimpleDFSane(batched = true)) + @testset "T: $T" for T in (Float32, Float64) + p = rand(T, 5) + @testset "size(u0): $sz" for sz in ((2, 5), (1, 5), (2, 3, 5)) + u0 = ones(T, sz) + prob = NonlinearProblem{true}(f!, u0, p) + + sol = solve(prob, solver) + + @test SciMLBase.successful_retcode(sol.retcode) + + @test sol.resid≈zero(sol.resid) atol=5e-3 + end + + p = rand(T, 1) + @testset "size(u0): $sz" for sz in ((3,), (5,), (10,)) + u0 = ones(T, sz) + prob = NonlinearProblem{true}(f!, u0, p) + + sol = solve(prob, solver) + + @test SciMLBase.successful_retcode(sol.retcode) + + @test sol.resid≈zero(sol.resid) atol=5e-3 + end + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 94a0086..bea57ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,14 +1,15 @@ -using Pkg using SafeTestsets -const LONGER_TESTS = false const GROUP = get(ENV, "GROUP", "All") -const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR") @time begin if GROUP == "All" || GROUP == "Core" @time @safetestset "Basic Tests + Some AD" begin include("basictests.jl") end + + @time @safetestset "Inplace Tests" begin + include("inplace.jl") + end end end