Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type error in backward pass #282

Open
Red-Portal opened this issue Oct 4, 2024 · 3 comments
Open

Type error in backward pass #282

Red-Portal opened this issue Oct 4, 2024 · 3 comments
Labels
bug Something isn't working high priority

Comments

@Red-Portal
Copy link

Hi, here's the MWE for the failing test in AdvancedVI.

using Bijectors
using DifferentiationInterface
using Distributions
using Mooncake
using Optimisers
using Random

function f(params, aux)
    (; re, samples, baseline, q_stop) = aux
    q = re(params)

    ℓq = logpdf.(Ref(q), eachcol(samples))
    ℓq_stop = logpdf.(Ref(q_stop), eachcol(samples))
    ℓπ = sum(abs2, samples, dims=1)[1,:]
    ℓπ_mean = mean(ℓπ)
    score_grad = mean(@. ℓq * (ℓπ - baseline))
    score_grad_stop = mean(@. ℓq_stop * (ℓπ - baseline))
    energy = ℓπ_mean + (score_grad - score_grad_stop)
    energy
end

function main()
    rng = Random.default_rng()
    q0  = MvNormal(zeros(3), ones(3))
    b   = Bijectors.Stacked(
        Bijectors.bijector.([LogNormal(0, 1), MvNormal(zeros(2), ones(2))]),
        [1:1, 2:3]
    )
    q  = Bijectors.transformed(q0, Bijectors.inverse(b), )

    params, re = Optimisers.destructure(q)
    
    adtype = AutoMooncake(; config=nothing)
    aux = (
        samples  = rand(rng, q, 10),
        baseline = 1.0,
        re       = re,
        q_stop   = q,
    )
    value_and_gradient(f, adtype, params, Constant(aux))
end
@willtebbutt
Copy link
Member

Thanks for narrowing this down. I won't have time to look at it today unforunately, but I should do on Monday.

@willtebbutt willtebbutt added bug Something isn't working high priority labels Oct 4, 2024
@willtebbutt
Copy link
Member

Alas, the emergence of v1.11 in CI has meant that I need to focus on that. I'll try to ensure that this is resolved by the end of the week.

@Red-Portal
Copy link
Author

The v0.3 release of AdvancedVI will wait until this is resolved!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working high priority
Projects
None yet
Development

No branches or pull requests

2 participants