Skip to content

Accumulators stage 2 #925

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: breaking
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/src/DynamicPPLBenchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ function make_suite(model, varinfo_choice::Symbol, adbackend::Symbol, islinked::
vi = DynamicPPL.link(vi, model)
end

f = DynamicPPL.LogDensityFunction(model, vi; adtype=adbackend)
f = DynamicPPL.LogDensityFunction(model, DynamicPPL.getlogjoint, vi; adtype=adbackend)
# The parameters at which we evaluate f.
θ = vi[:]

Expand Down
141 changes: 91 additions & 50 deletions src/logdensityfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
"""
LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model);
getlogdensity::Function=getlogjoint,
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing
)

Expand All @@ -28,9 +29,10 @@
- and if `adtype` is provided, calculate the gradient of the log density at
that point.

At its most basic level, a LogDensityFunction wraps the model together with the
type of varinfo to be used. These must be known in order to calculate the log
density (using [`DynamicPPL.evaluate!!`](@ref)).
At its most basic level, a LogDensityFunction wraps the model together with a
function that specifies how to extract the log density, and the type of
VarInfo to be used. These must be known in order to calculate the log density
(using [`DynamicPPL.evaluate!!`](@ref)).

If the `adtype` keyword argument is provided, then this struct will also store
the adtype along with other information for efficient calculation of the
Expand Down Expand Up @@ -72,13 +74,13 @@
1

julia> # By default it uses `VarInfo` under the hood, but this is not necessary.
f = LogDensityFunction(model, SimpleVarInfo(model));
f = LogDensityFunction(model, getlogjoint, SimpleVarInfo(model));

julia> LogDensityProblems.logdensity(f, [0.0])
-2.3378770664093453

julia> # LogDensityFunction respects the accumulators in VarInfo:
f_prior = LogDensityFunction(model, setaccs!!(VarInfo(model), (LogPriorAccumulator(),)));
julia> # One can also specify evaluating e.g. the log prior only:
f_prior = LogDensityFunction(model, getlogprior);

julia> LogDensityProblems.logdensity(f_prior, [0.0]) == logpdf(Normal(), 0.0)
true
Expand All @@ -93,11 +95,13 @@
```
"""
struct LogDensityFunction{
M<:Model,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType}
M<:Model,F<:Function,V<:AbstractVarInfo,AD<:Union{Nothing,ADTypes.AbstractADType}
} <: AbstractModel
"model used for evaluation"
model::M
"varinfo used for evaluation"
"function to be called on `varinfo` to extract the log density. By default `getlogjoint`."
getlogdensity::F
"varinfo used for evaluation. If not specified, generated with `ldf_default_varinfo`."
varinfo::V
"AD type used for evaluation of log density gradient. If `nothing`, no gradient can be calculated"
adtype::AD
Expand All @@ -106,7 +110,8 @@

