From 24165428a7bdd06cf7629e02fa56efd84a7dd2b2 Mon Sep 17 00:00:00 2001 From: zsteve Date: Sun, 23 May 2021 19:31:56 -0700 Subject: [PATCH 1/8] added barycenter_unbalanced --- src/PythonOT.jl | 2 +- src/lib.jl | 46 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/src/PythonOT.jl b/src/PythonOT.jl index 65aace2..60373e0 100644 --- a/src/PythonOT.jl +++ b/src/PythonOT.jl @@ -2,7 +2,7 @@ module PythonOT using PyCall: PyCall -export emd, emd2, sinkhorn, sinkhorn2, barycenter, sinkhorn_unbalanced, sinkhorn_unbalanced2 +export emd, emd2, sinkhorn, sinkhorn2, barycenter, barycenter_unbalanced, sinkhorn_unbalanced, sinkhorn_unbalanced2 const pot = PyCall.PyNULL() diff --git a/src/lib.jl b/src/lib.jl index bbaf491..9ccb051 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -295,3 +295,49 @@ function barycenter(A, C, ε; kwargs...) kwargs..., ) end + +""" + barycenter_unbalanced(A, C, ε, λ; kwargs...) + +Compute the entropically regularized unbalanced Wasserstein barycenter with histograms `A`, cost matrix +`C`, entropic regularization parameter `ε` and marginal relaxation parameter `λ`. + +The Wasserstein barycenter is a histogram and solves +```math +\\inf_{a} \\sum_{i} W_{\\varepsilon,C,λ}(a, a_i), +``` +where the histograms ``a_i`` are columns of matrix `A` and ``W_{\\varepsilon,C,λ}(a, a_i)}`` +is the optimal transport cost for the entropically regularized optimal transport problem +with marginals ``a`` and ``a_i``, cost matrix ``C``, entropic regularization parameter +``\\varepsilon`` and marginal relaxation parameter ``\\lambda``. Optionally, weights of the histograms ``a_i`` can be provided with the +keyword argument `weights`. + +This function is a wrapper of the function +[`barycenter_unbalanced`](https://pythonot.github.io/gen_modules/ot.unbalanced.html#ot.unbalanced.barycenter_unbalanced) in the +Python Optimal Transport package. Keyword arguments are listed in the documentation of the +Python function. + +# Examples + +```jldoctest +julia> A = rand(10, 3); + +julia> A ./= sum(A; dims=1); + +julia> C = rand(10, 10); + +julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 1; method="sinkhorn_stabilized")), 1; atol=1e-4) +false + +julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 1000; method="sinkhorn_stabilized")), 1; atol=1e-4) +true +``` +""" +function barycenter_unbalanced(A, C, ε, λ; kwargs...) + return pot.barycenter_unbalanced( + PyCall.PyReverseDims(permutedims(A)), + PyCall.PyReverseDims(permutedims(C)), + ε, + λ; + kwargs...) +end From c43d9e6c3873c03ad9c9b49537dd7871bbc039d8 Mon Sep 17 00:00:00 2001 From: zsteve Date: Sun, 23 May 2021 19:45:51 -0700 Subject: [PATCH 2/8] update docs --- docs/src/api.md | 1 + src/PythonOT.jl | 9 ++++++++- src/lib.jl | 15 +++++++++------ 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index b19bddd..6db4165 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -20,4 +20,5 @@ barycenter ```@docs sinkhorn_unbalanced sinkhorn_unbalanced2 +barycenter_unbalanced ``` diff --git a/src/PythonOT.jl b/src/PythonOT.jl index 60373e0..2622886 100644 --- a/src/PythonOT.jl +++ b/src/PythonOT.jl @@ -2,7 +2,14 @@ module PythonOT using PyCall: PyCall -export emd, emd2, sinkhorn, sinkhorn2, barycenter, barycenter_unbalanced, sinkhorn_unbalanced, sinkhorn_unbalanced2 +export emd, + emd2, + sinkhorn, + sinkhorn2, + barycenter, + barycenter_unbalanced, + sinkhorn_unbalanced, + sinkhorn_unbalanced2 const pot = PyCall.PyNULL() diff --git a/src/lib.jl b/src/lib.jl index 9ccb051..a7ba7d1 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -329,15 +329,18 @@ julia> C = rand(10, 10); julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 1; method="sinkhorn_stabilized")), 1; atol=1e-4) false -julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 1000; method="sinkhorn_stabilized")), 1; atol=1e-4) +julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 10000; method="sinkhorn_stabilized")), 1; atol=1e-4) true ``` + +See also: [`barycenter`](@ref) """ function barycenter_unbalanced(A, C, ε, λ; kwargs...) return pot.barycenter_unbalanced( - PyCall.PyReverseDims(permutedims(A)), - PyCall.PyReverseDims(permutedims(C)), - ε, - λ; - kwargs...) + PyCall.PyReverseDims(permutedims(A)), + PyCall.PyReverseDims(permutedims(C)), + ε, + λ; + kwargs..., + ) end From 70de2054897a11ca7ff936cb04fb63deb0798d89 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 24 May 2021 16:46:24 +0200 Subject: [PATCH 3/8] Fix format --- src/lib.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lib.jl b/src/lib.jl index 2e1020c..0e9c38b 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -362,4 +362,6 @@ true See also: [`barycenter`](@ref) """ -barycenter_unbalanced(A, C, ε, λ; kwargs...) = pot.barycenter_unbalanced(A, C, ε, λ; kwargs...) +function barycenter_unbalanced(A, C, ε, λ; kwargs...) + return pot.barycenter_unbalanced(A, C, ε, λ; kwargs...) +end From b76f6330930eabcfa387dd47c83c734350a0daed Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 24 May 2021 17:24:35 +0200 Subject: [PATCH 4/8] Reduce parameter --- src/lib.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib.jl b/src/lib.jl index 0e9c38b..9229544 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -331,9 +331,9 @@ Compute the entropically regularized unbalanced Wasserstein barycenter with hist The Wasserstein barycenter is a histogram and solves ```math -\\inf_{a} \\sum_{i} W_{\\varepsilon,C,λ}(a, a_i), +\\inf_{a} \\sum_{i} W_{\\varepsilon,C,\\lambda}(a, a_i), ``` -where the histograms ``a_i`` are columns of matrix `A` and ``W_{\\varepsilon,C,λ}(a, a_i)}`` +where the histograms ``a_i`` are columns of matrix `A` and ``W_{\\varepsilon,C,\\lambda}(a, a_i)}`` is the optimal transport cost for the entropically regularized optimal transport problem with marginals ``a`` and ``a_i``, cost matrix ``C``, entropic regularization parameter ``\\varepsilon`` and marginal relaxation parameter ``\\lambda``. Optionally, weights of the histograms ``a_i`` can be provided with the @@ -356,7 +356,7 @@ julia> C = rand(10, 10); julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 1; method="sinkhorn_stabilized")), 1; atol=1e-4) false -julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 10000; method="sinkhorn_stabilized")), 1; atol=1e-4) +julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 1_000; method="sinkhorn_stabilized")), 1; atol=1e-4) true ``` From e24e45f63e7ff2082980ca9a20a16df49e4e0a69 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 24 May 2021 17:47:43 +0200 Subject: [PATCH 5/8] Update lib.jl --- src/lib.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.jl b/src/lib.jl index 9229544..66c545c 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -356,7 +356,7 @@ julia> C = rand(10, 10); julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 1; method="sinkhorn_stabilized")), 1; atol=1e-4) false -julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 1_000; method="sinkhorn_stabilized")), 1; atol=1e-4) +julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 5_000; method="sinkhorn_stabilized")), 1; atol=1e-4) true ``` From 9673c92554bb26a1dfaf163fe3c7dd17ca983d5f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 24 May 2021 17:57:23 +0200 Subject: [PATCH 6/8] Increase iterations --- src/lib.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lib.jl b/src/lib.jl index 66c545c..0184f4b 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -356,7 +356,9 @@ julia> C = rand(10, 10); julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 1; method="sinkhorn_stabilized")), 1; atol=1e-4) false -julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 5_000; method="sinkhorn_stabilized")), 1; atol=1e-4) +julia> isapprox(sum(barycenter_unbalanced( + A, C, 0.01, 5_000; method="sinkhorn_stabilized", numItermax=5_000 + )), 1; atol=1e-4) true ``` From 284bdd9ad15dc5b469b6577f49052ec69489a67f Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 24 May 2021 18:01:04 +0200 Subject: [PATCH 7/8] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 46b166a..b31e817 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PythonOT" uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef" authors = ["David Widmann"] -version = "0.1.1" +version = "0.1.2" [deps] PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" From 5f8cdb967df13a9f864ead3852190d1ed95c2998 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 24 May 2021 18:05:27 +0200 Subject: [PATCH 8/8] Increase parameter again... --- src/lib.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib.jl b/src/lib.jl index 0184f4b..da10c79 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -357,7 +357,7 @@ julia> isapprox(sum(barycenter_unbalanced(A, C, 0.01, 1; method="sinkhorn_stabil false julia> isapprox(sum(barycenter_unbalanced( - A, C, 0.01, 5_000; method="sinkhorn_stabilized", numItermax=5_000 + A, C, 0.01, 10_000; method="sinkhorn_stabilized", numItermax=5_000 )), 1; atol=1e-4) true ```