From ef1fa2b34d4587eb682aa5f58dcc84bfd145bf7c Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 1 Jun 2021 13:28:42 +0200 Subject: [PATCH 1/4] Add `emd_1d` and `emd2_1d` --- src/PythonOT.jl | 2 ++ src/lib.jl | 73 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/src/PythonOT.jl b/src/PythonOT.jl index b1cd684..9b3c534 100644 --- a/src/PythonOT.jl +++ b/src/PythonOT.jl @@ -4,6 +4,8 @@ using PyCall: PyCall export emd, emd2, + emd_1d, + emd2_1d, sinkhorn, sinkhorn2, barycenter, diff --git a/src/lib.jl b/src/lib.jl index da10c79..f2ef283 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -69,6 +69,79 @@ function emd2(μ, ν, C; kwargs...) return pot.lp.emd2(μ, ν, PyCall.PyReverseDims(permutedims(C)); kwargs...) end +""" + emd_1d(xsource, xtarget; kwargs...) + +Compute the optimal transport plan for the Monge-Kantorovich problem with univariate +discrete measures with support `xsource` and `xtarget` as source and target marginals. + +This function is a wrapper of the function +[`emd_1d`](https://pythonot.github.io/all.html#ot.emd_1d) in the Python Optimal Transport +package. Keyword arguments are listed in the documentation of the Python function. + +# Examples + +```jldoctest +julia> xsource = [0.2, 0.5]; + +julia> xtarget = [0.8, 0.3]; + +julia> emd_1d(xsource, xtarget) +2×2 Matrix{Float64}: + 0.0 0.5 + 0.5 0.0 + +julia> histogram_source = [0.8, 0.2]; + +julia> histogram_target = [0.7, 0.3]; + +julia> emd_1d(xsource, xtarget; a=histogram_source, b=histogram_target) +2×2 Matrix{Float64}: + 0.5 0.3 + 0.2 0.0 +``` + +See also: [`emd`](@ref), [`emd2_1d`](@ref) +""" +function emd_1d(xsource, xtarget; kwargs...) + return pot.lp.emd_1d(xsource, xtarget; kwargs...) +end + + +""" + emd2_1d(xsource, xtarget; kwargs...) + +Compute the optimal transport cost for the Monge-Kantorovich problem with univariate +discrete measures with support `xsource` and `xtarget` as source and target marginals. + +This function is a wrapper of the function +[`emd2_1d`](https://pythonot.github.io/all.html#ot.emd2_1d) in the Python Optimal Transport +package. Keyword arguments are listed in the documentation of the Python function. + +# Examples + +```jldoctest +julia> xsource = [0.2, 0.5]; + +julia> xtarget = [0.8, 0.3]; + +julia> round(emd2_1d(xsource, xtarget); sigdigits=6) +0.05 + +julia> histogram_source = [0.8, 0.2]; + +julia> histogram_target = [0.7, 0.3]; + +julia> round(emd2_1d(xsource, xtarget; a=histogram_source, b=histogram_target); sigdigits=6) +0.201 +``` + +See also: [`emd2`](@ref), [`emd2_1d`](@ref) +""" +function emd2_1d(xsource, xtarget; kwargs...) + return pot.lp.emd2_1d(xsource, xtarget; kwargs...) +end + """ sinkhorn(μ, ν, C, ε; kwargs...) From 7579f1f2e72e71029c2fb6106da738f554357ab6 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 1 Jun 2021 13:31:16 +0200 Subject: [PATCH 2/4] Bump version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b31e817..42f0260 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "PythonOT" uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef" authors = ["David Widmann"] -version = "0.1.2" +version = "0.1.3" [deps] PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" From f5fee2b3a99cbd7903e4ba7a5ad8674f53276249 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 1 Jun 2021 13:33:19 +0200 Subject: [PATCH 3/4] Fix format Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/lib.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lib.jl b/src/lib.jl index f2ef283..219820f 100644 --- a/src/lib.jl +++ b/src/lib.jl @@ -107,7 +107,6 @@ function emd_1d(xsource, xtarget; kwargs...) return pot.lp.emd_1d(xsource, xtarget; kwargs...) end - """ emd2_1d(xsource, xtarget; kwargs...) From 85f195ce791df4d420f2e9103100242838e04c52 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 1 Jun 2021 14:41:16 +0200 Subject: [PATCH 4/4] Update documentation --- docs/src/api.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/src/api.md b/docs/src/api.md index 157af1d..feaa2d7 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -5,6 +5,8 @@ ```@docs emd emd2 +emd_1d +emd2_1d ``` ## Regularized optimal transport