Skip to content

Commit

Permalink
add Tapir extension
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Jul 12, 2024
1 parent 767274c commit 28174c9
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions ext/AdvancedVITapirExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@

module AdvancedVITapirExt

if isdefined(Base, :get_extension)
using AdvancedVI
using AdvancedVI: ADTypes, DiffResults
using Tapir
else
using ..AdvancedVI
using ..AdvancedVI: ADTypes, DiffResults
using ..Tapir
end

AdvancedVI.init_adbackend(::ADTypes.AutoTapir, f, x) = Tapir.build_rrule(f, x)

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoTapir,
st_ad,
f,
x ::AbstractVector{<:Real},
out ::DiffResults.MutableDiffResult
)
rule = st_ad
y, g = Tapir.value_and_gradient!!(rule, f, x)
DiffResults.value!(out, y)
DiffResults.gradient!(out, last(g))
return out
end

AdvancedVI.init_adbackend(::ADTypes.AutoTapir, f, x, aux) = Tapir.build_rrule(f, x, aux)

function AdvancedVI.value_and_gradient!(
::ADTypes.AutoTapir,
st_ad,
f,
x ::AbstractVector{<:Real},
aux,
out ::DiffResults.MutableDiffResult
)
rule = st_ad
y, g = Tapir.value_and_gradient!!(rule, f, x, aux)
DiffResults.value!(out, y)
DiffResults.gradient!(out, last(g))
return out
end

end

0 comments on commit 28174c9

Please sign in to comment.