Skip to content

Commit

Permalink
[BREAKING] Improve type stability tests and benchmarking (#560)
Browse files Browse the repository at this point in the history
* Improve type stability tests and benchmarking

* Remove `first_order` and `second_order`

* Docs

* Zero allocs

* Fixes

* Call count

* Fix

* Fix

* Add count calls

* Default count calls

* Fix
  • Loading branch information
gdalle authored Oct 10, 2024
1 parent bd62b64 commit 3698dbe
Show file tree
Hide file tree
Showing 26 changed files with 910 additions and 635 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
actions: write
contents: read
strategy:
fail-fast: true
fail-fast: false # TODO: toggle
matrix:
version:
- "1.10"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,6 @@ test_differentiation(
test_differentiation(
AutoChainRules(ZygoteRuleConfig()),
default_scenarios(; include_normal=false, include_constantified=true);
second_order=false,
excluded=SECOND_ORDER,
logging=LOGGING,
);
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,8 @@ for backend in [AutoDiffractor()]
end

test_differentiation(
AutoDiffractor(), default_scenarios(; linalg=false); second_order=false, logging=LOGGING
AutoDiffractor(),
default_scenarios(; linalg=false);
excluded=SECOND_ORDER,
logging=LOGGING,
);
19 changes: 8 additions & 11 deletions DifferentiationInterface/test/Back/Enzyme/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,19 +31,19 @@ end;

## First order

test_differentiation(backends, default_scenarios(); second_order=false, logging=LOGGING);
test_differentiation(backends, default_scenarios(); excluded=SECOND_ORDER, logging=LOGGING);

test_differentiation(
backends[1:3],
default_scenarios(; include_normal=false, include_constantified=true);
second_order=false,
excluded=SECOND_ORDER,
logging=LOGGING,
);

test_differentiation(
duplicated_backends,
default_scenarios(; include_normal=false, include_closurified=true);
second_order=false,
excluded=SECOND_ORDER,
logging=LOGGING,
);

Expand All @@ -54,8 +54,8 @@ test_differentiation(
AutoEnzyme(; mode=Enzyme.Forward), # TODO: add more
default_scenarios(; include_batchified=false);
correctness=false,
type_stability=true,
second_order=false,
type_stability=:prepared,
excluded=SECOND_ORDER,
logging=LOGGING,
);
=#
Expand All @@ -65,27 +65,24 @@ test_differentiation(
test_differentiation(
AutoEnzyme(),
default_scenarios(; include_constantified=true);
first_order=false,
excluded=FIRST_ORDER,
logging=LOGGING,
);

test_differentiation(
AutoEnzyme(; mode=Enzyme.Forward);
first_order=false,
excluded=[:hessian, :hvp],
excluded=vcat(FIRST_ORDER, [:hessian, :hvp]),
logging=LOGGING,
);

test_differentiation(
AutoEnzyme(; mode=Enzyme.Reverse);
first_order=false,
excluded=[:second_derivative],
excluded=vcat(FIRST_ORDER, [:second_derivative]),
logging=LOGGING,
);

test_differentiation(
SecondOrder(AutoEnzyme(; mode=Enzyme.Reverse), AutoEnzyme(; mode=Enzyme.Forward));
first_order=false,
logging=LOGGING,
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ end
test_differentiation(
AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)),
default_scenarios(; include_constantified=true);
second_order=false,
excluded=SECOND_ORDER,
logging=LOGGING,
);
8 changes: 4 additions & 4 deletions DifferentiationInterface/test/Back/ForwardDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,21 @@ test_differentiation(
);

test_differentiation(
AutoForwardDiff(); correctness=false, type_stability=true, logging=LOGGING
AutoForwardDiff(); correctness=false, type_stability=:prepared, logging=LOGGING
);

test_differentiation(
AutoForwardDiff(; chunksize=5);
correctness=false,
type_stability=(; preparation=true, prepared_op=true, unprepared_op=false),
type_stability=:full,
excluded=[:hessian],
logging=LOGGING,
);

test_differentiation(
backends,
vcat(component_scenarios(), static_scenarios()); # FD accesses individual indices
excluded=[:jacobian], # jacobian is super slow for some reason
second_order=false,
excluded=vcat(SECOND_ORDER, [:jacobian]), # jacobian is super slow for some reason
logging=LOGGING,
);

Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/Back/Mooncake/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@ end
test_differentiation(
backends,
default_scenarios(; include_constantified=true);
second_order=false,
excluded=SECOND_ORDER,
logging=LOGGING,
);
2 changes: 1 addition & 1 deletion DifferentiationInterface/test/Back/Tracker/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ end
test_differentiation(
AutoTracker(),
default_scenarios(; include_constantified=true);
second_order=false,
excluded=SECOND_ORDER,
logging=LOGGING,
);
4 changes: 2 additions & 2 deletions DifferentiationInterface/test/Back/Zygote/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@ test_differentiation(
logging=LOGGING,
);

test_differentiation(second_order_backends; first_order=false, logging=LOGGING);
test_differentiation(second_order_backends; logging=LOGGING);

test_differentiation(
backends[1],
vcat(component_scenarios(), gpu_scenarios());
second_order=false,
excluded=SECOND_ORDER,
logging=LOGGING,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@ end
test_differentiation(
[AutoForwardDiff(), AutoZygote()],
differentiatewith_scenarios();
second_order=false,
excluded=SECOND_ORDER,
logging=LOGGING,
)
9 changes: 4 additions & 5 deletions DifferentiationInterface/test/Misc/ZeroBackends/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ test_differentiation(
AutoZeroForward(),
default_scenarios(; include_batchified=false, include_constantified=true);
correctness=false,
type_stability=true,
type_stability=:full,
logging=LOGGING,
)

test_differentiation(
AutoZeroReverse(),
default_scenarios(; include_batchified=false, include_constantified=true);
correctness=false,
type_stability=(; preparation=true, prepared_op=true, unprepared_op=false),
type_stability=:full,
logging=LOGGING,
)

Expand All @@ -41,16 +41,15 @@ test_differentiation(
],
default_scenarios(; include_batchified=false, include_constantified=true);
correctness=false,
type_stability=(; preparation=true, prepared_op=true, unprepared_op=true),
first_order=false,
type_stability=:full,
logging=LOGGING,
)

test_differentiation(
AutoSparse.(zero_backends, coloring_algorithm=GreedyColoringAlgorithm()),
default_scenarios(; include_constantified=true);
correctness=false,
type_stability=(; preparation=true, prepared_op=true, unprepared_op=false),
type_stability=:full,
excluded=[:pushforward, :pullback, :gradient, :derivative, :hvp, :second_derivative],
logging=LOGGING,
)
Expand Down
7 changes: 3 additions & 4 deletions DifferentiationInterfaceTest/Project.toml
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
name = "DifferentiationInterfaceTest"
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.7.1"
version = "0.8.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Expand All @@ -31,7 +30,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays"
DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux"]
DifferentiationInterfaceTestFluxExt = ["FiniteDifferences", "Flux", "Functors"]
DifferentiationInterfaceTestJLArraysExt = "JLArrays"
DifferentiationInterfaceTestLuxExt = ["ComponentArrays", "ForwardDiff", "Lux", "LuxTestUtils"]
DifferentiationInterfaceTestStaticArraysExt = "StaticArrays"
Expand Down
9 changes: 8 additions & 1 deletion DifferentiationInterfaceTest/docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@ DifferentiationInterfaceTest
Scenario
test_differentiation
benchmark_differentiation
DifferentiationBenchmarkDataRow
FIRST_ORDER
SECOND_ORDER
```

## Utilities

```@docs
DifferentiationInterfaceTest.DifferentiationBenchmarkDataRow
```

## Pre-made scenario lists
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterfaceTest/docs/src/tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ test_differentiation(
backends, # the backends you want to compare
scenarios, # the scenarios you defined,
correctness=true, # compares values against the reference
type_stability=false, # checks type stability with JET.jl
type_stability=:none, # checks type stability with JET.jl
detailed=true, # prints a detailed test set
)
```
Expand Down
24 changes: 19 additions & 5 deletions DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,29 @@ using DifferentiationInterface:
Rewrap
import DifferentiationInterface as DI
using DocStringExtensions
using Functors: fmap
using JET: JET
using JET: @test_opt
using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent
using ProgressMeter: ProgressUnknown, next!
using Random: AbstractRNG, default_rng, rand!
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
import SparseMatrixColorings as SMC
using SparseArrays: SparseArrays, AbstractSparseMatrix, SparseMatrixCSC, nnz, spdiagm
using Test: @testset, @test

"""
FIRST_ORDER = [:pushforward, :pullback, :derivative, :gradient, :jacobian]
List of all first-order operators, to facilitate exclusion during tests.
"""
const FIRST_ORDER = [:pushforward, :pullback, :derivative, :gradient, :jacobian]

"""
SECOND_ORDER = [:hvp, :second_derivative, :hessian]
List of all second-order operators, to facilitate exclusion during tests.
"""
const SECOND_ORDER = [:hvp, :second_derivative, :hessian]

const ALL_OPS = vcat(FIRST_ORDER, SECOND_ORDER)

include("utils.jl")

include("scenarios/scenario.jl")
Expand All @@ -71,11 +85,11 @@ include("scenarios/extensions.jl")

include("tests/correctness_eval.jl")
include("tests/type_stability_eval.jl")
include("tests/sparsity.jl")
include("tests/benchmark.jl")
include("tests/benchmark_eval.jl")
include("test_differentiation.jl")

export FIRST_ORDER, SECOND_ORDER
export Scenario
export default_scenarios, sparse_scenarios
export test_differentiation, benchmark_differentiation
Expand Down
2 changes: 2 additions & 0 deletions DifferentiationInterfaceTest/src/scenarios/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ struct NumToArr{A} end
NumToArr(::Type{A}) where {A} = NumToArr{A}()
Base.eltype(::NumToArr{A}) where {A} = eltype(A)

Base.show(io::IO, ::NumToArr{A}) where {A} = print(io, "num_to_arr{$A}")

function (f::NumToArr{A})(x::Number) where {A}
a = multiplicator(A)
return sin.(x .* a)
Expand Down
11 changes: 0 additions & 11 deletions DifferentiationInterfaceTest/src/scenarios/scenario.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,3 @@
const ALL_OPS = (
:pushforward,
:pullback,
:derivative,
:gradient,
:jacobian,
:hessian,
:hvp,
:second_derivative,
)

"""
Scenario{op,pl_op,pl_fun}
Expand Down
Loading

0 comments on commit 3698dbe

Please sign in to comment.