Skip to content

Commit

Permalink
Merge pull request #30 from ReactiveBayes/bump-manifolds
Browse files Browse the repository at this point in the history
feat: bump Manifolds ecosystem
  • Loading branch information
wouterwln authored Feb 11, 2025
2 parents f3f00a3 + 745c852 commit ea0c972
Show file tree
Hide file tree
Showing 7 changed files with 95 additions and 17 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
BayesBase = "1.3"
ExponentialFamily = "1.6.0"
LinearAlgebra = "1.10"
Manifolds = "0.9"
ManifoldsBase = "0.15"
Manifolds = "0.10"
ManifoldsBase = "1"
Random = "1.10"
RecursiveArrayTools = "3"
Static = "0.8, 1"
Expand Down
17 changes: 12 additions & 5 deletions src/shifted_negative_numbers.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
ShiftedNegativeNumbers(shift)
Expand Down Expand Up @@ -56,16 +55,20 @@ function ManifoldsBase.inner(M::ShiftedNegativeNumbers, p, X, Y)
)
end

function ManifoldsBase.exp!(M::ShiftedNegativeNumbers, q, p, X, t::Number=1)
function ManifoldsBase.exp_fused!(M::ShiftedNegativeNumbers, q, p, X, t::Number)
@inbounds q[1] = shift(
M,
-ManifoldsBase.exp(
-ManifoldsBase.exp_fused(
PositiveNumbers(), -unshift(M, @inbounds(p[1])), -@inbounds(X[1]), t
),
)
return q
end

function ManifoldsBase.exp!(M::ShiftedNegativeNumbers, q, p, X)
return ManifoldsBase.exp_fused!(M, q, p, X, one(eltype(p)))
end

function ManifoldsBase.log!(M::ShiftedNegativeNumbers, X, p, q)
@inbounds X[1] =
-ManifoldsBase.log(
Expand All @@ -82,10 +85,14 @@ end

ManifoldsBase.default_retraction_method(::ShiftedNegativeNumbers) = ExponentialRetraction()

function ManifoldsBase.retract!(
function ManifoldsBase.retract_fused!(
M::ShiftedNegativeNumbers, q, p, X, t::Number, ::ExponentialRetraction
)
return ManifoldsBase.exp!(M, q, p, X, t)
return ManifoldsBase.exp_fused!(M, q, p, X, t)
end

function ManifoldsBase.retract!(M::ShiftedNegativeNumbers, q, p, X, ::ExponentialRetraction)
return ManifoldsBase.retract_fused!(M, q, p, X, one(eltype(p)), ExponentialRetraction())
end

function ManifoldsBase.parallel_transport_to!(M::ShiftedNegativeNumbers, Y, p, X, q)
Expand Down
19 changes: 14 additions & 5 deletions src/shifted_positive_numbers.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
ShiftedPositiveNumbers(shift)
Expand Down Expand Up @@ -56,16 +55,22 @@ function ManifoldsBase.inner(M::ShiftedPositiveNumbers, p, X, Y)
)
end

function ManifoldsBase.exp!(M::ShiftedPositiveNumbers, q, p, X, t::Number=1)
function ManifoldsBase.exp_fused!(M::ShiftedPositiveNumbers, q, p, X, t::Number)
@inbounds q[1] = shift(
M,
ManifoldsBase.exp(
PositiveNumbers(), unshift(M, @inbounds(p[1])), @inbounds(X[1]), t
PositiveNumbers(),
unshift(M, @inbounds(p[1])),
@inbounds(X[1]) * t # Scale the tangent vector by t
),
)
return q
end

function ManifoldsBase.exp!(M::ShiftedPositiveNumbers, q, p, X)
return ManifoldsBase.exp_fused!(M, q, p, X, one(eltype(p)))
end

function ManifoldsBase.log!(M::ShiftedPositiveNumbers, X, p, q)
@inbounds X[1] = ManifoldsBase.log(
PositiveNumbers(), unshift(M, @inbounds(p[1])), unshift(M, @inbounds(q[1]))
Expand All @@ -82,10 +87,14 @@ end

ManifoldsBase.default_retraction_method(::ShiftedPositiveNumbers) = ExponentialRetraction()

function ManifoldsBase.retract!(
function ManifoldsBase.retract_fused!(
M::ShiftedPositiveNumbers, q, p, X, t::Number, ::ExponentialRetraction
)
return ManifoldsBase.exp!(M, q, p, X, t)
return ManifoldsBase.exp_fused!(M, q, p, X, t)
end

function ManifoldsBase.retract!(M::ShiftedPositiveNumbers, q, p, X, ::ExponentialRetraction)
return ManifoldsBase.retract_fused!(M, q, p, X, one(eltype(p)), ExponentialRetraction())
end

function ManifoldsBase.parallel_transport_to!(M::ShiftedPositiveNumbers, Y, p, X, q)
Expand Down
17 changes: 12 additions & 5 deletions src/symmetric_negative_definite.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

"""
SymmetricNegativeDefinite(k)
Expand Down Expand Up @@ -66,11 +65,15 @@ function ManifoldsBase.inner(M::SymmetricNegativeDefinite, p, X, Y)
return ManifoldsBase.inner(M.base, Negated(p), Negated(X), Negated(Y))
end

function ManifoldsBase.exp!(M::SymmetricNegativeDefinite, q, p, X, t::Number=1)
ManifoldsBase.exp!(M.base, q, Negated(p), Negated(X), t)
function ManifoldsBase.exp_fused!(M::SymmetricNegativeDefinite, q, p, X, t::Number)
ManifoldsBase.exp_fused!(M.base, q, Negated(p), Negated(X), t)
return negate!(q)
end

function ManifoldsBase.exp!(M::SymmetricNegativeDefinite, q, p, X)
return ManifoldsBase.exp_fused!(M, q, p, X, one(eltype(p)))
end

function ManifoldsBase.log!(M::SymmetricNegativeDefinite, X, p, q)
ManifoldsBase.log!(M.base, X, Negated(p), Negated(q))
return negate!(X)
Expand All @@ -85,10 +88,14 @@ function ManifoldsBase.default_retraction_method(::SymmetricNegativeDefinite)
return ExponentialRetraction()
end

function ManifoldsBase.retract!(
function ManifoldsBase.retract_fused!(
M::SymmetricNegativeDefinite, q, p, X, t::Number, ::ExponentialRetraction
)
return ManifoldsBase.exp!(M, q, p, X, t)
return ManifoldsBase.exp_fused!(M, q, p, X, t)
end

function ManifoldsBase.retract!(M::SymmetricNegativeDefinite, q, p, X, ::ExponentialRetraction)
return ManifoldsBase.retract_fused!(M, q, p, X, one(eltype(p)), ExponentialRetraction())
end

function ManifoldsBase.parallel_transport_to!(M::SymmetricNegativeDefinite, Y, p, X, q)
Expand Down
18 changes: 18 additions & 0 deletions test/shifted_negative_numbers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,21 @@ end
@test f(M, q1)[1] expected_minimum rtol = 1e-6
@test norm(M, f(M, q1), grad_f(M, q1)) <= 1e-8
end

@testitem "Retraction methods consistency" begin
using ManifoldsBase, Random
import ExponentialFamilyManifolds: ShiftedNegativeNumbers

rng = Random.default_rng()
M = ShiftedNegativeNumbers(0.0)
p = rand(rng, M)
X = randn(rng, 1)

q1 = similar(p)
q2 = similar(p)

# Test that retract! is equivalent to retract_fused! with t=1
retract!(M, q1, p, X, ExponentialRetraction())
ManifoldsBase.retract_fused!(M, q2, p, X, 1.0, ExponentialRetraction())
@test q1 q2
end
18 changes: 18 additions & 0 deletions test/shifted_positive_numbers_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,21 @@ end
@test f(M, q1)[1] expected_minimum rtol = 1e-6
@test norm(M, f(M, q1), grad_f(M, q1)) <= 1e-8
end

@testitem "Retraction methods consistency" begin
using ManifoldsBase, Random
import ExponentialFamilyManifolds: ShiftedPositiveNumbers

rng = Random.default_rng()
M = ShiftedPositiveNumbers(0.0)
p = rand(rng, M)
X = randn(rng, 1)

q1 = similar(p)
q2 = similar(p)

# Test that retract! is equivalent to retract_fused! with t=1
retract!(M, q1, p, X, ExponentialRetraction())
ManifoldsBase.retract_fused!(M, q2, p, X, 1.0, ExponentialRetraction())
@test q1 q2
end
19 changes: 19 additions & 0 deletions test/symmetric_negative_definite_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,4 +215,23 @@ end
@test f(M, q1) expected_minimum rtol = 1e-7
@test norm(M, q1, g(M, q1)) <= eps
end
end

@testitem "Retraction methods consistency" begin
using ManifoldsBase, Random, LinearAlgebra
import ExponentialFamilyManifolds: SymmetricNegativeDefinite

rng = Random.default_rng()
M = SymmetricNegativeDefinite(2)
p = rand(rng, M)
X = randn(rng, 2, 2)
X = (X + X')/2 # Make symmetric

q1 = similar(p)
q2 = similar(p)

# Test that retract! is equivalent to retract_fused! with t=1
retract!(M, q1, p, X, ExponentialRetraction())
ManifoldsBase.retract_fused!(M, q2, p, X, 1.0, ExponentialRetraction())
@test q1 q2
end

0 comments on commit ea0c972

Please sign in to comment.