This example trains an Auxiliary Classifier Generative Adversarial Network (ACGAN) on the MNIST dataset.
For background of ACGAN, see:
- Augustus Odena, Christopher Olah, Jonathon Shlens. (2017) "Conditional image synthesis with auxiliary classifier GANs" https://arxiv.org/abs/1610.09585
The training script in this example (gan.js) is based on the Keras example at:
This example of TensorFlow.js runs simultaneously in two different environments:
- Training in the Node.js environment. During the long-running training process, a checkpoint of the generator will be saved to the disk at the end of every epoch.
- Demonstration of generation in the browser. The demo webpage will load the checkpoints saved from the training process and use it to generate fake MNIST images in the browser.
This example can be used in two ways:
- Performing both training and generation demo on your local machine, or
- Run only the generation demo, by loading a hosted generator model from the web.
For approach 1, you can start the training by:
yarn
yarn train
If you have a CUDA-enabled GPU on your system, you can add the --gpu
flag
to train the model on the GPU, which should give you a significant boost in
the speed of training:
yarn
yarn train --gpu
The training job is a long running one and takes a few hours to complete on
a GPU (using @tensorflow/tfjs-node-gpu) and even longer on a CPU
(using @tensorflow/tfjs-node). It saves the generator part of the ACGAN
into the ./dist/generator
folder at the beginning of the training and
at the end of every training epoch. Some additional metadata is
saved with the model as well.
The Node.js-based training script allows you to log the loss values from the generator and the discriminator to TensorBoard. Relative to printing loss values to the console, which the training script performs by default, logging to tensorboard has the following advantanges:
- Persistence of the loss values, so you can have a copy of the training history available even if the system crashes in the middle of the training for some reason, while logs in consoles a more ephemeral.
- Visualizing the loss values as curves makes the trends easier to see (e.g., see the screenshot below).
To do this in this example, add the flag --logDir
to the yarn train
command, followed by the directory to which you want the logs to
be written, e.g.,
yarn train --gpu --logDir /tmp/mnist-acgan-logs
Then install tensorboard and start it by pointing it to the log directory:
# Skip this step if you have already installed tensorboard.
pip install tensorboard
tensorboard --logdir /tmp/mnist-acgan-logs
tensorboard will print an HTTP URL in the terminal. Open your browser and navigate to the URL to view the loss curves in the Scalar dashboard of TensorBoard.
To start the demo in the browser, do in a separate terminal:
yarn
yarn watch
When the browser demo starts up, it will try to load the generator model
and metadata from ./generator
. If it succeeds, fake MNIST digits will
be generated using the loaded generator model and displayed on the page
right away. If it fails (e.g., because no local training job has ever
been started), the user may still click the "Load Hosted Model" button
to load a remotely-hosted generator.
It is recommended to use tfjs-node-gpu to train the model on a CUDA-enabled GPU, as the convolution heavy operations run several times faster a GPU than on the CPU with tfjs-node.
By default, the training script runs on the CPU using tfjs-node. To run it on the GPU, repace the line
require('@tensorflow/tfjs-node');
with
require('@tensorflow/tfjs-node-gpu');
This example comes with JavaScript unit tests. To run them, do:
pushd ../ # Go to the root directory of tfjs-exapmles
yarn
popd # Go back to mnist-acgan/
yarn
yarn test