Skip to content

Commit

Permalink
Fix von Mises-Fisher sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Dec 16, 2024
1 parent 3de6038 commit ca68082
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 56 deletions.
139 changes: 97 additions & 42 deletions src/samplers/vonmisesfisher.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,39 @@
# Sampler for von Mises-Fisher
# Ref https://doi.org/10.18637/jss.v058.i10
# Ref https://hal.science/hal-04004568v3
struct VonMisesFisherSampler <: Sampleable{Multivariate,Continuous}
p::Int # the dimension
κ::Float64
b::Float64
x0::Float64
c::Float64
v::Vector{Float64}
rotate::Bool # whether to rotate the samples
end

function VonMisesFisherSampler::Vector{Float64}, κ::Float64)
# Step 1: Calculate b, x₀, and c
p = length(μ)
b = _vmf_bval(p, κ)
b = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1)))
x0 = (1.0 - b) / (1.0 + b)
c = κ * x0 + (p - 1) * log1p(-abs2(x0))
v = _vmf_householder_vec(μ)
VonMisesFisherSampler(p, κ, b, x0, c, v)

# Compute Householder transformation, and whether it has to be applied
v, rotate = _vmf_householder_vec(μ)

return VonMisesFisherSampler(p, κ, b, x0, c, v, rotate)
end

Base.length(s::VonMisesFisherSampler) = length(s.v)

@inline function _vmf_rot!(v::AbstractVector, x::AbstractVector)
# rotate
scale = 2.0 * (v' * x)
@. x -= (scale * v)
return x
end
function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector{<:Real})
# TODO: Generalize to more general indices
Base.require_one_based_indexing(x)

# Sample angle `w`
w = _vmf_angle(rng, spl)

function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector)
w = _vmf_genw(rng, spl)
# Generate sample assuming `μ = (1, 0, 0, ..., 0)`
p = spl.p
x[1] = w
s = 0.0
Expand All @@ -43,47 +48,81 @@ function _rand!(rng::AbstractRNG, spl::VonMisesFisherSampler, x::AbstractVector)
x[i] *= r
end

return _vmf_rot!(spl.v, x)
# Rotate for general `μ` (if necessary)
return _vmf_rotate!(x, spl)
end

### Core computation

_vmf_bval(p::Int, κ::Real) = (p - 1) / (2.0κ + sqrt(4 * abs2(κ) + abs2(p - 1)))

function _vmf_genw3(rng::AbstractRNG, p, b, x0, c, κ)
ξ = rand(rng)
w = 1.0 + (log+ (1.0 - ξ)*exp(-2κ))/κ)
return w::Float64
end

function _vmf_genwp(rng::AbstractRNG, p, b, x0, c, κ)
r = (p - 1) / 2.0
betad = Beta(r, r)
z = rand(rng, betad)
w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z)
while κ * w + (p - 1) * log(1 - x0 * w) - c < log(rand(rng))
z = rand(rng, betad)
w = (1.0 - (1.0 + b) * z) / (1.0 - (1.0 - b) * z)
end
return w::Float64
end
# Step 2: Sample angle W
function _vmf_angle(rng::AbstractRNG, spl::VonMisesFisherSampler)
p = spl.p
κ = spl.κ

# generate the W value -- the key step in simulating vMF
#
# following movMF's document for the p != 3 case
# and Wenzel Jakob's document for the p == 3 case
function _vmf_genw(rng::AbstractRNG, p, b, x0, c, κ)
if p == 3
return _vmf_genw3(rng, p, b, x0, c, κ)
_vmf_angle3(rng, κ)
else
return _vmf_genwp(rng, p, b, x0, c, κ)
# General case: Rejection sampling
# Ref https://doi.org/10.18637/jss.v058.i10
b = spl.b
c = spl.c
p = spl.p
κ = spl.κ
x0 = spl.x0
pm1 = p - 1

if p == 2
# In this case the distribution reduces to the von Mises distribution on the circle
# We exploit the fact that `Beta(1/2, 1/2) = Arcsine(0, 1)`
dist = Arcsine(zero(b), one(b))
while true
z = rand(rng, dist)
w = (1 - (1 + b) * z) / (1 - (1 - b) * z)
if κ * w + pm1 * log1p(- x0 * w) >= c - randexp(rng)
return w::Float64
end
end
else
# We sample from a `Beta((p - 1)/2, (p - 1)/2)` distribution, possibly repeatedly
# Therefore we construct a sampler
# To avoid the type instability of `sampler(Beta(...))` and `sampler(Gamma(...))`
# we directly construct the Gamma sampler for Gamma((p - 1)/2, 1)
# Since (p - 1)/2 > 1, we construct a `GammaMTSampler`
r = pm1 / 2
gammasampler = GammaMTSampler(Gamma{typeof(r)}(r, one(r)))
while true
# w is supposed to be generated as
# z ~ Beta((p - 1)/ 2, (p - 1)/2)
# w = (1 - (1 + b) * z) / (1 - (1 - b) * z)
# We sample z as
# z1 ~ Gamma((p - 1) / 2, 1)
# z2 ~ Gamma((p - 1) / 2, 1)
# z = z1 / (z1 + z2)
# and rewrite the expression for w
# Cf. case p == 2 above
z1 = rand(rng, gammasampler)
z2 = rand(rng, gammasampler)
b_z1 = b * z1
w = (z2 - b_z1) / (z2 + b_z1)
if κ * w + pm1 * log1p(- x0 * w) >= c - randexp(rng)
return w::Float64
end
end
end
end
end

