From 59ceff38b010a8793ee2be4164c9941b73f94057 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 27 Jun 2024 06:52:01 +0200 Subject: [PATCH] Fix pretty printing and ReverseDiff constructor (#67) --- Project.toml | 2 +- src/dense.jl | 85 +++++++++++++++++--------------------------------- src/legacy.jl | 2 ++ src/sparse.jl | 10 +++--- test/legacy.jl | 5 +++ test/misc.jl | 38 ++++++++++++++++++++++ 6 files changed, 80 insertions(+), 62 deletions(-) diff --git a/Project.toml b/Project.toml index d9ea062..9557b76 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = [ "Vaibhav Dixit , Guillaume Dalle and contributors", ] -version = "1.5.1" +version = "1.5.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/dense.jl b/src/dense.jl index 7958a67..dd7e8e5 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -20,7 +20,7 @@ end mode(::AutoChainRules) = ForwardOrReverseMode() # specialized in the extension function Base.show(io::IO, backend::AutoChainRules) - print(io, "AutoChainRules(ruleconfig=$(repr(backend.ruleconfig, context=io)))") + print(io, AutoChainRules, "(ruleconfig=", repr(backend.ruleconfig; context = io), ")") end """ @@ -63,11 +63,9 @@ end mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension function Base.show(io::IO, backend::AutoEnzyme) - if isnothing(backend.mode) - print(io, "AutoEnzyme()") - else - print(io, "AutoEnzyme(mode=$(repr(backend.mode, context=io)))") - end + print(io, AutoEnzyme, "(") + !isnothing(backend.mode) && print(io, "mode=", repr(backend.mode; context = io)) + print(io, ")") end """ @@ -111,21 +109,14 @@ end mode(::AutoFiniteDiff) = ForwardMode() function Base.show(io::IO, backend::AutoFiniteDiff) - s = "AutoFiniteDiff(" - if backend.fdtype != Val(:forward) - s *= "fdtype=$(repr(backend.fdtype, context=io)), " - end - if backend.fdjtype != backend.fdtype - s *= "fdjtype=$(repr(backend.fdjtype, context=io)), " - end - if backend.fdhtype != Val(:hcentral) - s *= "fdhtype=$(repr(backend.fdhtype, context=io)), " - end - if endswith(s, ", ") - s = s[1:(end - 2)] - end - s *= ")" - print(io, s) + print(io, AutoFiniteDiff, "(") + backend.fdtype != Val(:forward) && + print(io, "fdtype=", repr(backend.fdtype; context = io), ", ") + backend.fdjtype != backend.fdtype && + print(io, "fdjtype=", repr(backend.fdjtype; context = io), ", ") + backend.fdhtype != Val(:hcentral) && + print(io, "fdhtype=", repr(backend.fdhtype; context = io)) + print(io, ")") end """ @@ -150,7 +141,7 @@ end mode(::AutoFiniteDifferences) = ForwardMode() function Base.show(io::IO, backend::AutoFiniteDifferences) - print(io, "AutoFiniteDifferences(fdm=$(repr(backend.fdm, context=io)))") + print(io, AutoFiniteDifferences, "(fdm=", repr(backend.fdm; context = io), ")") end """ @@ -183,18 +174,11 @@ end mode(::AutoForwardDiff) = ForwardMode() function Base.show(io::IO, backend::AutoForwardDiff{chunksize}) where {chunksize} - s = "AutoForwardDiff(" - if chunksize !== nothing - s *= "chunksize=$chunksize, " - end - if backend.tag !== nothing - s *= "tag=$(repr(backend.tag, context=io)), " - end - if endswith(s, ", ") - s = s[1:(end - 2)] - end - s *= ")" - print(io, s) + print(io, AutoForwardDiff, "(") + chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io), + (backend.tag !== nothing ? ", " : "")) + backend.tag !== nothing && print(io, "tag=", repr(backend.tag; context = io)) + print(io, ")") end """ @@ -227,18 +211,11 @@ end mode(::AutoPolyesterForwardDiff) = ForwardMode() function Base.show(io::IO, backend::AutoPolyesterForwardDiff{chunksize}) where {chunksize} - s = "AutoPolyesterForwardDiff(" - if chunksize !== nothing - s *= "chunksize=$chunksize, " - end - if backend.tag !== nothing - s *= "tag=$(repr(backend.tag, context=io)), " - end - if endswith(s, ", ") - s = s[1:(end - 2)] - end - s *= ")" - print(io, s) + print(io, AutoPolyesterForwardDiff, "(") + chunksize !== nothing && print(io, "chunksize=", repr(chunksize; context = io), + (backend.tag !== nothing ? ", " : "")) + backend.tag !== nothing && print(io, "tag=", repr(backend.tag; context = io)) + print(io, ")") end """ @@ -277,11 +254,9 @@ end mode(::AutoReverseDiff) = ReverseMode() function Base.show(io::IO, ::AutoReverseDiff{compile}) where {compile} - if !compile - print(io, "AutoReverseDiff()") - else - print(io, "AutoReverseDiff(compile=true)") - end + print(io, AutoReverseDiff, "(") + compile && print(io, "compile=true") + print(io, ")") end """ @@ -321,11 +296,9 @@ end mode(::AutoTapir) = ReverseMode() function Base.show(io::IO, backend::AutoTapir) - if backend.safe_mode - print(io, "AutoTapir()") - else - print(io, "AutoTapir(safe_mode=false)") - end + print(io, AutoTapir, "(") + !(backend.safe_mode) && print(io, "safe_mode=false") + print(io, ")") end """ diff --git a/src/legacy.jl b/src/legacy.jl index 5784399..91d8acc 100644 --- a/src/legacy.jl +++ b/src/legacy.jl @@ -11,6 +11,8 @@ @deprecate AutoSparseZygote() AutoSparse(AutoZygote()) +@deprecate AutoReverseDiff(compile) AutoReverseDiff(; compile) + function mtk_to_symbolics(obj_sparse::Bool, cons_sparse::Bool) if obj_sparse || cons_sparse return AutoSparse(AutoSymbolics()) diff --git a/src/sparse.jl b/src/sparse.jl index 85a8e97..ffd157a 100644 --- a/src/sparse.jl +++ b/src/sparse.jl @@ -155,15 +155,15 @@ function AutoSparse( end function Base.show(io::IO, backend::AutoSparse) - s = "AutoSparse(dense_ad=$(repr(backend.dense_ad, context=io)), " + print(io, AutoSparse, "(dense_ad=", repr(backend.dense_ad, context = io)) if backend.sparsity_detector != NoSparsityDetector() - s *= "sparsity_detector=$(repr(backend.sparsity_detector, context=io)), " + print(io, ", sparsity_detector=", repr(backend.sparsity_detector, context = io)) end if backend.coloring_algorithm != NoColoringAlgorithm() - s *= "coloring_algorithm=$(repr(backend.coloring_algorithm, context=io))), " + print( + io, ", coloring_algorithm=", repr(backend.coloring_algorithm, context = io)) end - s = s[1:(end - 2)] * ")" - print(io, s) + print(io, ")") end """ diff --git a/test/legacy.jl b/test/legacy.jl index 08f09d9..d4f2076 100644 --- a/test/legacy.jl +++ b/test/legacy.jl @@ -58,3 +58,8 @@ end @test ad isa AbstractADType @test dense_ad(ad) isa AutoZygote end + +@testset "AutoReverseDiff without kwarg" begin + ad = @test_deprecated AutoReverseDiff(true) + @test ad.compile +end diff --git a/test/misc.jl b/test/misc.jl index a1b2274..0ca1ddd 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -7,6 +7,7 @@ end @testset "Printing" begin for ad in every_ad_with_options() @test startswith(string(ad), "Auto") + @test contains(string(ad), "(") @test endswith(string(ad), ")") end @@ -19,3 +20,40 @@ end @test contains(string(sparse_backend1), string(AutoForwardDiff())) @test length(string(sparse_backend1)) < length(string(sparse_backend2)) end + +import ADTypes + +struct FakeSparsityDetector <: ADTypes.AbstractSparsityDetector end +struct FakeColoringAlgorithm <: ADTypes.AbstractColoringAlgorithm end + +for backend in [ + # dense + ADTypes.AutoChainRules(; ruleconfig = :rc), + ADTypes.AutoDiffractor(), + ADTypes.AutoEnzyme(), + ADTypes.AutoEnzyme(mode = :forward), + ADTypes.AutoFastDifferentiation(), + ADTypes.AutoFiniteDiff(), + ADTypes.AutoFiniteDiff(fdtype = :fd, fdjtype = :fdj, fdhtype = :fdh), + ADTypes.AutoFiniteDifferences(; fdm = :fdm), + ADTypes.AutoForwardDiff(), + ADTypes.AutoForwardDiff(chunksize = 3, tag = :tag), + ADTypes.AutoPolyesterForwardDiff(), + ADTypes.AutoPolyesterForwardDiff(chunksize = 3, tag = :tag), + ADTypes.AutoReverseDiff(), + ADTypes.AutoReverseDiff(compile = true), + ADTypes.AutoSymbolics(), + ADTypes.AutoTapir(), + ADTypes.AutoTapir(safe_mode = false), + ADTypes.AutoTracker(), + ADTypes.AutoZygote(), + # sparse + ADTypes.AutoSparse(ADTypes.AutoForwardDiff()), + ADTypes.AutoSparse( + ADTypes.AutoForwardDiff(); + sparsity_detector = FakeSparsityDetector(), + coloring_algorithm = FakeColoringAlgorithm() + ) +] + println(backend) +end