From 59f430d6929f9f3c3ed408dad79ee14d9bd141af Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 26 Apr 2024 07:24:29 +0200 Subject: [PATCH 1/2] Implement DifferentiateWith to translate between backends --- DifferentiationInterface/docs/src/api.md | 6 ++ DifferentiationInterface/docs/src/backends.md | 4 +- .../docs/src/overloads.md | 2 +- DifferentiationInterface/docs/src/overview.md | 6 ++ ...fferentiationInterfaceChainRulesCoreExt.jl | 39 +++----- .../differentiate_with.jl | 12 +++ .../reverse_onearg.jl | 27 ++++++ .../src/DifferentiationInterface.jl | 4 + .../src/translation/differentiate_with.jl | 54 +++++++++++ .../src/utils/exceptions.jl | 2 +- .../src/utils/printing.jl | 13 ++- DifferentiationInterface/test/chunk.jl | 10 +-- DifferentiationInterface/test/coloring.jl | 24 ++--- .../test/differentiate_with.jl | 23 +++++ DifferentiationInterface/test/runtests.jl | 21 +++-- DifferentiationInterface/test/sparsity.jl | 4 +- .../test/test_exceptions.jl | 10 +-- DifferentiationInterface/test/test_imports.jl | 5 +- .../src/DifferentiationInterfaceTest.jl | 2 +- .../src/scenarios/scenario.jl | 37 ++++++++ .../src/test_differentiation.jl | 90 ++++++++++--------- .../src/tests/benchmark.jl | 2 +- 22 files changed, 276 insertions(+), 121 deletions(-) create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl create mode 100644 DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl create mode 100644 DifferentiationInterface/src/translation/differentiate_with.jl create mode 100644 DifferentiationInterface/test/differentiate_with.jl diff --git a/DifferentiationInterface/docs/src/api.md b/DifferentiationInterface/docs/src/api.md index e3cb90652..07fcddcd9 100644 --- a/DifferentiationInterface/docs/src/api.md +++ b/DifferentiationInterface/docs/src/api.md @@ -91,6 +91,12 @@ check_twoarg check_hessian ``` +## Translation + +```@docs +DifferentiateWith +``` + ## Internals This is not part of the public API. diff --git a/DifferentiationInterface/docs/src/backends.md b/DifferentiationInterface/docs/src/backends.md index 9d7299b64..75560516a 100644 --- a/DifferentiationInterface/docs/src/backends.md +++ b/DifferentiationInterface/docs/src/backends.md @@ -5,7 +5,7 @@ CollapsedDocStrings = true ```@setup backends using DifferentiationInterface -using DifferentiationInterface: backend_string +using DifferentiationInterface: backend_str import Markdown import Diffractor, Enzyme, FastDifferentiation, FiniteDiff, FiniteDifferences, ForwardDiff, PolyesterForwardDiff, ReverseDiff, Tapir, Tracker, Zygote @@ -37,7 +37,7 @@ println(io, "|:--------|:------------:|:----------------------:|:--------------- for example in backend_examples b = eval(Meta.parse(example)) # backend - join(io, [backend_string(b), unicode_check_available(b), unicode_check_twoarg(b), unicode_check_hessian(b), "`$example`"], '|') + join(io, [backend_str(b), unicode_check_available(b), unicode_check_twoarg(b), unicode_check_hessian(b), "`$example`"], '|') println(io, '|' ) end backend_table = Markdown.parse(String(take!(io))) diff --git a/DifferentiationInterface/docs/src/overloads.md b/DifferentiationInterface/docs/src/overloads.md index 60b3f7918..e31b824d6 100644 --- a/DifferentiationInterface/docs/src/overloads.md +++ b/DifferentiationInterface/docs/src/overloads.md @@ -24,7 +24,7 @@ Each cell can have three values: ```@setup overloads using ADTypes: AbstractADType using DifferentiationInterface -using DifferentiationInterface: backend_string, mutation_support, MutationSupported +using DifferentiationInterface: backend_str, mutation_support, MutationSupported using Markdown: Markdown using Diffractor: Diffractor using Enzyme: Enzyme diff --git a/DifferentiationInterface/docs/src/overview.md b/DifferentiationInterface/docs/src/overview.md index 3c4ad558b..afaabb42d 100644 --- a/DifferentiationInterface/docs/src/overview.md +++ b/DifferentiationInterface/docs/src/overview.md @@ -123,6 +123,12 @@ We make this available for all backends with the following operators: | :--------------------------------- | :---------------------------------- | | [`value_and_pullback_split`](@ref) | [`value_and_pullback!_split`](@ref) | +## Translation + +The wrapper [`DifferentiateWith`](@ref) allows you to take a function and specify that it should be differentiated with the backend of your choice. +In other words, when you try to differentiate `dw = DifferentiateWith(f, backend1)` with `backend2`, then `backend1` steps in and `backend2` does nothing. +At the moment it only works when `backend2` supports [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl). + ## Going further ### Non-standard types diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl index dc69f3156..0f105b6c9 100644 --- a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/DifferentiationInterfaceChainRulesCoreExt.jl @@ -2,9 +2,15 @@ module DifferentiationInterfaceChainRulesCoreExt using ADTypes: ADTypes, AutoChainRules using ChainRulesCore: - HasForwardsMode, HasReverseMode, NoTangent, RuleConfig, frule_via_ad, rrule_via_ad + ChainRulesCore, + HasForwardsMode, + HasReverseMode, + NoTangent, + RuleConfig, + frule_via_ad, + rrule_via_ad import DifferentiationInterface as DI -using DifferentiationInterface: NoPullbackExtras, NoPushforwardExtras +using DifferentiationInterface: DifferentiateWith, NoPullbackExtras, NoPushforwardExtras ruleconfig(backend::AutoChainRules) = backend.ruleconfig @@ -14,32 +20,7 @@ const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}} DI.check_available(::AutoChainRules) = true DI.mutation_support(::AutoChainRules) = DI.MutationNotSupported() -## Pullback - -DI.prepare_pullback(f, ::AutoReverseChainRules, x, dy) = NoPullbackExtras() - -function DI.value_and_pullback_split( - f, backend::AutoReverseChainRules, x, ::NoPullbackExtras -) - rc = ruleconfig(backend) - y, pullback = rrule_via_ad(rc, f, x) - pullbackfunc(dy) = last(pullback(dy)) - return y, pullbackfunc -end - -function DI.value_and_pullback!_split( - f, backend::AutoReverseChainRules, x, extras::NoPullbackExtras -) - y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras) - pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy)) - return y, pullbackfunc! -end - -function DI.value_and_pullback( - f, backend::AutoReverseChainRules, x, dy, extras::NoPullbackExtras -) - y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras) - return y, pullbackfunc(dy) -end +include("reverse_onearg.jl") +include("differentiate_with.jl") end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl new file mode 100644 index 000000000..12071a43a --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/differentiate_with.jl @@ -0,0 +1,12 @@ +function ChainRulesCore.frule((_, dx), dw::DifferentiateWith, x) + (; f, backend) = dw + y, dy = DI.value_and_pushforward(f, backend, x, dx) + return y, dy +end + +function ChainRulesCore.rrule(dw::DifferentiateWith, x) + (; f, backend) = dw + y, pullbackfunc = DI.value_and_pullback_split(f, backend, x) + pullbackfunc_adjusted(dy) = (NoTangent(), pullbackfunc(dy)) + return y, pullbackfunc_adjusted +end diff --git a/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl new file mode 100644 index 000000000..706af6583 --- /dev/null +++ b/DifferentiationInterface/ext/DifferentiationInterfaceChainRulesCoreExt/reverse_onearg.jl @@ -0,0 +1,27 @@ +## Pullback + +DI.prepare_pullback(f, ::AutoReverseChainRules, x, dy) = NoPullbackExtras() + +function DI.value_and_pullback_split( + f, backend::AutoReverseChainRules, x, ::NoPullbackExtras +) + rc = ruleconfig(backend) + y, pullback = rrule_via_ad(rc, f, x) + pullbackfunc(dy) = last(pullback(dy)) + return y, pullbackfunc +end + +function DI.value_and_pullback!_split( + f, backend::AutoReverseChainRules, x, extras::NoPullbackExtras +) + y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras) + pullbackfunc!(dx, dy) = copyto!(dx, pullbackfunc(dy)) + return y, pullbackfunc! +end + +function DI.value_and_pullback( + f, backend::AutoReverseChainRules, x, dy, extras::NoPullbackExtras +) + y, pullbackfunc = DI.value_and_pullback_split(f, backend, x, extras) + return y, pullbackfunc(dy) +end diff --git a/DifferentiationInterface/src/DifferentiationInterface.jl b/DifferentiationInterface/src/DifferentiationInterface.jl index d2ff4a8e4..f19dcb58a 100644 --- a/DifferentiationInterface/src/DifferentiationInterface.jl +++ b/DifferentiationInterface/src/DifferentiationInterface.jl @@ -60,6 +60,8 @@ include("sparse/fallbacks.jl") include("sparse/jacobian.jl") include("sparse/hessian.jl") +include("translation/differentiate_with.jl") + export SecondOrder export value_and_pushforward!, value_and_pushforward @@ -87,6 +89,8 @@ export prepare_second_derivative, prepare_hvp, prepare_hessian export check_available, check_twoarg, check_hessian +export DifferentiateWith + # Re-export backends from ADTypes export AutoChainRules export AutoDiffractor diff --git a/DifferentiationInterface/src/translation/differentiate_with.jl b/DifferentiationInterface/src/translation/differentiate_with.jl new file mode 100644 index 000000000..fe551c8f5 --- /dev/null +++ b/DifferentiationInterface/src/translation/differentiate_with.jl @@ -0,0 +1,54 @@ +""" + DifferentiateWith + +Callable function wrapper that enforces differentiation with a specified (inner) backend. + +This works by defining new rules overriding the behavior of the outer backend that would normally be used. + +!!! warning + This is an experimental functionality, whose API cannot yet be considered stable. + At the moment, it only supports one-argument functions, and rules are only defined for [ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl)-compatible outer backends. + +# Fields + +- `f`: the function in question +- `backend::AbstractADType`: the inner backend to use for differentiation + +# Constructor + + DifferentiateWith(f, backend) + +# Example + +```@repl +using DifferentiationInterface +import ForwardDiff, Zygote + +function f(x) + a = Vector{eltype(x)}(undef, 1) + a[1] = sum(x) # mutation that breaks Zygote + return a[1] +end + +dw = DifferentiateWith(f, AutoForwardDiff()); + +gradient(dw, AutoZygote(), [1.0, 2.0]) # works because it calls ForwardDiff instead +gradient(f, AutoZygote(), [1.0, 2.0]) # fails +``` +""" +struct DifferentiateWith{F,B<:AbstractADType} + f::F + backend::B +end + +""" + (dw::DifferentiateWith)(x) + +Call the underlying function `dw.f` of a [`DifferentiateWith`](@ref) wrapper. +""" +(dw::DifferentiateWith)(x) = dw.f(x) + +function Base.show(io::IO, dw::DifferentiateWith) + (; f, backend) = dw + return print(io, "$f differentiated with $(backend_str(backend))") +end diff --git a/DifferentiationInterface/src/utils/exceptions.jl b/DifferentiationInterface/src/utils/exceptions.jl index c6509e78f..f0dd08809 100644 --- a/DifferentiationInterface/src/utils/exceptions.jl +++ b/DifferentiationInterface/src/utils/exceptions.jl @@ -2,7 +2,7 @@ struct MissingBackendError <: Exception backend::AbstractADType end function Base.showerror(io::IO, e::MissingBackendError) - println(io, "failed to use $(backend_string(e.backend)) backend.") + println(io, "failed to use $(backend_str(e.backend)) backend.") if !check_available(e.backend) print( io, diff --git a/DifferentiationInterface/src/utils/printing.jl b/DifferentiationInterface/src/utils/printing.jl index 6a71c414c..63b9f5e46 100644 --- a/DifferentiationInterface/src/utils/printing.jl +++ b/DifferentiationInterface/src/utils/printing.jl @@ -15,11 +15,8 @@ backend_package_name(::AutoTracker) = "Tracker" backend_package_name(::AutoZygote) = "Zygote" backend_package_name(::AutoReverseDiff) = "ReverseDiff" -backend_string_aux(b::AbstractADType) = backend_package_name(b) -backend_string_aux(b::AutoReverseDiff) = "ReverseDiff$(b.compile ? "{compiled}" : "")" - -function backend_string(backend::AbstractADType) - bs = backend_string_aux(backend) +function backend_str(backend::AbstractADType) + bs = backend_package_name(backend) if mode(backend) isa ForwardMode return "$bs (forward)" elseif mode(backend) isa ReverseMode @@ -33,8 +30,8 @@ function backend_string(backend::AbstractADType) end end -backend_string(backend::AutoSparse) = "Sparse $(backend_string(dense_ad(backend)))" +backend_str(backend::AutoSparse) = "Sparse $(backend_str(dense_ad(backend)))" -function backend_string(backend::SecondOrder) - return "$(backend_string(outer(backend))) / $(backend_string(inner(backend)))" +function backend_str(backend::SecondOrder) + return "$(backend_str(outer(backend))) / $(backend_str(inner(backend)))" end diff --git a/DifferentiationInterface/test/chunk.jl b/DifferentiationInterface/test/chunk.jl index d7e1501d6..d52bbd3eb 100644 --- a/DifferentiationInterface/test/chunk.jl +++ b/DifferentiationInterface/test/chunk.jl @@ -1,9 +1,9 @@ -using DifferentiationInterface: pick_chunksize, DEFAULT_CHUNKSIZE - -@test pick_chunksize.(1:DEFAULT_CHUNKSIZE) == 1:DEFAULT_CHUNKSIZE +@test DI.pick_chunksize.(1:(DI.DEFAULT_CHUNKSIZE)) == 1:(DI.DEFAULT_CHUNKSIZE) @test all( - pick_chunksize.((DEFAULT_CHUNKSIZE + 1):(5DEFAULT_CHUNKSIZE)) .<= DEFAULT_CHUNKSIZE + DI.pick_chunksize.((DI.DEFAULT_CHUNKSIZE + 1):(5DI.DEFAULT_CHUNKSIZE)) .<= + DI.DEFAULT_CHUNKSIZE, ) @test all( - pick_chunksize.((DEFAULT_CHUNKSIZE + 1):(5DEFAULT_CHUNKSIZE)) .>= DEFAULT_CHUNKSIZE / 2 + DI.pick_chunksize.((DI.DEFAULT_CHUNKSIZE + 1):(5DI.DEFAULT_CHUNKSIZE)) .>= + DI.DEFAULT_CHUNKSIZE / 2, ) diff --git a/DifferentiationInterface/test/coloring.jl b/DifferentiationInterface/test/coloring.jl index 6e3c2d3f6..ddea1231d 100644 --- a/DifferentiationInterface/test/coloring.jl +++ b/DifferentiationInterface/test/coloring.jl @@ -1,26 +1,16 @@ -using ADTypes: column_coloring, row_coloring, symmetric_coloring -using DifferentiationInterface: - GreedyColoringAlgorithm, - check_structurally_orthogonal_columns, - check_structurally_orthogonal_rows, - check_symmetrically_structurally_orthogonal -using LinearAlgebra -using SparseArrays -using Test - -alg = GreedyColoringAlgorithm() +alg = DI.GreedyColoringAlgorithm() A = sprand(Bool, 100, 200, 0.1) -column_colors = column_coloring(A, alg) -@test check_structurally_orthogonal_columns(A, column_colors) +column_colors = ADTypes.column_coloring(A, alg) +@test DI.check_structurally_orthogonal_columns(A, column_colors) @test maximum(column_colors) < size(A, 2) ÷ 2 -row_colors = row_coloring(A, alg) -@test check_structurally_orthogonal_rows(A, row_colors) +row_colors = ADTypes.row_coloring(A, alg) +@test DI.check_structurally_orthogonal_rows(A, row_colors) @test maximum(row_colors) < size(A, 1) ÷ 2 S = Symmetric(sprand(Bool, 100, 100, 0.1)) + I -symmetric_colors = symmetric_coloring(S, alg) -@test check_symmetrically_structurally_orthogonal(S, symmetric_colors) +symmetric_colors = ADTypes.symmetric_coloring(S, alg) +@test DI.check_symmetrically_structurally_orthogonal(S, symmetric_colors) @test maximum(symmetric_colors) < size(A, 2) ÷ 2 diff --git a/DifferentiationInterface/test/differentiate_with.jl b/DifferentiationInterface/test/differentiate_with.jl new file mode 100644 index 000000000..487cea7d7 --- /dev/null +++ b/DifferentiationInterface/test/differentiate_with.jl @@ -0,0 +1,23 @@ +function zygote_breaking_scenarios() + onearg_scens = filter(default_scenarios()) do scen + DIT.nb_args(scen) == 1 + end + bad_onearg_scens = map(onearg_scens) do scen + function bad_f(x) + a = Vector{eltype(x)}(undef, 1) + a[1] = sum(x) + return scen.f(x) + end + wrapped_bad_f = DifferentiateWith(bad_f, AutoForwardDiff()) + bad_scen = DIT.change_function(scen, wrapped_bad_f) + return bad_scen + end + return bad_onearg_scens +end + +test_differentiation( + AutoZygote(), + zygote_breaking_scenarios(); + second_order=false, + logging=logging=get(ENV, "CI", "false") == "false",, +) diff --git a/DifferentiationInterface/test/runtests.jl b/DifferentiationInterface/test/runtests.jl index 03eeafef4..10e9a749e 100644 --- a/DifferentiationInterface/test/runtests.jl +++ b/DifferentiationInterface/test/runtests.jl @@ -23,9 +23,6 @@ include("test_imports.jl") Documenter.doctest(DifferentiationInterface) - @testset verbose = true "Exception handling" begin - include("test_exceptions.jl") - end @testset verbose = true "First order" begin include("first_order.jl") end @@ -34,14 +31,14 @@ include("test_imports.jl") include("second_order.jl") end - @testset verbose = true "Coloring" begin - include("coloring.jl") - end - @testset verbose = true "Sparsity" begin include("sparsity.jl") end + @testset verbose = true "DifferentiateWith" begin + include("differentiate_with.jl") + end + @testset verbose = true "Bonus round" begin @testset "Type stability" begin include("type_stability.jl") @@ -50,9 +47,19 @@ include("test_imports.jl") @testset "Weird arrays" begin include("weird_arrays.jl") end + end + + @testset verbose = true "Internals" begin + @testset verbose = true "Exception handling" begin + include("test_exceptions.jl") + end @testset "Chunks" begin include("chunk.jl") end + + @testset verbose = true "Coloring" begin + include("coloring.jl") + end end end; diff --git a/DifferentiationInterface/test/sparsity.jl b/DifferentiationInterface/test/sparsity.jl index 065cb6cf9..73f5ee750 100644 --- a/DifferentiationInterface/test/sparsity.jl +++ b/DifferentiationInterface/test/sparsity.jl @@ -1,5 +1,5 @@ -coloring_algorithm = DifferentiationInterface.GreedyColoringAlgorithm() -sparsity_detector = DifferentiationInterface.SymbolicsSparsityDetector() +coloring_algorithm = DI.GreedyColoringAlgorithm() +sparsity_detector = DI.SymbolicsSparsityDetector() sparse_backends = [ AutoSparse(AutoFastDifferentiation()), diff --git a/DifferentiationInterface/test/test_exceptions.jl b/DifferentiationInterface/test/test_exceptions.jl index 5ac9cea60..5fa719f95 100644 --- a/DifferentiationInterface/test/test_exceptions.jl +++ b/DifferentiationInterface/test/test_exceptions.jl @@ -1,5 +1,3 @@ -using DifferentiationInterface: MissingBackendError - """ AutoBrokenForward <: ADTypes.AbstractADType @@ -25,9 +23,9 @@ DifferentiationInterface.check_available(::AutoBrokenReverse) = true f(x::AbstractArray) = sum(abs2, x) x = [1.0, 2.0, 3.0] - @test_throws MissingBackendError gradient(f, AutoBrokenForward(), x) - @test_throws MissingBackendError gradient(f, AutoBrokenReverse(), x) + @test_throws DI.MissingBackendError gradient(f, AutoBrokenForward(), x) + @test_throws DI.MissingBackendError gradient(f, AutoBrokenReverse(), x) - @test_throws MissingBackendError hvp(f, AutoBrokenForward(), x, x) - @test_throws MissingBackendError hvp(f, AutoBrokenReverse(), x, x) + @test_throws DI.MissingBackendError hvp(f, AutoBrokenForward(), x, x) + @test_throws DI.MissingBackendError hvp(f, AutoBrokenReverse(), x, x) end diff --git a/DifferentiationInterface/test/test_imports.jl b/DifferentiationInterface/test/test_imports.jl index 7c0a3cc15..7b8595274 100644 --- a/DifferentiationInterface/test/test_imports.jl +++ b/DifferentiationInterface/test/test_imports.jl @@ -9,6 +9,8 @@ Pkg.develop( using ADTypes using DifferentiationInterface using DifferentiationInterfaceTest +import DifferentiationInterface as DI +import DifferentiationInterfaceTest as DIT using Aqua: Aqua using Documenter: Documenter @@ -16,7 +18,8 @@ using JET: JET using JuliaFormatter: JuliaFormatter using Test -using SparseArrays: SparseArrays +using LinearAlgebra +using SparseArrays ## diff --git a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl index ae28c6e00..09d2f6688 100644 --- a/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl +++ b/DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl @@ -22,7 +22,7 @@ using Chairmarks: @be, Benchmark, Sample using ComponentArrays: ComponentVector using DifferentiationInterface using DifferentiationInterface: - backend_string, + backend_str, inner, mode, outer, diff --git a/DifferentiationInterfaceTest/src/scenarios/scenario.jl b/DifferentiationInterfaceTest/src/scenarios/scenario.jl index a172921db..8e6594b47 100644 --- a/DifferentiationInterfaceTest/src/scenarios/scenario.jl +++ b/DifferentiationInterfaceTest/src/scenarios/scenario.jl @@ -54,6 +54,13 @@ function compatible(backend::AbstractADType, scen::AbstractScenario) return true end +function group_by_scen_type(scenarios) + return Dict( + st => filter(s -> scen_type(s) == st, scenarios) for + st in unique(scen_type.(scenarios)) + ) +end + function Base.string(scen::S) where {args,op,F,X,Y,S<:AbstractScenario{args,op,F,X,Y}} return "$(S.name.name){$args,$op} $(string(scen.f)) : $X -> $Y" end @@ -220,6 +227,16 @@ for S in ( end return ($S){args,operator,F,X,typeof(y),R}(f, x, y, ref) end + + function change_function(s::($S), f) + return ($S)( + f; + x=s.x, + y=(nb_args(s) == 1 ? nothing : s.y), + ref=s.ref, + operator=operator_place(s), + ) + end end end @@ -239,6 +256,16 @@ for S in (:PushforwardScenario, :HVPScenario) end return ($S){args,operator,F,X,typeof(y),typeof(dx),R}(f, x, y, dx, ref) end + + function change_function(s::($S), f) + return ($S)( + f; + x=s.x, + y=(nb_args(s) == 1 ? nothing : s.y), + ref=s.ref, + operator=operator_place(s), + ) + end end end @@ -258,5 +285,15 @@ for S in (:PullbackScenario,) end return ($S){args,operator,F,X,typeof(y),typeof(dy),R}(f, x, y, dy, ref) end + + function change_function(s::($S), f) + return ($S)( + f; + x=s.x, + y=(nb_args(s) == 1 ? nothing : s.y), + ref=s.ref, + operator=operator_place(s), + ) + end end end diff --git a/DifferentiationInterfaceTest/src/test_differentiation.jl b/DifferentiationInterfaceTest/src/test_differentiation.jl index 09b22a064..0a5ffcda8 100644 --- a/DifferentiationInterfaceTest/src/test_differentiation.jl +++ b/DifferentiationInterfaceTest/src/test_differentiation.jl @@ -67,33 +67,37 @@ function test_differentiation( prog = ProgressUnknown(; desc="$title", spinner=true, enabled=logging) @testset verbose = true "$title" begin - @testset verbose = detailed "$(backend_string(backend))" for (i, backend) in - enumerate(backends) + @testset verbose = detailed "$(backend_str(backend))" for (i, backend) in + enumerate(backends) filtered_scenarios = filter(s -> compatible(backend, s), scenarios) - @testset "$scen" for (j, scen) in enumerate(filtered_scenarios) - next!( - prog; - showvalues=[ - (:backend, "$(backend_string(backend)) - $i/$(length(backends))"), - ( - :scenario, - "$(scen_type(scen)) - $j/$(length(filtered_scenarios))", - ), - (:arguments, nb_args(scen)), - (:operator, operator_place(scen)), - (:function, scen.f), - (:input, typeof(scen.x)), - (:output, typeof(scen.y)), - ], - ) - correctness && @testset "Correctness" begin - test_correctness(backend, scen; isapprox, atol, rtol, ref_backend) - end - type_stability && @testset "Type stability" begin - test_jet(backend, scen; ref_backend) - end - sparsity && @testset "Sparsity" begin - test_sparsity(backend, scen; ref_backend) + grouped_scenarios = group_by_scen_type(filtered_scenarios) + @testset verbose = detailed "$st" for (j, (st, st_group)) in + enumerate(pairs(grouped_scenarios)) + @testset "$scen" for (k, scen) in enumerate(st_group) + next!( + prog; + showvalues=[ + (:backend, "$(backend_str(backend)) - $i/$(length(backends))"), + (:scenario_type, "$st - $j/$(length(grouped_scenarios))"), + (:scenario, "$k/$(length(st_group))"), + (:arguments, nb_args(scen)), + (:operator, operator_place(scen)), + (:function, scen.f), + (:input_type, typeof(scen.x)), + (:input_size, size(scen.x)), + (:output_type, typeof(scen.y)), + (:output_size, size(scen.y)), + ], + ) + correctness && @testset "Correctness" begin + test_correctness(backend, scen; isapprox, atol, rtol, ref_backend) + end + type_stability && @testset "Type stability" begin + test_jet(backend, scen; ref_backend) + end + sparsity && @testset "Sparsity" begin + test_sparsity(backend, scen; ref_backend) + end end end end @@ -140,20 +144,26 @@ function benchmark_differentiation( prog = ProgressUnknown(; desc="Benchmarking", spinner=true, enabled=logging) for (i, backend) in enumerate(backends) filtered_scenarios = filter(s -> compatible(backend, s), scenarios) - for (j, scen) in enumerate(filtered_scenarios) - next!( - prog; - showvalues=[ - (:backend, "$(backend_string(backend)) - $i/$(length(backends))"), - (:scenario, "$(scen_type(scen)) - $j/$(length(filtered_scenarios))"), - (:arguments, nb_args(scen)), - (:operator, operator_place(scen)), - (:function, scen.f), - (:input, typeof(scen.x)), - (:output, typeof(scen.y)), - ], - ) - run_benchmark!(benchmark_data, backend, scen) + grouped_scenarios = group_by_scen_type(filtered_scenarios) + for (j, (st, st_group)) in enumerate(pairs(grouped_scenarios)) + for (k, scen) in enumerate(st_group) + next!( + prog; + showvalues=[ + (:backend, "$(backend_str(backend)) - $i/$(length(backends))"), + (:scenario_type, "$st - $j/$(length(grouped_scenarios))"), + (:scenario, "$k/$(length(st_group))"), + (:arguments, nb_args(scen)), + (:operator, operator_place(scen)), + (:function, scen.f), + (:input_type, typeof(scen.x)), + (:input_size, size(scen.x)), + (:output_type, typeof(scen.y)), + (:output_size, size(scen.y)), + ], + ) + run_benchmark!(benchmark_data, backend, scen) + end end end return benchmark_data diff --git a/DifferentiationInterfaceTest/src/tests/benchmark.jl b/DifferentiationInterfaceTest/src/tests/benchmark.jl index 88560fa31..806673de3 100644 --- a/DifferentiationInterfaceTest/src/tests/benchmark.jl +++ b/DifferentiationInterfaceTest/src/tests/benchmark.jl @@ -96,7 +96,7 @@ function record!( ) bench_min = minimum(bench) row = BenchmarkDataRow(; - backend=backend_string(backend), + backend=backend_str(backend), mode=mode(backend), scenario=typeof(scenario).name.name, operator=Symbol(operator), From 340bb958646d84628288174dea6d44d8115a9e77 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 26 Apr 2024 07:45:41 +0200 Subject: [PATCH 2/2] Fix parsing --- DifferentiationInterface/test/differentiate_with.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DifferentiationInterface/test/differentiate_with.jl b/DifferentiationInterface/test/differentiate_with.jl index 487cea7d7..1a2b8eaea 100644 --- a/DifferentiationInterface/test/differentiate_with.jl +++ b/DifferentiationInterface/test/differentiate_with.jl @@ -19,5 +19,5 @@ test_differentiation( AutoZygote(), zygote_breaking_scenarios(); second_order=false, - logging=logging=get(ENV, "CI", "false") == "false",, + logging=logging = get(ENV, "CI", "false") == "false", )