An implementation of Wasserstein Autoencoder. In this work, I have focused on the WAE-GAN variant. In this implementation the encoder is implemented as a dirac measure. However, the paper theoretically claims that their approach can be extended to probabilistic encoders as well.
Model weights can be downloaded from here. In the given^ link you will find weights for each of the model trained on celebA, MNIST and CIFAR10 dataset.
- Python 3.5+
- Tensorflow 1.9
File config.py contains the hyper-parameters for WAE-GAN reported results.
File wae_gan.py contains the code to both train and test WAE-GAN model. For training call train function.
NOTE: For celebA, make sure you have the downloaded dataset from here and keep it in the current directory of project.
python wae_gan.py
Just comment the train() function call and then place the model weights in model_directory (mentioned in wae_gan.py).
python wae-gan.py
MNIST | Celeb-A | Cifar10 |
---|---|---|