Skip to content

Commit

Permalink
Exploit more of FiniteDifferences (#114)
Browse files Browse the repository at this point in the history
* Exploit more of FiniteDifferences

* Remove exception
  • Loading branch information
gdalle authored Mar 27, 2024
1 parent 6155b2a commit 38fd496
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DifferentiationInterfaceFiniteDifferencesExt
using ADTypes: AutoFiniteDifferences
import DifferentiationInterface as DI
using FillArrays: OneElement
using FiniteDifferences: FiniteDifferences, jvp, j′vp
using FiniteDifferences: FiniteDifferences, grad, jacobian, jvp, j′vp
using LinearAlgebra: dot

DI.supports_mutation(::AutoFiniteDifferences) = DI.MutationNotSupported()
Expand All @@ -12,22 +12,62 @@ function FiniteDifferences.to_vec(a::OneElement) # TODO: remove type piracy (ht
return FiniteDifferences.to_vec(collect(a))
end

function DI.value_and_pushforward(
f, backend::AutoFiniteDifferences{fdm}, x, dx, extras::Nothing
) where {fdm}
y = f(x)
return y, jvp(backend.fdm, f, (x, dx))
## Pushforward

function DI.pushforward(f, backend::AutoFiniteDifferences, x, dx, extras::Nothing)
return jvp(backend.fdm, f, (x, dx))
end

function DI.value_and_pushforward(f, backend::AutoFiniteDifferences, x, dx, extras::Nothing)
return f(x), DI.pushforward(f, backend, x, dx, extras)
end

## Pullback

function DI.pullback(f, backend::AutoFiniteDifferences, x, dy, extras::Nothing)
return only(j′vp(backend.fdm, f, dy, x))
end

#=
# TODO: why does this fail?
function DI.value_and_pullback(f, backend::AutoFiniteDifferences, x, dy, extras::Nothing)
return f(x), DI.pullback(f, backend, x, dy, extras)
end

## Gradient

function DI.gradient(f, backend::AutoFiniteDifferences, x, extras::Nothing)
return only(grad(backend.fdm, f, x))
end

function DI.value_and_gradient(f, backend::AutoFiniteDifferences, x, extras::Nothing)
return f(x), DI.gradient(f, backend, x, extras)
end

function DI.gradient!!(f, grad, backend::AutoFiniteDifferences, x, extras::Nothing)
return DI.gradient(f, backend, x, extras)
end

function DI.value_and_gradient!!(
f, grad, backend::AutoFiniteDifferences, x, extras::Nothing
)
return DI.value_and_gradient(f, backend, x)
end

## Jacobian

function DI.jacobian(f, backend::AutoFiniteDifferences, x, extras::Nothing)
return only(jacobian(backend.fdm, f, x))
end

function DI.value_and_jacobian(f, backend::AutoFiniteDifferences, x, extras::Nothing)
return f(x), DI.jacobian(f, backend, x, extras)
end

function DI.jacobian!!(f, jac, backend::AutoFiniteDifferences, x, extras::Nothing)
return DI.jacobian(f, backend, x, extras)
end

function DI.value_and_pullback(
f, backend::AutoFiniteDifferences{fdm}, x, dy, extras::Nothing
) where {fdm}
y = f(x)
return y, j′vp(backend.fdm, f, x, dy)[1]
function DI.value_and_jacobian!!(f, jac, backend::AutoFiniteDifferences, x, extras::Nothing)
return DI.value_and_jacobian(f, backend, x)
end
=#

end
1 change: 0 additions & 1 deletion src/backends.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ function check_available(backend::AbstractADType)
return true
catch exception
@warn "Backend $backend not available" exception
throw(exception)
if exception isa MethodError
return false
else
Expand Down
4 changes: 2 additions & 2 deletions src/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
function value_and_gradient(
f, backend::AbstractADType, x, extras=prepare_gradient(f, backend, x)
)
return value_and_pullback(f, backend, x, true, extras)
return value_and_pullback(f, backend, x, one(eltype(x)), extras)
end

"""
Expand All @@ -15,7 +15,7 @@ end
function value_and_gradient!!(
f, grad, backend::AbstractADType, x, extras=prepare_gradient(f, backend, x)
)
return value_and_pullback!!(f, grad, backend, x, true, extras)
return value_and_pullback!!(f, grad, backend, x, one(eltype(x)), extras)
end

"""
Expand Down
43 changes: 23 additions & 20 deletions src/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,15 @@ end
value_and_jacobian!!(f, jac, backend, x, [extras]) -> (y, jac)
"""
function value_and_jacobian!!(
f,
jac::AbstractMatrix,
backend::AbstractADType,
x,
extras=prepare_jacobian(f, backend, x),
f, jac, backend::AbstractADType, x, extras=prepare_jacobian(f, backend, x)
)
return value_and_jacobian_aux!!(
f, jac, backend, x, extras, pushforward_performance(backend)
)
end

function value_and_jacobian_aux!!(
f, jac, backend, x::AbstractArray, extras, ::PushforwardFast
f, jac::AbstractMatrix, backend, x::AbstractArray, extras, ::PushforwardFast
)
y = f(x)
for (k, j) in enumerate(CartesianIndices(x))
Expand All @@ -59,7 +55,7 @@ function value_and_jacobian_aux!!(
end

function value_and_jacobian_aux!!(
f, jac, backend, x::AbstractArray, extras, ::PushforwardSlow
f, jac::AbstractMatrix, backend, x::AbstractArray, extras, ::PushforwardSlow
)
y = f(x)
for (k, i) in enumerate(CartesianIndices(y))
Expand All @@ -83,11 +79,7 @@ end
jacobian!!(f, jac, backend, x, [extras]) -> jac
"""
function jacobian!!(
f,
jac::AbstractMatrix,
backend::AbstractADType,
x,
extras=prepare_jacobian(f, backend, x),
f, jac, backend::AbstractADType, x, extras=prepare_jacobian(f, backend, x)
)
return value_and_jacobian!!(f, jac, backend, x, extras)[2]
end
Expand All @@ -98,19 +90,22 @@ end
value_and_jacobian!!(f!, y, jac, backend, x, [extras]) -> (y, jac)
"""
function value_and_jacobian!!(
f!,
y::AbstractArray,
jac::AbstractMatrix,
backend::AbstractADType,
x::AbstractArray,
extras=prepare_jacobian(f!, backend, y, x),
f!, y, jac, backend::AbstractADType, x, extras=prepare_jacobian(f!, backend, y, x)
)
return value_and_jacobian_aux!!(
f!, y, jac, backend, x, extras, pushforward_performance(backend)
)
end

function value_and_jacobian_aux!!(f!, y, jac, backend, x, extras, ::PushforwardFast)
function value_and_jacobian_aux!!(
f!,
y::AbstractArray,
jac::AbstractMatrix,
backend,
x::AbstractArray,
extras,
::PushforwardFast,
)
f!(y, x)
for (k, j) in enumerate(CartesianIndices(x))
dx_j = basis(backend, x, j)
Expand All @@ -124,7 +119,15 @@ function value_and_jacobian_aux!!(f!, y, jac, backend, x, extras, ::PushforwardF
return y, jac
end

function value_and_jacobian_aux!!(f!, y, jac, backend, x, extras, ::PushforwardSlow)
function value_and_jacobian_aux!!(
f!,
y::AbstractArray,
jac::AbstractMatrix,
backend,
x::AbstractArray,
extras,
::PushforwardSlow,
)
f!(y, x)
for (k, i) in enumerate(CartesianIndices(y))
dy_i = basis(backend, y, i)
Expand Down

0 comments on commit 38fd496

Please sign in to comment.