diff --git a/Project.toml b/Project.toml index 46b166a..b31e817 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PythonOT" uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef" authors = ["David Widmann"] -version = "0.1.1" +version = "0.1.2" [deps] PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" diff --git a/docs/src/api.md b/docs/src/api.md index a66266b..157af1d 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -31,4 +31,5 @@ PythonOT.Smooth.smooth_ot_dual ```@docs sinkhorn_unbalanced sinkhorn_unbalanced2 +barycenter_unbalanced ``` diff --git a/src/PythonOT.jl b/src/PythonOT.jl index 4240406..b1cd684 100644 --- a/src/PythonOT.jl +++ b/src/PythonOT.jl @@ -2,7 +2,14 @@ module PythonOT using PyCall: PyCall -export emd, emd2, sinkhorn, sinkhorn2, barycenter, sinkhorn_unbalanced, sinkhorn_unbalanced2 +export emd, + emd2, + sinkhorn, + sinkhorn2, + barycenter, + barycenter_unbalanced, + sinkhorn_unbalanced, + sinkhorn_unbalanced2 const pot = PyCall.PyNULL() diff --git a/src/lib.jl b/src/lib.jl index 7028019..da10c79 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -322,3 +322,48 @@ true ``` """ barycenter(A, C, ε; kwargs...) = pot.barycenter(A, C, ε; kwargs...) + +""" + barycenter_unbalanced(A, C, ε, λ; kwargs...) + +Compute the entropically regularized unbalanced Wasserstein barycenter with histograms `A`, cost matrix +`C`, entropic regularization parameter `ε` and marginal relaxation parameter `λ`. + +The Wasserstein barycenter is a histogram and solves +```math +\\inf_{a} \\sum_{i} W_{\\varepsilon,C,\\lambda}(a, a_i), +``` +where the histograms ``a_i`` are columns of matrix `A` and ``W_{\\varepsilon,C,\\lambda}(a, a_i)}`` +is the optimal transport cost for the entropically regularized optimal transport problem +with marginals ``a`` and ``a_i``, cost matrix ``C``, entropic regularization parameter +``\\varepsilon`` and marginal relaxation parameter ``\\lambda``. Optionally, weights of the histograms ``a_i`` can be provided with the +keyword argument `weights`. + +This function is a wrapper of the function +[`barycenter_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.barycenter_unbalanced) in the +Python Optimal Transport package. Keyword arguments are listed in the documentation of the +Python function. + +# Examples + +```jldoctest +julia> A = rand(10, 3); + +julia> A ./= sum(A; dims=1); + +julia> C = rand(10, 10); + +julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 1; method="sinkhorn_stabilized")), 1; atol=1e-4) +false + +julia> isapprox(sum(barycenter_unbalanced( + A, C, 0.01, 10_000; method="sinkhorn_stabilized", numItermax=5_000 + )), 1; atol=1e-4) +true +``` + +See also: [`barycenter`](@ref) +""" +function barycenter_unbalanced(A, C, ε, λ; kwargs...) + return pot.barycenter_unbalanced(A, C, ε, λ; kwargs...) +end