function LogDensityFunction(
model::Model,
varinfo::AbstractVarInfo=VarInfo(model);
getlogdensity::Function=getlogjoint,
varinfo::AbstractVarInfo=ldf_default_varinfo(model, getlogdensity);
adtype::Union{ADTypes.AbstractADType,Nothing}=nothing,
)
if adtype === nothing
Expand All @@ -120,15 +125,22 @@
# Get a set of dummy params to use for prep
x = map(identity, varinfo[:])
if use_closure(adtype)
prep = DI.prepare_gradient(LogDensityAt(model, varinfo), adtype, x)
prep = DI.prepare_gradient(
LogDensityAt(model, getlogdensity, varinfo), adtype, x
)
else
prep = DI.prepare_gradient(
logdensity_at, adtype, x, DI.Constant(model), DI.Constant(varinfo)
logdensity_at,
adtype,
x,
DI.Constant(model),
DI.Constant(getlogdensity),
DI.Constant(varinfo),
)
end
end
return new{typeof(model),typeof(varinfo),typeof(adtype)}(
model, varinfo, adtype, prep
return new{typeof(model),typeof(getlogdensity),typeof(varinfo),typeof(adtype)}(
model, getlogdensity, varinfo, adtype, prep
)
end
end
Expand All @@ -149,83 +161,112 @@
return if adtype === f.adtype
f # Avoid recomputing prep if not needed
else
LogDensityFunction(f.model, f.varinfo; adtype=adtype)
LogDensityFunction(f.model, f.getlogdensity, f.varinfo; adtype=adtype)
end
end

"""
ldf_default_varinfo(model::Model, getlogdensity::Function)

Create the default AbstractVarInfo that should be used for evaluating the log density.

Only the accumulators necesessary for `getlogdensity` will be used.
"""
function ldf_default_varinfo(::Model, getlogdensity::Function)
msg = """

Check warning on line 176 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L175-L176

Added lines #L175 - L176 were not covered by tests
LogDensityFunction does not know what sort of VarInfo should be used when \
`getlogdensity` is $getlogdensity. Please specify a VarInfo explicitly.
"""
return error(msg)

Check warning on line 180 in src/logdensityfunction.jl

View check run for this annotation

Codecov / codecov/patch

src/logdensityfunction.jl#L180

Added line #L180 was not covered by tests
end

ldf_default_varinfo(model::Model, ::typeof(getlogjoint)) = VarInfo(model)

function ldf_default_varinfo(model::Model, ::typeof(getlogprior))
return setaccs!!(VarInfo(model), (LogPriorAccumulator(),))
end

function ldf_default_varinfo(model::Model, ::typeof(getloglikelihood))
return setaccs!!(VarInfo(model), (LogLikelihoodAccumulator(),))
end

"""
logdensity_at(
x::AbstractVector,
model::Model,
getlogdensity::Function,
varinfo::AbstractVarInfo,
)

Evaluate the log density of the given `model` at the given parameter values `x`,
using the given `varinfo`. Note that the `varinfo` argument is provided only
for its structure, in the sense that the parameters from the vector `x` are
inserted into it, and its own parameters are discarded. It does, however,
determine whether the log prior, likelihood, or joint is returned, based on
which accumulators are set in it.
Evaluate the log density of the given `model` at the given parameter values
`x`, using the given `varinfo`. Note that the `varinfo` argument is provided
only for its structure, in the sense that the parameters from the vector `x`
are inserted into it, and its own parameters are discarded. `getlogdensity` is
the function that extracts the log density from the evaluated varinfo.
"""
function logdensity_at(x::AbstractVector, model::Model, varinfo::AbstractVarInfo)
function logdensity_at(
x::AbstractVector, model::Model, getlogdensity::Function, varinfo::AbstractVarInfo
)
varinfo_new = unflatten(varinfo, x)
varinfo_eval = last(evaluate!!(model, varinfo_new))
has_prior = hasacc(varinfo_eval, Val(:LogPrior))
has_likelihood = hasacc(varinfo_eval, Val(:LogLikelihood))
if has_prior && has_likelihood
return getlogjoint(varinfo_eval)
elseif has_prior
return getlogprior(varinfo_eval)
elseif has_likelihood
return getloglikelihood(varinfo_eval)
else
error("LogDensityFunction: varinfo tracks neither log prior nor log likelihood")
end
return getlogdensity(varinfo_eval)
end

"""
LogDensityAt{M<:Model,V<:AbstractVarInfo}(
LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}(
model::M
getlogdensity::F,
varinfo::V
)

A callable struct that serves the same purpose as `x -> logdensity_at(x, model,
varinfo)`.
getlogdensity, varinfo)`.
"""
struct LogDensityAt{M<:Model,V<:AbstractVarInfo}
struct LogDensityAt{M<:Model,F<:Function,V<:AbstractVarInfo}
model::M
getlogdensity::F
varinfo::V
end
(ld::LogDensityAt)(x::AbstractVector) = logdensity_at(x, ld.model, ld.varinfo)
function (ld::LogDensityAt)(x::AbstractVector)
return logdensity_at(x, ld.model, ld.getlogdensity, ld.varinfo)
end

### LogDensityProblems interface

function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,V,Nothing}}
) where {M,V}
::Type{<:LogDensityFunction{M,F,V,Nothing}}
) where {M,F,V}
return LogDensityProblems.LogDensityOrder{0}()
end
function LogDensityProblems.capabilities(
::Type{<:LogDensityFunction{M,V,AD}}
) where {M,V,AD<:ADTypes.AbstractADType}
::Type{<:LogDensityFunction{M,F,V,AD}}
) where {M,F,V,AD<:ADTypes.AbstractADType}
return LogDensityProblems.LogDensityOrder{1}()
end
function LogDensityProblems.logdensity(f::LogDensityFunction, x::AbstractVector)
return logdensity_at(x, f.model, f.varinfo)
return logdensity_at(x, f.model, f.getlogdensity, f.varinfo)
end
function LogDensityProblems.logdensity_and_gradient(
f::LogDensityFunction{M,V,AD}, x::AbstractVector
) where {M,V,AD<:ADTypes.AbstractADType}
f::LogDensityFunction{M,F,V,AD}, x::AbstractVector
) where {M,F,V,AD<:ADTypes.AbstractADType}
f.prep === nothing &&
error("Gradient preparation not available; this should not happen")
x = map(identity, x) # Concretise type
# Make branching statically inferrable, i.e. type-stable (even if the two
# branches happen to return different types)
return if use_closure(f.adtype)
DI.value_and_gradient(LogDensityAt(f.model, f.varinfo), f.prep, f.adtype, x)
DI.value_and_gradient(
LogDensityAt(f.model, f.getlogdensity, f.varinfo), f.prep, f.adtype, x
)
else
DI.value_and_gradient(
logdensity_at, f.prep, f.adtype, x, DI.Constant(f.model), DI.Constant(f.varinfo)
logdensity_at,
f.prep,
f.adtype,
x,
DI.Constant(f.model),
DI.Constant(f.getlogdensity),
DI.Constant(f.varinfo),
)
end
end
Expand Down Expand Up @@ -264,9 +305,9 @@

