-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add Neal's Funnel and Warped Gaussian * fixed bug in warped gaussian * add reference for warped Gauss * add Cross dsitribution * udpate docs for cross * Update example/targets/cross.jl change comment into docs Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update example/targets/cross.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * update cross docs * minor ed * Update example/targets/neal_funnel.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update example/targets/neal_funnel.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * doc banana using * fixing docs with latex * baanan docs with latex * add NF quick intro * Revert "add NF quick intro" This reverts commit e399274. * rm unnecesary code for cross * rm example/manifest * Update example/targets/neal_funnel.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update example/targets/cross.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * Update example/targets/cross.jl Co-authored-by: Tor Erlend Fjelde <[email protected]> * minor update to cross docs --------- Co-authored-by: Tor Erlend Fjelde <[email protected]>
- Loading branch information
Showing
4 changed files
with
206 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
using Distributions, Random | ||
""" | ||
Cross(μ::Real=2.0, σ::Real=0.15) | ||
2-dimensional Cross distribution | ||
# Explanation | ||
The Cross distribution is a 2-dimension 4-component Gaussian distribution with a "cross" | ||
shape that is symmetric about the y- and x-axises. The mixture is defined as | ||
```math | ||
\begin{aligned} | ||
p(x) = | ||
& 0.25 \mathcal{N}(x | (0, \mu), (\sigma, 1)) + \\ | ||
& 0.25 \mathcal{N}(x | (\mu, 0), (1, \sigma)) + \\ | ||
& 0.25 \mathcal{N}(x | (0, -\mu), (\sigma, 1)) + \\ | ||
& 0.25 \mathcal{N}(x | (-\mu, 0), (1, \sigma))) | ||
\end{aligned} | ||
``` | ||
where ``μ`` and ``σ`` are the mean and standard deviation of the Gaussian components, | ||
respectively. See an example of the Cross distribution in Page 18 of [1]. | ||
# Reference | ||
[1] Zuheng Xu, Naitong Chen, Trevor Campbell | ||
"MixFlows: principled variational inference via mixed flows." | ||
International Conference on Machine Learning, 2023 | ||
""" | ||
Cross() = Cross(2.0, 0.15) | ||
function Cross(μ::T, σ::T) where {T<:Real} | ||
return MixtureModel([ | ||
MvNormal([zero(μ), μ], [σ, one(σ)]), | ||
MvNormal([-μ, one(μ)], [one(σ), σ]), | ||
MvNormal([μ, one(μ)], [one(σ), σ]), | ||
MvNormal([zero(μ), -μ], [σ, one(σ)]), | ||
]) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
using Distributions, Random | ||
|
||
""" | ||
Funnel{T<:Real} | ||
Multidimensional Neal's Funnel distribution | ||
# Fields | ||
$(FIELDS) | ||
# Explanation | ||
The Neal's Funnel distribution is a p-dimensional distribution with a funnel shape, | ||
originally proposed by Radford Neal in [2]. | ||
The marginal distribution of ``x_1`` is Gaussian with mean "μ" and standard | ||
deviation "σ". The conditional distribution of ``x_2, \dots, x_p | x_1`` are independent | ||
Gaussian distributions with mean 0 and standard deviation ``\\exp(x_1/2)``. | ||
The generative process is given by | ||
```math | ||
x_1 \sim \mathcal{N}(\mu, \sigma^2), \quad x_2, \ldots, x_p \sim \mathcal{N}(0, \exp(x_1)) | ||
``` | ||
# Reference | ||
[1] Stan User’s Guide: | ||
https://mc-stan.org/docs/2_18/stan-users-guide/reparameterization-section.html#ref-Neal:2003 | ||
[2] Radford Neal 2003. “Slice Sampling.” Annals of Statistics 31 (3): 705–67. | ||
""" | ||
struct Funnel{T<:Real} <: ContinuousMultivariateDistribution | ||
"Dimension of the distribution, must be >= 2" | ||
dim::Int | ||
"Mean of the first dimension" | ||
μ::T | ||
"Standard deviation of the first dimension, must be > 0" | ||
σ::T | ||
function Funnel{T}(dim::Int, μ::T, σ::T) where {T<:Real} | ||
dim >= 2 || error("dim must be >= 2") | ||
σ > 0 || error("σ must be > 0") | ||
return new{T}(dim, μ, σ) | ||
end | ||
end | ||
Funnel(dim::Int, μ::T, σ::T) where {T<:Real} = Funnel{T}(dim, μ, σ) | ||
Funnel(dim::Int, σ::T) where {T<:Real} = Funnel{T}(dim, zero(T), σ) | ||
Funnel(dim::Int) = Funnel(dim, 0.0, 9.0) | ||
|
||
Base.length(p::Funnel) = p.dim | ||
Base.eltype(p::Funnel{T}) where {T<:Real} = T | ||
|
||
function Distributions._rand!(rng::AbstractRNG, p::Funnel, x::AbstractVecOrMat) | ||
T = eltype(x) | ||
d, μ, σ = p.dim, p.μ, p.σ | ||
d == size(x, 1) || error("Dimension mismatch") | ||
x[1, :] .= randn(rng, T, size(x, 2)) .* σ .+ μ | ||
x[2:end, :] .= randn(rng, T, d - 1, size(x, 2)) .* exp.(@view(x[1, :]) ./ 2)' | ||
return x | ||
end | ||
|
||
function Distributions._logpdf(p::Funnel, x::AbstractVector) | ||
d, μ, σ = p.dim, p.μ, p.σ | ||
lpdf1 = logpdf(Normal(μ, σ), x[1]) | ||
lpdfs = logpdf.(Normal.(zeros(T, d - 1), exp(x[1] / 2)), @view(x[2:end])) | ||
return lpdf1 + sum(lpdfs) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
using Distributions, Random, LinearAlgebra, IrrationalConstants | ||
|
||
""" | ||
WarpedGauss{T<:Real} | ||
2-dimensional warped Gaussian distribution | ||
# Fields | ||
$(FIELDS) | ||
# Explanation | ||
The banana distribution is obtained by applying a transformation ϕ to a 2-dimensional normal | ||
distribution ``\\mathcal{N}(0, diag(\\sigma_1, \\sigma_2))``. The transformation ϕ(x) is defined as | ||
```math | ||
ϕ(x_1, x_2) = (r*\cos(\theta + r/2), r*\sin(\theta + r/2)), | ||
``` | ||
where ``r = \\sqrt{x\_1^2 + x_2^2}``, ``\\theta = \\atan(x₂, x₁)``, | ||
and "atan(y, x) ∈ [-π, π]" is the angle, in radians, between the positive x axis and the | ||
ray to the point "(x, y)". See page 18. of [1] for reference. | ||
# Reference | ||
[1] Zuheng Xu, Naitong Chen, Trevor Campbell | ||
"MixFlows: principled variational inference via mixed flows." | ||
International Conference on Machine Learning, 2023 | ||
""" | ||
struct WarpedGauss{T<:Real} <: ContinuousMultivariateDistribution | ||
"Standard deviation of the first dimension, must be > 0" | ||
σ1::T | ||
"Standard deviation of the second dimension, must be > 0" | ||
σ2::T | ||
function WarpedGauss{T}(σ1, σ2) where {T<:Real} | ||
σ1 > 0 || error("σ₁ must be > 0") | ||
σ2 > 0 || error("σ₂ must be > 0") | ||
return new{T}(σ1, σ2) | ||
end | ||
end | ||
WarpedGauss(σ1::T, σ2::T) where {T<:Real} = WarpedGauss{T}(σ1, σ2) | ||
WarpedGauss() = WarpedGauss(1.0, 0.12) | ||
|
||
Base.length(p::WarpedGauss) = 2 | ||
Base.eltype(p::WarpedGauss{T}) where {T<:Real} = T | ||
Distributions.sampler(p::WarpedGauss) = p | ||
|
||
# Define the transformation function φ and the inverse ϕ⁻¹ for the warped Gaussian distribution | ||
function ϕ!(p::WarpedGauss, z::AbstractVector) | ||
length(z) == 2 || error("Dimension mismatch") | ||
x, y = z | ||
r = norm(z) | ||
θ = atan(y, x) #in [-π , π] | ||
θ -= r / 2 | ||
z .= r .* [cos(θ), sin(θ)] | ||
return z | ||
end | ||
|
||
function ϕ⁻¹(p::WarpedGauss, z::AbstractVector) | ||
length(z) == 2 || error("Dimension mismatch") | ||
x, y = z | ||
r = norm(z) | ||
θ = atan(y, x) #in [-π , π] | ||
# increase θ depending on r to "smear" | ||
θ += r / 2 | ||
|
||
# get the x,y coordinates foαtransformed point | ||
xn = r * cos(θ) | ||
yn = r * sin(θ) | ||
# compute jacobian | ||
logJ = log(r) | ||
return [xn, yn], logJ | ||
end | ||
|
||
function Distributions._rand!(rng::AbstractRNG, p::WarpedGauss, x::AbstractVecOrMat) | ||
size(x, 1) == 2 || error("Dimension mismatch") | ||
σ₁, σ₂ = p.σ₁, p.σ₂ | ||
randn!(rng, x) | ||
x .*= [σ₁, σ₂] | ||
for y in eachcol(x) | ||
ϕ!(p, y) | ||
end | ||
return x | ||
end | ||
|
||
function Distributions._logpdf(p::WarpedGauss, x::AbstractVector) | ||
size(x, 1) == 2 || error("Dimension mismatch") | ||
σ₁, σ₂ = p.σ₁, p.σ₂ | ||
S = [σ₁, σ₂] .^ 2 | ||
z, logJ = ϕ⁻¹(p, x) | ||
return -sum(z .^ 2 ./ S) / 2 - IrrationalConstants.log2π - log(σ₁) - log(σ₂) + logJ | ||
end |