# Special case: 2-sphere
@inline function _vmf_angle3(rng::AbstractRNG, κ::Real)
# In this case, we can directly sample the angle
# Ref https://www.mitsuba-renderer.org/~wenzel/files/vmf.pdf
ξ = rand(rng)
w = 1.0 + (log+ (1.0 - ξ)*exp(-2κ))/κ)
return w::Float64
end

_vmf_genw(rng::AbstractRNG, s::VonMisesFisherSampler) =
_vmf_genw(rng, s.p, s.b, s.x0, s.c, s.κ)

# Create Householder transformation to rotate samples for `μ = (1, 0, ..., 0)`
# to samples for general `μ`
function _vmf_householder_vec::Vector{Float64})
# assuming μ is a unit-vector (which it should be)
# can compute v in a single pass over μ
Expand All @@ -92,11 +131,27 @@ function _vmf_householder_vec(μ::Vector{Float64})
v = similar(μ)
v[1] = μ[1] - 1.0
s = sqrt(-2*v[1])
if iszero(s)
# In this case, μ is (approx.) (1, 0, ..., 0)
# Hence no rotation has to be performed and `v` is not used
return v, false
end

v[1] /= s

@inbounds for i in 2:p
v[i] = μ[i] / s
end

return v
return v, true
end

# Rotate samples for general `μ` (if needed)
@inline function _vmf_rotate!(x::AbstractVector{<:Real}, spl::VonMisesFisherSampler)
if spl.rotate
v = spl.v
scale = 2.0 * (v' * x)
@. x -= (scale * v)
end
return x
end
1 change: 0 additions & 1 deletion src/univariate/continuous/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ function rand(rng::AbstractRNG, d::Gamma)
# TODO: shape(d) = 0.5 : use scaled chisq
return rand(rng, GammaIPSampler(d))
elseif shape(d) == 1.0
θ =
return rand(rng, Exponential{partype(d)}(scale(d)))
else
return rand(rng, GammaMTSampler(d))
Expand Down
42 changes: 29 additions & 13 deletions test/multivariate/vonmisesfisher.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ function gen_vmf_tdata(n::Int, p::Int,
end

function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing)
# Random μ
if ismissing(rng)
μ = randn(p)
x = randn(p)
Expand All @@ -34,16 +35,24 @@ function test_vmf_rot(p::Int, rng::Union{AbstractRNG, Missing} = missing)
μ = μ ./ κ

s = Distributions.VonMisesFisherSampler(μ, κ)
@test s.rotate
v = μ - vcat(1, zeros(p-1))
H = I - 2*v*v'/(v'*v)

@test Distributions._vmf_rot!(s.v, copy(x)) (H*x)

end
@test Distributions._vmf_rotate!(copy(x), s) (H*x)

# Special case: μ = (1, 0, ..., 0)
# In this case no rotation is performed
μ = zeros(p)
μ[1] = 1
s = Distributions.VonMisesFisherSampler(μ, κ)
@test !s.rotate
@test Distributions._vmf_rotate!(copy(x), s) == x

return nothing
end

function test_genw3::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing)
function test_angle3::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missing)
p = 3

if ismissing(rng)
Expand All @@ -53,21 +62,20 @@ function test_genw3(κ::Real, ns::Int, rng::Union{AbstractRNG, Missing} = missin
end
μ = μ ./ norm(μ)

s = Distributions.VonMisesFisherSampler(μ, float(κ))
spl = Distributions.VonMisesFisherSampler(μ, float(κ))
angle3_res = [Distributions._vmf_angle3(rng, spl.κ) for _ in 1:ns]
angle_res = [Distributions._vmf_angle(rng, spl) for _ in 1:ns]

genw3_res = [Distributions._vmf_genw3(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns]
genwp_res = [Distributions._vmf_genwp(rng, s.p, s.b, s.x0, s.c, s.κ) for _ in 1:ns]

@test isapprox(mean(genw3_res), mean(genwp_res), atol=0.01)
@test isapprox(std(genw3_res), std(genwp_res), atol=0.01/κ)
@test mean(angle3_res) mean(angle_res) rtol=5e-2
@test std(angle3_res) std(angle_res) rtol=1e-2

# test mean and stdev against analytical formulas
coth_κ = coth(κ)
mean_w = coth_κ - 1/κ
var_w = 1 - coth_κ^2 + 1/κ^2

@test isapprox(mean(genw3_res), mean_w, atol=0.01)
@test isapprox(std(genw3_res), sqrt(var_w), atol=0.01/κ)
@test mean(angle3_res) mean_w rtol=5e-2
@test std(angle3_res) sqrt(var_w) rtol=1e-2
end


Expand Down Expand Up @@ -178,7 +186,15 @@ ns = 10^6

if !ismissing(rng)
@testset "Testing genw with $key at (3, )" for κ in [0.1, 0.5, 1.0, 2.0, 5.0]
test_genw3(κ, ns, rng)
test_angle3(κ, ns, rng)
end
end
end

# issue #1423
@testset "Special case: No rotation" begin
for n in 2:10
d = VonMisesFisher(vcat(1, zeros(n - 1)), 1.0)
@test sum(abs2, rand(d)) 1
end
end

0 comments on commit ca68082

Please sign in to comment.