diff --git a/Project.toml b/Project.toml index e9e3b43..75f3d5e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ImplicitDifferentiation" uuid = "57b37032-215b-411a-8a7c-41a003a55207" authors = ["Guillaume Dalle", "Mohamed Tarek and contributors"] -version = "0.5.0" +version = "0.5.1" [deps] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" diff --git a/src/linear_solver.jl b/src/linear_solver.jl index 2bc418a..fd39398 100644 --- a/src/linear_solver.jl +++ b/src/linear_solver.jl @@ -16,22 +16,36 @@ abstract type AbstractLinearSolver end """ IterativeLinearSolver -An implementation of `AbstractLinearSolver` using `Krylov.gmres`. +An implementation of `AbstractLinearSolver` using `Krylov.gmres`, set as the default for `ImplicitFunction`. # Fields -- `verbose::Bool`: Whether to throw a warning when the solver fails (defaults to `true`) +- `verbose::Bool`: Whether to display a warning when the solver fails and returns `NaN`s (defaults to `true`) +- `accept_inconsistent::Bool`: Whether to accept approximate least squares solutions for inconsistent systems, or fail and return `NaN`s (defaults to `false`) + +!!! note + If you find that your implicit gradients contains `NaN`s, try using this solver with `accept_inconsistent=true`. + However, beware that the implicit function theorem does not cover the case of inconsistent linear systems `AJ = B`, so it is unclear what the result will mean. """ Base.@kwdef struct IterativeLinearSolver <: AbstractLinearSolver verbose::Bool = true + accept_inconsistent::Bool = false end presolve(::IterativeLinearSolver, A, y) = A function solve(sol::IterativeLinearSolver, A, b) x, stats = gmres(A, b) - if !stats.solved || stats.inconsistent - sol.verbose && @warn "IterativeLinearSolver failed, result contains NaNs" + if sol.accept_inconsistent + success = stats.solved || stats.inconsistent + else + success = stats.solved && !stats.inconsistent + end + if !success + if sol.verbose + @warn "IterativeLinearSolver failed, result contains NaNs" + @show stats + end x .= NaN end return x diff --git a/test/errors.jl b/test/errors.jl index b0040bc..0a729f4 100644 --- a/test/errors.jl +++ b/test/errors.jl @@ -26,8 +26,13 @@ end @testset verbose = true "Derivative NaNs" begin x = zeros(Float32, 2) linear_solvers = ( - IterativeLinearSolver(; verbose=false), DirectLinearSolver(; verbose=false) + IterativeLinearSolver(; verbose=false), # + IterativeLinearSolver(; verbose=false, accept_inconsistent=true), # + DirectLinearSolver(; verbose=false), # ) + function should_give_nan(linear_solver) + return linear_solver isa DirectLinearSolver || !linear_solver.accept_inconsistent + end @testset "Infinite derivative" begin f = x -> sqrt.(x) # nondifferentiable at 0 @@ -37,8 +42,10 @@ end implicit = ImplicitFunction(f, c; linear_solver) J1 = ForwardDiff.jacobian(implicit, x) J2 = Zygote.jacobian(implicit, x)[1] - @test all(isnan, J1) && eltype(J1) == Float32 - @test all(isnan, J2) && eltype(J2) == Float32 + @test all(isnan, J1) == should_give_nan(linear_solver) + @test all(isnan, J2) == should_give_nan(linear_solver) + @test eltype(J1) == Float32 + @test eltype(J2) == Float32 end end end @@ -51,8 +58,10 @@ end implicit = ImplicitFunction(f, c; linear_solver) J1 = ForwardDiff.jacobian(implicit, x) J2 = Zygote.jacobian(implicit, x)[1] - @test all(isnan, J1) && eltype(J1) == Float32 - @test all(isnan, J2) && eltype(J2) == Float32 + @test all(isnan, J1) == should_give_nan(linear_solver) + @test all(isnan, J2) == should_give_nan(linear_solver) + @test eltype(J1) == Float32 + @test eltype(J2) == Float32 end end end