Skip to content

Commit

Permalink
More type stability tests (#543)
Browse files Browse the repository at this point in the history
* More type stability tests

* Remove old kwarg

* Don't test everything for now

* Use fill!

* LTS JET
  • Loading branch information
gdalle authored Oct 5, 2024
1 parent 84378d7 commit 88c48c1
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 94 deletions.
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/misc/zero_backends.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ end

(rz::ReturnZero)(i) = zero(rz.template)

_zero!(x::AbstractArray) = x .= zero(eltype(x))
_zero!(x::AbstractArray) = fill!(x, zero(eltype(x)))

## Forward

Expand Down
3 changes: 1 addition & 2 deletions DifferentiationInterface/test/Back/ForwardDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ test_differentiation(
test_differentiation(
AutoForwardDiff(; chunksize=5);
correctness=false,
type_stability=true,
preparation_type_stability=true,
type_stability=(; preparation=true, prepared_op=true, unprepared_op=false),
logging=LOGGING,
);

Expand Down
22 changes: 14 additions & 8 deletions DifferentiationInterface/test/Misc/ZeroBackends/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,19 @@ end
## Type stability

test_differentiation(
zero_backends,
default_scenarios(; include_constantified=true);
AutoZeroForward(),
default_scenarios(; include_batchified=false, include_constantified=true);
correctness=false,
type_stability=true,
preparation_type_stability=true,
logging=LOGGING,
)

test_differentiation(
AutoZeroReverse(),
default_scenarios(; include_batchified=false, include_constantified=true);
correctness=false,
# TODO: set unprepared_op=true after ignoring DataFrames
type_stability=(; preparation=true, prepared_op=true, unprepared_op=false),
logging=LOGGING,
)

Expand All @@ -32,10 +40,9 @@ test_differentiation(
SecondOrder(AutoZeroForward(), AutoZeroReverse()),
SecondOrder(AutoZeroReverse(), AutoZeroForward()),
],
default_scenarios();
default_scenarios(; include_batchified=false, include_constantified=true);
correctness=false,
type_stability=true,
preparation_type_stability=true,
type_stability=(; preparation=true, prepared_op=true, unprepared_op=true),
first_order=false,
logging=LOGGING,
)
Expand All @@ -44,8 +51,7 @@ test_differentiation(
AutoSparse.(zero_backends, coloring_algorithm=GreedyColoringAlgorithm()),
default_scenarios(; include_constantified=true);
correctness=false,
type_stability=true,
preparation_type_stability=true,
type_stability=(; preparation=true, prepared_op=true, unprepared_op=false),
excluded=[:pushforward, :pullback, :gradient, :derivative, :hvp, :second_derivative],
logging=LOGGING,
)
Expand Down
22 changes: 12 additions & 10 deletions DifferentiationInterfaceTest/src/test_differentiation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Cross-test a list of `backends` on a list of `scenarios`, running a variety of d
Testing:
- `correctness=true`: whether to compare the differentiation results with the theoretical values specified in each scenario
- `type_stability=false`: whether to check type stability of operators with JET.jl (thanks to `JET.@test_opt`)
- `type_stability=false`: whether to check type stability of operators with JET.jl (thanks to `JET.@test_opt`). It can be either a `Bool` or a more detailed named tuple `(; preparation, prepared_op, unprepared_op)` to specify which variants should be analyzed.
- `sparsity`: whether to check sparsity of the jacobian / hessian
- `detailed=false`: whether to print a detailed or condensed test log
Expand All @@ -51,8 +51,7 @@ function test_differentiation(
scenarios::Vector{<:Scenario}=default_scenarios();
# testing
correctness::Bool=true,
type_stability::Bool=false,
preparation_type_stability::Bool=false,
type_stability=false,
call_count::Bool=false,
sparsity::Bool=false,
detailed=false,
Expand All @@ -73,10 +72,12 @@ function test_differentiation(
scenarios; first_order, second_order, input_type, output_type, excluded
)

bool_type_stability = (type_stability == true || type_stability isa NamedTuple)

title_additions =
(correctness != false ? " + correctness" : "") *
(call_count ? " + calls" : "") *
(type_stability ? " + types" : "") *
(bool_type_stability ? " + type stability" : "") *
(sparsity ? " + sparsity" : "")
title = "Testing" * title_additions[3:end]

Expand Down Expand Up @@ -115,13 +116,14 @@ function test_differentiation(
adapted_backend, scen; isapprox, atol, rtol, scenario_intact
)
end
type_stability && @testset "Type stability" begin
kwargs_type_stability = if type_stability isa NamedTuple
type_stability
else
(; preparation=false, prepared_op=type_stability, unprepared_op=false)
end
bool_type_stability && @testset "Type stability" begin
@static if VERSION >= v"1.7"
test_jet(
adapted_backend,
scen;
test_preparation=preparation_type_stability,
)
test_jet(adapted_backend, scen; kwargs_type_stability...)
end
end
sparsity && @testset "Sparsity" begin
Expand Down
Loading

0 comments on commit 88c48c1

Please sign in to comment.