-
Notifications
You must be signed in to change notification settings - Fork 13
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement DifferentiateWith to translate between backends (#218)
* Implement DifferentiateWith to translate between backends * Fix parsing
- Loading branch information
Showing
22 changed files
with
276 additions
and
121 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
12 changes: 12 additions & 0 deletions
12
DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
function ChainRulesCore.frule((_, dx), dw::DifferentiateWith, x) | ||
(; f, backend) = dw | ||
y, dy = DI.value_and_pushforward(f, backend, x, dx) | ||
return y, dy | ||
end | ||
|
||
function ChainRulesCore.rrule(dw::DifferentiateWith, x) | ||
(; f, backend) = dw | ||
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x) | ||
pullbackfunc_adjusted(dy) = (NoTangent(), pullbackfunc(dy)) | ||
return y, pullbackfunc_adjusted | ||
end |
27 changes: 27 additions & 0 deletions
27
DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
## Pullback | ||
|
||
DI.prepare_pullback(f, ::AutoReverseChainRules, x, dy) = NoPullbackExtras() | ||
|
||
function DI.value_and_pullback_split( | ||
f, backend::AutoReverseChainRules, x, ::NoPullbackExtras | ||
) | ||
rc = ruleconfig(backend) | ||
y, pullback = rrule_via_ad(rc, f, x) | ||
pullbackfunc(dy) = last(pullback(dy)) | ||
return y, pullbackfunc | ||
end | ||
|
||
function DI.value_and_pullback!_split( | ||
f, backend::AutoReverseChainRules, x, extras::NoPullbackExtras | ||
) | ||
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras) | ||
pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy)) | ||
return y, pullbackfunc! | ||
end | ||
|
||
function DI.value_and_pullback( | ||
f, backend::AutoReverseChainRules, x, dy, extras::NoPullbackExtras | ||
) | ||
y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras) | ||
return y, pullbackfunc(dy) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
54 changes: 54 additions & 0 deletions
54
DifferentiationInterface/src/translation/differentiate_with.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
""" | ||
DifferentiateWith | ||
Callable function wrapper that enforces differentiation with a specified (inner) backend. | ||
This works by defining new rules overriding the behavior of the outer backend that would normally be used. | ||
!!! warning | ||
This is an experimental functionality, whose API cannot yet be considered stable. | ||
At the moment, it only supports one-argument functions, and rules are only defined for [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible outer backends. | ||
# Fields | ||
- `f`: the function in question | ||
- `backend::AbstractADType`: the inner backend to use for differentiation | ||
# Constructor | ||
DifferentiateWith(f, backend) | ||
# Example | ||
```@repl | ||
using DifferentiationInterface | ||
import ForwardDiff, Zygote | ||
function f(x) | ||
a = Vector{eltype(x)}(undef, 1) | ||
a[1] = sum(x) # mutation that breaks Zygote | ||
return a[1] | ||
end | ||
dw = DifferentiateWith(f, AutoForwardDiff()); | ||
gradient(dw, AutoZygote(), [1.0, 2.0]) # works because it calls ForwardDiff instead | ||
gradient(f, AutoZygote(), [1.0, 2.0]) # fails | ||
``` | ||
""" | ||
struct DifferentiateWith{F,B<:AbstractADType} | ||
f::F | ||
backend::B | ||
end | ||
|
||
""" | ||
(dw::DifferentiateWith)(x) | ||
Call the underlying function `dw.f` of a [`DifferentiateWith`](@ref) wrapper. | ||
""" | ||
(dw::DifferentiateWith)(x) = dw.f(x) | ||
|
||
function Base.show(io::IO, dw::DifferentiateWith) | ||
(; f, backend) = dw | ||
return print(io, "$f differentiated with $(backend_str(backend))") | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,9 @@ | ||
using DifferentiationInterface: pick_chunksize, DEFAULT_CHUNKSIZE | ||
|
||
@test pick_chunksize.(1:DEFAULT_CHUNKSIZE) == 1:DEFAULT_CHUNKSIZE | ||
@test DI.pick_chunksize.(1:(DI.DEFAULT_CHUNKSIZE)) == 1:(DI.DEFAULT_CHUNKSIZE) | ||
@test all( | ||
pick_chunksize.((DEFAULT_CHUNKSIZE + 1):(5DEFAULT_CHUNKSIZE)) .<= DEFAULT_CHUNKSIZE | ||
DI.pick_chunksize.((DI.DEFAULT_CHUNKSIZE + 1):(5DI.DEFAULT_CHUNKSIZE)) .<= | ||
DI.DEFAULT_CHUNKSIZE, | ||
) | ||
@test all( | ||
pick_chunksize.((DEFAULT_CHUNKSIZE + 1):(5DEFAULT_CHUNKSIZE)) .>= DEFAULT_CHUNKSIZE / 2 | ||
DI.pick_chunksize.((DI.DEFAULT_CHUNKSIZE + 1):(5DI.DEFAULT_CHUNKSIZE)) .>= | ||
DI.DEFAULT_CHUNKSIZE / 2, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,16 @@ | ||
using ADTypes: column_coloring, row_coloring, symmetric_coloring | ||
using DifferentiationInterface: | ||
GreedyColoringAlgorithm, | ||
check_structurally_orthogonal_columns, | ||
check_structurally_orthogonal_rows, | ||
check_symmetrically_structurally_orthogonal | ||
using LinearAlgebra | ||
using SparseArrays | ||
using Test | ||
|
||
alg = GreedyColoringAlgorithm() | ||
alg = DI.GreedyColoringAlgorithm() | ||
|
||
A = sprand(Bool, 100, 200, 0.1) | ||
|
||
column_colors = column_coloring(A, alg) | ||
@test check_structurally_orthogonal_columns(A, column_colors) | ||
column_colors = ADTypes.column_coloring(A, alg) | ||
@test DI.check_structurally_orthogonal_columns(A, column_colors) | ||
@test maximum(column_colors) < size(A, 2) ÷ 2 | ||
|
||
row_colors = row_coloring(A, alg) | ||
@test check_structurally_orthogonal_rows(A, row_colors) | ||
row_colors = ADTypes.row_coloring(A, alg) | ||
@test DI.check_structurally_orthogonal_rows(A, row_colors) | ||
@test maximum(row_colors) < size(A, 1) ÷ 2 | ||
|
||
S = Symmetric(sprand(Bool, 100, 100, 0.1)) + I | ||
symmetric_colors = symmetric_coloring(S, alg) | ||
@test check_symmetrically_structurally_orthogonal(S, symmetric_colors) | ||
symmetric_colors = ADTypes.symmetric_coloring(S, alg) | ||
@test DI.check_symmetrically_structurally_orthogonal(S, symmetric_colors) | ||
@test maximum(symmetric_colors) < size(A, 2) ÷ 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
function zygote_breaking_scenarios() | ||
onearg_scens = filter(default_scenarios()) do scen | ||
DIT.nb_args(scen) == 1 | ||
end | ||
bad_onearg_scens = map(onearg_scens) do scen | ||
function bad_f(x) | ||
a = Vector{eltype(x)}(undef, 1) | ||
a[1] = sum(x) | ||
return scen.f(x) | ||
end | ||
wrapped_bad_f = DifferentiateWith(bad_f, AutoForwardDiff()) | ||
bad_scen = DIT.change_function(scen, wrapped_bad_f) | ||
return bad_scen | ||
end | ||
return bad_onearg_scens | ||
end | ||
|
||
test_differentiation( | ||
AutoZygote(), | ||
zygote_breaking_scenarios(); | ||
second_order=false, | ||
logging=logging = get(ENV, "CI", "false") == "false", | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.