diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 224f81d4..21b3b6a3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -19,7 +19,7 @@ jobs: fail-fast: false matrix: version: - - '1.7' + #- '1.7' - '1.10' os: - ubuntu-latest diff --git a/Project.toml b/Project.toml index fff721f2..d218a160 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -32,6 +33,7 @@ AdvancedVIBijectorsExt = "Bijectors" AdvancedVIEnzymeExt = "Enzyme" AdvancedVIForwardDiffExt = "ForwardDiff" AdvancedVIReverseDiffExt = "ReverseDiff" +AdvancedVITapirExt = "Tapir" AdvancedVIZygoteExt = "Zygote" [compat] @@ -55,6 +57,7 @@ Requires = "1.0" ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" +Tapir = "0.2.34" Zygote = "0.6.63" julia = "1.7" @@ -64,6 +67,7 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl index 6904fa7a..d9d255be 100644 --- a/ext/AdvancedVIForwardDiffExt.jl +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -15,8 +15,9 @@ getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize function AdvancedVI.value_and_gradient!( ad::ADTypes.AutoForwardDiff, + ::Any, f, - x::AbstractVector{<:Real}, + x::AbstractVector, out::DiffResults.MutableDiffResult, ) chunk_size = getchunksize(ad) @@ -31,12 +32,13 @@ end function AdvancedVI.value_and_gradient!( ad::ADTypes.AutoForwardDiff, + st_ad, f, x::AbstractVector, aux, out::DiffResults.MutableDiffResult, ) - return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) + return AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) end end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl index 9cde91a1..c0a87502 100644 --- a/ext/AdvancedVIReverseDiffExt.jl +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -13,7 +13,8 @@ end # ReverseDiff without compiled tape function AdvancedVI.value_and_gradient!( - ad::ADTypes.AutoReverseDiff, + ::ADTypes.AutoReverseDiff, + ::Any, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult, @@ -25,12 +26,13 @@ end function AdvancedVI.value_and_gradient!( ad::ADTypes.AutoReverseDiff, + st_ad, f, x::AbstractVector{<:Real}, aux, out::DiffResults.MutableDiffResult, ) - return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) + return AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) end end diff --git a/ext/AdvancedVITapirExt.jl b/ext/AdvancedVITapirExt.jl new file mode 100644 index 00000000..b2e00af9 --- /dev/null +++ b/ext/AdvancedVITapirExt.jl @@ -0,0 +1,47 @@ + +module AdvancedVITapirExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults + using Tapir +else + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults + using ..Tapir +end + +AdvancedVI.init_adbackend(::ADTypes.AutoTapir, f, x) = Tapir.build_rrule(f, x) + +function AdvancedVI.value_and_gradient!( + ::ADTypes.AutoTapir, + st_ad, + f, + x::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, +) + rule = st_ad + y, g = Tapir.value_and_gradient!!(rule, f, x) + DiffResults.value!(out, y) + DiffResults.gradient!(out, last(g)) + return out +end + +AdvancedVI.init_adbackend(::ADTypes.AutoTapir, f, x, aux) = Tapir.build_rrule(f, x, aux) + +function AdvancedVI.value_and_gradient!( + ::ADTypes.AutoTapir, + st_ad, + f, + x::AbstractVector{<:Real}, + aux, + out::DiffResults.MutableDiffResult, +) + rule = st_ad + y, g = Tapir.value_and_gradient!!(rule, f, x, aux) + DiffResults.value!(out, y) + DiffResults.gradient!(out, g[2]) + return out +end + +end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl index 2cdd8392..3a477a24 100644 --- a/ext/AdvancedVIZygoteExt.jl +++ b/ext/AdvancedVIZygoteExt.jl @@ -14,7 +14,11 @@ else end function AdvancedVI.value_and_gradient!( - ::ADTypes.AutoZygote, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult + ::ADTypes.AutoZygote, + ::Any, + f, + x::AbstractVector{<:Real}, + out::DiffResults.MutableDiffResult, ) y, back = Zygote.pullback(f, x) ∇x = back(one(y)) @@ -25,12 +29,13 @@ end function AdvancedVI.value_and_gradient!( ad::ADTypes.AutoZygote, + st_ad, f, x::AbstractVector{<:Real}, aux, out::DiffResults.MutableDiffResult, ) - return AdvancedVI.value_and_gradient!(ad, x′ -> f(x′, aux), x, out) + return AdvancedVI.value_and_gradient!(ad, st_ad, x′ -> f(x′, aux), x, out) end end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 8ac1b645..8b9d7921 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -25,14 +25,15 @@ using StatsBase # derivatives """ - value_and_gradient!(ad, f, x, out) - value_and_gradient!(ad, f, x, aux, out) + value_and_gradient!(adtype, ad_st, f, x, out) + value_and_gradient!(adtype, ad_st, f, x, aux, out) -Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation backend `ad` and store the result in `out`. +Evaluate the value and gradient of a function `f` at `x` using the automatic differentiation (AD) backend `ad` and store the result in `out`. `f` may receive auxiliary input as `f(x,aux)`. # Arguments -- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. +- `adtype::ADTypes.AbstractADType`: AD backend. +- `ad_st`: State used by the AD backend. (This will often be pre-compiled tapes/caches.) - `f`: Function subject to differentiation. - `x`: The point to evaluate the gradient. - `aux`: Auxiliary input passed to `f`. @@ -40,6 +41,21 @@ Evaluate the value and gradient of a function `f` at `x` using the automatic dif """ function value_and_gradient! end +""" + init_adbackend(adtype, f, x) + init_adbackend(adtype, f, x, aux) +Initialize the AD backend and setup necessary states. +# Arguments +- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. +- `f`: Function subject to differentiation. +- `x`: The point to evaluate the gradient. +- `aux`: Auxiliary input passed to `f`. +# Returns +- `ad_st`: State of the AD backend. (This will often be pre-compiled tapes/caches.) +""" +init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any) = nothing +init_adbackend(::ADTypes.AbstractADType, ::Any, ::Any, ::Any) = nothing + """ restructure_ad_forward(adtype, restructure, params) @@ -95,18 +111,22 @@ If the estimator is stateful, it can implement `init` to initialize the state. abstract type AbstractVariationalObjective end """ - init(rng, obj, prob, params, restructure) + init(rng, obj, adtype, prob, params, restructure) -Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. +Initialize a state of the variational objective `obj`. This function needs to be implemented only if `obj` is stateful. +The state of the AD backend `adtype` shall also be initialized here. # Arguments - `rng::Random.AbstractRNG`: Random number generator. - `obj::AbstractVariationalObjective`: Variational objective. +- `adtype::ADTypes.ADType`:Automatic differentiation backend. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. - `params`: Initial variational parameters. - `restructure`: Function that reconstructs the variational approximation from `λ`. """ -init(::Random.AbstractRNG, ::AbstractVariationalObjective, ::Any, ::Any, ::Any) = nothing +init(::Random.AbstractRNG, ::AbstractVariationalObjective, ::Any, ::Any, ::Any, ::Any) = + nothing """ estimate_objective([rng,] obj, q, prob; kwargs...) diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl index e6f04ae8..da9250ff 100644 --- a/src/objectives/elbo/repgradelbo.jl +++ b/src/objectives/elbo/repgradelbo.jl @@ -101,6 +101,20 @@ function estimate_repgradelbo_ad_forward(params′, aux) return -elbo end +function init( + rng::Random.AbstractRNG, + obj::RepGradELBO, + adtype::ADTypes.AbstractADType, + prob, + params, + restructure, +) + q_stop = restructure(params) + aux = (rng=rng, obj=obj, problem=prob, restructure=restructure, q_stop=q_stop) + ad_st = init_adbackend(adtype, estimate_repgradelbo_ad_forward, params, aux) + return (ad_st=ad_st,) +end + function estimate_gradient!( rng::Random.AbstractRNG, obj::RepGradELBO, @@ -123,5 +137,5 @@ function estimate_gradient!( value_and_gradient!(adtype, estimate_repgradelbo_ad_forward, params, aux, out) nelbo = DiffResults.value(out) stat = (elbo=-nelbo,) - return out, nothing, stat + return out, state, stat end diff --git a/src/optimize.jl b/src/optimize.jl index eb462ff5..77cdf756 100644 --- a/src/optimize.jl +++ b/src/optimize.jl @@ -66,7 +66,9 @@ function optimize( ) params, restructure = Optimisers.destructure(deepcopy(q_init)) opt_st = maybe_init_optimizer(state_init, optimizer, params) - obj_st = maybe_init_objective(state_init, rng, objective, problem, params, restructure) + obj_st = maybe_init_objective( + state_init, rng, adtype, objective, problem, params, restructure + ) avg_st = maybe_init_averager(state_init, averager, params) grad_buf = DiffResults.DiffResult(zero(eltype(params)), similar(params)) stats = NamedTuple[] diff --git a/src/utils.jl b/src/utils.jl index 11618677..d0c19f73 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -24,6 +24,7 @@ end function maybe_init_objective( state_init::NamedTuple, rng::Random.AbstractRNG, + adtype::ADTypes.AbstractADType, objective::AbstractVariationalObjective, problem, params, @@ -32,7 +33,7 @@ function maybe_init_objective( if haskey(state_init, :objective) state_init.objective else - init(rng, objective, problem, params, restructure) + init(rng, objective, adtype, problem, params, restructure) end end diff --git a/test/Project.toml b/test/Project.toml index 251869e7..0b4aa731 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -18,6 +18,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -42,7 +43,8 @@ ReverseDiff = "1.15.1" SimpleUnPack = "1.1.0" StableRNGs = "1.0.0" Statistics = "1" +Tapir = "0.2.23" Test = "1" Tracker = "0.2.20" Zygote = "0.6.63" -julia = "1.6" +julia = "1.7" diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 3a458e38..8702fa8d 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -5,6 +5,7 @@ AD_distributionsad = if VERSION >= v"1.10" #:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment :Zygote => AutoZygote(), :Enzyme => AutoEnzyme(), + :Tapir => AutoTapir(false), ) else Dict( diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index b0007706..cccbee4d 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -5,6 +5,7 @@ AD_locationscale = if VERSION >= v"1.10" :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Enzyme => AutoEnzyme(), + :Tapir => AutoTapir(false), ) else Dict( diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 79c81c52..54103186 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -5,6 +5,7 @@ AD_locationscale_bijectors = if VERSION >= v"1.10" :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Enzyme => AutoEnzyme(), + :Tapir => AutoTapir(false), ) else Dict( diff --git a/test/interface/ad.jl b/test/interface/ad.jl index 380c2b9b..9898ee48 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -2,18 +2,21 @@ using Test @testset "ad" begin - @testset "$(adname)" for (adname, adsymbol) in Dict( + @testset "$(adname)" for (adname, adtype) in Dict( :ForwardDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), - :Enzyme => AutoEnzyme(), + :Tapir => AutoTapir(; safe_mode=false), + #:Enzyme => AutoEnzyme() ) D = 10 A = randn(D, D) λ = randn(D) - grad_buf = DiffResults.GradientResult(λ) f(λ′) = λ′' * A * λ′ / 2 - AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf) + + ad_st = AdvancedVI.init_adbackend(adtype, f, λ) + grad_buf = DiffResults.GradientResult(λ) + AdvancedVI.value_and_gradient!(adtype, ad_st, f, λ, grad_buf) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) @test ∇ ≈ (A + A') * λ / 2 diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index 5fec46ff..afaba148 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -39,6 +39,7 @@ end ADTypes.AutoReverseDiff(), ADTypes.AutoZygote(), ADTypes.AutoEnzyme(), + ADTypes.AutoTapir(false), ] q_true = MeanFieldGaussian( Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) @@ -49,7 +50,7 @@ end aux = (rng=rng, obj=obj, problem=model, restructure=re, q_stop=q_true, adtype=ad) AdvancedVI.value_and_gradient!( - ad, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out + adtype, ad_st, AdvancedVI.estimate_repgradelbo_ad_forward, params, aux, out ) grad = DiffResults.gradient(out) @test norm(grad) ≈ 0 atol = 1e-5 diff --git a/test/runtests.jl b/test/runtests.jl index 80194a43..ac28d49f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -20,7 +20,7 @@ using DistributionsAD using LogDensityProblems using Optimisers using ADTypes -using ForwardDiff, ReverseDiff, Zygote, Enzyme +using ForwardDiff, ReverseDiff, Zygote, Enzyme, Tapir using AdvancedVI