Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit aedccfb

Browse files
committed
Add Inplace tests
1 parent 908a98e commit aedccfb

File tree

6 files changed

+59
-9
lines changed

6 files changed

+59
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ LinearSolve = "7ed4a6bd-45f5-4d41-b270-4a48e9bafcae"
2121
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
2222

2323
[extensions]
24-
SimpleNonlinearSolveNNlibExt = "NNlib"
2524
SimpleNonlinearSolveADLinearSolveExt = ["AbstractDifferentiation", "LinearSolve"]
25+
SimpleNonlinearSolveNNlibExt = "NNlib"
2626

2727
[compat]
2828
AbstractDifferentiation = "0.5"

src/SimpleNonlinearSolve.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ include("batched/utils.jl")
4545
include("batched/raphson.jl")
4646
include("batched/dfsane.jl")
4747
include("batched/broyden.jl")
48-
include("batched/lbroyden.jl")
4948

5049
import PrecompileTools
5150

@@ -79,7 +78,6 @@ end
7978
# DiffEq styled algorithms
8079
export Bisection, Brent, Broyden, LBroyden, SimpleDFSane, Falsi, Halley, Klement,
8180
Ridder, SimpleNewtonRaphson, SimpleTrustRegion, Alefeld
82-
export BatchedBroyden, SimpleBatchedNewtonRaphson, SimpleBatchedDFSane,
83-
BatchedLBroyden
81+
export BatchedBroyden, SimpleBatchedNewtonRaphson, SimpleBatchedDFSane
8482

8583
end # module

src/batched/dfsane.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
Base.@kwdef struct SimpleBatchedDFSane{T, F, TC <: NLSolveTerminationCondition} <:
2-
AbstractBatchedNonlinearSolveAlgorithm
2+
AbstractBatchedNonlinearSolveAlgorithm
33
σₘᵢₙ::T = 1.0f-10
44
σₘₐₓ::T = 1.0f+10
55
σ₁::T = 1.0f0

src/batched/lbroyden.jl

Lines changed: 0 additions & 1 deletion
This file was deleted.

test/inplace.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using SimpleNonlinearSolve, StaticArrays, BenchmarkTools, DiffEqBase, LinearAlgebra, Test,
2+
NNlib, AbstractDifferentiation, LinearSolve
3+
4+
# Supported Solvers: BatchedBroyden, SimpleBatchedDFSane
5+
function f!(du::AbstractArray{<:Number, N},
6+
u::AbstractArray{<:Number, N},
7+
p::AbstractVector) where {N}
8+
u_ = reshape(u, :, size(u, N))
9+
du .= reshape(sum(abs2, u_; dims = 1) .- reshape(p, 1, :),
10+
ntuple(_ -> 1, N - 1)...,
11+
size(u, N))
12+
return du
13+
end
14+
15+
function f!(du::AbstractMatrix, u::AbstractMatrix, p::AbstractVector)
16+
du .= sum(abs2, u; dims = 1) .- reshape(p, 1, :)
17+
return du
18+
end
19+
20+
function f!(du::AbstractVector, u::AbstractVector, p::AbstractVector)
21+
du .= sum(abs2, u) .- p
22+
return du
23+
end
24+
25+
@testset "Solver: $(nameof(typeof(solver)))" for solver in (Broyden(batched = true),
26+
SimpleDFSane(batched = true))
27+
@testset "T: $T" for T in (Float32, Float64)
28+
p = rand(T, 5)
29+
@testset "size(u0): $sz" for sz in ((2, 5), (1, 5), (2, 3, 5))
30+
u0 = ones(T, sz)
31+
prob = NonlinearProblem{true}(f!, u0, p)
32+
33+
sol = solve(prob, solver)
34+
35+
@test SciMLBase.successful_retcode(sol.retcode)
36+
37+
@test sol.residzero(sol.resid) atol=5e-3
38+
end
39+
40+
p = rand(T, 1)
41+
@testset "size(u0): $sz" for sz in ((3,), (5,), (10,))
42+
u0 = ones(T, sz)
43+
prob = NonlinearProblem{true}(f!, u0, p)
44+
45+
sol = solve(prob, solver)
46+
47+
@test SciMLBase.successful_retcode(sol.retcode)
48+
49+
@test sol.residzero(sol.resid) atol=5e-3
50+
end
51+
end
52+
end

test/runtests.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1-
using Pkg
21
using SafeTestsets
3-
const LONGER_TESTS = false
42

53
const GROUP = get(ENV, "GROUP", "All")
6-
const is_APPVEYOR = Sys.iswindows() && haskey(ENV, "APPVEYOR")
74

85
@time begin
96
if GROUP == "All" || GROUP == "Core"
107
@time @safetestset "Basic Tests + Some AD" begin
118
include("basictests.jl")
129
end
10+
11+
@time @safetestset "Inplace Tests" begin
12+
include("inplace.jl")
13+
end
1314
end
1415
end

0 commit comments

Comments
 (0)