|
| 1 | +import array |
| 2 | +import gzip |
| 3 | +import os |
| 4 | +import struct |
| 5 | +import sys |
| 6 | +import urllib.request |
| 7 | +from os import path |
| 8 | + |
| 9 | +import haiku as hk |
| 10 | +import jax |
| 11 | +import jax.numpy as jnp |
| 12 | +import numpy as np |
| 13 | +import optax |
| 14 | +from sklearn.decomposition import PCA |
| 15 | + |
| 16 | +import orthax |
| 17 | + |
| 18 | + |
| 19 | + |
| 20 | + |
| 21 | + |
| 22 | +def mnist_raw(): |
| 23 | + base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" |
| 24 | + |
| 25 | + _DATA = "/tmp/" |
| 26 | + |
| 27 | + def _download(url, filename): |
| 28 | + """Download a url to a file in the JAX data temp directory.""" |
| 29 | + |
| 30 | + if not path.exists(_DATA): |
| 31 | + os.makedirs(_DATA) |
| 32 | + out_file = path.join(_DATA, filename) |
| 33 | + if not path.isfile(out_file): |
| 34 | + urllib.request.urlretrieve(url, out_file) |
| 35 | + print("downloaded {} to {}".format(url, _DATA)) |
| 36 | + |
| 37 | + def parse_labels(filename): |
| 38 | + with gzip.open(filename, "rb") as fh: |
| 39 | + _ = struct.unpack(">II", fh.read(8)) |
| 40 | + return np.array(array.array("B", fh.read()), dtype=np.uint8) |
| 41 | + |
| 42 | + def parse_images(filename): |
| 43 | + with gzip.open(filename, "rb") as fh: |
| 44 | + _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) |
| 45 | + return np.array(array.array("B", fh.read()), |
| 46 | + dtype=np.uint8).reshape(num_data, rows, cols) |
| 47 | + |
| 48 | + for filename in [ |
| 49 | + "train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz", |
| 50 | + "t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz" |
| 51 | + ]: |
| 52 | + _download(base_url + filename, filename) |
| 53 | + |
| 54 | + train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) |
| 55 | + train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz")) |
| 56 | + test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz")) |
| 57 | + test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz")) |
| 58 | + |
| 59 | + return train_images, train_labels, test_images, test_labels |
| 60 | + |
| 61 | + |
| 62 | +def mnist(digits=None): |
| 63 | + def _maybe_filter(images, labels, digits): |
| 64 | + mask = np.isin(labels, digits) |
| 65 | + return images[mask], labels[mask] |
| 66 | + |
| 67 | + def _partial_flatten(x): |
| 68 | + return np.reshape(x, (x.shape[0], -1)) |
| 69 | + |
| 70 | + def _one_hot(x, d, dtype=np.float32): |
| 71 | + return np.array(x[:, None] == d, dtype) |
| 72 | + |
| 73 | + train_images, train_labels, test_images, test_labels = mnist_raw() |
| 74 | + if digits is not None: |
| 75 | + train_images, train_labels = _maybe_filter(train_images, train_labels, |
| 76 | + digits) |
| 77 | + test_images, test_labels = _maybe_filter(test_images, test_labels, |
| 78 | + digits) |
| 79 | + train_labels = _one_hot(train_labels, np.array(digits)) |
| 80 | + test_labels = _one_hot(test_labels, np.array(digits)) |
| 81 | + else: |
| 82 | + train_labels = _one_hot(train_labels, np.arange(10)) |
| 83 | + test_labels = _one_hot(test_labels, np.arange(10)) |
| 84 | + |
| 85 | + train_images = _partial_flatten(train_images) / np.float32(255.) |
| 86 | + test_images = _partial_flatten(test_images) / np.float32(255.) |
| 87 | + |
| 88 | + return train_images, train_labels, test_images, test_labels |
| 89 | + |
| 90 | + |
| 91 | +def pca(train_x, test_x, n_components=8): |
| 92 | + decomposition = PCA(n_components).fit(train_x) |
| 93 | + train_x = decomposition.transform(train_x) |
| 94 | + test_x = decomposition.transform(test_x) |
| 95 | + return train_x, test_x |
| 96 | + |
| 97 | + |
| 98 | + |
| 99 | + |
| 100 | + |
| 101 | + |
| 102 | + |
| 103 | +# training parameters |
| 104 | +seed = 123 |
| 105 | +batch_size = 50 |
| 106 | +n_components = 8 |
| 107 | +digits = [6,9] |
| 108 | +learning_rate = 0.001 |
| 109 | +train_steps = 5000 |
| 110 | + |
| 111 | +# network parameters |
| 112 | +output_sizes = [4,2] |
| 113 | +normalize_inputs = False |
| 114 | +with_bias = True |
| 115 | +t_init = hk.initializers.RandomUniform(minval=-np.pi, maxval=np.pi) |
| 116 | +b_init = hk.initializers.Constant(0.) |
| 117 | +activation = jax.nn.sigmoid |
| 118 | +activate_final = False |
| 119 | + |
| 120 | + |
| 121 | + |
| 122 | + |
| 123 | + |
| 124 | + |
| 125 | + |
| 126 | + |
| 127 | + |
| 128 | +# set random state |
| 129 | +random_state = np.random.RandomState(seed) |
| 130 | +rng_key = jax.random.PRNGKey( |
| 131 | + random_state.randint(-sys.maxsize - 1, sys.maxsize + 1, |
| 132 | + dtype=np.int64)) |
| 133 | + |
| 134 | +# load data |
| 135 | +train_images, train_labels, test_images, test_labels = jax.device_put(mnist(digits)) |
| 136 | +train_features, test_features = pca(train_images, test_images, |
| 137 | + n_components) |
| 138 | + |
| 139 | +# build batch iterator |
| 140 | +num_train = train_images.shape[0] |
| 141 | +num_complete_batches, leftover = divmod(num_train, batch_size) |
| 142 | +num_batches = num_complete_batches + bool(leftover) |
| 143 | + |
| 144 | +def data_stream(batch_size): |
| 145 | + while True: |
| 146 | + perm = random_state.permutation(num_train) |
| 147 | + for i in range(num_batches): |
| 148 | + batch_idx = perm[i * batch_size:(i + 1) * batch_size] |
| 149 | + yield train_features[batch_idx], train_labels[batch_idx] |
| 150 | + |
| 151 | +batches = iter(data_stream(batch_size)) |
| 152 | + |
| 153 | +# build network |
| 154 | +def orthogonal_net(x): |
| 155 | + module = orthax.haiku.OrthogonalMLP(output_sizes, |
| 156 | + normalize_inputs, |
| 157 | + with_bias, |
| 158 | + t_init, |
| 159 | + b_init, |
| 160 | + activation, |
| 161 | + activate_final) |
| 162 | + return module(x) |
| 163 | +net = hk.without_apply_rng(hk.transform(orthogonal_net)) |
| 164 | +params = avg_params = net.init(rng_key, next(batches)[0]) |
| 165 | + |
| 166 | +# build optimizer |
| 167 | +opt = optax.rmsprop(learning_rate) |
| 168 | +opt_state = opt.init(params) |
| 169 | + |
| 170 | +# build model |
| 171 | +def loss(params, features, labels): |
| 172 | + logits = net.apply(params, features) |
| 173 | + l2_loss = 0.5 * sum( |
| 174 | + jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params)) |
| 175 | + softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits)) |
| 176 | + softmax_xent /= labels.shape[0] |
| 177 | + return softmax_xent + 1e-4 * l2_loss |
| 178 | + |
| 179 | +@jax.jit |
| 180 | +def accuracy(params, features, labels): |
| 181 | + predictions = net.apply(params, features) |
| 182 | + return jnp.mean( |
| 183 | + jnp.argmax(predictions, axis=1) == jnp.argmax(labels, axis=1)) |
| 184 | + |
| 185 | +@jax.jit |
| 186 | +def update(params, opt_state, features, labels): |
| 187 | + grads = jax.grad(loss)(params, features, labels) |
| 188 | + updates, opt_state = opt.update(grads, opt_state) |
| 189 | + new_params = optax.apply_updates(params, updates) |
| 190 | + return new_params, opt_state |
| 191 | + |
| 192 | +@jax.jit |
| 193 | +def ema_update(params, avg_params): |
| 194 | + return optax.incremental_update(params, avg_params, step_size=0.001) |
| 195 | + |
| 196 | + |
| 197 | + |
| 198 | + |
| 199 | + |
| 200 | + |
| 201 | +# train/eval loop. |
| 202 | +for step in range(train_steps): |
| 203 | + batch_features, batch_labels = next(batches) |
| 204 | + if step % 100 == 0: |
| 205 | + # evaluate classification accuracy on train & test sets. |
| 206 | + train_accuracy = accuracy(avg_params, batch_features, batch_labels) |
| 207 | + test_accuracy = accuracy(avg_params, test_features, test_labels) |
| 208 | + train_accuracy, test_accuracy = jax.device_get( |
| 209 | + (train_accuracy, test_accuracy)) |
| 210 | + print(f"[Step {step}] Train / Test accuracy: " |
| 211 | + f"{train_accuracy:.3f} / {test_accuracy:.3f}.") |
| 212 | + |
| 213 | + # update params |
| 214 | + params, opt_state = update(params, opt_state, batch_features, batch_labels) |
| 215 | + avg_params = ema_update(params, avg_params) |
0 commit comments