From e95648f999def1c44231db857df9dd0390a8c4b1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Sun, 2 May 2021 17:55:28 +0200 Subject: [PATCH] Update CI and dependencies + format code consistently (#8) --- .JuliaFormatter.toml | 1 + .github/workflows/{ci.yml => CI.yml} | 24 +----- .github/workflows/CompatHelper.yml | 17 ++++- .github/workflows/DocCleanup.yml | 26 ------- .github/workflows/Docs.yml | 27 +++++++ .github/workflows/DocsPreviewCleanup.yml | 26 +++++++ .github/workflows/Format.yml | 20 +++++ Project.toml | 6 +- README.md | 7 +- docs/Manifest.toml | 97 ++++++++++++++++++++---- docs/Project.toml | 5 ++ docs/make.jl | 24 ++++-- src/StochasticOptimalTransport.jl | 15 ++-- src/discrete.jl | 16 ++-- src/semidiscrete.jl | 31 ++------ src/utils.jl | 32 +++----- test/Project.toml | 2 + test/discrete.jl | 16 ++-- test/runtests.jl | 24 ++++-- test/semidiscrete.jl | 26 +++---- 20 files changed, 271 insertions(+), 171 deletions(-) create mode 100644 .JuliaFormatter.toml rename .github/workflows/{ci.yml => CI.yml} (70%) delete mode 100644 .github/workflows/DocCleanup.yml create mode 100644 .github/workflows/Docs.yml create mode 100644 .github/workflows/DocsPreviewCleanup.yml create mode 100644 .github/workflows/Format.yml diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml new file mode 100644 index 0000000..1e72b50 --- /dev/null +++ b/.JuliaFormatter.toml @@ -0,0 +1 @@ +style="blue" diff --git a/.github/workflows/ci.yml b/.github/workflows/CI.yml similarity index 70% rename from .github/workflows/ci.yml rename to .github/workflows/CI.yml index 88bdad5..d8d1e52 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/CI.yml @@ -3,7 +3,7 @@ name: CI on: push: branches: - - master + - main pull_request: jobs: @@ -60,25 +60,3 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} path-to-lcov: lcov.info - docs: - name: Documentation - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v2 - - uses: julia-actions/setup-julia@v1 - with: - version: '1' - - run: | - julia --project=docs -e ' - using Pkg - Pkg.develop(PackageSpec(path=pwd())) - Pkg.instantiate()' - - run: | - julia --project=docs -e ' - using Documenter: doctest - using StochasticOptimalTransport - doctest(StochasticOptimalTransport)' - - run: julia --project=docs docs/make.jl - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} diff --git a/.github/workflows/CompatHelper.yml b/.github/workflows/CompatHelper.yml index cba9134..428ecee 100644 --- a/.github/workflows/CompatHelper.yml +++ b/.github/workflows/CompatHelper.yml @@ -7,10 +7,19 @@ jobs: CompatHelper: runs-on: ubuntu-latest steps: - - name: Pkg.add("CompatHelper") - run: julia -e 'using Pkg; Pkg.add("CompatHelper")' - - name: CompatHelper.main() + - name: "Install CompatHelper" + run: | + import Pkg + name = "CompatHelper" + uuid = "aa819f21-2bde-4658-8897-bab36330d9b7" + version = "2" + Pkg.add(; name, uuid, version) + shell: julia --color=yes {0} + - name: "Run CompatHelper" + run: | + import CompatHelper + CompatHelper.main(; subdirs=["", "test", "docs"]) + shell: julia --color=yes {0} env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }} - run: julia -e 'using CompatHelper; CompatHelper.main()' diff --git a/.github/workflows/DocCleanup.yml b/.github/workflows/DocCleanup.yml deleted file mode 100644 index 7a867d1..0000000 --- a/.github/workflows/DocCleanup.yml +++ /dev/null @@ -1,26 +0,0 @@ -name: Doc Preview Cleanup - -on: - pull_request: - types: [closed] - -jobs: - doc-preview-cleanup: - runs-on: ubuntu-latest - steps: - - name: Checkout gh-pages branch - uses: actions/checkout@v2 - with: - ref: gh-pages - - name: Delete preview and history - run: | - git config user.name "Documenter.jl" - git config user.email "documenter@juliadocs.github.io" - git rm -rf "previews/PR$PRNUM" - git commit -m "delete preview" - git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree}) - env: - PRNUM: ${{ github.event.number }} - - name: Push changes - run: | - git push --force origin gh-pages-new:gh-pages diff --git a/.github/workflows/Docs.yml b/.github/workflows/Docs.yml new file mode 100644 index 0000000..68c5eb0 --- /dev/null +++ b/.github/workflows/Docs.yml @@ -0,0 +1,27 @@ +name: Docs + +on: + push: + branches: + - main + tags: '*' + pull_request: + +jobs: + docs: + name: Documentation + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@v1 + with: + version: '1' + - run: | + julia --project=docs -e ' + using Pkg + Pkg.develop(PackageSpec(path=pwd())) + Pkg.instantiate()' + - run: julia --project=docs docs/make.jl + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} diff --git a/.github/workflows/DocsPreviewCleanup.yml b/.github/workflows/DocsPreviewCleanup.yml new file mode 100644 index 0000000..4f57bc4 --- /dev/null +++ b/.github/workflows/DocsPreviewCleanup.yml @@ -0,0 +1,26 @@ +name: DocsPreviewCleanup + +on: + pull_request: + types: [closed] + +jobs: + cleanup: + runs-on: ubuntu-latest + steps: + - name: Checkout gh-pages branch + uses: actions/checkout@v2 + with: + ref: gh-pages + - name: Delete preview and history + push changes + run: | + if [ -d "previews/PR$PRNUM" ]; then + git config user.name "Documenter.jl" + git config user.email "documenter@juliadocs.github.io" + git rm -rf "previews/PR$PRNUM" + git commit -m "delete preview" + git branch gh-pages-new $(echo "delete history" | git commit-tree HEAD^{tree}) + git push --force origin gh-pages-new:gh-pages + fi + env: + PRNUM: ${{ github.event.number }} diff --git a/.github/workflows/Format.yml b/.github/workflows/Format.yml new file mode 100644 index 0000000..3fd18fa --- /dev/null +++ b/.github/workflows/Format.yml @@ -0,0 +1,20 @@ +name: Format + +on: + pull_request: + +jobs: + format: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: julia-actions/setup-julia@latest + with: + version: 1 + - run: | + julia -e 'using Pkg; Pkg.add("JuliaFormatter")' + julia -e 'using JuliaFormatter; format("."; verbose=true)' + - uses: reviewdog/action-suggester@v1 + with: + tool_name: JuliaFormatter + fail_on_error: true diff --git a/Project.toml b/Project.toml index 5b39ac2..022569b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,14 +1,14 @@ name = "StochasticOptimalTransport" uuid = "d0107fbf-5e6e-4997-ac8d-099d2392f4a6" authors = ["David Widmann "] -version = "0.1.0" +version = "0.1.1" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] -StatsFuns = "0.9" +LogExpFunctions = "0.2" julia = "1.3" diff --git a/README.md b/README.md index 5c419c2..a2db873 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,10 @@ Julia implementation of stochastic optimization algorithms for large-scale optim [![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://devmotion.github.io/StochasticOptimalTransport.jl/stable) [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://devmotion.github.io/StochasticOptimalTransport.jl/dev) -[![Build Status](https://github.com/devmotion/StochasticOptimalTransport.jl/workflows/CI/badge.svg?branch=master)](https://github.com/devmotion/StochasticOptimalTransport.jl/actions?query=workflow%3ACI%20branch%3Amaster) -[![Coverage](https://codecov.io/gh/devmotion/StochasticOptimalTransport.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/devmotion/StochasticOptimalTransport.jl) -[![Coverage](https://coveralls.io/repos/github/devmotion/StochasticOptimalTransport.jl/badge.svg?branch=master)](https://coveralls.io/github/devmotion/StochasticOptimalTransport.jl?branch=master) +[![Build Status](https://github.com/devmotion/StochasticOptimalTransport.jl/workflows/CI/badge.svg?branch=main)](https://github.com/devmotion/StochasticOptimalTransport.jl/actions?query=workflow%3ACI%20branch%3Amain) +[![Coverage](https://codecov.io/gh/devmotion/StochasticOptimalTransport.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/devmotion/StochasticOptimalTransport.jl) +[![Coverage](https://coveralls.io/repos/github/devmotion/StochasticOptimalTransport.jl/badge.svg?branch=main)](https://coveralls.io/github/devmotion/StochasticOptimalTransport.jl?branch=main) +[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle) # Bibliography diff --git a/docs/Manifest.toml b/docs/Manifest.toml index 2aa3878..f50b22e 100644 --- a/docs/Manifest.toml +++ b/docs/Manifest.toml @@ -1,5 +1,11 @@ # This file is machine-generated - editing it directly is not advised +[[ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" + +[[Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + [[Base64]] uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" @@ -7,21 +13,21 @@ uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" deps = ["Printf"] uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" -[[Distributed]] -deps = ["Random", "Serialization", "Sockets"] -uuid = "8ba89e20-285c-5b6f-9357-94700520ee1b" - [[DocStringExtensions]] deps = ["LibGit2", "Markdown", "Pkg", "Test"] -git-tree-sha1 = "50ddf44c53698f5e784bbebb3f4b21c5807401b1" +git-tree-sha1 = "9d4f64f79012636741cf01133158a54b24924c32" uuid = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -version = "0.8.3" +version = "0.8.4" [[Documenter]] deps = ["Base64", "Dates", "DocStringExtensions", "IOCapture", "InteractiveUtils", "JSON", "LibGit2", "Logging", "Markdown", "REPL", "Test", "Unicode"] -git-tree-sha1 = "c01a7e8bcf7a6693444a52a0c5ac8b4e9528600e" +git-tree-sha1 = "3ebb967819b284dc1e3c0422229b58a40a255649" uuid = "e30172f5-a6a5-5a46-863b-614d45cd2de4" -version = "0.26.0" +version = "0.26.3" + +[[Downloads]] +deps = ["ArgTools", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" [[IOCapture]] deps = ["Logging"] @@ -39,13 +45,35 @@ git-tree-sha1 = "81690084b6198a2e1da36fcfda16eeca9f9f24e4" uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" version = "0.21.1" +[[LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" + +[[LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" + [[LibGit2]] -deps = ["Printf"] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" +[[LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" + [[Libdl]] uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" +[[LinearAlgebra]] +deps = ["Libdl"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[LogExpFunctions]] +deps = ["DocStringExtensions", "LinearAlgebra"] +git-tree-sha1 = "ed26854d7c2c867d143f0e07c198fc9e8b721d10" +uuid = "2ab3a3ac-af41-5b50-aa03-7779005ae688" +version = "0.2.3" + [[Logging]] uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" @@ -53,17 +81,27 @@ uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" deps = ["Base64"] uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +[[MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" + [[Mmap]] uuid = "a63ad114-7e13-5084-954f-fe012c677804" +[[MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" + +[[NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" + [[Parsers]] deps = ["Dates"] -git-tree-sha1 = "6370b5b3cf2ce5a3d2b6f7ab2dc10f374e4d7d2b" +git-tree-sha1 = "c8abc88faa3f7a3950832ac5d6e690881590d6dc" uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" -version = "1.0.14" +version = "1.1.0" [[Pkg]] -deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" [[Printf]] @@ -71,7 +109,7 @@ deps = ["Unicode"] uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" [[REPL]] -deps = ["InteractiveUtils", "Markdown", "Sockets"] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" [[Random]] @@ -87,13 +125,30 @@ uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" [[Sockets]] uuid = "6462fe0b-24de-5631-8697-dd941f90decc" +[[SparseArrays]] +deps = ["LinearAlgebra", "Random"] +uuid = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" + +[[Statistics]] +deps = ["LinearAlgebra", "SparseArrays"] +uuid = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" + [[StochasticOptimalTransport]] +deps = ["LinearAlgebra", "LogExpFunctions", "Random", "Statistics"] path = ".." uuid = "d0107fbf-5e6e-4997-ac8d-099d2392f4a6" -version = "0.1.0" +version = "0.1.1" + +[[TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" + +[[Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" [[Test]] -deps = ["Distributed", "InteractiveUtils", "Logging", "Random"] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [[UUIDs]] @@ -102,3 +157,15 @@ uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" [[Unicode]] uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" + +[[nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" + +[[p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" diff --git a/docs/Project.toml b/docs/Project.toml index b564819..ab5eb67 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -1,3 +1,8 @@ [deps] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" StochasticOptimalTransport = "d0107fbf-5e6e-4997-ac8d-099d2392f4a6" + +[compat] +Documenter = "0.26" +StochasticOptimalTransport = "0.1" +julia = "1.3" diff --git a/docs/make.jl b/docs/make.jl index e4107b3..70255c2 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,6 +1,19 @@ -using StochasticOptimalTransport using Documenter +# Print `@debug` statements (https://github.com/JuliaDocs/Documenter.jl/issues/955) +if haskey(ENV, "GITHUB_ACTIONS") + ENV["JULIA_DEBUG"] = "Documenter" +end + +using StochasticOptimalTransport + +DocMeta.setdocmeta!( + StochasticOptimalTransport, + :DocTestSetup, + :(using StochasticOptimalTransport); + recursive=true, +) + makedocs(; modules=[StochasticOptimalTransport], authors="David Widmann ", @@ -11,12 +24,13 @@ makedocs(; canonical="https://devmotion.github.io/StochasticOptimalTransport.jl", assets=String[], ), - pages=[ - "Home" => "index.md", - ], + pages=["Home" => "index.md"], + strict=true, + checkdocs=:exports, ) deploydocs(; repo="github.com/devmotion/StochasticOptimalTransport.jl", - push_preview = true, + push_preview=true, + devbranch="main", ) diff --git a/src/StochasticOptimalTransport.jl b/src/StochasticOptimalTransport.jl index 8af79d2..0b6e5d5 100644 --- a/src/StochasticOptimalTransport.jl +++ b/src/StochasticOptimalTransport.jl @@ -1,10 +1,10 @@ module StochasticOptimalTransport -import StatsFuns +using LogExpFunctions: LogExpFunctions -import LinearAlgebra -import Random -import Statistics +using LinearAlgebra: LinearAlgebra +using Random: Random +using Statistics: Statistics include("utils.jl") include("discrete.jl") @@ -60,12 +60,7 @@ Peyré, Gabriel, & Marco Cuturi (2019). Computational Optimal Transport. Foundat """ wasserstein(args...; kwargs...) = wasserstein(Random.GLOBAL_RNG, args...; kwargs...) function wasserstein( - rng::Random.AbstractRNG, - c, - μ, - ν, - ε::Union{Real,Nothing} = nothing; - kwargs..., + rng::Random.AbstractRNG, c, μ, ν, ε::Union{Real,Nothing}=nothing; kwargs... ) # approximate solution `v` of the dual problem v = dual_v(rng, c, μ, ν, ε; kwargs...) diff --git a/src/discrete.jl b/src/discrete.jl index ceeb57e..7fd7359 100644 --- a/src/discrete.jl +++ b/src/discrete.jl @@ -1,11 +1,5 @@ function dual_cost( - rng::Random.AbstractRNG, - c, - v, - μ::DiscreteMeasure, - ν::DiscreteMeasure, - ε; - kwargs..., + rng::Random.AbstractRNG, c, v, μ::DiscreteMeasure, ν::DiscreteMeasure, ε; kwargs... ) # compute mean c-transform mean_ctransform = sum(p * ctransform(c, v, x, ν, ε) for (x, p) in zip(μ.xs, μ.ps)) @@ -19,10 +13,10 @@ function dual_v( μ::DiscreteMeasure, ν::DiscreteMeasure, ε; - stepsize = 1, - maxiters::Int = 10_000, - atol = 0, - rtol = iszero(atol) ? typeof(float(atol))(1 // 10_000) : 0, + stepsize=1, + maxiters::Int=10_000, + atol=0, + rtol=iszero(atol) ? typeof(float(atol))(1//10_000) : 0, ) # initial iterates k = 1 diff --git a/src/semidiscrete.jl b/src/semidiscrete.jl index f5acf11..6fb3eeb 100644 --- a/src/semidiscrete.jl +++ b/src/semidiscrete.jl @@ -1,12 +1,4 @@ -function dual_cost( - rng::Random.AbstractRNG, - c, - v, - μ::DiscreteMeasure, - ν, - ε; - kwargs..., -) +function dual_cost(rng::Random.AbstractRNG, c, v, μ::DiscreteMeasure, ν, ε; kwargs...) return dual_cost(rng, c, v, ν, μ, ε; kwargs...) end function dual_cost( @@ -16,7 +8,7 @@ function dual_cost( μ, ν::DiscreteMeasure, ε; - montecarlo_samples = 10_000, + montecarlo_samples=10_000, kwargs..., ) # compute MC estimate of the expected c-transform with respect to `μ` @@ -27,14 +19,7 @@ function dual_cost( return LinearAlgebra.dot(v, ν.ps) + mean_ctransform end -function dual_v( - rng::Random.AbstractRNG, - c, - μ::DiscreteMeasure, - ν, - ε; - kwargs..., -) +function dual_v(rng::Random.AbstractRNG, c, μ::DiscreteMeasure, ν, ε; kwargs...) return dual_v(rng, c, ν, μ, ε; kwargs...) end function dual_v( @@ -43,11 +28,11 @@ function dual_v( μ, ν::DiscreteMeasure, ε; - maxiters::Int = 10_000, - stepsize = 1, - warmup_phase = 1, - atol = 0, - rtol = iszero(atol) ? typeof(float(atol))(1 // 10_000) : 0, + maxiters::Int=10_000, + stepsize=1, + warmup_phase=1, + atol=0, + rtol=iszero(atol) ? typeof(float(atol))(1//10_000) : 0, ) # initial iterates k = 1 diff --git a/src/utils.jl b/src/utils.jl index ff51b0c..d0a9e7f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -5,7 +5,7 @@ struct DiscreteMeasure{X<:AbstractVector,P<:AbstractVector} function DiscreteMeasure{X,P}(xs::X, ps::P) where {X,P} length(xs) == length(ps) || error("length of support `xs` and probabilities `ps` must be equal") - new{X,P}(xs, ps) + return new{X,P}(xs, ps) end end @@ -31,8 +31,8 @@ end # initial gradient step (regularized gradient) function gradient_step(c, τ, ν::DiscreteMeasure, x, ε::Real) - tmp = @. - c((x,), ν.xs) / ε - StatsFuns.softmax!(tmp) + tmp = @. -c((x,), ν.xs) / ε + LogExpFunctions.softmax!(tmp) z = @. τ * (ν.ps - tmp) return z end @@ -47,7 +47,7 @@ function gradient_step!( x, tmp::AbstractVector, ::Nothing; - reset::Bool = false, + reset::Bool=false, ) @. tmp = c((x,), ν.xs) - v if reset @@ -69,10 +69,10 @@ function gradient_step!( x, tmp::AbstractVector, ε::Real; - reset::Bool = false, + reset::Bool=false, ) @. tmp = (v - c((x,), ν.xs)) / ε - StatsFuns.softmax!(tmp) + LogExpFunctions.softmax!(tmp) if reset @. z = τ * (ν.ps - tmp) else @@ -92,24 +92,12 @@ v^{c,ε}(x) = \begin{cases} \end{cases} ``` """ -function ctransform( - c, - v::AbstractVector, - x, - ν::DiscreteMeasure, - ::Nothing, -) +function ctransform(c, v::AbstractVector, x, ν::DiscreteMeasure, ::Nothing) return minimum(c(x, yᵢ) - vᵢ for (vᵢ, yᵢ) in zip(v, ν.xs)) end -function ctransform( - c, - v::AbstractVector, - x, - ν::DiscreteMeasure, - ε::Real, -) - t = StatsFuns.logsumexp( +function ctransform(c, v::AbstractVector, x, ν::DiscreteMeasure, ε::Real) + t = LogExpFunctions.logsumexp( (vᵢ - c(x, yᵢ)) / ε + log(νᵢ) for (vᵢ, yᵢ, νᵢ) in zip(v, ν.xs, ν.ps) ) - return - ε * (t + 1) + return -ε * (t + 1) end diff --git a/test/Project.toml b/test/Project.toml index 20f9226..60e0c9d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,10 @@ [deps] Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Distributions = "0.24" +Documenter = "0.26" julia = "1.3" diff --git a/test/discrete.jl b/test/discrete.jl index cbc8b8b..cccbb71 100644 --- a/test/discrete.jl +++ b/test/discrete.jl @@ -8,9 +8,9 @@ μ = SOT.DiscreteMeasure(xs, ps) ν = SOT.DiscreteMeasure(xs, ps) - @test SOT.wasserstein(c, μ, ν; stepsize = 0.05) ≈ 0 atol=1e-4 - @test SOT.wasserstein(c, μ, ν, 1e-6; stepsize = 0.05) ≈ 0 atol=1e-4 - @test SOT.wasserstein(c, μ, ν, 1e-3; stepsize = 0.05) ≈ 0 atol=1e-4 + @test SOT.wasserstein(c, μ, ν; stepsize=0.05) ≈ 0 atol = 1e-4 + @test SOT.wasserstein(c, μ, ν, 1e-6; stepsize=0.05) ≈ 0 atol = 1e-4 + @test SOT.wasserstein(c, μ, ν, 1e-3; stepsize=0.05) ≈ 0 atol = 1e-4 end @testset "uniform weights" begin @@ -19,15 +19,15 @@ n = 5 xs = randn(n) ys = randn(n) - ps = fill(1/n, n) + ps = fill(1 / n, n) μ = SOT.DiscreteMeasure(xs, ps) ν = SOT.DiscreteMeasure(ys, ps) # analytic Wasserstein distance d = sum(abs, sort(xs) .- sort(ys)) / n - @test SOT.wasserstein(c, μ, ν; stepsize = 0.05) ≈ d atol=5e-2 - @test SOT.wasserstein(c, μ, ν, 1e-6; stepsize = 0.05) ≈ d atol=5e-2 - @test SOT.wasserstein(c, μ, ν, 1e-3; stepsize = 0.05) ≈ d atol=5e-2 + @test SOT.wasserstein(c, μ, ν; stepsize=0.05) ≈ d atol = 5e-2 + @test SOT.wasserstein(c, μ, ν, 1e-6; stepsize=0.05) ≈ d atol = 5e-2 + @test SOT.wasserstein(c, μ, ν, 1e-3; stepsize=0.05) ≈ d atol = 5e-2 end -end \ No newline at end of file +end diff --git a/test/runtests.jl b/test/runtests.jl index e6ee95a..aebf87f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,7 +1,6 @@ using StochasticOptimalTransport - using Distributions - +using Documenter using Random using Test @@ -10,7 +9,22 @@ Random.seed!(1234) const SOT = StochasticOptimalTransport @testset "StochasticOptimalTransport.jl" begin - @testset "Utilities" begin include("utils.jl") end - @testset "Discrete OT" begin include("discrete.jl") end - @testset "Semi-discrete OT" begin include("semidiscrete.jl") end + @testset "Utilities" begin + include("utils.jl") + end + @testset "Discrete OT" begin + include("discrete.jl") + end + @testset "Semi-discrete OT" begin + include("semidiscrete.jl") + end + @testset "doctests" begin + DocMeta.setdocmeta!( + StochasticOptimalTransport, + :DocTestSetup, + :(using StochasticOptimalTransport); + recursive=true, + ) + doctest(StochasticOptimalTransport) + end end diff --git a/test/semidiscrete.jl b/test/semidiscrete.jl index 53ed67b..1ea0428 100644 --- a/test/semidiscrete.jl +++ b/test/semidiscrete.jl @@ -8,13 +8,13 @@ μ = DiscreteNonParametric(xs, ps) ν = SOT.DiscreteMeasure(xs, ps) - @test SOT.wasserstein(c, μ, ν) ≈ 0 atol=2e-2 - @test SOT.wasserstein(c, μ, ν, 1e-6) ≈ 0 atol=2e-2 - @test SOT.wasserstein(c, μ, ν, 1e-3) ≈ 0 atol=2e-2 + @test SOT.wasserstein(c, μ, ν) ≈ 0 atol = 2e-2 + @test SOT.wasserstein(c, μ, ν, 1e-6) ≈ 0 atol = 2e-2 + @test SOT.wasserstein(c, μ, ν, 1e-3) ≈ 0 atol = 2e-2 - @test SOT.wasserstein(c, ν, μ) ≈ 0 atol=2e-2 - @test SOT.wasserstein(c, ν, μ, 1e-6) ≈ 0 atol=2e-2 - @test SOT.wasserstein(c, ν, μ, 1e-3) ≈ 0 atol=2e-2 + @test SOT.wasserstein(c, ν, μ) ≈ 0 atol = 2e-2 + @test SOT.wasserstein(c, ν, μ, 1e-6) ≈ 0 atol = 2e-2 + @test SOT.wasserstein(c, ν, μ, 1e-3) ≈ 0 atol = 2e-2 end @testset "uniform weights" begin @@ -23,19 +23,19 @@ n = 5 xs = randn(n) ys = randn(n) - ps = fill(1/n, n) + ps = fill(1 / n, n) μ = DiscreteNonParametric(xs, ps) ν = SOT.DiscreteMeasure(ys, ps) # analytic Wasserstein distance d = sum(abs, sort(xs) .- sort(ys)) / n - @test SOT.wasserstein(c, μ, ν) ≈ d atol=2e-2 - @test SOT.wasserstein(c, μ, ν, 1e-6) ≈ d atol=2e-2 - @test SOT.wasserstein(c, μ, ν, 1e-3) ≈ d atol=2e-2 + @test SOT.wasserstein(c, μ, ν) ≈ d atol = 2e-2 + @test SOT.wasserstein(c, μ, ν, 1e-6) ≈ d atol = 2e-2 + @test SOT.wasserstein(c, μ, ν, 1e-3) ≈ d atol = 2e-2 - @test SOT.wasserstein(c, ν, μ) ≈ d atol=2e-2 - @test SOT.wasserstein(c, ν, μ, 1e-6) ≈ d atol=2e-2 - @test SOT.wasserstein(c, ν, μ, 1e-3) ≈ d atol=2e-2 + @test SOT.wasserstein(c, ν, μ) ≈ d atol = 2e-2 + @test SOT.wasserstein(c, ν, μ, 1e-6) ≈ d atol = 2e-2 + @test SOT.wasserstein(c, ν, μ, 1e-3) ≈ d atol = 2e-2 end end