This is the official Python
implementation of the ICLR 2021 paper Wasserstein-2 Generative Networks (paper on openreview) by Alexander Korotin, Vahe Egizarian, Arip Asadulaev, Alexander Safin and Evgeny Burnaev.
The repository contains reproducible PyTorch
source code for computing optimal transport maps (and distances) in high dimensions via the end-to-end non-minimax method (proposed in the paper) by using input convex neural networks. Examples are provided for various real-world problems: color transfer, latent space mass transport, domain adaptation, style transfer.
- Video presentation by Alexander Korotin at ICLR 2021 (May 2021, EN);
- Talk by Alexander Korotin at BayesGroup research seminar workshop (2020, RU, slides);
@inproceedings{
korotin2021wasserstein,
title={Wasserstein-2 Generative Networks},
author={Alexander Korotin and Vage Egiazarian and Arip Asadulaev and Alexander Safin and Evgeny Burnaev},
booktitle={International Conference on Learning Representations},
year={2021},
url={https://openreview.net/forum?id=bEoxzW_EXsa}
}
The implementation is GPU-based. Single GPU (~GTX 1080 ti) is enough to run each particular experiment. Tested with
torch==1.3.0 torchvision==0.4.1
The code might not run as intended in newer torch
versions. Newer torchvision
might conflict with FID score evaluation.
- Repository for Kantorovich Strikes Back! Wasserstein GANs are not Optimal Transport? paper.
- Repository for Continuous Wasserstein-2 Barycenter Estimation without Minimax Optimization paper.
- Repository for Do Neural Optimal Transport Solvers Work? A Continuous Wasserstein-2 Benchmark paper.
- Repository for Large-Scale Wasserstein Gradient Flows paper.
All the experiments are issued in the form of pretty self-explanatory jupyter notebooks (notebooks/
). For convenience, the majority of the evaluation output is preserved. Auxilary source code is moved to .py
modules (src/
).
notebooks/W2GN_toy_experiments.ipynb
-- toy experiments (2D: Swiss Roll, 100 Gaussuans, ...);notebooks/W2GN_gaussians_high_dimensions.ipynb
-- optimal maps between Gaussians in high dimensions;notebooks/W2GN_latent_space_optimal_transport.ipynb
-- latent space optimal transport for generating CelebA 64x64 aligned images;notebooks/W2GN_domain_adaptation.ipynb
-- domain adaptation for MNIST-USPS digits datasets;notebooks/W2GN_color_transfer.ipynb
-- cycle monotone pixel-wise image-to-image color transfer (example images are provided indata/color_transfer/
);notebooks/W2GN_style_transfer.ipynb
-- cycle monotone image dataset-to-dataset style transfer (used datasets are publicitly available at the official CycleGan repo);
src/icnn.py
-- modules for Input Convex Neural Network architectures (DenseICNN, ConvICNN);
poster/W2GN_poster.png
-- poster (landscape format)poster/W2GN_poster.svg
-- source file for the poster
Transforming single Gaussian to the mixture of 100 Gaussuans without mode dropping/collapse (and some other toy cases).
Assessing the quality of fitted optimal transport maps between two high-dimensional Gaussians (tested in dim up to 4096). The metric is Unexplained Variance Percentage (UVP, %).
2 | 4 | 8 | 16 | 32 | 64 | 128 | 256 | 512 | 1024 | 2048 | 4096 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|
Large-scale OT | <1 | 3.7 | 7.5 | 14.3 | 23 | 34.7 | 46.9 | >50 | >50 | >50 | >50 | >50 |
Wasserstein-2 GN | <1 | <1 | <1 | <1 | <1 | <1 | 1 | 1.1 | 1.3 | 1.7 | 1.8 | 1.5 |
CelebA 64x64 generated faces. The quality of the model highly depends on the quality of the autoencoder. Use notebooks/AE_Celeba.ipynb
to train MSE or perceptual AE (on VGG features, to improve AE visual quality).
Pre-trained autoencoders: MSE-AE [Goodle Drive, Yandex Disk], VGG-AE [Google Drive, Yandex Disk].
Combining simple pre-trained MSE autoencoder with W2GN is enough to surpass Wasserstein GAN model in Freschet Inception Distance Score (FID).
AE Reconstruct | AE Raw Decode | AE + W2GN | WGAN | |
---|---|---|---|---|
FID Score | 23.35 | 86.66 | 43.35 | 45.23 |
Perceptual VGG autoencoder combined with W2GN provides nearly State-of-the-art FID (compared to Wasserstein GAN with Quadratic Cost).
AE Reconstruct | AE Raw Decode | AE + W2GN | WGAN-QC | |
---|---|---|---|---|
FID Score | 7.5 | 31.81 | 17.21 | 14.41 |
Cycle monotone color transfer is applicable even to gigapixel images!
MNIST-USPS domain adaptation. PCA Visualization of feature spaces (see the paper for metrics).
Optimal transport map in the space of images. Photo2Cezanne and Winter2Summer datasets are used.
- cycleGAN repo with Winter2Summer and Photo2Monet datasets;
- CelebA page with faces dataset and this page with its aligned 64x64 version;
- pytorch-fid repo to compute FID score;