From e7f8faa0716f7f75d4d98a589fda686579161344 Mon Sep 17 00:00:00 2001 From: Luyao Peng Date: Sun, 8 Nov 2020 10:17:01 -0700 Subject: [PATCH] Adding an example notebook of variational GP Demonstrating how variational GP is used in multiclass-classification --- ...an_Process_Multiclass_Classification.ipynb | 504 ++++++++++++++++++ 1 file changed, 504 insertions(+) create mode 100644 tensorflow_probability/examples/jupyter_notebooks/Variational_Gaussian_Process_Multiclass_Classification.ipynb diff --git a/tensorflow_probability/examples/jupyter_notebooks/Variational_Gaussian_Process_Multiclass_Classification.ipynb b/tensorflow_probability/examples/jupyter_notebooks/Variational_Gaussian_Process_Multiclass_Classification.ipynb new file mode 100644 index 0000000000..fa7c81e649 --- /dev/null +++ b/tensorflow_probability/examples/jupyter_notebooks/Variational_Gaussian_Process_Multiclass_Classification.ipynb @@ -0,0 +1,504 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Sparse Variational Gaussian Process in Multiclass Classification\n", + "\n", + "In this notebook, sparse variational gaussian process model (VGP) is applied to a multiclass classification problem. VGP is easily scalable to large scale dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Background\n", + "\n", + "$\\qquad$ Consider making inference about a stochastic function $f$ given a likelihood $p(y|f)$ and $N$ observations $y=\\{y_1, y_2, \\dots, y_N\\}^T$ at observation index points $X=\\{x_1, x_2, \\dots, x_N\\}^T$. Place a GP prior on $f$: $p(f) \\sim N(f|m(X), K(X, X))$. The joint distribution of data and latent stochastic function is \n", + "\n", + "$$p(y, f) = \\prod_{i=1}^{N}p(y_i|f_i)N(f|m(X), K(X, X)) \\tag{1}$$ \n", + "\n", + "$\\qquad$ The main interest is the posterior over the function values given the observations $p(f|y)$. The posterior is intractable when the likelihood $p(y|f)$ is non-Gaussian, which is often the case in classification problems; and the computational complexity is $O(N^3)$ due to the inversion of $K_{X, X}$, which is also intractable for large dataset.\n", + "\n", + "$\\qquad$ To reduce the computational complexity, $M << N$ inducing index points $Z=\\{z_1, z_2, \\dots, z_M\\}^T$ and inducing variables $u=f(Z)$ are introduced. Assuming a GP prior on the joint density $p(f, u)$, $$p(f, u) = N\\begin{pmatrix} \\begin{bmatrix} f \\\\ u\\end{bmatrix}| \\begin{bmatrix} m(X) \\\\ m(Z)\\end{bmatrix}, \\begin{bmatrix} K(X, X) & K(X, Z) \\\\ K(Z, X) & K(Z, Z)\\end{bmatrix}\\end{pmatrix},$$ and a GP prior on $u$ $$p(u) = N(u|m(Z), K(Z, Z)),$$ the conditional of $f$ is $p(f|u) = N(f|\\mu, \\Sigma)$, where for $i, j = 1, \\dots, N$\n", + "\n", + "$$[\\mu]_i = m(x_i) + \\alpha(x_i)^T(u-m(Z)), \\tag{2}$$\n", + "\n", + "$$[\\Sigma]_{ij} = K(x_i, x_j) - \\alpha(x_i)^T K(Z, Z)\\alpha(x_j), \\tag{3}$$\n", + " where $\\alpha(x_i) = K(Z, Z)^{-1}K(Z, x_i)$, and the joint density of $y, f, u$ becomes\n", + "\n", + "$$p(y, f, u) = p(f|u; X, Z)p(u; Z)\\prod_{i=1}^{N}p(y_i|f_i)$$\n", + "\n", + "$\\qquad$ The goal is still finding the posterior of the function values $f$, however, the likelihood $p(y_i|f_i)$ is not Gaussian, so no closed-form solution for the posterior of $f$. Therefore, a variational posterior is used to solve the difficulty. \n", + "\n", + "$\\qquad$ Replacing the posterior $p(u|y)$ by an arbitrary full-rank Gaussian distribution $q(u)$ [Hensman et al. (2013)], then the variational posterior for $y$ and $u$ jointly becomes \n", + "\n", + "$$q(f, u; X, Z) = p(f|u; X, Z)q(u), \\tag{4}$$\n", + "\n", + "$$\\mbox{where } q(u) \\sim N(\\mathbf{m}, \\mathbf{S})$$\n", + "\n", + "$\\mathbf{m}, \\mathbf{S}$ are parameters to be chosen by optimizing an evidence lower bound (ELBO). \n", + "\n", + "$\\qquad$ Since both $p(f|u; X, Z)q(u)$ are Gaussian, the marginal variational posterior of $f$ can be computed analytically\n", + "\n", + "$$q(f|\\mathbf{m}, \\mathbf{S}; X, Z) = \\int p(f|u; X, Z)q(u) du \\sim N(\\tilde{\\mu}, \\tilde{\\Sigma}) \\tag{5}$$\n", + "\n", + "with $[\\tilde{\\mu}]_i = \\mu_{\\mathbf{m}, Z}(x_i), [\\tilde{\\Sigma}]_{ij} = \\Sigma_{\\mathbf{S}, Z}(x_i, x_j)$, and\n", + "\n", + "$$\\mu_{m, Z}(x_i) = m(x_i)+\\alpha(x_i)^T(\\mathbf{m}-m(Z)) \\tag{6}$$\n", + "$$\\Sigma_{S, Z}(x_i, x_j) = K(x_i, x_j) - \\alpha(x_i)^T[K(Z, Z) - \\mathbf{S}]\\alpha(x_j) \\tag{7}$$\n", + "\n", + "$\\qquad$ The variation parameters $Z, \\mathbf{m}, \\mathbf{S}$ in $q(f|\\mathbf{m}, \\mathbf{S}; X, Z)$ are determined by maximizing the lower bound \n", + "\n", + "$$L = \\sum_{i=1}^N \\mathbb{E}_{q(f_i|\\mathbf{m}, \\mathbf{S}; X, Z)}[logp(y_i|f_i)] - KL[q(u)|| p(u)], \\tag{8}$$\n", + "\n", + "where the expected log-likelihood can be computed with Gauss–Hermite quadrature.\n", + "\n", + "$\\qquad$ The variational posterior is given as $q(f)$ in (5). To make predictions for a set of test index points $X^*$, the new latent function values $f^*$ is approximated by\n", + "\n", + "\\begin{equation}\\begin{array}{rcl}p(f^*|y) &=& \\int p(f^*|f, u)p(f, u|y) df du\\\\ &\\approx& \\int p(f^*|f, u)p(f|u)q(u)df du \\\\ &=& \\int p(f^*|u)q(u) du \\\\ &=& q(f^*)\\end{array}\\end{equation}\n", + "\n", + "where the last line is following (5), (6) and (7) by replacing $x_i$ by $x_i^*$.\n", + "\n", + "$\\qquad$ With the variational posterior in (5), the predictive mean and variance of $y^*$ are computed as \n", + "\n", + "$$\\hat{y}^* = \\mathbb{E}(y^*) = \\int\\int y^* p(y^*|f^*)q(f^*) df^* dy^* \\tag{9}$$\n", + "$$\\hat{\\mathbb{V}}(y^*) = \\int\\int y^{*2} p(y^*|f^*)q(f^*) df^* dy^* \\tag{10}$$\n", + "\n", + "## References\n", + "\n", + "[1]: Titsias, M. \"Variational Model Selection for Sparse Gaussian Process Regression\", 2009. http://proceedings.mlr.press/v5/titsias09a/titsias09a.pdf \n", + "\n", + "[2]: Hensman, J., Lawrence, N. \"Gaussian Processes for Big Data\", 2013. https://arxiv.org/abs/1309.6835\n", + "\n", + "[3]: Salimbeni, H. and Deisenroth, M. \"Doubly stochastic variational inference for deep Gaussian processes.\" Advances in Neural Information Processing Systems. 2017. https://arxiv.org/pdf/1705.08933.pdf" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 87, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import tensorflow.compat.v2 as tf\n", + "import tensorflow_probability as tfp\n", + "import pandas as pd\n", + "from sklearn.preprocessing import LabelEncoder\n", + "from sklearn.model_selection import train_test_split\n", + "from scipy.cluster.vq import kmeans2\n", + "\n", + "tf.enable_v2_behavior()\n", + "\n", + "tfb = tfp.bijectors\n", + "tfd = tfp.distributions\n", + "tfk = tfp.math.psd_kernels\n", + "\n", + "dtype = np.float64" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Load Glass Data\n", + "\n", + "A standard imbalanced machine learning dataset referred to as the “Glass Identification” dataset, or simply “glass”.\n", + "\n", + "The dataset describes the chemical properties of glass and involves classifying samples of glass using their chemical properties as one of six classes. The dataset was credited to Vina Spiehler in 1987." + ] + }, + { + "cell_type": "code", + "execution_count": 100, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2020-11-08 09:08:57-- https://raw.githubusercontent.com/jbrownlee/Datasets/master/glass.csv\n", + "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.192.133, ...\n", + "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 11154 (11K) [text/plain]\n", + "Saving to: 'glass.csv.1'\n", + "\n", + "100%[======================================>] 11,154 --.-K/s in 0.001s \n", + "\n", + "2020-11-08 09:08:57 (11.2 MB/s) - 'glass.csv.1' saved [11154/11154]\n", + "\n" + ] + } + ], + "source": [ + "!wget https://raw.githubusercontent.com/jbrownlee/Datasets/master/glass.csv" + ] + }, + { + "cell_type": "code", + "execution_count": 102, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(171, 9) (43, 9) (214,)\n" + ] + } + ], + "source": [ + "data = pd.read_csv('glass.csv', header=None)\n", + "data = data.values\n", + "X = data[:,0:9].astype(dtype)\n", + "Y = data[:,9]\n", + "\n", + "encoder = LabelEncoder()\n", + "encoder.fit(Y)\n", + "encoded_Y = encoder.transform(Y)\n", + "encoded_Y = encoded_Y.astype(dtype)\n", + "num_outputs = 6\n", + "\n", + "X_train, X_test, y_train, y_test = train_test_split(X, encoded_Y, test_size=0.2, random_state=42)\n", + "print(X_train.shape, X_test.shape, encoded_Y.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Defining Trainable Variables in VGP\n", + "\n", + "* Using kmeans to initialize 30 representative `inducing_index_points` $Z$ and make them learnable variable" + ] + }, + { + "cell_type": "code", + "execution_count": 89, + "metadata": {}, + "outputs": [], + "source": [ + "num_inducing_points_ = 30\n", + "inducing_index_points_init = kmeans2(X_train, num_inducing_points_, minit=\"points\")[0] #50, 60\n", + "inducing_index_points = tf.Variable(inducing_index_points_init, dtype=dtype, name='inducing_index_points')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* Initializing RBF kernel and kernel parameters, which are `amplitude` and `length_scale` (the same length scale is used for all $X$ columns)\n", + "* Initializing the variational mean and covariance $\\mathbf{m}, \\mathbf{S}$ in $q(u)$" + ] + }, + { + "cell_type": "code", + "execution_count": 90, + "metadata": {}, + "outputs": [], + "source": [ + "amplitude = tfp.util.TransformedVariable(\n", + " 1., tfb.Softplus(), dtype=dtype, name='amplitude')\n", + "length_scale = tfp.util.TransformedVariable(\n", + " 1., tfb.Softplus(), dtype=dtype, name='length_scale')\n", + "kernel = tfk.ExponentiatedQuadratic(amplitude=amplitude, length_scale=length_scale)\n", + "\n", + "observation_noise_variance = tfp.util.TransformedVariable(1., tfb.Softplus(), dtype=dtype, name='observation_noise_variance')\n", + "\n", + "variational_inducing_observations_loc = tf.Variable(np.zeros([num_outputs, num_inducing_points_], dtype=dtype), name='variational_inducing_observations_loc')\n", + "\n", + "Ku = kernel.matrix(inducing_index_points, inducing_index_points)\n", + "variational_inducing_observations_scale_init = np.linalg.cholesky(Ku + np.eye(num_inducing_points_)*1e-6)\n", + "variational_inducing_observations_scale = tf.Variable(np.tile(variational_inducing_observations_scale_init[None, :, :], [num_outputs, 1, 1]), \n", + " name='variational_inducing_observations_scale')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* Defining log probability. For multiclass classification, Categorical distribution is used. The `observations` is a flat array of batch size; since the expected log likelihood in VGP is approximated by Gauss–Hermite quadrature, the input logits is reshaped to (`quadrature_size, batch_size, num_outputs`) to adapt to the `sparse_softmax_cross_entropy_with_logits` in `log_prob`\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 91, + "metadata": {}, + "outputs": [], + "source": [ + "def log_prob(observations, f):\n", + " #f is (6, 20, 64)\n", + " berns = tfd.Independent(tfd.Categorical(logits=tf.transpose(f, perm=[1,2,0])), 1) #(20, 64, 6), n_quadrature, bs, n_outputs\n", + " return berns.log_prob(observations) #sparse_softmax_cross_entropy_with_logits: have logits of shape [batch_size, num_classes] and have labels of shape [batch_size]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Constructing Model and Training" + ] + }, + { + "cell_type": "code", + "execution_count": 92, + "metadata": {}, + "outputs": [], + "source": [ + "vgp = tfd.VariationalGaussianProcess(\n", + " kernel,\n", + " index_points=X_test,\n", + " inducing_index_points=inducing_index_points,\n", + " variational_inducing_observations_loc=variational_inducing_observations_loc, #TensorShape([6, 30])\n", + " variational_inducing_observations_scale=variational_inducing_observations_scale, #TensorShape([6, 30, 30])\n", + " observation_noise_variance=observation_noise_variance)" + ] + }, + { + "cell_type": "code", + "execution_count": 93, + "metadata": {}, + "outputs": [], + "source": [ + "batch_size = 64\n", + "\n", + "optimizer = tf.optimizers.Adam(learning_rate=.01)\n", + "\n", + "@tf.function\n", + "def optimize(x_train_batch, y_train_batch):\n", + " with tf.GradientTape() as tape:\n", + " # Create the loss function we want to optimize.\n", + " recon = vgp.surrogate_posterior_expected_log_likelihood(\n", + " observations=y_train_batch,\n", + " observation_index_points=x_train_batch,\n", + " log_likelihood_fn=log_prob,\n", + " quadrature_size=20)\n", + "\n", + " elbo = -tf.reduce_sum(recon) + tf.reduce_sum(vgp.surrogate_posterior_kl_divergence_prior())\n", + "\n", + " grads = tape.gradient(elbo, vgp.trainable_variables)\n", + " optimizer.apply_gradients(zip(grads, vgp.trainable_variables))\n", + " return elbo" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training by Batch" + ] + }, + { + "cell_type": "code", + "execution_count": 94, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0 114.67260603059552\n", + "120 67.871934322895\n", + "240 60.0149377742081\n", + "360 56.29566984959375\n", + "480 53.7197290981309\n", + "600 51.48258717207506\n", + "720 52.6730473693974\n", + "840 44.445047997931475\n", + "960 50.50109738207267\n", + "1080 44.75255389124902\n", + "1199 51.55517736426542\n" + ] + } + ], + "source": [ + "num_iters = 1200\n", + "num_logs = 10\n", + "num_training_points_ = X_train.shape[0]\n", + "\n", + "for i in range(num_iters):\n", + " batch_idxs = np.random.randint(num_training_points_, size=[batch_size])\n", + " x_train_batch = X_train[batch_idxs, ...]\n", + " y_train_batch = y_train[batch_idxs]\n", + " loss = optimize(x_train_batch, y_train_batch)\n", + "\n", + " if i % (num_iters / num_logs) == 0 or i + 1 == num_iters:\n", + " print(i, loss.numpy())\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Computing the Predictive Mean and Variance\n", + "\n", + "To compute the predictive mean and variance for a set of new $X^*$, the `predict_mean_and_var` from gpflow is used to compute (9) and (10). " + ] + }, + { + "cell_type": "code", + "execution_count": 95, + "metadata": {}, + "outputs": [], + "source": [ + "from gpflow.likelihoods.multiclass import Softmax\n", + "\n", + "Fmu = tf.cast(tf.transpose(vgp.mean()), tf.float32) #TensorShape([6, 43])\n", + "Fvar = tf.cast(tf.transpose(vgp.variance()), tf.float32)##TensorShape([6, 43])\n", + "\n", + "S = Softmax(num_outputs)\n", + "m, v = S.predict_mean_and_var(Fmu, Fvar) #shape=(43, 6)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Results and Plots" + ] + }, + { + "cell_type": "code", + "execution_count": 96, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Multiclass classification accuracy for 6 is 0.7209302325581395\n" + ] + } + ], + "source": [ + "acc =np.mean(np.argmax(m, 1).astype(int) == y_test.astype(int))\n", + "print(\"Multiclass classification accuracy for {} is {}\".format(num_outputs, acc))" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Multiclass classification variances for 6 is [0.23027779 0.21305028 0.23469782 0.2065194 0.23566929 0.18067268\n", + " 0.21430728 0.2347648 0.2097274 0.21258086 0.2179921 0.21743922\n", + " 0.2189273 0.22974259 0.22108854 0.20057674 0.1710678 0.21648297\n", + " 0.19312762 0.18159598 0.18870537 0.22279613 0.2087079 0.2288218\n", + " 0.19952135 0.21122998 0.21671966 0.22578193 0.2281411 0.24098359\n", + " 0.17000449 0.2195439 0.20093569 0.20745346 0.19372132 0.24052121\n", + " 0.15541928 0.22103915 0.19316027 0.20382655 0.16869901 0.22654828\n", + " 0.19985217]\n" + ] + } + ], + "source": [ + "indices = np.argmax(m, 1)\n", + "v_ = tf.reduce_sum(tf.one_hot(indices, 6)*v, 1)\n", + "print(\"Multiclass classification variances for {} is {}\".format(num_outputs, v_))" + ] + }, + { + "cell_type": "code", + "execution_count": 98, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.decomposition import PCA\n", + "\n", + "pca = PCA(n_components=1)\n", + "X_test_pca = pca.fit_transform(X_test)\n", + "y_preds = np.argmax(m, 1).astype(int)\n", + "y_sd = (v_)**0.5" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Each dot is the predicted class for each $x_i^*$, and the error bar is one sd of $y_i^*$. If the a dot and a cross overlap, this is a correct prediction." + ] + }, + { + "cell_type": "code", + "execution_count": 99, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 0, 'X_test_pca')" + ] + }, + "execution_count": 99, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.figure(figsize=(15, 5))\n", + "plt.scatter(X_test_pca, y_test,\n", + " marker='x', s=50, c=y_test, zorder=10)\n", + "\n", + "plt.errorbar(X_test_pca, y_preds, yerr=y_sd, fmt='o', capthick=1, label='Predidtion', alpha=0.5)\n", + "plt.legend(loc='upper right')\n", + "plt.ylabel('6 Classes of Glasses')\n", + "plt.xlabel('X_test_pca')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "JupyterPy2", + "language": "python", + "name": "ipykernel_py2" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}