Skip to content

Commit

Permalink
Unprepared operators for ForwardDiff (#414)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Aug 18, 2024
1 parent 9182912 commit 8077dce
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 9 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.5.13"
version = "0.5.14"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,30 @@ end

## Gradient

### Unprepared

function DI.value_and_gradient!(f::F, grad, ::AutoForwardDiff, x) where {F}
result = MutableDiffResult(zero(eltype(x)), (grad,))
result = gradient!(result, f, x)
return DiffResults.value(result), DiffResults.gradient(result)
end

function DI.value_and_gradient(f::F, ::AutoForwardDiff, x) where {F}
result = GradientResult(x)
result = gradient!(result, f, x)
return DiffResults.value(result), DiffResults.gradient(result)
end

function DI.gradient!(f::F, grad, ::AutoForwardDiff, x) where {F}
return gradient!(grad, f, x)
end

function DI.gradient(f::F, ::AutoForwardDiff, x) where {F}
return gradient(f, x)
end

### Prepared

struct ForwardDiffGradientExtras{C} <: GradientExtras
config::C
end
Expand Down Expand Up @@ -130,6 +154,29 @@ end

## Jacobian

### Unprepared

function DI.value_and_jacobian!(f::F, jac, ::AutoForwardDiff, x) where {F}
y = f(x)
result = MutableDiffResult(y, (jac,))
result = jacobian!(result, f, x)
return DiffResults.value(result), DiffResults.jacobian(result)
end

function DI.value_and_jacobian(f::F, ::AutoForwardDiff, x) where {F}
return f(x), jacobian(f, x)
end

function DI.jacobian!(f::F, jac, ::AutoForwardDiff, x) where {F}
return jacobian!(jac, f, x)
end

function DI.jacobian(f::F, ::AutoForwardDiff, x) where {F}
return jacobian(f, x)
end

### Prepared

struct ForwardDiffOneArgJacobianExtras{C} <: JacobianExtras
config::C
end
Expand Down Expand Up @@ -219,6 +266,34 @@ end

## Hessian

### Unprepared

function DI.hessian!(f::F, hess, ::AutoForwardDiff, x) where {F}
return hessian!(hess, f, x)
end

function DI.hessian(f::F, ::AutoForwardDiff, x) where {F}
return hessian(f, x)
end

function DI.value_gradient_and_hessian!(f::F, grad, hess, ::AutoForwardDiff, x) where {F}
result = MutableDiffResult(one(eltype(x)), (grad, hess))
result = hessian!(result, f, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
end

function DI.value_gradient_and_hessian(f::F, ::AutoForwardDiff, x) where {F}
result = HessianResult(x)
result = hessian!(result, f, x)
return (
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
)
end

### Prepared

struct ForwardDiffHessianExtras{C1,C2,C3} <: HessianExtras
array_config::C1
manual_result_config::C2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,30 @@ end

## Derivative

### Unprepared

function DI.value_and_derivative(f!::F, y, ::AutoForwardDiff, x) where {F}
result = MutableDiffResult(y, (similar(y),))
result = derivative!(result, f!, y, x)
return DiffResults.value(result), DiffResults.derivative(result)
end

function DI.value_and_derivative!(f!::F, y, der, ::AutoForwardDiff, x) where {F}
result = MutableDiffResult(y, (der,))
result = derivative!(result, f!, y, x)
return DiffResults.value(result), DiffResults.derivative(result)
end

function DI.derivative(f!::F, y, ::AutoForwardDiff, x) where {F}
return derivative(f!, y, x)
end

function DI.derivative!(f!::F, y, der, ::AutoForwardDiff, x) where {F}
return derivative!(der, f!, y, x)
end

### Prepared

struct ForwardDiffTwoArgDerivativeExtras{C} <: DerivativeExtras
config::C
end
Expand Down Expand Up @@ -133,19 +157,42 @@ end
function DI.derivative(
f!::F, y, ::AutoForwardDiff, x, extras::ForwardDiffTwoArgDerivativeExtras
) where {F}
der = derivative(f!, y, x, extras.config)
return der
return derivative(f!, y, x, extras.config)
end

function DI.derivative!(
f!::F, y, der, ::AutoForwardDiff, x, extras::ForwardDiffTwoArgDerivativeExtras
) where {F}
der = derivative!(der, f!, y, x, extras.config)
return der
return derivative!(der, f!, y, x, extras.config)
end

## Jacobian

### Unprepared

function DI.value_and_jacobian(f!::F, y, ::AutoForwardDiff, x) where {F}
jac = similar(y, length(y), length(x))
result = MutableDiffResult(y, (jac,))
result = jacobian!(result, f!, y, x)
return DiffResults.value(result), DiffResults.jacobian(result)
end

function DI.value_and_jacobian!(f!::F, y, jac, ::AutoForwardDiff, x) where {F}
result = MutableDiffResult(y, (jac,))
result = jacobian!(result, f!, y, x)
return DiffResults.value(result), DiffResults.jacobian(result)
end

function DI.jacobian(f!::F, y, ::AutoForwardDiff, x) where {F}
return jacobian(f!, y, x)
end

function DI.jacobian!(f!::F, y, jac, ::AutoForwardDiff, x) where {F}
return jacobian!(jac, f!, y, x)
end

### Prepared

struct ForwardDiffTwoArgJacobianExtras{C} <: JacobianExtras
config::C
end
Expand Down Expand Up @@ -176,13 +223,11 @@ end
function DI.jacobian(
f!::F, y, ::AutoForwardDiff, x, extras::ForwardDiffTwoArgJacobianExtras
) where {F}
jac = jacobian(f!, y, x, extras.config)
return jac
return jacobian(f!, y, x, extras.config)
end

function DI.jacobian!(
f!::F, y, jac, ::AutoForwardDiff, x, extras::ForwardDiffTwoArgJacobianExtras
) where {F}
jac = jacobian!(jac, f!, y, x, extras.config)
return jac
return jacobian!(jac, f!, y, x, extras.config)
end

0 comments on commit 8077dce

Please sign in to comment.