diff --git a/Project.toml b/Project.toml index 9c1cd5b..d42fc33 100644 --- a/Project.toml +++ b/Project.toml @@ -15,7 +15,7 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" [compat] BayesBase = "1.3" -ExponentialFamily = "1.4.3" +ExponentialFamily = "1.5.1" LinearAlgebra = "1.10" Manifolds = "0.9" ManifoldsBase = "0.15" diff --git a/docs/src/index.md b/docs/src/index.md index b58e96b..1002d99 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -81,6 +81,7 @@ ExponentialFamilyManifolds.partition_point ExponentialFamilyManifolds.ShiftedPositiveNumbers ExponentialFamilyManifolds.ShiftedNegativeNumbers ExponentialFamilyManifolds.SymmetricNegativeDefinite +ExponentialFamilyManifolds.SinglePointManifold ``` ## Optimization example diff --git a/src/ExponentialFamilyManifolds.jl b/src/ExponentialFamilyManifolds.jl index 390e9ff..2a82cad 100644 --- a/src/ExponentialFamilyManifolds.jl +++ b/src/ExponentialFamilyManifolds.jl @@ -5,12 +5,14 @@ using BayesBase, ExponentialFamily, ManifoldsBase, Manifolds, Random, LinearAlge include("symmetric_negative_definite.jl") include("shifted_negative_numbers.jl") include("shifted_positive_numbers.jl") +include("single_point_manifold.jl") include("natural_manifolds.jl") include("natural_manifolds/bernoulli.jl") include("natural_manifolds/beta.jl") include("natural_manifolds/binomial.jl") include("natural_manifolds/chisq.jl") +include("natural_manifolds/categorical.jl") include("natural_manifolds/dirichlet.jl") include("natural_manifolds/exponential.jl") include("natural_manifolds/gamma.jl") diff --git a/src/natural_manifolds/categorical.jl b/src/natural_manifolds/categorical.jl new file mode 100644 index 0000000..33c080c --- /dev/null +++ b/src/natural_manifolds/categorical.jl @@ -0,0 +1,20 @@ + +""" + get_natural_manifold_base(::Type{Categorical}, dims::Tuple{Int}, conditioner=nothing) + +Get the natural manifold base for the `Categorical` distribution. +""" +function get_natural_manifold_base(::Type{Categorical}, ::Tuple{}, conditioner=nothing) + return ProductManifold( + Euclidean(conditioner-1), SinglePointManifold([0]) + ) +end + +""" + partition_point(::Type{Categorical}, dims::Tuple{Int}, p, conditioner=nothing) + +Converts the `point` to a compatible representation for the natural manifold of type `Categorical`. +""" +function partition_point(::Type{Categorical}, ::Tuple{}, p, conditioner=nothing) + return ArrayPartition(view(p, 1:conditioner-1), view(p, conditioner:conditioner)) +end \ No newline at end of file diff --git a/src/single_point_manifold.jl b/src/single_point_manifold.jl new file mode 100644 index 0000000..2c36d81 --- /dev/null +++ b/src/single_point_manifold.jl @@ -0,0 +1,76 @@ +using ManifoldsBase +using Random + +""" + SinglePointManifold(point) + +This manifold represents a set from one point. +""" +struct SinglePointManifold{T, R} <: AbstractManifold{ℝ} + point::T + representation_size::R +end + +function SinglePointManifold(point::T) where {T} + return SinglePointManifold(point, size(point)) +end + +function Base.show(io::IO, M::SinglePointManifold) + print(io, "SinglePointManifold(", M.point, ")") +end + +ManifoldsBase.manifold_dimension(::SinglePointManifold) = 0 +ManifoldsBase.representation_size(M::SinglePointManifold) = M.representation_size +ManifoldsBase.injectivity_radius(M::SinglePointManifold) = zero(eltype(M.point)) + +ManifoldsBase.default_retraction_method(::SinglePointManifold) = ExponentialRetraction() + +function ManifoldsBase.check_point(M::SinglePointManifold, p; kwargs...) + if !(p ≈ M.point) + return DomainError(p, "The point $(p) does not lie on $(M), which contains only $(M.point).") + end + return nothing +end + +function ManifoldsBase.check_vector(M::SinglePointManifold, p, X; kwargs...) + if !iszero(X) && size(M.point) == size(X) + return DomainError(X, "The tangent space of $(M) contains only the zero vector.") + end + return nothing +end + +ManifoldsBase.is_flat(::SinglePointManifold) = true + +ManifoldsBase.embed(::SinglePointManifold, p) = p +ManifoldsBase.embed(::SinglePointManifold, p, X) = X + +function ManifoldsBase.inner(::SinglePointManifold, p, X, Y) + return zero(eltype(X)) +end + +function ManifoldsBase.exp!(M::SinglePointManifold, q, p, X, t::Number=1) + q .= M.point + return q +end + +function ManifoldsBase.log!(::SinglePointManifold, X, p, q) + X .= zero(eltype(X)) + return X +end + +function ManifoldsBase.project!(::SinglePointManifold, Y, p, X) + fill!(Y, zero(eltype(Y))) + return Y +end + +function ManifoldsBase.zero_vector!(::SinglePointManifold, X, p) + return fill!(X, zero(eltype(X))) +end + +function Random.rand(M::SinglePointManifold; kwargs...) + return rand(Random.default_rng(), M; kwargs...) +end + +function Random.rand(rng::AbstractRNG, M::SinglePointManifold; kwargs...) + return M.point +end diff --git a/test/natural_manifolds/categorical_tests.jl b/test/natural_manifolds/categorical_tests.jl new file mode 100644 index 0000000..c613a20 --- /dev/null +++ b/test/natural_manifolds/categorical_tests.jl @@ -0,0 +1,40 @@ +@testitem "Check `Categorical` natural manifold" begin + include("natural_manifolds_setuptests.jl") + + test_natural_manifold() do rng + p = rand(rng, 10) + normalize!(p, 1) + return Categorical(p) + end +end + +@testitem "Check that optimization work on Categorical" begin + include("natural_manifolds_setuptests.jl") + + using Manopt, ForwardDiff + using BayesBase + + rng = StableRNG(42) + p = rand(StableRNG(42), 10) + normalize!(p, 1) + distribution = Categorical(p) + sample = rand(rng, distribution) + dims = size(sample) + ef = convert(ExponentialFamilyDistribution, distribution) + T = ExponentialFamily.exponential_family_typetag(ef) + M = get_natural_manifold(T, dims, getconditioner(ef)) + + function f(M, p) + ef = convert(ExponentialFamilyDistribution, M, p) + η = getnaturalparameters(ef) + return (mean(η) - 0.5)^2 + end + + function g(M, p) + return project(M, p, 2 * p ./ 10) + end + + q = gradient_descent(M, f, g, rand(rng, M)) + @test q ∈ M + @test mean(q) ≈ 0.5 atol = 1e-1 +end diff --git a/test/single_point_manifold_tests.jl b/test/single_point_manifold_tests.jl new file mode 100644 index 0000000..0f18a17 --- /dev/null +++ b/test/single_point_manifold_tests.jl @@ -0,0 +1,108 @@ +@testitem "Generic properties of SinglePointManifold" begin + import ManifoldsBase: check_point, check_vector, embed, representation_size, injectivity_radius, get_embedding, is_flat, inner, manifold_dimension + import ExponentialFamilyManifolds: SinglePointManifold + using ManifoldsBase, Static, StaticArrays, JET, Manifolds + using StableRNGs + using Random + + rng = StableRNG(42) + + + points = [ + 0, + 0.0, + 0.0f0, + 1, + 1.0, + 1.0f0, + -1, + 2, + π, + rand(), + randn() + ] + + for p in points + M = SinglePointManifold(p) + + @test repr(M) == "SinglePointManifold($p)" + + @test @inferred(representation_size(M)) === () + @test @inferred(manifold_dimension(M)) === 0 + @test @inferred(is_flat(M)) === true + @test injectivity_radius(M) ≈ 0 + @test default_retraction_method(M) == ExponentialRetraction() + + @test_throws MethodError get_embedding(M) + + @test check_point(M, p) === nothing + @test check_point(M, p + 1) isa DomainError + @test check_point(M, p - 1) isa DomainError + + @test check_vector(M, p, 0) === nothing + @test check_vector(M, p, 1) isa DomainError + @test check_vector(M, p, -1) isa DomainError + + @test @eval(@allocated(representation_size($M))) === 0 + @test @eval(@allocated(manifold_dimension($M))) === 0 + @test @eval(@allocated(is_flat($M))) === 0 + + X = [1] + Y = [1] + + @test_opt inner(M, p, X, Y) + @test_opt inner(M, p, 0, 0) + + @test embed(M, p) == p + @test embed(M, p, 0) == 0 + @test inner(M, p, 0, 0) == 0 + end + + vector_points = [[1], [1, 2], [1, 2, 3]] + + for p in vector_points + M = SinglePointManifold(p) + q = similar(p) + X = zero_vector(M, p) + @test ManifoldsBase.exp!(M, q, p, X) == p + @test ManifoldsBase.log!(M, X, p, p) == zero_vector(M, p) + @test ManifoldsBase.log(M, p, p) == zero_vector(M, p) + @test ManifoldsBase.project!(M, similar(X), p, similar(X)) == zero_vector(M, p) + @test rand(rng, M) ∈ M + @test rand(M) ∈ M + end +end + +@testitem "Simple manifold optimization problem #1" begin + using Manopt, ForwardDiff, Static, StableRNGs, LinearAlgebra + + import ExponentialFamilyManifolds: SinglePointManifold + + for a in (2.0, 3.0), + b in (10.0, 5.0), + c in (1.0, 10.0, -1.0), + eps in (1e-4, 1e-5, 1e-8, 1e-10), + stepsize in (ConstantStepsize(0.1), ConstantStepsize(0.01), ConstantStepsize(0.001)) + + f(M, x) = (a .* x .^ 2 .+ b .* x .+ c)[1] + grad_f(M, x) = 2 .* a .* x .+ b + + rng = StableRNG(42) + + for s in [0, 0.0, 10] + M = SinglePointManifold(s) + p0 = rand(rng, M) + + q1 = gradient_descent( + M, + f, + grad_f, + p0; + stepsize=stepsize, + stopping_criterion=StopAfterIteration(1) + ) + + @test q1 ≈ s + end + end +end \ No newline at end of file