From 7adeeb07d533914b97977175809602fd2652737a Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Jun 2024 12:47:15 +0200 Subject: [PATCH 1/4] Add Symbol -> AbstractADType mapping --- Project.toml | 2 +- docs/src/index.md | 6 ++++++ src/ADTypes.jl | 1 + src/symbols.jl | 42 ++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 3 +++ test/symbols.jl | 17 +++++++++++++++++ 6 files changed, 70 insertions(+), 1 deletion(-) create mode 100644 src/symbols.jl create mode 100644 test/symbols.jl diff --git a/Project.toml b/Project.toml index 2945298..8740ef9 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.3.0" +version = "1.3.1" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/docs/src/index.md b/docs/src/index.md index 5154899..9640a33 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -92,3 +92,9 @@ ADTypes.ForwardOrReverseMode ADTypes.ReverseMode ADTypes.SymbolicMode ``` + +## Miscellaneous + +```@docs +ADTypes.Auto +``` diff --git a/src/ADTypes.jl b/src/ADTypes.jl index 7fefb2b..56ebaae 100644 --- a/src/ADTypes.jl +++ b/src/ADTypes.jl @@ -20,6 +20,7 @@ include("mode.jl") include("dense.jl") include("sparse.jl") include("legacy.jl") +include("symbols.jl") if !isdefined(Base, :get_extension) include("../ext/ADTypesChainRulesCoreExt.jl") diff --git a/src/symbols.jl b/src/symbols.jl new file mode 100644 index 0000000..d6f7ab5 --- /dev/null +++ b/src/symbols.jl @@ -0,0 +1,42 @@ +""" + ADTypes.Auto(package::Symbol) + +A shortcut that converts an AD package name into an instance of [`AbstractADType`](@ref), with all parameters set to their default values. + +!!! warning + + This function is type-unstable by design and might lead to suboptimal performance. + In most cases, you should never need it: use the individual backend types directly. + +# Example + +```jldoctest +import ADTypes +backend = ADTypes.Auto(:Zygote) + +# output + +ADTypes.AutoZygote() +``` +""" +Auto(package::Symbol) = Auto(Val(package)) + +Auto(::Val{:Diffractor}) = AutoDiffractor() +Auto(::Val{:Enzyme}) = AutoEnzyme() +Auto(::Val{:FastDifferentiation}) = AutoFastDifferentiation() +Auto(::Val{:FiniteDiff}) = AutoFiniteDiff() +Auto(::Val{:ForwardDiff}) = AutoForwardDiff() +Auto(::Val{:PolyesterForwardDiff}) = AutoPolyesterForwardDiff() +Auto(::Val{:ReverseDiff}) = AutoReverseDiff() +Auto(::Val{:Symbolics}) = AutoSymbolics() +Auto(::Val{:Tapir}) = AutoTapir() +Auto(::Val{:Tracker}) = AutoTracker() +Auto(::Val{:Zygote}) = AutoZygote() + +function Auto(::Val{:ChainRules}) + throw(ArgumentError("ChainRules backend has mandatory arguments")) +end + +function Auto(::Val{:FiniteDifferences}) + throw(ArgumentError("FiniteDifferences backend has mandatory arguments")) +end diff --git a/test/runtests.jl b/test/runtests.jl index 23d0d96..e7d72f9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -66,6 +66,9 @@ end @testset "Sparse" begin include("sparse.jl") end + @testset "Symbols" begin + include("symbols.jl") + end @testset "Legacy" begin include("legacy.jl") end diff --git a/test/symbols.jl b/test/symbols.jl new file mode 100644 index 0000000..9937c9a --- /dev/null +++ b/test/symbols.jl @@ -0,0 +1,17 @@ +using ADTypes +using Test + +@test ADTypes.Auto(:Diffractor) isa AutoDiffractor +@test ADTypes.Auto(:Enzyme) isa AutoEnzyme +@test ADTypes.Auto(:FastDifferentiation) isa AutoFastDifferentiation +@test ADTypes.Auto(:FiniteDiff) isa AutoFiniteDiff +@test ADTypes.Auto(:ForwardDiff) isa AutoForwardDiff +@test ADTypes.Auto(:PolyesterForwardDiff) isa AutoPolyesterForwardDiff +@test ADTypes.Auto(:ReverseDiff) isa AutoReverseDiff +@test ADTypes.Auto(:Symbolics) isa AutoSymbolics +@test ADTypes.Auto(:Tapir) isa AutoTapir +@test ADTypes.Auto(:Tracker) isa AutoTracker +@test ADTypes.Auto(:Zygote) isa AutoZygote + +@test_throws ArgumentError ADTypes.Auto(:ChainRules) +@test_throws ArgumentError ADTypes.Auto(:FiniteDifferences) From 2fa94c9d29ce4880cff0e8bb222308071e959152 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Jun 2024 17:00:56 +0200 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: Vaibhav Kumar Dixit Co-authored-by: Miles Cranmer --- Project.toml | 2 +- src/symbols.jl | 25 ++++++------------------- 2 files changed, 7 insertions(+), 20 deletions(-) diff --git a/Project.toml b/Project.toml index 8740ef9..c54c8ca 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.3.1" +version = "1.4.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/symbols.jl b/src/symbols.jl index d6f7ab5..3d4ec8e 100644 --- a/src/symbols.jl +++ b/src/symbols.jl @@ -19,24 +19,11 @@ backend = ADTypes.Auto(:Zygote) ADTypes.AutoZygote() ``` """ -Auto(package::Symbol) = Auto(Val(package)) - -Auto(::Val{:Diffractor}) = AutoDiffractor() -Auto(::Val{:Enzyme}) = AutoEnzyme() -Auto(::Val{:FastDifferentiation}) = AutoFastDifferentiation() -Auto(::Val{:FiniteDiff}) = AutoFiniteDiff() -Auto(::Val{:ForwardDiff}) = AutoForwardDiff() -Auto(::Val{:PolyesterForwardDiff}) = AutoPolyesterForwardDiff() -Auto(::Val{:ReverseDiff}) = AutoReverseDiff() -Auto(::Val{:Symbolics}) = AutoSymbolics() -Auto(::Val{:Tapir}) = AutoTapir() -Auto(::Val{:Tracker}) = AutoTracker() -Auto(::Val{:Zygote}) = AutoZygote() - -function Auto(::Val{:ChainRules}) - throw(ArgumentError("ChainRules backend has mandatory arguments")) -end +Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...) -function Auto(::Val{:FiniteDifferences}) - throw(ArgumentError("FiniteDifferences backend has mandatory arguments")) +for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation, + :FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff, + :ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote) + @eval Auto(::Val{$backend}, args...; kws...) = $(Symbol(:Auto, backend))(args...; kws...) end + From d25eacc4e30c0fa8f0b93a08dd2a39c4ed02b814 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Jun 2024 17:16:10 +0200 Subject: [PATCH 3/4] Apply suggestions from code review Co-authored-by: Miles Cranmer --- src/symbols.jl | 2 +- test/symbols.jl | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/symbols.jl b/src/symbols.jl index 3d4ec8e..f349e84 100644 --- a/src/symbols.jl +++ b/src/symbols.jl @@ -24,6 +24,6 @@ Auto(package::Symbol, args...; kws...) = Auto(Val(package), args...; kws...) for backend in (:ChainRules, :Diffractor, :Enzyme, :FastDifferentiation, :FiniteDiff, :FiniteDifferences, :ForwardDiff, :PolyesterForwardDiff, :ReverseDiff, :Symbolics, :Tapir, :Tracker, :Zygote) - @eval Auto(::Val{$backend}, args...; kws...) = $(Symbol(:Auto, backend))(args...; kws...) + @eval Auto(::Val{$(QuoteNode(backend))}, args...; kws...) = $(Symbol(:Auto, backend))(args...; kws...) end diff --git a/test/symbols.jl b/test/symbols.jl index 9937c9a..72743bd 100644 --- a/test/symbols.jl +++ b/test/symbols.jl @@ -1,10 +1,12 @@ using ADTypes using Test +@test ADTypes.Auto(:ChainRules, 1) isa AutoChainRules{Int64} @test ADTypes.Auto(:Diffractor) isa AutoDiffractor @test ADTypes.Auto(:Enzyme) isa AutoEnzyme @test ADTypes.Auto(:FastDifferentiation) isa AutoFastDifferentiation @test ADTypes.Auto(:FiniteDiff) isa AutoFiniteDiff +@test ADTypes.Auto(:FiniteDifferences, 1.0) isa AutoFiniteDifferences{Float64} @test ADTypes.Auto(:ForwardDiff) isa AutoForwardDiff @test ADTypes.Auto(:PolyesterForwardDiff) isa AutoPolyesterForwardDiff @test ADTypes.Auto(:ReverseDiff) isa AutoReverseDiff @@ -12,6 +14,3 @@ using Test @test ADTypes.Auto(:Tapir) isa AutoTapir @test ADTypes.Auto(:Tracker) isa AutoTracker @test ADTypes.Auto(:Zygote) isa AutoZygote - -@test_throws ArgumentError ADTypes.Auto(:ChainRules) -@test_throws ArgumentError ADTypes.Auto(:FiniteDifferences) From aa6853e40e213148dbaaf87784cc675ef95c473e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 17 Jun 2024 20:03:14 +0200 Subject: [PATCH 4/4] Add exception tests --- test/symbols.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/symbols.jl b/test/symbols.jl index 72743bd..bc39c7e 100644 --- a/test/symbols.jl +++ b/test/symbols.jl @@ -14,3 +14,7 @@ using Test @test ADTypes.Auto(:Tapir) isa AutoTapir @test ADTypes.Auto(:Tracker) isa AutoTracker @test ADTypes.Auto(:Zygote) isa AutoZygote + +@test_throws MethodError ADTypes.Auto(:ThisPackageDoesNotExist) +@test_throws UndefKeywordError ADTypes.Auto(:ChainRules) +@test_throws UndefKeywordError ADTypes.Auto(:FiniteDifferences)