diff --git a/src/SampleChainsDynamicHMC.jl b/src/SampleChainsDynamicHMC.jl index d79c971..b4adcc6 100644 --- a/src/SampleChainsDynamicHMC.jl +++ b/src/SampleChainsDynamicHMC.jl @@ -35,7 +35,7 @@ function DynamicHMCChain(t::TransformVariables.TransformTuple, Q::DynamicHMC.Eva meta = steps transform = t - return DynamicHMCChain{T}(samples, logq, info, meta, Q, transform) + return DynamicHMCChain{T}(samples, logq, info, meta, zeroarr(Q), transform) end TupleVectors.summarize(ch::DynamicHMCChain) = summarize(samples(ch)) @@ -51,7 +51,9 @@ function SampleChains.pushsample!(chain::DynamicHMCChain, Q::DynamicHMC.Evaluate end function SampleChains.step!(chain::DynamicHMCChain) - Q, tree_stats = DynamicHMC.mcmc_next_step(getfield(chain, :meta), getfield(chain, :state)) + Q, tree_stats = DynamicHMC.mcmc_next_step(getfield(chain, :meta), getfield(chain, :state)[]) + getfield(chain, :state)[] = Q + return Q, tree_stats end @concrete struct DynamicHMCConfig <: ChainConfig{DynamicHMCChain} @@ -156,4 +158,11 @@ function SampleChains.sample!(chain::DynamicHMCChain, n::Int=1000) return chain end +# I have no idea if this makes sense. But it's kind of cool so I wanted to try it +function zeroarr(x::T) where {T} + a = Array{T,0}(undef) + a[] = x + return a +end + end