Skip to content

Commit 3698dbe

Browse files
authored
[BREAKING] Improve type stability tests and benchmarking (#560)
* Improve type stability tests and benchmarking * Remove `first_order` and `second_order` * Docs * Zero allocs * Fixes * Call count * Fix * Fix * Add count calls * Default count calls * Fix
1 parent bd62b64 commit 3698dbe

File tree

26 files changed

+910
-635
lines changed

26 files changed

+910
-635
lines changed

.github/workflows/Test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
actions: write
2626
contents: read
2727
strategy:
28-
fail-fast: true
28+
fail-fast: false # TODO: toggle
2929
matrix:
3030
version:
3131
- "1.10"

DifferentiationInterface/test/Back/ChainRulesBackends/chainrules_zygote.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,6 @@ test_differentiation(
2323
test_differentiation(
2424
AutoChainRules(ZygoteRuleConfig()),
2525
default_scenarios(; include_normal=false, include_constantified=true);
26-
second_order=false,
26+
excluded=SECOND_ORDER,
2727
logging=LOGGING,
2828
);

DifferentiationInterface/test/Back/ChainRulesBackends/diffractor.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,8 @@ for backend in [AutoDiffractor()]
1313
end
1414

1515
test_differentiation(
16-
AutoDiffractor(), default_scenarios(; linalg=false); second_order=false, logging=LOGGING
16+
AutoDiffractor(),
17+
default_scenarios(; linalg=false);
18+
excluded=SECOND_ORDER,
19+
logging=LOGGING,
1720
);

DifferentiationInterface/test/Back/Enzyme/test.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,19 @@ end;
3131

3232
## First order
3333

34-
test_differentiation(backends, default_scenarios(); second_order=false, logging=LOGGING);
34+
test_differentiation(backends, default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING);
3535

3636
test_differentiation(
3737
backends[1:3],
3838
default_scenarios(; include_normal=false, include_constantified=true);
39-
second_order=false,
39+
excluded=SECOND_ORDER,
4040
logging=LOGGING,
4141
);
4242

4343
test_differentiation(
4444
duplicated_backends,
4545
default_scenarios(; include_normal=false, include_closurified=true);
46-
second_order=false,
46+
excluded=SECOND_ORDER,
4747
logging=LOGGING,
4848
);
4949

@@ -54,8 +54,8 @@ test_differentiation(
5454
AutoEnzyme(; mode=Enzyme.Forward), # TODO: add more
5555
default_scenarios(; include_batchified=false);
5656
correctness=false,
57-
type_stability=true,
58-
second_order=false,
57+
type_stability=:prepared,
58+
excluded=SECOND_ORDER,
5959
logging=LOGGING,
6060
);
6161
=#
@@ -65,27 +65,24 @@ test_differentiation(
6565
test_differentiation(
6666
AutoEnzyme(),
6767
default_scenarios(; include_constantified=true);
68-
first_order=false,
68+
excluded=FIRST_ORDER,
6969
logging=LOGGING,
7070
);
7171

7272
test_differentiation(
7373
AutoEnzyme(; mode=Enzyme.Forward);
74-
first_order=false,
75-
excluded=[:hessian, :hvp],
74+
excluded=vcat(FIRST_ORDER, [:hessian, :hvp]),
7675
logging=LOGGING,
7776
);
7877

7978
test_differentiation(
8079
AutoEnzyme(; mode=Enzyme.Reverse);
81-
first_order=false,
82-
excluded=[:second_derivative],
80+
excluded=vcat(FIRST_ORDER, [:second_derivative]),
8381
logging=LOGGING,
8482
);
8583

8684
test_differentiation(
8785
SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse), AutoEnzyme(; mode=Enzyme.Forward));
88-
first_order=false,
8986
logging=LOGGING,
9087
);
9188

DifferentiationInterface/test/Back/FiniteDifferences/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ end
1515
test_differentiation(
1616
AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)),
1717
default_scenarios(; include_constantified=true);
18-
second_order=false,
18+
excluded=SECOND_ORDER,
1919
logging=LOGGING,
2020
);

DifferentiationInterface/test/Back/ForwardDiff/test.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,21 @@ test_differentiation(
2323
);
2424

