Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental ReverseDiff support #123

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
*playground.jl
playground.jl
scratchpad.jl
*.csv
*.png
*.pdf
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a"
[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
ImplicitDifferentiationChainRulesCoreExt = "ChainRulesCore"
ImplicitDifferentiationForwardDiffExt = "ForwardDiff"
ImplicitDifferentiationReverseDiffExt = ["ChainRulesCore", "ReverseDiff"]
ImplicitDifferentiationStaticArraysExt = "StaticArrays"
ImplicitDifferentiationZygoteExt = "Zygote"

Expand Down
23 changes: 22 additions & 1 deletion docs/src/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,35 @@

## Supported autodiff backends

### For the function itself

To differentiate an `ImplicitFunction`, the following backends are supported.

| Backend | Forward mode | Reverse mode |
| ---------------------------------------------------------------------- | ------------ | ------------ |
| [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) | yes | - |
| [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible | soon | yes |
| [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible | soon | yes |
| [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) | someday | someday |

If you want to use [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl), you cannot differentiate the callable `ImplicitFunction` object, which is why we provide a workaround with `call_implicit_function`.
In addition, you have to declare the custom chain rule to ReverseDiff.jl beforehand.
You can adjust the following syntax depending on the additional `args` and `kwargs` of your special case:

```julia
import ImplicitDifferentiation as ID

ReverseDiff.@grad_from_chainrules ID.call_implicit_function(
implicit::ImplicitFunction, x::TrackedArray, arg1, arg2; kwarg1, kwarg2
)

# this will fail
ReverseDiff.jacobian(implicit, x, arg1, arg2; kwarg1, kwarg2)
# this will work
ReverseDiff.jacobian(_x -> ID.call_implicit_function(implicit, _x, arg1, arg2; kwarg1, kwarg2), x)
```

### For the conditions

By default, the conditions are differentiated with the same backend as the `ImplicitFunction` that contains them.
However, this can be switched to any backend compatible with [AbstractDifferentiation.jl](https://github.com/JuliaDiff/AbstractDifferentiation.jl) (i.e. a subtype of `AD.AbstractBackend`).
You can specify it with the `conditions_backend` keyword argument when constructing an `ImplicitFunction`.
Expand Down
35 changes: 24 additions & 11 deletions ext/ImplicitDifferentiationChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
module ImplicitDifferentiationChainRulesCoreExt

using AbstractDifferentiation: AbstractBackend, ReverseRuleConfigBackend, ruleconfig
using ChainRulesCore: ChainRulesCore, NoTangent, RuleConfig
using ChainRulesCore: rrule, rrule_via_ad, unthunk, @not_implemented
using ImplicitDifferentiation: ImplicitDifferentiation
using ImplicitDifferentiation: ImplicitFunction
using ImplicitDifferentiation: conditions_reverse_operators
using ImplicitDifferentiation: get_output, presolve, solve
using ChainRulesCore:
ChainRulesCore, NoTangent, RuleConfig, rrule, rrule_via_ad, unthunk, @not_implemented
using ImplicitDifferentiation:
ImplicitDifferentiation,
ImplicitFunction,
call_implicit_function,
conditions_reverse_operators,
get_output,
presolve,
reverse_conditions_backend,
solve
using LinearAlgebra: mul!
using SimpleUnPack: @unpack

Expand All @@ -21,8 +26,13 @@ We compute the vector-Jacobian product `Jᵀv` by solving `Aᵀu = v` and settin
Positional and keyword arguments are passed to both `implicit.forward` and `implicit.conditions`.
"""
function ChainRulesCore.rrule(
rc::RuleConfig, implicit::ImplicitFunction, x::X, args...; kwargs...
) where {R,X<:AbstractArray{R}}
rc::RuleConfig,
::typeof(call_implicit_function),
implicit::ImplicitFunction,
x::AbstractArray,
args...;
kwargs...,
)
linear_solver = implicit.linear_solver
y_or_yz = implicit(x, args...; kwargs...)
backend = reverse_conditions_backend(rc, implicit)
Expand All @@ -39,13 +49,13 @@ function ChainRulesCore.rrule(
return y_or_yz, implicit_pullback
end

function reverse_conditions_backend(
function ImplicitDifferentiation.reverse_conditions_backend(
rc::RuleConfig, ::ImplicitFunction{F,C,L,Nothing}
) where {F,C,L}
return ReverseRuleConfigBackend(rc)
end

function reverse_conditions_backend(
function ImplicitDifferentiation.reverse_conditions_backend(
::RuleConfig, implicit::ImplicitFunction{F,C,L,<:AbstractBackend}
) where {F,C,L}
return implicit.conditions_backend
Expand Down Expand Up @@ -88,8 +98,11 @@ function apply_implicit_pullback(
dc_vec = solve(linear_solver, Aᵀ_vec, -dy_vec)
dx_vec = similar(x_vec)
mul!(dx_vec, Bᵀ_vec, dc_vec)
dcall = NoTangent()
dimplicit = NoTangent()
dx = reshape(dx_vec, x_size)
return (NoTangent(), dx, ntuple(unimplemented_tangent, nbargs)...)
dargs = ntuple(unimplemented_tangent, nbargs)
return (dcall, dimplicit, dx, dargs...)
end

end
21 changes: 14 additions & 7 deletions ext/ImplicitDifferentiationForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,22 @@ else
end

using AbstractDifferentiation: AbstractBackend, ForwardDiffBackend
using ImplicitDifferentiation: ImplicitFunction, DirectLinearSolver, IterativeLinearSolver
using ImplicitDifferentiation: conditions_forward_operators
using ImplicitDifferentiation: get_output, get_byproduct, presolve, solve
using ImplicitDifferentiation: identity_break_autodiff
using ImplicitDifferentiation:
ImplicitDifferentiation,
ImplicitFunction,
DirectLinearSolver,
IterativeLinearSolver,
conditions_forward_operators,
identity_break_autodiff,
get_output,
get_byproduct,
presolve,
solve
using LinearAlgebra: mul!
using PrecompileTools: @compile_workload

"""
implicit(x_and_dx::AbstractArray{<:Dual}, args...; kwargs...)
call_implicit_function(implicit, x_and_dx::AbstractArray{<:Dual}, args...; kwargs...)

Overload an [`ImplicitFunction`](@ref) on dual numbers to ensure compatibility with forward mode autodiff.

Expand All @@ -24,8 +31,8 @@ This is only available if ForwardDiff.jl is loaded (extension).
We compute the Jacobian-vector product `Jv` by solving `Au = -Bv` and setting `Jv = u`.
Positional and keyword arguments are passed to both `implicit.forward` and `implicit.conditions`.
"""
function (implicit::ImplicitFunction)(
x_and_dx::AbstractArray{Dual{T,R,N}}, args...; kwargs...
function ImplicitDifferentiation.call_implicit_function(
implicit::ImplicitFunction, x_and_dx::AbstractArray{Dual{T,R,N}}, args...; kwargs...
) where {T,R,N}
linear_solver = implicit.linear_solver

Expand Down
58 changes: 58 additions & 0 deletions ext/ImplicitDifferentiationReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
module ImplicitDifferentiationReverseDiffExt

@static if isdefined(Base, :get_extension)
using ReverseDiff: TrackedArray, TrackedReal, @grad_from_chainrules, jacobian
else
using ..ReverseDiff: TrackedArray, TrackedReal, @grad_from_chainrules, jacobian
end

using AbstractDifferentiation: ReverseDiffBackend
using ChainRulesCore: ChainRulesCore, RuleConfig, HasReverseMode, NoForwardsMode, rrule
using ImplicitDifferentiation:
ImplicitDifferentiation,
ImplicitFunction,
DirectLinearSolver,
IterativeLinearSolver,
call_implicit_function,
check_valid_output,
identity_break_autodiff
using LinearAlgebra: mul!
using PrecompileTools: @compile_workload

struct MyReverseDiffRuleConfig <: RuleConfig{Union{HasReverseMode,NoForwardsMode}} end

function ImplicitDifferentiation.reverse_conditions_backend(
::MyReverseDiffRuleConfig, ::ImplicitFunction{F,C,L,Nothing}
) where {F,C,L}
return ReverseDiffBackend()
end

function ChainRulesCore.rrule(
::typeof(call_implicit_function),
implicit::ImplicitFunction,
x::AbstractArray,
args...;
kwargs...,
)
# The macro ReverseDiff.@grad_from_chainrules calls ChainRulesCore.rrule without a ruleconfig
rc = MyReverseDiffRuleConfig()
return rrule(rc, call_implicit_function, implicit, x, args...; kwargs...)
end

@grad_from_chainrules ImplicitDifferentiation.call_implicit_function(
implicit::ImplicitFunction, x::TrackedArray
)

@compile_workload begin
forward(x) = sqrt.(identity_break_autodiff(x))
conditions(x, y) = y .^ 2 .- x
for linear_solver in (DirectLinearSolver(), IterativeLinearSolver())
implicit = ImplicitFunction(forward, conditions; linear_solver)
x = rand(2)
implicit(x)
# TODO: the following line kills Julia during precompilation, so fast that I cannot see the error
# jacobian(_x -> call_implicit_function(implicit, _x), x)
end
end

end
4 changes: 2 additions & 2 deletions ext/ImplicitDifferentiationZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ else
using ..Zygote: jacobian
end

using ImplicitDifferentiation: ImplicitFunction, identity_break_autodiff
using ImplicitDifferentiation: DirectLinearSolver, IterativeLinearSolver
using ImplicitDifferentiation:
ImplicitFunction, identity_break_autodiff, DirectLinearSolver, IterativeLinearSolver
using PrecompileTools: @compile_workload

@compile_workload begin
Expand Down
5 changes: 5 additions & 0 deletions src/ImplicitDifferentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ export AbstractLinearSolver, IterativeLinearSolver, DirectLinearSolver
@require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
include("../ext/ImplicitDifferentiationForwardDiffExt.jl")
end
@require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" begin
include("../ext/ImplicitDifferentiationReverseDiffExt.jl")
end
end
@require StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" begin
include("../ext/ImplicitDifferentiationStaticArraysExt.jl")
end
Expand Down
19 changes: 19 additions & 0 deletions src/implicit_function.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,24 @@ Return `implicit.forward(x, args...; kwargs...)`, which can be either an array `
This call is differentiable.
"""
function (implicit::ImplicitFunction)(x::AbstractArray, args...; kwargs...)
return call_implicit_function(implicit, x, args...; kwargs...)
end

"""
call_implicit_function(implicit::ImplicitFunction, x::AbstractArray, args...; kwargs...)

Under the hood, that is what `implicit(x, args...; kwargs...)` calls.

This is relevant for users who need to avoid callable structs, e.g. when working with ReverseDiff.jl.
"""
function call_implicit_function(
implicit::ImplicitFunction, x::AbstractArray, args...; kwargs...
)
y_or_yz = implicit.forward(x, args...; kwargs...)
return check_valid_output(y_or_yz)
end

function check_valid_output(y_or_yz)
valid = (
y_or_yz isa AbstractArray || #
(y_or_yz isa Tuple && length(y_or_yz) == 2 && y_or_yz[1] isa AbstractArray)
Expand All @@ -85,3 +102,5 @@ end
get_output(y::AbstractArray) = y
get_output(yz::Tuple) = yz[1]
get_byproduct(yz::Tuple) = yz[2]

function reverse_conditions_backend end
12 changes: 10 additions & 2 deletions test/errors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using ChainRulesCore
using ChainRulesTestUtils
using ForwardDiff
using ImplicitDifferentiation
using ImplicitDifferentiation: call_implicit_function
using Test
using Zygote

Expand Down Expand Up @@ -75,6 +76,13 @@ end
y, z = implicit(x)
dy = similar(y)
rc = Zygote.ZygoteRuleConfig()
test_rrule(rc, implicit, x; atol=1e-2, output_tangent=(dy, 0))
@test_skip test_rrule(rc, implicit, x; atol=1e-2, output_tangent=(dy, NoTangent()))
test_rrule(rc, call_implicit_function, implicit, x; atol=1e-2, output_tangent=(dy, 0))
@test_skip test_rrule(
rc,
call_implicit_function,
implicit,
x;
atol=1e-2,
output_tangent=(dy, NoTangent()),
)
end
Loading
Loading