From 38fd4969bb690e466317ee11b2226ea78d3c2ff5 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 27 Mar 2024 16:35:37 +0100 Subject: [PATCH] Exploit more of FiniteDifferences (#114) * Exploit more of FiniteDifferences * Remove exception --- ...rentiationInterfaceFiniteDifferencesExt.jl | 68 +++++++++++++++---- src/backends.jl | 1 - src/gradient.jl | 4 +- src/jacobian.jl | 43 ++++++------ 4 files changed, 79 insertions(+), 37 deletions(-) diff --git a/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl b/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl index 2132e35ed..bae8e7d95 100644 --- a/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl +++ b/ext/DifferentiationInterfaceFiniteDifferencesExt/DifferentiationInterfaceFiniteDifferencesExt.jl @@ -3,7 +3,7 @@ module DifferentiationInterfaceFiniteDifferencesExt using ADTypes: AutoFiniteDifferences import DifferentiationInterface as DI using FillArrays: OneElement -using FiniteDifferences: FiniteDifferences, jvp, j′vp +using FiniteDifferences: FiniteDifferences, grad, jacobian, jvp, j′vp using LinearAlgebra: dot DI.supports_mutation(::AutoFiniteDifferences) = DI.MutationNotSupported() @@ -12,22 +12,62 @@ function FiniteDifferences.to_vec(a::OneElement) # TODO: remove type piracy (ht return FiniteDifferences.to_vec(collect(a)) end -function DI.value_and_pushforward( - f, backend::AutoFiniteDifferences{fdm}, x, dx, extras::Nothing -) where {fdm} - y = f(x) - return y, jvp(backend.fdm, f, (x, dx)) +## Pushforward + +function DI.pushforward(f, backend::AutoFiniteDifferences, x, dx, extras::Nothing) + return jvp(backend.fdm, f, (x, dx)) +end + +function DI.value_and_pushforward(f, backend::AutoFiniteDifferences, x, dx, extras::Nothing) + return f(x), DI.pushforward(f, backend, x, dx, extras) +end + +## Pullback + +function DI.pullback(f, backend::AutoFiniteDifferences, x, dy, extras::Nothing) + return only(j′vp(backend.fdm, f, dy, x)) end -#= -# TODO: why does this fail? +function DI.value_and_pullback(f, backend::AutoFiniteDifferences, x, dy, extras::Nothing) + return f(x), DI.pullback(f, backend, x, dy, extras) +end + +## Gradient + +function DI.gradient(f, backend::AutoFiniteDifferences, x, extras::Nothing) + return only(grad(backend.fdm, f, x)) +end + +function DI.value_and_gradient(f, backend::AutoFiniteDifferences, x, extras::Nothing) + return f(x), DI.gradient(f, backend, x, extras) +end + +function DI.gradient!!(f, grad, backend::AutoFiniteDifferences, x, extras::Nothing) + return DI.gradient(f, backend, x, extras) +end + +function DI.value_and_gradient!!( + f, grad, backend::AutoFiniteDifferences, x, extras::Nothing +) + return DI.value_and_gradient(f, backend, x) +end + +## Jacobian + +function DI.jacobian(f, backend::AutoFiniteDifferences, x, extras::Nothing) + return only(jacobian(backend.fdm, f, x)) +end + +function DI.value_and_jacobian(f, backend::AutoFiniteDifferences, x, extras::Nothing) + return f(x), DI.jacobian(f, backend, x, extras) +end + +function DI.jacobian!!(f, jac, backend::AutoFiniteDifferences, x, extras::Nothing) + return DI.jacobian(f, backend, x, extras) +end -function DI.value_and_pullback( - f, backend::AutoFiniteDifferences{fdm}, x, dy, extras::Nothing -) where {fdm} - y = f(x) - return y, j′vp(backend.fdm, f, x, dy)[1] +function DI.value_and_jacobian!!(f, jac, backend::AutoFiniteDifferences, x, extras::Nothing) + return DI.value_and_jacobian(f, backend, x) end -=# end diff --git a/src/backends.jl b/src/backends.jl index 127101271..ee77d26b3 100644 --- a/src/backends.jl +++ b/src/backends.jl @@ -12,7 +12,6 @@ function check_available(backend::AbstractADType) return true catch exception @warn "Backend $backend not available" exception - throw(exception) if exception isa MethodError return false else diff --git a/src/gradient.jl b/src/gradient.jl index 7aad3959c..05e385c0f 100644 --- a/src/gradient.jl +++ b/src/gradient.jl @@ -6,7 +6,7 @@ function value_and_gradient( f, backend::AbstractADType, x, extras=prepare_gradient(f, backend, x) ) - return value_and_pullback(f, backend, x, true, extras) + return value_and_pullback(f, backend, x, one(eltype(x)), extras) end """ @@ -15,7 +15,7 @@ end function value_and_gradient!!( f, grad, backend::AbstractADType, x, extras=prepare_gradient(f, backend, x) ) - return value_and_pullback!!(f, grad, backend, x, true, extras) + return value_and_pullback!!(f, grad, backend, x, one(eltype(x)), extras) end """ diff --git a/src/jacobian.jl b/src/jacobian.jl index 3d924f46b..3e55cd410 100644 --- a/src/jacobian.jl +++ b/src/jacobian.jl @@ -33,11 +33,7 @@ end value_and_jacobian!!(f, jac, backend, x, [extras]) -> (y, jac) """ function value_and_jacobian!!( - f, - jac::AbstractMatrix, - backend::AbstractADType, - x, - extras=prepare_jacobian(f, backend, x), + f, jac, backend::AbstractADType, x, extras=prepare_jacobian(f, backend, x) ) return value_and_jacobian_aux!!( f, jac, backend, x, extras, pushforward_performance(backend) @@ -45,7 +41,7 @@ function value_and_jacobian!!( end function value_and_jacobian_aux!!( - f, jac, backend, x::AbstractArray, extras, ::PushforwardFast + f, jac::AbstractMatrix, backend, x::AbstractArray, extras, ::PushforwardFast ) y = f(x) for (k, j) in enumerate(CartesianIndices(x)) @@ -59,7 +55,7 @@ function value_and_jacobian_aux!!( end function value_and_jacobian_aux!!( - f, jac, backend, x::AbstractArray, extras, ::PushforwardSlow + f, jac::AbstractMatrix, backend, x::AbstractArray, extras, ::PushforwardSlow ) y = f(x) for (k, i) in enumerate(CartesianIndices(y)) @@ -83,11 +79,7 @@ end jacobian!!(f, jac, backend, x, [extras]) -> jac """ function jacobian!!( - f, - jac::AbstractMatrix, - backend::AbstractADType, - x, - extras=prepare_jacobian(f, backend, x), + f, jac, backend::AbstractADType, x, extras=prepare_jacobian(f, backend, x) ) return value_and_jacobian!!(f, jac, backend, x, extras)[2] end @@ -98,19 +90,22 @@ end value_and_jacobian!!(f!, y, jac, backend, x, [extras]) -> (y, jac) """ function value_and_jacobian!!( - f!, - y::AbstractArray, - jac::AbstractMatrix, - backend::AbstractADType, - x::AbstractArray, - extras=prepare_jacobian(f!, backend, y, x), + f!, y, jac, backend::AbstractADType, x, extras=prepare_jacobian(f!, backend, y, x) ) return value_and_jacobian_aux!!( f!, y, jac, backend, x, extras, pushforward_performance(backend) ) end -function value_and_jacobian_aux!!(f!, y, jac, backend, x, extras, ::PushforwardFast) +function value_and_jacobian_aux!!( + f!, + y::AbstractArray, + jac::AbstractMatrix, + backend, + x::AbstractArray, + extras, + ::PushforwardFast, +) f!(y, x) for (k, j) in enumerate(CartesianIndices(x)) dx_j = basis(backend, x, j) @@ -124,7 +119,15 @@ function value_and_jacobian_aux!!(f!, y, jac, backend, x, extras, ::PushforwardF return y, jac end -function value_and_jacobian_aux!!(f!, y, jac, backend, x, extras, ::PushforwardSlow) +function value_and_jacobian_aux!!( + f!, + y::AbstractArray, + jac::AbstractMatrix, + backend, + x::AbstractArray, + extras, + ::PushforwardSlow, +) f!(y, x) for (k, i) in enumerate(CartesianIndices(y)) dy_i = basis(backend, y, i)