diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index 5f88f4058..80439ad8a 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -41,6 +41,7 @@ include("utils/basis.jl") include("utils/printing.jl") include("utils/chunk.jl") include("utils/check.jl") +include("utils/exceptions.jl") include("first_order/pushforward.jl") include("first_order/pullback.jl") diff --git a/DifferentiationInterface/src/first_order/pullback.jl b/DifferentiationInterface/src/first_order/pullback.jl index add2dca62..b2a3cff92 100644 --- a/DifferentiationInterface/src/first_order/pullback.jl +++ b/DifferentiationInterface/src/first_order/pullback.jl @@ -84,6 +84,10 @@ function prepare_pullback_aux(f!, y, backend, x, dy, ::PullbackSlow) return PushforwardPullbackExtras(pushforward_extras) end +# Throw error if backend is missing +prepare_pullback_aux(f, backend, x, dy, ::PullbackFast) = throw(MissingBackendError(backend)) +prepare_pullback_aux(f!, y, backend, x, dy, ::PullbackFast) = throw(MissingBackendError(backend)) + ## One argument function value_and_pullback( diff --git a/DifferentiationInterface/src/first_order/pushforward.jl b/DifferentiationInterface/src/first_order/pushforward.jl index 4b99aed77..f600f533e 100644 --- a/DifferentiationInterface/src/first_order/pushforward.jl +++ b/DifferentiationInterface/src/first_order/pushforward.jl @@ -73,6 +73,10 @@ function prepare_pushforward_aux(f!, y, backend, x, dx, ::PushforwardSlow) return PullbackPushforwardExtras(pullback_extras) end +# Throw error if backend is missing +prepare_pushforward_aux(f, backend, x, dy, ::PushforwardFast) = throw(MissingBackendError(backend)) +prepare_pushforward_aux(f!, y, backend, x, dy, ::PushforwardFast) = throw(MissingBackendError(backend)) + ## One argument function value_and_pushforward( diff --git a/DifferentiationInterface/src/utils/exceptions.jl b/DifferentiationInterface/src/utils/exceptions.jl new file mode 100644 index 000000000..c6509e78f --- /dev/null +++ b/DifferentiationInterface/src/utils/exceptions.jl @@ -0,0 +1,20 @@ +struct MissingBackendError <: Exception + backend::AbstractADType +end +function Base.showerror(io::IO, e::MissingBackendError) + println(io, "failed to use $(backend_string(e.backend)) backend.") + if !check_available(e.backend) + print( + io, + """Backend package is not loaded. To fix, run + + using $(backend_package_name(e.backend)) + """, + ) + else + print( + io, + "Please open an issue: https://github.com/gdalle/DifferentiationInterface.jl/issues/new", + ) + end +end diff --git a/DifferentiationInterface/src/utils/printing.jl b/DifferentiationInterface/src/utils/printing.jl index 8502224df..f5bd7ad79 100644 --- a/DifferentiationInterface/src/utils/printing.jl +++ b/DifferentiationInterface/src/utils/printing.jl @@ -1,18 +1,21 @@ -backend_string_aux(b::AbstractADType) = string(b) +backend_package_name(b::AbstractADType) = strip(string(b), ['(', ')']) -backend_string_aux(::AutoChainRules) = "ChainRules" -backend_string_aux(::AutoDiffractor) = "Diffractor" -backend_string_aux(::AutoEnzyme) = "Enzyme" -backend_string_aux(::AutoFastDifferentiation) = "FastDifferentiation" -backend_string_aux(::AutoFiniteDiff) = "FiniteDiff" -backend_string_aux(::AutoFiniteDifferences) = "FiniteDifferences" -backend_string_aux(::AutoForwardDiff) = "ForwardDiff" -backend_string_aux(::AutoPolyesterForwardDiff) = "PolyesterForwardDiff" +backend_package_name(::AutoChainRules) = "ChainRules" +backend_package_name(::AutoDiffractor) = "Diffractor" +backend_package_name(::AutoEnzyme) = "Enzyme" +backend_package_name(::AutoFastDifferentiation) = "FastDifferentiation" +backend_package_name(::AutoFiniteDiff) = "FiniteDiff" +backend_package_name(::AutoFiniteDifferences) = "FiniteDifferences" +backend_package_name(::AutoForwardDiff) = "ForwardDiff" +backend_package_name(::AutoPolyesterForwardDiff) = "PolyesterForwardDiff" +backend_package_name(::AutoSymbolics) = "Symbolics" +backend_package_name(::AutoTapir) = "Tapir" +backend_package_name(::AutoTracker) = "Tracker" +backend_package_name(::AutoZygote) = "Zygote" +backend_package_name(::AutoReverseDiff) = "ReverseDiff" + +backend_string_aux(b::AbstractADType) = backend_package_name(b) backend_string_aux(b::AutoReverseDiff) = "ReverseDiff$(b.compile ? "{compiled}" : "")" -backend_string_aux(::AutoSymbolics) = "Symbolics" -backend_string_aux(::AutoTapir) = "Tapir" -backend_string_aux(::AutoTracker) = "Tracker" -backend_string_aux(::AutoZygote) = "Zygote" function backend_string(backend::AbstractADType) bs = backend_string_aux(backend) diff --git a/DifferentiationInterface/test/runtests.jl b/DifferentiationInterface/test/runtests.jl index a35ed0817..03eeafef4 100644 --- a/DifferentiationInterface/test/runtests.jl +++ b/DifferentiationInterface/test/runtests.jl @@ -23,6 +23,9 @@ include("test_imports.jl") Documenter.doctest(DifferentiationInterface) + @testset verbose = true "Exception handling" begin + include("test_exceptions.jl") + end @testset verbose = true "First order" begin include("first_order.jl") end diff --git a/DifferentiationInterface/test/test_exceptions.jl b/DifferentiationInterface/test/test_exceptions.jl new file mode 100644 index 000000000..5ac9cea60 --- /dev/null +++ b/DifferentiationInterface/test/test_exceptions.jl @@ -0,0 +1,33 @@ +using DifferentiationInterface: MissingBackendError + +""" + AutoBrokenForward <: ADTypes.AbstractADType + +Available forward-mode backend with no pushforward implementation. +Used to test error messages. +""" +struct AutoBrokenForward <: AbstractADType end +ADTypes.mode(::AutoBrokenForward) = ADTypes.ForwardMode() +DifferentiationInterface.check_available(::AutoBrokenForward) = true + +""" + AutoBrokenReverse <: ADTypes.AbstractADType + +Available reverse-mode backend with no pullback implementation. +Used to test error messages. +""" +struct AutoBrokenReverse <: AbstractADType end +ADTypes.mode(::AutoBrokenReverse) = ADTypes.ReverseMode() +DifferentiationInterface.check_available(::AutoBrokenReverse) = true + +## Test exceptions +@testset "MissingBackendError" begin + f(x::AbstractArray) = sum(abs2, x) + x = [1.0, 2.0, 3.0] + + @test_throws MissingBackendError gradient(f, AutoBrokenForward(), x) + @test_throws MissingBackendError gradient(f, AutoBrokenReverse(), x) + + @test_throws MissingBackendError hvp(f, AutoBrokenForward(), x, x) + @test_throws MissingBackendError hvp(f, AutoBrokenReverse(), x, x) +end