This is the official Python
implementation of the NeurIPS 2021 paper Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark (paper on arxiv) by Alexander Korotin, Lingxiao Li, Aude Genevay, Justin Solomon, Alexander Filippov and Evgeny Burnaev.
The repository contains a set of continuous benchmark measures for testing optimal transport solvers for quadratic cost (Wasserstein-2 distance), the code for optimal transport solvers and their evaluation.
@article{korotin2021neural,
title={Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark},
author={Korotin, Alexander and Li, Lingxiao and Genevay, Aude and Solomon, Justin M and Filippov, Alexander and Burnaev, Evgeny},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
The implementation is GPU-based. Single GPU (~GTX 1080 ti) is enough to run each particular experiment. Tested with
torch==1.3.0 torchvision==0.4.1
The code might not run as intended in newer torch
versions.
- Repository for Kantorovich Strikes Back! Wasserstein GANs are not Optimal Transport? paper.
- Repository for Wasserstein-2 Generative Networks paper.
- Repository for Continuous Wasserstein-2 Barycenter Estimation without Minimax Optimization paper.
- Repository for Continuous Regularized Wasserstein Barycenters paper.
- Repository for Large-Scale Wasserstein Gradient Flows paper.
from src import map_benchmark as mbm
# Load benchmark pair for dimension 16 (2, 4, ..., 256)
benchmark = mbm.Mix3ToMix10Benchmark(16)
# OR load 'Early' images benchmark pair ('Early', 'Mid', 'Late')
# benchmark = mbm.CelebA64Benchmark('Early')
# Sample 32 random points from the benchmark measures
X = benchmark.input_sampler.sample(32)
Y = benchmark.output_sampler.sample(32)
# Compute the true forward map for points X
X.requires_grad_(True)
Y_true = benchmark.map_fwd(X, nograd=True)
All the experiments are issued in the form of pretty self-explanatory jupyter notebooks (notebooks/
). Auxilary source code is moved to .py
modules (src/
). Continuous benchmark pairs are stored as .pt
checkpoints (benchmarks/
).
We provide all the code to evaluate existing dual OT solvers on our benchmark pairs. The qualitative results are shown below. For quantitative results, see the paper.
notebooks/MM_test_hd_benchmark.ipynb
-- testing [MM], [MMv2] solvers and their reversed versionsnotebooks/MMv1_test_hd_benchmark.ipynb
-- testing [MMv1] solvernotebooks/MM-B_test_hd_benchmark.ipynb
-- testing [MM-B] solvernotebooks/W2_test_hd_benchmark.ipynb
-- testing [W2] solver and its reversed versionnotebooks/QC_test_hd_benchmark.ipynb
-- testing [QC] solvernotebooks/LS_test_hd_benchmark.ipynb
-- testing [LS] solver
notebooks/MM_test_images_benchmark.ipynb
-- testing [MM] solver and its reversed versionnotebooks/W2_test_images_benchmark.ipynb
-- testing [W2]notebooks/MM-B_test_images_benchmark.ipynb
-- testing [MM-B] solvernotebooks/QC_test_images_benchmark.ipynb
-- testing [QC] solver
[LS], [MMv2], [MMv1] solvers are not considered in this experiment.
Warning: training may take several days before achieving reasonable FID scores!
notebooks/MM_test_image_generation.ipynb
-- generative modeling by [MM] solver or its reversed versionnotebooks/W2_test_image_generation.ipynb
-- generative modeling by [W2] solver
For [QC] solver we used the code from the official WGAN-QC repo.
This code is provided for completeness and is not intended to be used to retrain existing benchmark pairs, but might be used as the base to train new pairs on new datasets. High-dimensional benchmak pairs can be trained from scratch. Training images benchmark pairs requires generator network checkpoints. We used WGAN-QC model to provide such checkpoints.
notebooks/W2_train_hd_benchmark.ipynb
-- training high-dimensional benchmark bairs by [W2] solvernotebooks/W2_train_images_benchmark.ipynb
-- training images benchmark bairs by [W2] solver
- Weights & Biases developer tools for machine learning;
- CelebA page with faces dataset and this page with its aligned 64x64 version;
- pytorch-fid repo to compute FID score;
- UNet architecture for transporter network;
- ResNet architectures for generator and discriminator;