Skip to content

Commit

Permalink
More lenient iterative linear solver (#119)
Browse files Browse the repository at this point in the history
* More lenient iterative linear solver

* Typo

* Fix bools

* Fix tests and add docs
  • Loading branch information
gdalle authored Sep 9, 2023
1 parent f58c705 commit 028d3fb
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 10 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ImplicitDifferentiation"
uuid = "57b37032-215b-411a-8a7c-41a003a55207"
authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"]
version = "0.5.0"
version = "0.5.1"

[deps]
AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
Expand Down
22 changes: 18 additions & 4 deletions src/linear_solver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,36 @@ abstract type AbstractLinearSolver end
"""
IterativeLinearSolver
An implementation of `AbstractLinearSolver` using `Krylov.gmres`.
An implementation of `AbstractLinearSolver` using `Krylov.gmres`, set as the default for `ImplicitFunction`.
# Fields
- `verbose::Bool`: Whether to throw a warning when the solver fails (defaults to `true`)
- `verbose::Bool`: Whether to display a warning when the solver fails and returns `NaN`s (defaults to `true`)
- `accept_inconsistent::Bool`: Whether to accept approximate least squares solutions for inconsistent systems, or fail and return `NaN`s (defaults to `false`)
!!! note
If you find that your implicit gradients contains `NaN`s, try using this solver with `accept_inconsistent=true`.
However, beware that the implicit function theorem does not cover the case of inconsistent linear systems `AJ = B`, so it is unclear what the result will mean.
"""
Base.@kwdef struct IterativeLinearSolver <: AbstractLinearSolver
verbose::Bool = true
accept_inconsistent::Bool = false
end

presolve(::IterativeLinearSolver, A, y) = A

function solve(sol::IterativeLinearSolver, A, b)
x, stats = gmres(A, b)
if !stats.solved || stats.inconsistent
sol.verbose && @warn "IterativeLinearSolver failed, result contains NaNs"
if sol.accept_inconsistent
success = stats.solved || stats.inconsistent
else
success = stats.solved && !stats.inconsistent
end
if !success
if sol.verbose
@warn "IterativeLinearSolver failed, result contains NaNs"
@show stats
end
x .= NaN
end
return x
Expand Down
19 changes: 14 additions & 5 deletions test/errors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ end
@testset verbose = true "Derivative NaNs" begin
x = zeros(Float32, 2)
linear_solvers = (
IterativeLinearSolver(; verbose=false), DirectLinearSolver(; verbose=false)
IterativeLinearSolver(; verbose=false), #
IterativeLinearSolver(; verbose=false, accept_inconsistent=true), #
DirectLinearSolver(; verbose=false), #
)
function should_give_nan(linear_solver)
return linear_solver isa DirectLinearSolver || !linear_solver.accept_inconsistent
end

@testset "Infinite derivative" begin
f = x -> sqrt.(x) # nondifferentiable at 0
Expand All @@ -37,8 +42,10 @@ end
implicit = ImplicitFunction(f, c; linear_solver)
J1 = ForwardDiff.jacobian(implicit, x)
J2 = Zygote.jacobian(implicit, x)[1]
@test all(isnan, J1) && eltype(J1) == Float32
@test all(isnan, J2) && eltype(J2) == Float32
@test all(isnan, J1) == should_give_nan(linear_solver)
@test all(isnan, J2) == should_give_nan(linear_solver)
@test eltype(J1) == Float32
@test eltype(J2) == Float32
end
end
end
Expand All @@ -51,8 +58,10 @@ end
implicit = ImplicitFunction(f, c; linear_solver)
J1 = ForwardDiff.jacobian(implicit, x)
J2 = Zygote.jacobian(implicit, x)[1]
@test all(isnan, J1) && eltype(J1) == Float32
@test all(isnan, J2) && eltype(J2) == Float32
@test all(isnan, J1) == should_give_nan(linear_solver)
@test all(isnan, J2) == should_give_nan(linear_solver)
@test eltype(J1) == Float32
@test eltype(J2) == Float32
end
end
end
Expand Down

0 comments on commit 028d3fb

Please sign in to comment.