Skip to content

Commit 004e934

Browse files
authored
Revamp test scenarios (#323)
* Revamp test scenarios * Fix * Fixes * Fixes * Fix * Refix * Refix * Fixes * Fix * Fix * Coverage * Appease JET
1 parent cb4605a commit 004e934

File tree

31 files changed

+1179
-1572
lines changed

31 files changed

+1179
-1572
lines changed

.github/workflows/Test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ jobs:
123123
- Formalities
124124
- Zero
125125
- ForwardDiff
126+
- Zygote
126127
exclude:
127128
- version: '1.6'
128129
group: Formalities

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/DifferentiationInterfaceForwardDiffExt.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ using DifferentiationInterface:
1212
NoDerivativeExtras,
1313
NoSecondDerivativeExtras,
1414
PushforwardExtras
15-
using ForwardDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult
15+
using ForwardDiff.DiffResults:
16+
DiffResults, DiffResult, GradientResult, HessianResult, MutableDiffResult
1617
using ForwardDiff:
1718
Chunk,
1819
Dual,

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

+15-13
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,11 @@ function DI.value_and_gradient!(
132132
end
133133

134134
function DI.value_and_gradient(
135-
f::F, backend::AutoForwardDiff, x, extras::ForwardDiffGradientExtras
135+
f::F, ::AutoForwardDiff, x, extras::ForwardDiffGradientExtras
136136
) where {F}
137-
grad = similar(x)
138-
return DI.value_and_gradient!(f, grad, backend, x, extras)
137+
result = GradientResult(x)
138+
result = gradient!(result, f, x, extras.config)
139+
return DiffResults.value(result), DiffResults.gradient(result)
139140
end
140141

141142
function DI.gradient!(
@@ -189,19 +190,22 @@ end
189190

190191
## Hessian
191192

192-
struct ForwardDiffHessianExtras{C1,C2} <: HessianExtras
193+
struct ForwardDiffHessianExtras{C1,C2,C3} <: HessianExtras
193194
array_config::C1
194-
result_config::C2
195+
manual_result_config::C2
196+
auto_result_config::C3
195197
end
196198

197199
function DI.prepare_hessian(f, backend::AutoForwardDiff, x)
198-
example_result = MutableDiffResult(
200+
manual_result = MutableDiffResult(
199201
one(eltype(x)), (similar(x), similar(x, length(x), length(x)))
200202
)
203+
auto_result = HessianResult(x)
201204
chunk = choose_chunk(backend, x)
202205
array_config = HessianConfig(f, x, chunk)
203-
result_config = HessianConfig(f, example_result, x, chunk)
204-
return ForwardDiffHessianExtras(array_config, result_config)
206+
manual_result_config = HessianConfig(f, manual_result, x, chunk)
207+
auto_result_config = HessianConfig(f, auto_result, x, chunk)
208+
return ForwardDiffHessianExtras(array_config, manual_result_config, auto_result_config)
205209
end
206210

207211
function DI.hessian!(
@@ -218,7 +222,7 @@ function DI.value_gradient_and_hessian!(
218222
f::F, grad, hess, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras
219223
) where {F}
220224
result = MutableDiffResult(one(eltype(x)), (grad, hess))
221-
result = hessian!(result, f, x, extras.result_config)
225+
result = hessian!(result, f, x, extras.manual_result_config)
222226
return (
223227
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
224228
)
@@ -227,10 +231,8 @@ end
227231
function DI.value_gradient_and_hessian(
228232
f::F, ::AutoForwardDiff, x, extras::ForwardDiffHessianExtras
229233
) where {F}
230-
result = MutableDiffResult(
231-
one(eltype(x)), (similar(x), similar(x, length(x), length(x)))
232-
)
233-
result = hessian!(result, f, x, extras.result_config)
234+
result = HessianResult(x)
235+
result = hessian!(result, f, x, extras.auto_result_config)
234236
return (
235237
DiffResults.value(result), DiffResults.gradient(result), DiffResults.hessian(result)
236238
)

DifferentiationInterface/test/Double/ChainRulesCore-Zygote/test.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ end
1111

1212
test_differentiation(
1313
AutoChainRules(ZygoteRuleConfig());
14-
excluded=[SecondDerivativeScenario],
14+
excluded=[:second_derivative],
1515
second_order=VERSION >= v"1.10",
1616
logging=LOGGING,
1717
);

DifferentiationInterface/test/Single/Enzyme/test.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,14 @@ test_differentiation(
4848
SecondOrder(AutoEnzyme(; mode=Enzyme.Forward), AutoEnzyme(; mode=Enzyme.Reverse)),
4949
];
5050
first_order=false,
51-
excluded=[SecondDerivativeScenario],
51+
excluded=[:second_derivative],
5252
logging=LOGGING,
5353
);
5454

5555
test_differentiation(
5656
[AutoEnzyme(; mode=nothing), AutoEnzyme(; mode=Enzyme.Forward)];
5757
first_order=false,
58-
excluded=[HessianScenario, HVPScenario],
58+
excluded=[:hessian, :hvp],
5959
logging=LOGGING,
6060
);
6161

@@ -72,7 +72,7 @@ test_differentiation(
7272
test_differentiation(
7373
sparse_backends,
7474
default_scenarios();
75-
excluded=[DerivativeScenario, GradientScenario, PullbackScenario, PushforwardScenario],
75+
excluded=[:derivative, :gradient, :pullback, :pushforward],
7676
second_order=false,
7777
logging=LOGGING,
7878
);

DifferentiationInterface/test/Single/FiniteDiff/test.jl

+1-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,4 @@ for backend in [AutoFiniteDiff()]
88
@test check_hessian(backend)
99
end
1010

11-
test_differentiation(
12-
AutoFiniteDiff(); excluded=[SecondDerivativeScenario, HVPScenario], logging=LOGGING
13-
);
11+
test_differentiation(AutoFiniteDiff(); excluded=[:second_derivative, :hvp], logging=LOGGING);

DifferentiationInterface/test/Single/ForwardDiff/efficiency.jl

-87
This file was deleted.

DifferentiationInterface/test/Single/ForwardDiff/test.jl

+2-9
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ test_differentiation(
3636
# ForwardDiff accesses individual indices
3737
vcat(component_scenarios(), static_scenarios());
3838
# jacobian is super slow for some reason
39-
excluded=[JacobianScenario],
39+
excluded=[:jacobian],
4040
second_order=false,
4141
logging=LOGGING,
4242
);
@@ -46,14 +46,7 @@ test_differentiation(
4646
test_differentiation(
4747
sparse_backends,
4848
default_scenarios();
49-
excluded=[
50-
DerivativeScenario,
51-
GradientScenario,
52-
HVPScenario,
53-
PullbackScenario,
54-
PushforwardScenario,
55-
SecondDerivativeScenario,
56-
],
49+
excluded=[:derivative, :gradient, :hvp, :pullback, :pushforward, :second_derivative],
5750
logging=LOGGING,
5851
);
5952

DifferentiationInterface/test/Single/Zygote/test.jl

+8-17
Original file line numberDiff line numberDiff line change
@@ -21,32 +21,23 @@ end
2121

2222
## Dense backends
2323

24-
test_differentiation(AutoZygote(); excluded=[SecondDerivativeScenario], logging=LOGGING);
25-
26-
test_differentiation(
27-
AutoZygote(),
28-
vcat(component_scenarios(), static_scenarios());
29-
second_order=false,
30-
logging=LOGGING,
31-
)
24+
test_differentiation(AutoZygote(); excluded=[:second_derivative], logging=LOGGING);
3225

3326
if VERSION >= v"1.10"
34-
test_differentiation(AutoZygote(), gpu_scenarios(); second_order=false, logging=LOGGING)
27+
test_differentiation(
28+
AutoZygote(),
29+
vcat(component_scenarios(), gpu_scenarios(), static_scenarios());
30+
second_order=false,
31+
logging=LOGGING,
32+
)
3533
end
3634

3735
## Sparse backends
3836

3937
test_differentiation(
4038
sparse_backends,
4139
default_scenarios();
42-
excluded=[
43-
DerivativeScenario,
44-
GradientScenario,
45-
HVPScenario,
46-
PullbackScenario,
47-
PushforwardScenario,
48-
SecondDerivativeScenario,
49-
],
40+
excluded=[:derivative, :gradient, :hvp, :pullback, :pushforward, :second_derivative],
5041
logging=LOGGING,
5142
);
5243

DifferentiationInterfaceTest/Project.toml

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
name = "DifferentiationInterfaceTest"
22
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.4.4"
4+
version = "0.5.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
99
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1010
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
11+
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1112
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1213
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
1314
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
@@ -24,6 +25,7 @@ ADTypes = "1.0.0"
2425
Chairmarks = "1.2.1"
2526
Compat = "4"
2627
ComponentArrays = "0.15"
28+
DataFrames = "1.6.1"
2729
DifferentiationInterface = "0.5.4"
2830
DocStringExtensions = "0.9"
2931
JET = "0.4 - 0.8, 0.9"
@@ -50,6 +52,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
5052
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
5153
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
5254
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
55+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
5356

5457
[targets]
55-
test = ["ADTypes", "Aqua", "DataFrames", "DifferentiationInterface", "ForwardDiff", "JET", "JuliaFormatter", "Pkg", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "Test"]
58+
test = ["ADTypes", "Aqua", "DataFrames", "DifferentiationInterface", "ForwardDiff", "JET", "JuliaFormatter", "Pkg", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "Test", "Zygote"]

DifferentiationInterfaceTest/docs/src/api.md

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
# API reference
2+
13
```@meta
24
CurrentModule = Main
35
CollapsedDocStrings = true
46
```
57

6-
# API reference
7-
88
```@docs
99
DifferentiationInterfaceTest
1010
```
@@ -30,7 +30,6 @@ static_scenarios
3030
## Scenario types
3131

3232
```@docs
33-
AbstractScenario
3433
PushforwardScenario
3534
PullbackScenario
3635
DerivativeScenario

DifferentiationInterfaceTest/docs/src/tutorial.md

+12-15
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ We present a typical workflow with DifferentiationInterfaceTest.jl, building on
55
```@repl tuto
66
using DifferentiationInterface, DifferentiationInterfaceTest
77
import ForwardDiff, Enzyme
8-
import DataFrames, Markdown, PrettyTables, Printf
8+
import Markdown, PrettyTables, Printf
99
```
1010

1111
## Introduction
@@ -31,15 +31,18 @@ Of course we know the true gradient mapping:
3131
DifferentiationInterfaceTest.jl relies with so-called "scenarios", in which you encapsulate the information needed for your test:
3232

3333
- the function `f`
34-
- the input `x` (and output `y` for mutating functions)
35-
- optionally a reference `ref` to check against
34+
- the input `x` and output `y`
35+
- the number of arguments for `f` (either `1` or `2`)
36+
- the behavior of the operator (either `:inplace` or `:outofplace`)
3637

37-
There is one scenario per operator, and so here we will use [`GradientScenario`](@ref):
38+
There is one scenario constructor per operator, and so here we will use [`GradientScenario`](@ref):
3839

3940
```@example tuto
41+
xv = rand(Float32, 3)
42+
xm = rand(Float64, 3, 2)
4043
scenarios = [
41-
GradientScenario(f; x=rand(Float32, 3), ref=∇f, place=:inplace),
42-
GradientScenario(f; x=rand(Float64, 3, 2), ref=∇f, place=:inplace)
44+
GradientScenario(f; x=xv, y=f(xv), nb_args=1, place=:inplace),
45+
GradientScenario(f; x=xm, y=f(xm), nb_args=1, place=:inplace)
4346
];
4447
nothing # hide
4548
```
@@ -67,17 +70,11 @@ Once you are confident that your backends give the correct answers, you probably
6770
This is made easy by the [`benchmark_differentiation`](@ref) function, whose syntax should feel familiar:
6871

6972
```@example tuto
70-
benchmark_result = benchmark_differentiation(backends, scenarios);
73+
df = benchmark_differentiation(backends, scenarios);
7174
```
7275

73-
The resulting object is a `Vector` of [`DifferentiationBenchmarkDataRow`](@ref), which can easily be converted into a `DataFrame` from [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl):
74-
75-
```@example tuto
76-
df = DataFrames.DataFrame(benchmark_result)
77-
```
78-
79-
Here's what the resulting `DataFrame` looks like with all its columns.
80-
Note that we only compare (possibly) in-place operators, because they are always more efficient.
76+
The resulting object is `DataFrame` from [DataFrames.jl](https://github.com/JuliaData/DataFrames.jl), whose columns correspond to the fields of [`DifferentiationBenchmarkDataRow`](@ref):
77+
Here's what it looks like with all of its columns.
8178

8279
```@example tuto
8380
table = PrettyTables.pretty_table(

0 commit comments

Comments
 (0)