Skip to content

Commit

Permalink
Pushforwards and pullbacks for everyone (#113)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Mar 27, 2024
1 parent fb07d7f commit 6155b2a
Show file tree
Hide file tree
Showing 22 changed files with 320 additions and 315 deletions.
33 changes: 23 additions & 10 deletions docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ Note the double exclamation mark, which is a convention telling you that `grad`
@btime gradient!!($f, _grad, $backend, $x) evals=1 setup=(_grad=similar($x));
```

For some reason the in-place version is slower than our first attempt, but as you can see it has one less allocation, corresponding to the gradient vector.
For some reason the in-place version is not much better than our first attempt.
However, as you can see, it has one less allocation: it corresponds to the gradient vector we provided.
Don't worry, we're not done yet.

## Preparing for multiple gradients
Expand Down Expand Up @@ -133,36 +134,48 @@ It's blazingly fast.
And you know what's even better?
You didn't need to look at the docs of either ForwardDiff.jl or Enzyme.jl to achieve top performance with both, or to compare them.

## Testing and benchmarking
## Testing

DifferentiationInterface.jl also provides some utilities for more involved comparison between backends.
They are gathered in a submodule called [`DifferentiationInterfaceTest`](https://github.com/gdalle/DifferentiationInterface.jl/tree/main/lib/DifferentiationInterfaceTest).
They are gathered in a submodule called `DifferentiationInterfaceTest`, located [here](https://github.com/gdalle/DifferentiationInterface.jl/tree/main/lib/DifferentiationInterfaceTest) in the repo.

```@repl tuto
using DifferentiationInterfaceTest
```

The main entry point is [`test_differentiation`](@ref), which is used as follows:
For testing, you can use [`test_differentiation`](@ref) as follows:

```@repl tuto
data = test_differentiation(
test_differentiation(
[AutoForwardDiff(), AutoEnzyme(Enzyme.Reverse)], # backends to compare
[gradient], # operators to try
[Scenario(f; x=x)]; # test scenario
[gradient, pullback], # operators to try
[Scenario(f; x=rand(3)), Scenario(f; x=rand(3,3))]; # test scenarios
correctness=AutoZygote(), # compare results to a "ground truth" from Zygote
benchmark=true, # measure runtime and allocations too
detailed=true, # print detailed test set
);
```

The output of `test_differentiation` when `benchmark=true` can be converted to a `DataFrame` from [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl):
## Benchmarking

Once you have ascertained correctness, performance will be your next concern.
The interface of [`benchmark_differentiation`](@ref) is very similar to the one we've just seen, but this time it returns a data object.

```@repl tuto
data = benchmark_differentiation(
[AutoForwardDiff(), AutoEnzyme(Enzyme.Reverse)],
[gradient, pullback],
[Scenario(f; x=rand(3)), Scenario(f; x=rand(3,3))];
);
```

The `BenchmarkData` object is just a struct of vectors, and you can easily convert to a `DataFrame` from [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl):

```@repl tuto
df = DataFrames.DataFrame(pairs(data)...)
```

Here's what the resulting `DataFrame` looks like with all its columns.
Note that the results may be slightly different from the ones presented above (we use [Chairmarks.jl](https://github.com/LilithHafner/Chairmarks.jl) internally instead of BenchmarkTools.jl, and measure slightly different operators).
Note that the results may vary from the ones presented above (we use [Chairmarks.jl](https://github.com/LilithHafner/Chairmarks.jl) internally instead of BenchmarkTools.jl, and measure slightly different operators).

```@example tuto
import Markdown, PrettyTables # hide
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ DI.mode(::AutoForwardEnzyme) = ADTypes.AbstractForwardMode
DI.mode(::AutoReverseEnzyme) = ADTypes.AbstractReverseMode

# Enzyme's `Duplicated(x, dx)` expects both arguments to be of the same type
function DI.basisarray(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T}
function DI.basis(::AutoEnzyme, a::AbstractArray{T}, i::CartesianIndex) where {T}
b = zero(a)
b[i] = one(T)
return b
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using FastDifferentiation.RuntimeGeneratedFunctions: RuntimeGeneratedFunction

DI.mode(::AutoFastDifferentiation) = ADTypes.AbstractSymbolicDifferentiationMode
DI.supports_mutation(::AutoFastDifferentiation) = DI.MutationNotSupported()
DI.supports_pullback(::AutoFastDifferentiation) = DI.PullbackNotSupported()
DI.pullback_performance(::AutoFastDifferentiation) = DI.PullbackSlow()

myvec(x::Number) = [x]
myvec(x::AbstractArray) = vec(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module DifferentiationInterfaceFiniteDifferencesExt
using ADTypes: AutoFiniteDifferences
import DifferentiationInterface as DI
using FillArrays: OneElement
using FiniteDifferences: FiniteDifferences, jvp
using FiniteDifferences: FiniteDifferences, jvp, j′vp
using LinearAlgebra: dot

DI.supports_mutation(::AutoFiniteDifferences) = DI.MutationNotSupported()
Expand All @@ -19,4 +19,15 @@ function DI.value_and_pushforward(
return y, jvp(backend.fdm, f, (x, dx))
end

#=
# TODO: why does this fail?
function DI.value_and_pullback(
f, backend::AutoFiniteDifferences{fdm}, x, dy, extras::Nothing
) where {fdm}
y = f(x)
return y, j′vp(backend.fdm, f, x, dy)[1]
end
=#

end
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ using DifferentiationInterface:
mode,
outer,
supports_mutation,
supports_pushforward,
supports_pullback
pushforward_performance,
pullback_performance
using DocStringExtensions
import DifferentiationInterface as DI
using JET: @test_call, @test_opt
Expand All @@ -42,17 +42,19 @@ include("utils/zero.jl")
include("utils/compatibility.jl")
include("utils/printing.jl")
include("utils/misc.jl")
include("utils/filter.jl")

include("tests/correctness.jl")
include("tests/type_stability.jl")
include("tests/call_count.jl")
include("tests/benchmark.jl")
include("tests/test.jl")

export all_operators
export Scenario
export default_scenarios
export static_scenarios, component_scenarios, gpu_scenarios
export BenchmarkData, record!
export all_operators, test_differentiation
export BenchmarkData
export test_differentiation, benchmark_differentiation

end
43 changes: 43 additions & 0 deletions lib/DifferentiationInterfaceTest/src/tests/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,49 @@ function Base.pairs(data::BenchmarkData)
return ns .=> getfield.(Ref(data), ns)
end

"""
benchmark_differentiation(backends, [operators, scenarios]; [kwargs...])
Benchmark a list of `backends` for a list of `operators` on a list of `scenarios`.
# Keyword arguments
- filtering: same as [`test_differentiation`](@ref) for the filtering part.
- `logging=true`: whether to log progress
"""
function benchmark_differentiation(
backends::Vector{<:AbstractADType},
operators::Vector{<:Function}=all_operators(),
scenarios::Vector{<:Scenario}=default_scenarios();
# filtering
input_type::Type=Any,
output_type::Type=Any,
allocating=true,
mutating=true,
first_order=true,
second_order=true,
excluded::Vector{<:Function}=Function[],
# options
logging=false,
)
operators = filter_operators(operators; first_order, second_order, excluded)
scenarios = filter_scenarios(scenarios; input_type, output_type, allocating, mutating)

benchmark_data = BenchmarkData()
for backend in backends
for op in operators
for scen in filter(scenarios) do scen
compatible(backend, op, scen)
end
logging &&
@info "Benchmarking: $(backend_string(backend)) - $op - $(string(scen))"
run_benchmark!(benchmark_data, backend, op, scen; allocations=false)
end
end
end
return benchmark_data
end

function record!(data, tup::NamedTuple)
for n in fieldnames(typeof(tup))
push!(getfield(data, n), getfield(tup, n))
Expand Down
115 changes: 21 additions & 94 deletions lib/DifferentiationInterfaceTest/src/tests/test.jl
Original file line number Diff line number Diff line change
@@ -1,58 +1,7 @@
"""
all_operators()
List all operators that can be tested with [`test_differentiation`](@ref).
"""
function all_operators()
return [
pushforward,
pullback,
derivative,
gradient,
jacobian,
second_derivative,
hvp,
hessian,
]
end

function filter_operators(
operators::Vector{<:Function};
first_order::Bool,
second_order::Bool,
excluded::Vector{<:Function},
)
!first_order && (
operators = filter(
!in([pushforward, pullback, derivative, gradient, jacobian]), operators
)
)
!second_order && (operators = filter(!in([second_derivative, hvp, hessian]), operators))
operators = filter(!in(excluded), operators)
return operators
end

function filter_scenarios(
scenarios::Vector{<:Scenario};
input_type::Type,
output_type::Type,
allocating::Bool,
mutating::Bool,
)
scenarios = filter(scenarios) do scen
typeof(scen.x) <: input_type && typeof(scen.y) <: output_type
end
!allocating && (scenarios = filter(is_mutating, scenarios))
!mutating && (scenarios = filter(!is_mutating, scenarios))
return scenarios
end

"""
test_differentiation(backends, [operators, scenarios]; [kwargs...])
Cross-test a list of `backends` for a list of `operators` on a list of `scenarios`, running a variety of different tests.
If `benchmark=true`, return a [`BenchmarkData`](@ref) object, otherwise return `nothing`.
Test a list of `backends` for a list of `operators` on a list of `scenarios`.
# Default arguments
Expand All @@ -66,9 +15,7 @@ Testing:
- `correctness=true`: whether to compare the differentiation results with the theoretical values specified in each scenario. If a backend object like `correctness=AutoForwardDiff()` is passed instead of a boolean, the results will be compared using that reference backend as the ground truth.
- `call_count=false`: whether to check that the function is called the right number of times
- `type_stability=false`: whether to check type stability with JET.jl (thanks to `@test_opt`)
- `benchmark=false`: whether to run and return a benchmark suite with Chairmarks.jl
- `allocations=false`: whether to check that the benchmarks are allocation-free
- `detailed=false`: whether to print a detailed test set (by scenario) or condensed test set (by operator)
- `detailed=false`: whether to print a detailed or condensed test log
Filtering:
Expand All @@ -82,6 +29,7 @@ Filtering:
Options:
- `logging=true`: whether to log progress
- `isapprox=isapprox`: function used to compare objects, only needs to be set for complicated cases beyond arrays / scalars
- `rtol=1e-3`: precision for correctness testing (when comparing to the reference outputs)
"""
Expand All @@ -93,8 +41,6 @@ function test_differentiation(
correctness::Union{Bool,AbstractADType}=true,
type_stability::Bool=false,
call_count::Bool=false,
benchmark::Bool=false,
allocations::Bool=false,
detailed=false,
# filtering
input_type::Type=Any,
Expand All @@ -105,64 +51,45 @@ function test_differentiation(
second_order=true,
excluded::Vector{<:Function}=Function[],
# options
logging=false,
isapprox=isapprox,
rtol=1e-3,
)
operators = filter_operators(operators; first_order, second_order, excluded)
scenarios = filter_scenarios(scenarios; input_type, output_type, allocating, mutating)

benchmark_data = BenchmarkData()
if correctness isa AbstractADType
scenarios = change_ref.(scenarios, Ref(correctness))
end

title =
"Differentiation tests -" *
(correctness != false ? " correctness" : "") *
(call_count ? " calls" : "") *
(type_stability ? " types" : "") *
(benchmark ? " benchmark" : "") *
(allocations ? " allocations" : "")
(type_stability ? " types" : "")

@testset verbose = detailed "$(backend_string(backend))" for backend in backends
@testset verbose = detailed "$op" for op in operators
@testset "$scen" for scen in filter(scenarios) do scen
compatible(backend, op, scen)
end
if correctness != false
@testset "Correctness" begin
if correctness isa AbstractADType
test_correctness(
backend, op, change_ref(scen, correctness); isapprox, rtol
)
else
test_correctness(backend, op, scen; isapprox, rtol)
end
end
@testset verbose = true "$title" begin
@testset verbose = detailed "$(backend_string(backend))" for backend in backends
@testset verbose = detailed "$op" for op in operators
@testset "$scen" for scen in filter(scenarios) do scen
compatible(backend, op, scen)
end
if call_count
@testset "Call count" begin
logging &&
@info "Testing: $(backend_string(backend)) - $op - $(string(scen))"
correctness != false && @testset "Correctness" begin
test_correctness(backend, op, scen; isapprox, rtol)
end
call_count && @testset "Call count" begin
test_call_count(backend, op, scen)
end
end
if type_stability
@testset "Type stability" begin
type_stability && @testset "Type stability" begin
test_jet(backend, op, scen)
end
end
if benchmark || allocations
@testset "Allocations" begin
run_benchmark!(
benchmark_data, backend, op, scen; allocations=allocations
)
end
end
end
end
end

if benchmark
return benchmark_data
else
return nothing
end
return nothing
end

"""
Expand Down
8 changes: 0 additions & 8 deletions lib/DifferentiationInterfaceTest/src/utils/compatibility.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,6 @@ function compatible(::AbstractADType, ::Function)
return true
end

function compatible(backend::AbstractADType, ::typeof(pushforward))
return Bool(supports_pushforward(backend))
end

function compatible(backend::AbstractADType, ::typeof(pullback))
return Bool(supports_pullback(backend))
end

## Backend-scenario

function compatible(::AbstractADType, ::Scenario{false})
Expand Down
Loading

0 comments on commit 6155b2a

Please sign in to comment.