Skip to content

Commit

Permalink
Add Inplace tests
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 30, 2023
1 parent 908a98e commit aedccfb
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

[extensions]
SimpleNonlinearSolveNNlibExt = "NNlib"
SimpleNonlinearSolveADLinearSolveExt = ["AbstractDifferentiation", "LinearSolve"]
SimpleNonlinearSolveNNlibExt = "NNlib"

[compat]
AbstractDifferentiation = "0.5"
Expand Down
4 changes: 1 addition & 3 deletions src/SimpleNonlinearSolve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ include("batched/utils.jl")
include("batched/raphson.jl")
include("batched/dfsane.jl")
include("batched/broyden.jl")
include("batched/lbroyden.jl")

import PrecompileTools

Expand Down Expand Up @@ -79,7 +78,6 @@ end
# DiffEq styled algorithms
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement,
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld
export BatchedBroyden, SimpleBatchedNewtonRaphson, SimpleBatchedDFSane,
BatchedLBroyden
export BatchedBroyden, SimpleBatchedNewtonRaphson, SimpleBatchedDFSane

end # module
2 changes: 1 addition & 1 deletion src/batched/dfsane.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <:
AbstractBatchedNonlinearSolveAlgorithm
AbstractBatchedNonlinearSolveAlgorithm
σₘᵢₙ::T = 1.0f-10
σₘₐₓ::T = 1.0f+10
σ₁::T = 1.0f0
Expand Down
1 change: 0 additions & 1 deletion src/batched/lbroyden.jl

This file was deleted.

52 changes: 52 additions & 0 deletions test/inplace.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
using SimpleNonlinearSolve, StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test,
NNlib, AbstractDifferentiation, LinearSolve

# Supported Solvers: BatchedBroyden, SimpleBatchedDFSane
function f!(du::AbstractArray{<:Number, N},
u::AbstractArray{<:Number, N},
p::AbstractVector) where {N}
u_ = reshape(u, :, size(u, N))
du .= reshape(sum(abs2, u_; dims = 1) .- reshape(p, 1, :),
ntuple(_ -> 1, N - 1)...,
size(u, N))
return du
end

function f!(du::AbstractMatrix, u::AbstractMatrix, p::AbstractVector)
du .= sum(abs2, u; dims = 1) .- reshape(p, 1, :)
return du
end

function f!(du::AbstractVector, u::AbstractVector, p::AbstractVector)
du .= sum(abs2, u) .- p
return du
end

@testset "Solver: $(nameof(typeof(solver)))" for solver in (Broyden(batched = true),
SimpleDFSane(batched = true))
@testset "T: $T" for T in (Float32, Float64)
p = rand(T, 5)
@testset "size(u0): $sz" for sz in ((2, 5), (1, 5), (2, 3, 5))
u0 = ones(T, sz)
prob = NonlinearProblem{true}(f!, u0, p)

sol = solve(prob, solver)

@test SciMLBase.successful_retcode(sol.retcode)

@test sol.residzero(sol.resid) atol=5e-3
end

p = rand(T, 1)
@testset "size(u0): $sz" for sz in ((3,), (5,), (10,))
u0 = ones(T, sz)
prob = NonlinearProblem{true}(f!, u0, p)

sol = solve(prob, solver)

@test SciMLBase.successful_retcode(sol.retcode)

@test sol.residzero(sol.resid) atol=5e-3
end
end
end
7 changes: 4 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
using Pkg
using SafeTestsets
const LONGER_TESTS = false

const GROUP = get(ENV, "GROUP", "All")
const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR")

@time begin
if GROUP == "All" || GROUP == "Core"
@time @safetestset "Basic Tests + Some AD" begin
include("basictests.jl")
end

@time @safetestset "Inplace Tests" begin
include("inplace.jl")
end
end
end

0 comments on commit aedccfb

Please sign in to comment.