1. Construct a closure over the model, i.e. let g = Base.Fix1(logdensity, f)

2. Use a constant context. This lets us pass a two-argument function to
DifferentiationInterface, as long as we also give it the 'inactive argument'
(i.e. the model) wrapped in `DI.Constant`.
2. Use a constant DI.Context. This lets us pass a two-argument function to DI,
as long as we also give it the 'inactive argument' (i.e. the model) wrapped
in `DI.Constant`.

The relative performance of the two approaches, however, depends on the AD
backend used. Some benchmarks are provided here:
Expand All @@ -292,7 +333,7 @@
Set the `DynamicPPL.Model` in the given log-density function `f` to `model`.
"""
function setmodel(f::DynamicPPL.LogDensityFunction, model::DynamicPPL.Model)
return LogDensityFunction(model, f.varinfo; adtype=f.adtype)
return LogDensityFunction(model, f.getlogdensity, f.varinfo; adtype=f.adtype)
end

"""
Expand Down
8 changes: 6 additions & 2 deletions src/simple_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,9 @@
x = vi.values
y, logjac = with_logabsdet_jacobian(b, x)
vi_new = Accessors.@set(vi.values = y)
vi_new = acclogprior!!(vi_new, -logjac)
if hasacc(vi_new, Val(:LogPrior))
vi_new = acclogprior!!(vi_new, -logjac)

Check warning on line 617 in src/simple_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/simple_varinfo.jl#L616-L617

Added lines #L616 - L617 were not covered by tests
end
return settrans!!(vi_new, t)
end

Expand All @@ -626,7 +628,9 @@
y = vi.values
x, logjac = with_logabsdet_jacobian(b, y)
vi_new = Accessors.@set(vi.values = x)
vi_new = acclogprior!!(vi_new, logjac)
if hasacc(vi_new, Val(:LogPrior))
vi_new = acclogprior!!(vi_new, logjac)

Check warning on line 632 in src/simple_varinfo.jl

View check run for this annotation

Codecov / codecov/patch

src/simple_varinfo.jl#L631-L632

Added lines #L631 - L632 were not covered by tests
end
return settrans!!(vi_new, NoTransformation())
end

Expand Down
19 changes: 9 additions & 10 deletions src/test_utils/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@
using Chairmarks: @be
import DifferentiationInterface as DI
using DocStringExtensions
using DynamicPPL:
Model,
LogDensityFunction,
VarInfo,
AbstractVarInfo,
link,
DefaultContext,
AbstractContext
using DynamicPPL: Model, LogDensityFunction, VarInfo, AbstractVarInfo, getlogjoint, link
using LogDensityProblems: logdensity, logdensity_and_gradient
using Random: Random, Xoshiro
using Statistics: median
Expand Down Expand Up @@ -58,6 +51,8 @@
struct ADResult{Tparams<:AbstractFloat,Tresult<:AbstractFloat}
"The DynamicPPL model that was tested"
model::Model
"The function used to extract the log density from the model"
getlogdensity::Function
"The VarInfo that was used"
varinfo::AbstractVarInfo
"The values at which the model was evaluated"
Expand Down Expand Up @@ -184,6 +179,7 @@
benchmark::Bool=false,
value_atol::AbstractFloat=1e-6,
grad_atol::AbstractFloat=1e-6,
getlogdensity::Function=getlogjoint,
varinfo::AbstractVarInfo=link(VarInfo(model), model),
params::Union{Nothing,Vector{<:AbstractFloat}}=nothing,
reference_adtype::AbstractADType=REFERENCE_ADTYPE,
Expand All @@ -197,7 +193,7 @@

verbose && @info "Running AD on $(model.f) with $(adtype)\n"
verbose && println(" params : $(params)")
ldf = LogDensityFunction(model, varinfo; adtype=adtype)
ldf = LogDensityFunction(model, getlogdensity, varinfo; adtype=adtype)

value, grad = logdensity_and_gradient(ldf, params)
grad = collect(grad)
Expand All @@ -206,7 +202,9 @@
if test
# Calculate ground truth to compare against
value_true, grad_true = if expected_value_and_grad === nothing
ldf_reference = LogDensityFunction(model, varinfo; adtype=reference_adtype)
ldf_reference = LogDensityFunction(

Check warning on line 205 in src/test_utils/ad.jl

View check run for this annotation

Codecov / codecov/patch

src/test_utils/ad.jl#L205

Added line #L205 was not covered by tests
model, getlogdensity, varinfo; adtype=reference_adtype
)
logdensity_and_gradient(ldf_reference, params)
else
expected_value_and_grad
Expand Down Expand Up @@ -234,6 +232,7 @@

return ADResult(
model,
getlogdensity,
varinfo,
params,
adtype,
Expand Down
Loading
Loading