From 013b68fd1665e9d397e60808e509e14b82d6a8b5 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 14 Aug 2024 15:28:57 +0200 Subject: [PATCH] Add back support for ADTypes < 1.5 (#35) --- Project.toml | 2 +- ext/LogDensityProblemsADADTypesExt.jl | 14 ++++++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 316ee49..520cb34 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LogDensityProblemsAD" uuid = "996a588d-648d-4e1f-a8f0-a84b347e47b1" authors = ["Tamás K. Papp "] -version = "1.9.1" +version = "1.9.2" [deps] DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" diff --git a/ext/LogDensityProblemsADADTypesExt.jl b/ext/LogDensityProblemsADADTypesExt.jl index 24f4562..fbedfb6 100644 --- a/ext/LogDensityProblemsADADTypesExt.jl +++ b/ext/LogDensityProblemsADADTypesExt.jl @@ -37,8 +37,18 @@ function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoForwardDiff{C}, ℓ) wh end end -function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff{T}, ℓ) where {T} - return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile = Val(T)) +# ADTypes 1.5 introduced a type parameter for `AutoReverseDiff` and deprecated the +# `compile` field +# Since Julia < 1.9 uses Requires which does not respect the ADTypes compat entry, +# we keep the version for ADTypes < 1.5 as well +@static if ADTypes.AutoReverseDiff isa UnionAll + function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff{T}, ℓ) where {T} + return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile = Val(T)) + end +else + function LogDensityProblemsAD.ADgradient(ad::ADTypes.AutoReverseDiff, ℓ) + return LogDensityProblemsAD.ADgradient(Val(:ReverseDiff), ℓ; compile = Val(ad.compile)) + end end function LogDensityProblemsAD.ADgradient(::ADTypes.AutoTracker, ℓ)