Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DifferentiateWith to translate between backends #218

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions DifferentiationInterface/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ check_twoarg
check_hessian
```

## Translation

```@docs
DifferentiateWith
```

## Internals

This is not part of the public API.
Expand Down
4 changes: 2 additions & 2 deletions DifferentiationInterface/docs/src/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ CollapsedDocStrings = true

```@setup backends
using DifferentiationInterface
using DifferentiationInterface: backend_string
using DifferentiationInterface: backend_str
import Markdown
import Diffractor, Enzyme, FastDifferentiation, FiniteDiff, FiniteDifferences, ForwardDiff, PolyesterForwardDiff, ReverseDiff, Tapir, Tracker, Zygote

Expand Down Expand Up @@ -37,7 +37,7 @@ println(io, "|:--------|:------------:|:----------------------:|:---------------

for example in backend_examples
b = eval(Meta.parse(example)) # backend
join(io, [backend_string(b), unicode_check_available(b), unicode_check_twoarg(b), unicode_check_hessian(b), "`$example`"], '|')
join(io, [backend_str(b), unicode_check_available(b), unicode_check_twoarg(b), unicode_check_hessian(b), "`$example`"], '|')
println(io, '|' )
end
backend_table = Markdown.parse(String(take!(io)))
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/docs/src/overloads.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Each cell can have three values:
```@setup overloads
using ADTypes: AbstractADType
using DifferentiationInterface
using DifferentiationInterface: backend_string, mutation_support, MutationSupported
using DifferentiationInterface: backend_str, mutation_support, MutationSupported
using Markdown: Markdown
using Diffractor: Diffractor
using Enzyme: Enzyme
Expand Down
6 changes: 6 additions & 0 deletions DifferentiationInterface/docs/src/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ We make this available for all backends with the following operators:
| :--------------------------------- | :---------------------------------- |
| [`value_and_pullback_split`](@ref) | [`value_and_pullback!_split`](@ref) |

## Translation

The wrapper [`DifferentiateWith`](@ref) allows you to take a function and specify that it should be differentiated with the backend of your choice.
In other words, when you try to differentiate `dw = DifferentiateWith(f, backend1)` with `backend2`, then `backend1` steps in and `backend2` does nothing.
At the moment it only works when `backend2` supports [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl).

## Going further

### Non-standard types
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@ module DifferentiationInterfaceChainRulesCoreExt

using ADTypes: ADTypes, AutoChainRules
using ChainRulesCore:
HasForwardsMode, HasReverseMode, NoTangent, RuleConfig, frule_via_ad, rrule_via_ad
ChainRulesCore,
HasForwardsMode,
HasReverseMode,
NoTangent,
RuleConfig,
frule_via_ad,
rrule_via_ad
import DifferentiationInterface as DI
using DifferentiationInterface: NoPullbackExtras, NoPushforwardExtras
using DifferentiationInterface: DifferentiateWith, NoPullbackExtras, NoPushforwardExtras

ruleconfig(backend::AutoChainRules) = backend.ruleconfig

Expand All @@ -14,32 +20,7 @@ const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}}
DI.check_available(::AutoChainRules) = true
DI.mutation_support(::AutoChainRules) = DI.MutationNotSupported()

## Pullback

DI.prepare_pullback(f, ::AutoReverseChainRules, x, dy) = NoPullbackExtras()

function DI.value_and_pullback_split(
f, backend::AutoReverseChainRules, x, ::NoPullbackExtras
)
rc = ruleconfig(backend)
y, pullback = rrule_via_ad(rc, f, x)
pullbackfunc(dy) = last(pullback(dy))
return y, pullbackfunc
end

function DI.value_and_pullback!_split(
f, backend::AutoReverseChainRules, x, extras::NoPullbackExtras
)
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy))
return y, pullbackfunc!
end

function DI.value_and_pullback(
f, backend::AutoReverseChainRules, x, dy, extras::NoPullbackExtras
)
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
return y, pullbackfunc(dy)
end
include("reverse_onearg.jl")
include("differentiate_with.jl")

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
function ChainRulesCore.frule((_, dx), dw::DifferentiateWith, x)
(; f, backend) = dw
y, dy = DI.value_and_pushforward(f, backend, x, dx)
return y, dy
end

function ChainRulesCore.rrule(dw::DifferentiateWith, x)
(; f, backend) = dw
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x)
pullbackfunc_adjusted(dy) = (NoTangent(), pullbackfunc(dy))
return y, pullbackfunc_adjusted
end
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
## Pullback

DI.prepare_pullback(f, ::AutoReverseChainRules, x, dy) = NoPullbackExtras()

function DI.value_and_pullback_split(
f, backend::AutoReverseChainRules, x, ::NoPullbackExtras
)
rc = ruleconfig(backend)
y, pullback = rrule_via_ad(rc, f, x)
pullbackfunc(dy) = last(pullback(dy))
return y, pullbackfunc
end

function DI.value_and_pullback!_split(
f, backend::AutoReverseChainRules, x, extras::NoPullbackExtras
)
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy))
return y, pullbackfunc!
end

function DI.value_and_pullback(
f, backend::AutoReverseChainRules, x, dy, extras::NoPullbackExtras
)
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras)
return y, pullbackfunc(dy)
end
4 changes: 4 additions & 0 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ include("sparse/fallbacks.jl")
include("sparse/jacobian.jl")
include("sparse/hessian.jl")

include("translation/differentiate_with.jl")

export SecondOrder

export value_and_pushforward!, value_and_pushforward
Expand Down Expand Up @@ -87,6 +89,8 @@ export prepare_second_derivative, prepare_hvp, prepare_hessian

export check_available, check_twoarg, check_hessian

export DifferentiateWith

# Re-export backends from ADTypes
export AutoChainRules
export AutoDiffractor
Expand Down
54 changes: 54 additions & 0 deletions DifferentiationInterface/src/translation/differentiate_with.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
DifferentiateWith

Callable function wrapper that enforces differentiation with a specified (inner) backend.

This works by defining new rules overriding the behavior of the outer backend that would normally be used.

!!! warning
This is an experimental functionality, whose API cannot yet be considered stable.
At the moment, it only supports one-argument functions, and rules are only defined for [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible outer backends.

# Fields

- `f`: the function in question
- `backend::AbstractADType`: the inner backend to use for differentiation

# Constructor

DifferentiateWith(f, backend)

# Example

```@repl
using DifferentiationInterface
import ForwardDiff, Zygote

function f(x)
a = Vector{eltype(x)}(undef, 1)
a[1] = sum(x) # mutation that breaks Zygote
return a[1]
end

dw = DifferentiateWith(f, AutoForwardDiff());

gradient(dw, AutoZygote(), [1.0, 2.0]) # works because it calls ForwardDiff instead
gradient(f, AutoZygote(), [1.0, 2.0]) # fails
```
"""
struct DifferentiateWith{F,B<:AbstractADType}
f::F
backend::B
end

"""
(dw::DifferentiateWith)(x)

Call the underlying function `dw.f` of a [`DifferentiateWith`](@ref) wrapper.
"""
(dw::DifferentiateWith)(x) = dw.f(x)

function Base.show(io::IO, dw::DifferentiateWith)
(; f, backend) = dw
return print(io, "$f differentiated with $(backend_str(backend))")
end
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/utils/exceptions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ struct MissingBackendError <: Exception
backend::AbstractADType
end
function Base.showerror(io::IO, e::MissingBackendError)
println(io, "failed to use $(backend_string(e.backend)) backend.")
println(io, "failed to use $(backend_str(e.backend)) backend.")
if !check_available(e.backend)
print(
io,
Expand Down
13 changes: 5 additions & 8 deletions DifferentiationInterface/src/utils/printing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@ backend_package_name(::AutoTracker) = "Tracker"
backend_package_name(::AutoZygote) = "Zygote"
backend_package_name(::AutoReverseDiff) = "ReverseDiff"

backend_string_aux(b::AbstractADType) = backend_package_name(b)
backend_string_aux(b::AutoReverseDiff) = "ReverseDiff$(b.compile ? "{compiled}" : "")"

function backend_string(backend::AbstractADType)
bs = backend_string_aux(backend)
function backend_str(backend::AbstractADType)
bs = backend_package_name(backend)
if mode(backend) isa ForwardMode
return "$bs (forward)"
elseif mode(backend) isa ReverseMode
Expand All @@ -33,8 +30,8 @@ function backend_string(backend::AbstractADType)
end
end

backend_string(backend::AutoSparse) = "Sparse $(backend_string(dense_ad(backend)))"
backend_str(backend::AutoSparse) = "Sparse $(backend_str(dense_ad(backend)))"

function backend_string(backend::SecondOrder)
return "$(backend_string(outer(backend))) / $(backend_string(inner(backend)))"
function backend_str(backend::SecondOrder)
return "$(backend_str(outer(backend))) / $(backend_str(inner(backend)))"
end
10 changes: 5 additions & 5 deletions DifferentiationInterface/test/chunk.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
using DifferentiationInterface: pick_chunksize, DEFAULT_CHUNKSIZE

@test pick_chunksize.(1:DEFAULT_CHUNKSIZE) == 1:DEFAULT_CHUNKSIZE
@test DI.pick_chunksize.(1:(DI.DEFAULT_CHUNKSIZE)) == 1:(DI.DEFAULT_CHUNKSIZE)
@test all(
pick_chunksize.((DEFAULT_CHUNKSIZE + 1):(5DEFAULT_CHUNKSIZE)) .<= DEFAULT_CHUNKSIZE
DI.pick_chunksize.((DI.DEFAULT_CHUNKSIZE + 1):(5DI.DEFAULT_CHUNKSIZE)) .<=
DI.DEFAULT_CHUNKSIZE,
)
@test all(
pick_chunksize.((DEFAULT_CHUNKSIZE + 1):(5DEFAULT_CHUNKSIZE)) .>= DEFAULT_CHUNKSIZE / 2
DI.pick_chunksize.((DI.DEFAULT_CHUNKSIZE + 1):(5DI.DEFAULT_CHUNKSIZE)) .>=
DI.DEFAULT_CHUNKSIZE / 2,
)
24 changes: 7 additions & 17 deletions DifferentiationInterface/test/coloring.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,16 @@
using ADTypes: column_coloring, row_coloring, symmetric_coloring
using DifferentiationInterface:
GreedyColoringAlgorithm,
check_structurally_orthogonal_columns,
check_structurally_orthogonal_rows,
check_symmetrically_structurally_orthogonal
using LinearAlgebra
using SparseArrays
using Test

alg = GreedyColoringAlgorithm()
alg = DI.GreedyColoringAlgorithm()

A = sprand(Bool, 100, 200, 0.1)

column_colors = column_coloring(A, alg)
@test check_structurally_orthogonal_columns(A, column_colors)
column_colors = ADTypes.column_coloring(A, alg)
@test DI.check_structurally_orthogonal_columns(A, column_colors)
@test maximum(column_colors) < size(A, 2) ÷ 2

row_colors = row_coloring(A, alg)
@test check_structurally_orthogonal_rows(A, row_colors)
row_colors = ADTypes.row_coloring(A, alg)
@test DI.check_structurally_orthogonal_rows(A, row_colors)
@test maximum(row_colors) < size(A, 1) ÷ 2

S = Symmetric(sprand(Bool, 100, 100, 0.1)) + I
symmetric_colors = symmetric_coloring(S, alg)
@test check_symmetrically_structurally_orthogonal(S, symmetric_colors)
symmetric_colors = ADTypes.symmetric_coloring(S, alg)
@test DI.check_symmetrically_structurally_orthogonal(S, symmetric_colors)
@test maximum(symmetric_colors) < size(A, 2) ÷ 2
23 changes: 23 additions & 0 deletions DifferentiationInterface/test/differentiate_with.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
function zygote_breaking_scenarios()
onearg_scens = filter(default_scenarios()) do scen
DIT.nb_args(scen) == 1
end
bad_onearg_scens = map(onearg_scens) do scen
function bad_f(x)
a = Vector{eltype(x)}(undef, 1)
a[1] = sum(x)
return scen.f(x)
end
wrapped_bad_f = DifferentiateWith(bad_f, AutoForwardDiff())
bad_scen = DIT.change_function(scen, wrapped_bad_f)
return bad_scen
end
return bad_onearg_scens
end

test_differentiation(
AutoZygote(),
zygote_breaking_scenarios();
second_order=false,
logging=logging = get(ENV, "CI", "false") == "false",
)
21 changes: 14 additions & 7 deletions DifferentiationInterface/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@ include("test_imports.jl")

Documenter.doctest(DifferentiationInterface)

@testset verbose = true "Exception handling" begin
include("test_exceptions.jl")
end
@testset verbose = true "First order" begin
include("first_order.jl")
end
Expand All @@ -34,14 +31,14 @@ include("test_imports.jl")
include("second_order.jl")
end

@testset verbose = true "Coloring" begin
include("coloring.jl")
end

@testset verbose = true "Sparsity" begin
include("sparsity.jl")
end

@testset verbose = true "DifferentiateWith" begin
include("differentiate_with.jl")
end

@testset verbose = true "Bonus round" begin
@testset "Type stability" begin
include("type_stability.jl")
Expand All @@ -50,9 +47,19 @@ include("test_imports.jl")
@testset "Weird arrays" begin
include("weird_arrays.jl")
end
end

@testset verbose = true "Internals" begin
@testset verbose = true "Exception handling" begin
include("test_exceptions.jl")
end

@testset "Chunks" begin
include("chunk.jl")
end

@testset verbose = true "Coloring" begin
include("coloring.jl")
end
end
end;
4 changes: 2 additions & 2 deletions DifferentiationInterface/test/sparsity.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
coloring_algorithm = DifferentiationInterface.GreedyColoringAlgorithm()
sparsity_detector = DifferentiationInterface.SymbolicsSparsityDetector()
coloring_algorithm = DI.GreedyColoringAlgorithm()
sparsity_detector = DI.SymbolicsSparsityDetector()

sparse_backends = [
AutoSparse(AutoFastDifferentiation()),
Expand Down
Loading
Loading