2525
test_differentiation(
26-
AutoForwardDiff(); correctness=false, type_stability=true, logging=LOGGING
26+
AutoForwardDiff(); correctness=false, type_stability=:prepared, logging=LOGGING
2727
);
2828

2929
test_differentiation(
3030
AutoForwardDiff(; chunksize=5);
3131
correctness=false,
32-
type_stability=(; preparation=true, prepared_op=true, unprepared_op=false),
32+
type_stability=:full,
33+
excluded=[:hessian],
3334
logging=LOGGING,
3435
);
3536

3637
test_differentiation(
3738
backends,
3839
vcat(component_scenarios(), static_scenarios()); # FD accesses individual indices
39-
excluded=[:jacobian], # jacobian is super slow for some reason
40-
second_order=false,
40+
excluded=vcat(SECOND_ORDER, [:jacobian]), # jacobian is super slow for some reason
4141
logging=LOGGING,
4242
);
4343

DifferentiationInterface/test/Back/Mooncake/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ end
1717
test_differentiation(
1818
backends,
1919
default_scenarios(; include_constantified=true);
20-
second_order=false,
20+
excluded=SECOND_ORDER,
2121
logging=LOGGING,
2222
);

DifferentiationInterface/test/Back/Tracker/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@ end
1515
test_differentiation(
1616
AutoTracker(),
1717
default_scenarios(; include_constantified=true);
18-
second_order=false,
18+
excluded=SECOND_ORDER,
1919
logging=LOGGING,
2020
);

DifferentiationInterface/test/Back/Zygote/test.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ test_differentiation(
2828
logging=LOGGING,
2929
);
3030

31-
test_differentiation(second_order_backends; first_order=false, logging=LOGGING);
31+
test_differentiation(second_order_backends; logging=LOGGING);
3232

3333
test_differentiation(
3434
backends[1],
3535
vcat(component_scenarios(), gpu_scenarios());
36-
second_order=false,
36+
excluded=SECOND_ORDER,
3737
logging=LOGGING,
3838
)
3939

DifferentiationInterface/test/Misc/DifferentiateWith/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,6 @@ end
2424
test_differentiation(
2525
[AutoForwardDiff(), AutoZygote()],
2626
differentiatewith_scenarios();
27-
second_order=false,
27+
excluded=SECOND_ORDER,
2828
logging=LOGGING,
2929
)

DifferentiationInterface/test/Misc/ZeroBackends/test.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ test_differentiation(
2222
AutoZeroForward(),
2323
default_scenarios(; include_batchified=false, include_constantified=true);
2424
correctness=false,
25-
type_stability=true,
25+
type_stability=:full,
2626
logging=LOGGING,
2727
)
2828

2929
test_differentiation(
3030
AutoZeroReverse(),
3131
default_scenarios(; include_batchified=false, include_constantified=true);
3232
correctness=false,
33-
type_stability=(; preparation=true, prepared_op=true, unprepared_op=false),
33+
type_stability=:full,
3434
logging=LOGGING,
3535
)
3636

@@ -41,16 +41,15 @@ test_differentiation(
4141
],
4242
default_scenarios(; include_batchified=false, include_constantified=true);
4343
correctness=false,
44-
type_stability=(; preparation=true, prepared_op=true, unprepared_op=true),
45-
first_order=false,
44+
type_stability=:full,
4645
logging=LOGGING,
4746
)
4847

4948
test_differentiation(
5049
AutoSparse.(zero_backends, coloring_algorithm=GreedyColoringAlgorithm()),
5150
default_scenarios(; include_constantified=true);
5251
correctness=false,
53-
type_stability=(; preparation=true, prepared_op=true, unprepared_op=false),
52+
type_stability=:full,
5453
excluded=[:pushforward, :pullback, :gradient, :derivative, :hvp, :second_derivative],
5554
logging=LOGGING,
5655
)

DifferentiationInterfaceTest/Project.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,26 @@
11
name = "DifferentiationInterfaceTest"
22
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.7.1"
4+
version = "0.8.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
88
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
99
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
1010
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1111
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
12-
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1312
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
1413
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1514
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
1615
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1716
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
18-
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
1917
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2018

2119
[weakdeps]
2220
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
2321
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
2422
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
23+
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
2524
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2625
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
2726
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
@@ -31,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3130

