From ca6808208604497890f4ecc835be0a51022030e4 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 17 Dec 2024 00:24:57 +0100 Subject: [PATCH] Fix von Mises-Fisher sampler --- src/samplers/vonmisesfisher.jl | 139 +++++++++++++++++++--------- src/univariate/continuous/gamma.jl | 1 - test/multivariate/vonmisesfisher.jl | 42 ++++++--- 3 files changed, 126 insertions(+), 56 deletions(-) diff --git a/src/samplers/vonmisesfisher.jl b/src/samplers/vonmisesfisher.jl index fd3eb2df08..53dc23d3a0 100644 --- a/src/samplers/vonmisesfisher.jl +++ b/src/samplers/vonmisesfisher.jl @@ -1,4 +1,6 @@ # Sampler for von Mises-Fisher +# Ref https://doi.org/10.18637/jss.v058.i10 +# Ref https://hal.science/hal-04004568v3 struct VonMisesFisherSampler <: Sampleable{Multivariate,Continuous} p::Int # the dimension κ::Float64 @@ -6,29 +8,32 @@ struct VonMisesFisherSampler <: Sampleable{Multivariate,Continuous} x0::Float64 c::Float64 v::Vector{Float64} + rotate::Bool # whether to rotate the samples end function VonMisesFisherSampler(μ::Vector{Float64}, κ::Float64) + # Step 1: Calculate b, x₀, and c p = length(μ) - b = _vmf_bval(p, κ) + b = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1))) x0 = (1.0 - b) / (1.0 + b) c = κ * x0 + (p - 1) * log1p(-abs2(x0)) - v = _vmf_householder_vec(μ) - VonMisesFisherSampler(p, κ, b, x0, c, v) + + # Compute Householder transformation, and whether it has to be applied + v, rotate = _vmf_householder_vec(μ) + + return VonMisesFisherSampler(p, κ, b, x0, c, v, rotate) end Base.length(s::VonMisesFisherSampler) = length(s.v) -@inline function _vmf_rot!(v::AbstractVector, x::AbstractVector) - # rotate - scale = 2.0 * (v' * x) - @. x -= (scale * v) - return x -end +function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector{<:Real}) + # TODO: Generalize to more general indices + Base.require_one_based_indexing(x) + # Sample angle `w` + w = _vmf_angle(rng, spl) -function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector) - w = _vmf_genw(rng, spl) + # Generate sample assuming `μ = (1, 0, 0, ..., 0)` p = spl.p x[1] = w s = 0.0 @@ -43,47 +48,81 @@ function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector) x[i] *= r end - return _vmf_rot!(spl.v, x) + # Rotate for general `μ` (if necessary) + return _vmf_rotate!(x, spl) end ### Core computation -_vmf_bval(p::Int, κ::Real) = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1))) - -function _vmf_genw3(rng::AbstractRNG, p, b, x0, c, κ) - ξ = rand(rng) - w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ) - return w::Float64 -end - -function _vmf_genwp(rng::AbstractRNG, p, b, x0, c, κ) - r = (p - 1) / 2.0 - betad = Beta(r, r) - z = rand(rng, betad) - w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z) - while κ * w + (p - 1) * log(1 - x0 * w) - c < log(rand(rng)) - z = rand(rng, betad) - w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z) - end - return w::Float64 -end +# Step 2: Sample angle W +function _vmf_angle(rng::AbstractRNG, spl::VonMisesFisherSampler) + p = spl.p + κ = spl.κ -# generate the W value -- the key step in simulating vMF -# -# following movMF's document for the p != 3 case -# and Wenzel Jakob's document for the p == 3 case -function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ) if p == 3 - return _vmf_genw3(rng, p, b, x0, c, κ) + _vmf_angle3(rng, κ) else - return _vmf_genwp(rng, p, b, x0, c, κ) + # General case: Rejection sampling + # Ref https://doi.org/10.18637/jss.v058.i10 + b = spl.b + c = spl.c + p = spl.p + κ = spl.κ + x0 = spl.x0 + pm1 = p - 1 + + if p == 2 + # In this case the distribution reduces to the von Mises distribution on the circle + # We exploit the fact that `Beta(1/2, 1/2) = Arcsine(0, 1)` + dist = Arcsine(zero(b), one(b)) + while true + z = rand(rng, dist) + w = (1 - (1 + b) * z) / (1 - (1 - b) * z) + if κ * w + pm1 * log1p(- x0 * w) >= c - randexp(rng) + return w::Float64 + end + end + else + # We sample from a `Beta((p - 1)/2, (p - 1)/2)` distribution, possibly repeatedly + # Therefore we construct a sampler + # To avoid the type instability of `sampler(Beta(...))` and `sampler(Gamma(...))` + # we directly construct the Gamma sampler for Gamma((p - 1)/2, 1) + # Since (p - 1)/2 > 1, we construct a `GammaMTSampler` + r = pm1 / 2 + gammasampler = GammaMTSampler(Gamma{typeof(r)}(r, one(r))) + while true + # w is supposed to be generated as + # z ~ Beta((p - 1)/ 2, (p - 1)/2) + # w = (1 - (1 + b) * z) / (1 - (1 - b) * z) + # We sample z as + # z1 ~ Gamma((p - 1) / 2, 1) + # z2 ~ Gamma((p - 1) / 2, 1) + # z = z1 / (z1 + z2) + # and rewrite the expression for w + # Cf. case p == 2 above + z1 = rand(rng, gammasampler) + z2 = rand(rng, gammasampler) + b_z1 = b * z1 + w = (z2 - b_z1) / (z2 + b_z1) + if κ * w + pm1 * log1p(- x0 * w) >= c - randexp(rng) + return w::Float64 + end + end + end end end +# Special case: 2-sphere +@inline function _vmf_angle3(rng::AbstractRNG, κ::Real) + # In this case, we can directly sample the angle + # Ref https://www.mitsuba-renderer.org/~wenzel/files/vmf.pdf + ξ = rand(rng) + w = 1.0 + (log(ξ + (1.0 - ξ)*exp(-2κ))/κ) + return w::Float64 +end -_vmf_genw(rng::AbstractRNG, s::VonMisesFisherSampler) = - _vmf_genw(rng, s.p, s.b, s.x0, s.c, s.κ) - +# Create Householder transformation to rotate samples for `μ = (1, 0, ..., 0)` +# to samples for general `μ` function _vmf_householder_vec(μ::Vector{Float64}) # assuming μ is a unit-vector (which it should be) # can compute v in a single pass over μ @@ -92,11 +131,27 @@ function _vmf_householder_vec(μ::Vector{Float64}) v = similar(μ) v[1] = μ[1] - 1.0 s = sqrt(-2*v[1]) + if iszero(s) + # In this case, μ is (approx.) (1, 0, ..., 0) + # Hence no rotation has to be performed and `v` is not used + return v, false + end + v[1] /= s @inbounds for i in 2:p v[i] = μ[i] / s end - return v + return v, true +end + +# Rotate samples for general `μ` (if needed) +@inline function _vmf_rotate!(x::AbstractVector{<:Real}, spl::VonMisesFisherSampler) + if spl.rotate + v = spl.v + scale = 2.0 * (v' * x) + @. x -= (scale * v) + end + return x end diff --git a/src/univariate/continuous/gamma.jl b/src/univariate/continuous/gamma.jl index 866255fb7d..8ba207d2c7 100644 --- a/src/univariate/continuous/gamma.jl +++ b/src/univariate/continuous/gamma.jl @@ -105,7 +105,6 @@ function rand(rng::AbstractRNG, d::Gamma) # TODO: shape(d) = 0.5 : use scaled chisq return rand(rng, GammaIPSampler(d)) elseif shape(d) == 1.0 - θ = return rand(rng, Exponential{partype(d)}(scale(d))) else return rand(rng, GammaMTSampler(d)) diff --git a/test/multivariate/vonmisesfisher.jl b/test/multivariate/vonmisesfisher.jl index cc45f41ed5..7dc6f90028 100644 --- a/test/multivariate/vonmisesfisher.jl +++ b/test/multivariate/vonmisesfisher.jl @@ -23,6 +23,7 @@ function gen_vmf_tdata(n::Int, p::Int, end function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing) + # Random μ if ismissing(rng) μ = randn(p) x = randn(p) @@ -34,16 +35,24 @@ function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing) μ = μ ./ κ s = Distributions.VonMisesFisherSampler(μ, κ) + @test s.rotate v = μ - vcat(1, zeros(p-1)) H = I - 2*v*v'/(v'*v) - @test Distributions._vmf_rot!(s.v, copy(x)) ≈ (H*x) - -end + @test Distributions._vmf_rotate!(copy(x), s) ≈ (H*x) + # Special case: μ = (1, 0, ..., 0) + # In this case no rotation is performed + μ = zeros(p) + μ[1] = 1 + s = Distributions.VonMisesFisherSampler(μ, κ) + @test !s.rotate + @test Distributions._vmf_rotate!(copy(x), s) == x + return nothing +end -function test_genw3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing) +function test_angle3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing) p = 3 if ismissing(rng) @@ -53,21 +62,20 @@ function test_genw3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missin end μ = μ ./ norm(μ) - s = Distributions.VonMisesFisherSampler(μ, float(κ)) + spl = Distributions.VonMisesFisherSampler(μ, float(κ)) + angle3_res = [Distributions._vmf_angle3(rng, spl.κ) for _ in 1:ns] + angle_res = [Distributions._vmf_angle(rng, spl) for _ in 1:ns] - genw3_res = [Distributions._vmf_genw3(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns] - genwp_res = [Distributions._vmf_genwp(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns] - - @test isapprox(mean(genw3_res), mean(genwp_res), atol=0.01) - @test isapprox(std(genw3_res), std(genwp_res), atol=0.01/κ) + @test mean(angle3_res) ≈ mean(angle_res) rtol=5e-2 + @test std(angle3_res) ≈ std(angle_res) rtol=1e-2 # test mean and stdev against analytical formulas coth_κ = coth(κ) mean_w = coth_κ - 1/κ var_w = 1 - coth_κ^2 + 1/κ^2 - @test isapprox(mean(genw3_res), mean_w, atol=0.01) - @test isapprox(std(genw3_res), sqrt(var_w), atol=0.01/κ) + @test mean(angle3_res) ≈ mean_w rtol=5e-2 + @test std(angle3_res) ≈ sqrt(var_w) rtol=1e-2 end @@ -178,7 +186,15 @@ ns = 10^6 if !ismissing(rng) @testset "Testing genw with $key at (3, $κ)" for κ in [0.1, 0.5, 1.0, 2.0, 5.0] - test_genw3(κ, ns, rng) + test_angle3(κ, ns, rng) end end end + +# issue #1423 +@testset "Special case: No rotation" begin + for n in 2:10 + d = VonMisesFisher(vcat(1, zeros(n - 1)), 1.0) + @test sum(abs2, rand(d)) ≈ 1 + end +end