Skip to content

Commit

Permalink
tests for Sasaki retraction
Browse files Browse the repository at this point in the history
  • Loading branch information
mateuszbaran committed Oct 10, 2023
1 parent c3a6b12 commit 22aaf3a
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/point_vector_fallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,26 @@ macro default_manifold_fallbacks(TM, TP, TV, pfield::Symbol, vfield::Symbol)
ManifoldsBase.retract_embedded!(M, q.$pfield, p.$pfield, X.$vfield, t, m)
return q
end
function ManifoldsBase.retract_sasaki(
M::$TM,
p::$TP,
X::$TV,
t::Number,
m::SasakiRetraction,
)
return $TP(ManifoldsBase.retract_sasaki(M, p.$pfield, X.$vfield, t, m))
end
function ManifoldsBase.retract_sasaki!(
M::$TM,
q::$TP,
p::$TP,
X::$TV,
t::Number,
m::SasakiRetraction,
)
ManifoldsBase.retract_sasaki!(M, q.$pfield, p.$pfield, X.$vfield, t, m)
return q
end
end,
)
for f_postfix in [:polar, :project, :qr, :softmax]
Expand Down
9 changes: 9 additions & 0 deletions src/retractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,15 @@ retract_softmax!(M::AbstractManifold, q, p, X, t::Number)

function retract_softmax! end

"""
retract_sasaki!(M::AbstractManifold, q, p, X, t::Number, m::SasakiRetraction)
Compute the in-place variant of the [`SasakiRetraction`](@ref) `m`.
"""
retract_pade!(M::AbstractManifold, q, p, X, t::Number, m::SasakiRetraction)

function retract_sasaki! end

@doc raw"""
retract(M::AbstractManifold, p, X, method::AbstractRetractionMethod=default_retraction_method(M, typeof(p)))
retract(M::AbstractManifold, p, X, t::Number=1, method::AbstractRetractionMethod=default_retraction_method(M, typeof(p)))
Expand Down
11 changes: 11 additions & 0 deletions test/default_manifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,16 @@ function ManifoldsBase.retract_exp_ode!(
return (q .= p .+ t .* X)
end
ManifoldsBase.retract_pade!(::DefaultManifold, q, p, X, t::Number, i) = (q .= p .+ t .* X)
function ManifoldsBase.retract_sasaki!(
::DefaultManifold,
q,
p,
X,
t::Number,
::SasakiRetraction,
)
return (q .= p .+ t .* X)
end
ManifoldsBase.retract_softmax!(::DefaultManifold, q, p, X, t::Number) = (q .= p .+ t .* X)
ManifoldsBase.get_embedding(M::DefaultManifold) = M # dummy embedding
ManifoldsBase.inverse_retract_polar!(::DefaultManifold, Y, p, q) = (Y .= q .- p)
Expand Down Expand Up @@ -723,6 +733,7 @@ Base.size(x::MatrixVectorTransport) = (size(x.m, 2),)
ODEExponentialRetraction(PolarRetraction(), DefaultBasis()),
PadeRetraction(2),
EmbeddedRetraction(ExponentialRetraction()),
SasakiRetraction(5),
]
@test retract(M, q, Y, retr) == DefaultPoint(q.value + Y.value)
@test retract(M, q, Y, 0.5, retr) == DefaultPoint(q.value + 0.5 * Y.value)
Expand Down

0 comments on commit 22aaf3a

Please sign in to comment.