3231
[extensions]
3332
DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays"
34-
DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux"]
33+
DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux", "Functors"]
3534
DifferentiationInterfaceTestJLArraysExt = "JLArrays"
3635
DifferentiationInterfaceTestLuxExt = ["ComponentArrays", "ForwardDiff", "Lux", "LuxTestUtils"]
3736
DifferentiationInterfaceTestStaticArraysExt = "StaticArrays"

DifferentiationInterfaceTest/docs/src/api.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@ DifferentiationInterfaceTest
1515
Scenario
1616
test_differentiation
1717
benchmark_differentiation
18-
DifferentiationBenchmarkDataRow
18+
FIRST_ORDER
19+
SECOND_ORDER
20+
```
21+
22+
## Utilities
23+
24+
```@docs
25+
DifferentiationInterfaceTest.DifferentiationBenchmarkDataRow
1926
```
2027

2128
## Pre-made scenario lists

DifferentiationInterfaceTest/docs/src/tutorial.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ test_differentiation(
5555
backends, # the backends you want to compare
5656
scenarios, # the scenarios you defined,
5757
correctness=true, # compares values against the reference
58-
type_stability=false, # checks type stability with JET.jl
58+
type_stability=:none, # checks type stability with JET.jl
5959
detailed=true, # prints a detailed test set
6060
)
6161
```

DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,15 +51,29 @@ using DifferentiationInterface:
5151
Rewrap
5252
import DifferentiationInterface as DI
5353
using DocStringExtensions
54-
using Functors: fmap
55-
using JET: JET
54+
using JET: @test_opt
5655
using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent
5756
using ProgressMeter: ProgressUnknown, next!
5857
using Random: AbstractRNG, default_rng, rand!
59-
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
60-
import SparseMatrixColorings as SMC
58+
using SparseArrays: SparseArrays, AbstractSparseMatrix, SparseMatrixCSC, nnz, spdiagm
6159
using Test: @testset, @test
6260

61+
"""
62+
FIRST_ORDER = [:pushforward, :pullback, :derivative, :gradient, :jacobian]
63+
64+
List of all first-order operators, to facilitate exclusion during tests.
65+
"""
66+
const FIRST_ORDER = [:pushforward, :pullback, :derivative, :gradient, :jacobian]
67+
68+
"""
69+
SECOND_ORDER = [:hvp, :second_derivative, :hessian]
70+
71+
List of all second-order operators, to facilitate exclusion during tests.
72+
"""
73+
const SECOND_ORDER = [:hvp, :second_derivative, :hessian]
74+
75+
const ALL_OPS = vcat(FIRST_ORDER, SECOND_ORDER)
76+
6377
include("utils.jl")
6478

6579
include("scenarios/scenario.jl")
@@ -71,11 +85,11 @@ include("scenarios/extensions.jl")
7185

7286
include("tests/correctness_eval.jl")
7387
include("tests/type_stability_eval.jl")
74-
include("tests/sparsity.jl")
7588
include("tests/benchmark.jl")
7689
include("tests/benchmark_eval.jl")
7790
include("test_differentiation.jl")
7891

92+
export FIRST_ORDER, SECOND_ORDER
7993
export Scenario
8094
export default_scenarios, sparse_scenarios
8195
export test_differentiation, benchmark_differentiation

DifferentiationInterfaceTest/src/scenarios/default.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ struct NumToArr{A} end
7878
NumToArr(::Type{A}) where {A} = NumToArr{A}()
7979
Base.eltype(::NumToArr{A}) where {A} = eltype(A)
8080

81+
Base.show(io::IO, ::NumToArr{A}) where {A} = print(io, "num_to_arr{$A}")
82+
8183
function (f::NumToArr{A})(x::Number) where {A}
8284
a = multiplicator(A)
8385
return sin.(x .* a)

DifferentiationInterfaceTest/src/scenarios/scenario.jl

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,3 @@
1-
const ALL_OPS = (
2-
:pushforward,
3-
:pullback,
4-
:derivative,
5-
:gradient,
6-
:jacobian,
7-
:hessian,
8-
:hvp,
9-
:second_derivative,
10-
)
11-
121
"""
132
Scenario{op,pl_op,pl_fun}
143

0 commit comments

Comments
 (0)