Skip to content

Commit af649a0

Browse files
Add emd_1d and emd2_1d (#12)
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 53806cf commit af649a0

File tree

4 files changed

+77
-1
lines changed

4 files changed

+77
-1
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "PythonOT"
22
uuid = "3c485715-4278-42b2-9b5f-8f00e43c12ef"
33
authors = ["David Widmann"]
4-
version = "0.1.2"
4+
version = "0.1.3"
55

66
[deps]
77
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"

docs/src/api.md

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
```@docs
66
emd
77
emd2
8+
emd_1d
9+
emd2_1d
810
```
911

1012
## Regularized optimal transport

src/PythonOT.jl

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ using PyCall: PyCall
44

55
export emd,
66
emd2,
7+
emd_1d,
8+
emd2_1d,
79
sinkhorn,
810
sinkhorn2,
911
barycenter,

src/lib.jl

+72
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,78 @@ function emd2(μ, ν, C; kwargs...)
6969
return pot.lp.emd2(μ, ν, PyCall.PyReverseDims(permutedims(C)); kwargs...)
7070
end
7171

72+
"""
73+
emd_1d(xsource, xtarget; kwargs...)
74+
75+
Compute the optimal transport plan for the Monge-Kantorovich problem with univariate
76+
discrete measures with support `xsource` and `xtarget` as source and target marginals.
77+
78+
This function is a wrapper of the function
79+
[`emd_1d`](https://pythonot.github.io/all.html#ot.emd_1d) in the Python Optimal Transport
80+
package. Keyword arguments are listed in the documentation of the Python function.
81+
82+
# Examples
83+
84+
```jldoctest
85+
julia> xsource = [0.2, 0.5];
86+
87+
julia> xtarget = [0.8, 0.3];
88+
89+
julia> emd_1d(xsource, xtarget)
90+
2×2 Matrix{Float64}:
91+
0.0 0.5
92+
0.5 0.0
93+
94+
julia> histogram_source = [0.8, 0.2];
95+
96+
julia> histogram_target = [0.7, 0.3];
97+
98+
julia> emd_1d(xsource, xtarget; a=histogram_source, b=histogram_target)
99+
2×2 Matrix{Float64}:
100+
0.5 0.3
101+
0.2 0.0
102+
```
103+
104+
See also: [`emd`](@ref), [`emd2_1d`](@ref)
105+
"""
106+
function emd_1d(xsource, xtarget; kwargs...)
107+
return pot.lp.emd_1d(xsource, xtarget; kwargs...)
108+
end
109+
110+
"""
111+
emd2_1d(xsource, xtarget; kwargs...)
112+
113+
Compute the optimal transport cost for the Monge-Kantorovich problem with univariate
114+
discrete measures with support `xsource` and `xtarget` as source and target marginals.
115+
116+
This function is a wrapper of the function
117+
[`emd2_1d`](https://pythonot.github.io/all.html#ot.emd2_1d) in the Python Optimal Transport
118+
package. Keyword arguments are listed in the documentation of the Python function.
119+
120+
# Examples
121+
122+
```jldoctest
123+
julia> xsource = [0.2, 0.5];
124+
125+
julia> xtarget = [0.8, 0.3];
126+
127+
julia> round(emd2_1d(xsource, xtarget); sigdigits=6)
128+
0.05
129+
130+
julia> histogram_source = [0.8, 0.2];
131+
132+
julia> histogram_target = [0.7, 0.3];
133+
134+
julia> round(emd2_1d(xsource, xtarget; a=histogram_source, b=histogram_target); sigdigits=6)
135+
0.201
136+
```
137+
138+
See also: [`emd2`](@ref), [`emd2_1d`](@ref)
139+
"""
140+
function emd2_1d(xsource, xtarget; kwargs...)
141+
return pot.lp.emd2_1d(xsource, xtarget; kwargs...)
142+
end
143+
72144
"""
73145
sinkhorn(μ, ν, C, ε; kwargs...)
74146

0 commit comments

Comments
 (0)