Skip to content

Commit 2bb18ef

Browse files
author
zhaoxiang
committed
fix two bugs in common.py
bug 1: it was not using all thetas bug 2: wrong negative sign before sinθ
1 parent 99d71b0 commit 2bb18ef

File tree

4 files changed

+353
-4
lines changed

4 files changed

+353
-4
lines changed

examples/mnist_classifier.py

+215
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,215 @@
1+
import array
2+
import gzip
3+
import os
4+
import struct
5+
import sys
6+
import urllib.request
7+
from os import path
8+
9+
import haiku as hk
10+
import jax
11+
import jax.numpy as jnp
12+
import numpy as np
13+
import optax
14+
from sklearn.decomposition import PCA
15+
16+
import orthax
17+
18+
19+
20+
21+
22+
def mnist_raw():
23+
base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/"
24+
25+
_DATA = "/tmp/"
26+
27+
def _download(url, filename):
28+
"""Download a url to a file in the JAX data temp directory."""
29+
30+
if not path.exists(_DATA):
31+
os.makedirs(_DATA)
32+
out_file = path.join(_DATA, filename)
33+
if not path.isfile(out_file):
34+
urllib.request.urlretrieve(url, out_file)
35+
print("downloaded {} to {}".format(url, _DATA))
36+
37+
def parse_labels(filename):
38+
with gzip.open(filename, "rb") as fh:
39+
_ = struct.unpack(">II", fh.read(8))
40+
return np.array(array.array("B", fh.read()), dtype=np.uint8)
41+
42+
def parse_images(filename):
43+
with gzip.open(filename, "rb") as fh:
44+
_, num_data, rows, cols = struct.unpack(">IIII", fh.read(16))
45+
return np.array(array.array("B", fh.read()),
46+
dtype=np.uint8).reshape(num_data, rows, cols)
47+
48+
for filename in [
49+
"train-images-idx3-ubyte.gz", "train-labels-idx1-ubyte.gz",
50+
"t10k-images-idx3-ubyte.gz", "t10k-labels-idx1-ubyte.gz"
51+
]:
52+
_download(base_url + filename, filename)
53+
54+
train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz"))
55+
train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz"))
56+
test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz"))
57+
test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz"))
58+
59+
return train_images, train_labels, test_images, test_labels
60+
61+
62+
def mnist(digits=None):
63+
def _maybe_filter(images, labels, digits):
64+
mask = np.isin(labels, digits)
65+
return images[mask], labels[mask]
66+
67+
def _partial_flatten(x):
68+
return np.reshape(x, (x.shape[0], -1))
69+
70+
def _one_hot(x, d, dtype=np.float32):
71+
return np.array(x[:, None] == d, dtype)
72+
73+
train_images, train_labels, test_images, test_labels = mnist_raw()
74+
if digits is not None:
75+
train_images, train_labels = _maybe_filter(train_images, train_labels,
76+
digits)
77+
test_images, test_labels = _maybe_filter(test_images, test_labels,
78+
digits)
79+
train_labels = _one_hot(train_labels, np.array(digits))
80+
test_labels = _one_hot(test_labels, np.array(digits))
81+
else:
82+
train_labels = _one_hot(train_labels, np.arange(10))
83+
test_labels = _one_hot(test_labels, np.arange(10))
84+
85+
train_images = _partial_flatten(train_images) / np.float32(255.)
86+
test_images = _partial_flatten(test_images) / np.float32(255.)
87+
88+
return train_images, train_labels, test_images, test_labels
89+
90+
91+
def pca(train_x, test_x, n_components=8):
92+
decomposition = PCA(n_components).fit(train_x)
93+
train_x = decomposition.transform(train_x)
94+
test_x = decomposition.transform(test_x)
95+
return train_x, test_x
96+
97+
98+
99+
100+
101+
102+
103+
# training parameters
104+
seed = 123
105+
batch_size = 50
106+
n_components = 8
107+
digits = [6,9]
108+
learning_rate = 0.001
109+
train_steps = 5000
110+
111+
# network parameters
112+
output_sizes = [4,2]
113+
normalize_inputs = False
114+
with_bias = True
115+
t_init = hk.initializers.RandomUniform(minval=-np.pi, maxval=np.pi)
116+
b_init = hk.initializers.Constant(0.)
117+
activation = jax.nn.sigmoid
118+
activate_final = False
119+
120+
121+
122+
123+
124+
125+
126+
127+
128+
# set random state
129+
random_state = np.random.RandomState(seed)
130+
rng_key = jax.random.PRNGKey(
131+
random_state.randint(-sys.maxsize - 1, sys.maxsize + 1,
132+
dtype=np.int64))
133+
134+
# load data
135+
train_images, train_labels, test_images, test_labels = jax.device_put(mnist(digits))
136+
train_features, test_features = pca(train_images, test_images,
137+
n_components)
138+
139+
# build batch iterator
140+
num_train = train_images.shape[0]
141+
num_complete_batches, leftover = divmod(num_train, batch_size)
142+
num_batches = num_complete_batches + bool(leftover)
143+
144+
def data_stream(batch_size):
145+
while True:
146+
perm = random_state.permutation(num_train)
147+
for i in range(num_batches):
148+
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
149+
yield train_features[batch_idx], train_labels[batch_idx]
150+
151+
batches = iter(data_stream(batch_size))
152+
153+
# build network
154+
def orthogonal_net(x):
155+
module = orthax.haiku.OrthogonalMLP(output_sizes,
156+
normalize_inputs,
157+
with_bias,
158+
t_init,
159+
b_init,
160+
activation,
161+
activate_final)
162+
return module(x)
163+
net = hk.without_apply_rng(hk.transform(orthogonal_net))
164+
params = avg_params = net.init(rng_key, next(batches)[0])
165+
166+
# build optimizer
167+
opt = optax.rmsprop(learning_rate)
168+
opt_state = opt.init(params)
169+
170+
# build model
171+
def loss(params, features, labels):
172+
logits = net.apply(params, features)
173+
l2_loss = 0.5 * sum(
174+
jnp.sum(jnp.square(p)) for p in jax.tree_leaves(params))
175+
softmax_xent = -jnp.sum(labels * jax.nn.log_softmax(logits))
176+
softmax_xent /= labels.shape[0]
177+
return softmax_xent + 1e-4 * l2_loss
178+
179+
@jax.jit
180+
def accuracy(params, features, labels):
181+
predictions = net.apply(params, features)
182+
return jnp.mean(
183+
jnp.argmax(predictions, axis=1) == jnp.argmax(labels, axis=1))
184+
185+
@jax.jit
186+
def update(params, opt_state, features, labels):
187+
grads = jax.grad(loss)(params, features, labels)
188+
updates, opt_state = opt.update(grads, opt_state)
189+
new_params = optax.apply_updates(params, updates)
190+
return new_params, opt_state
191+
192+
@jax.jit
193+
def ema_update(params, avg_params):
194+
return optax.incremental_update(params, avg_params, step_size=0.001)
195+
196+
197+
198+
199+
200+
201+
# train/eval loop.
202+
for step in range(train_steps):
203+
batch_features, batch_labels = next(batches)
204+
if step % 100 == 0:
205+
# evaluate classification accuracy on train & test sets.
206+
train_accuracy = accuracy(avg_params, batch_features, batch_labels)
207+
test_accuracy = accuracy(avg_params, test_features, test_labels)
208+
train_accuracy, test_accuracy = jax.device_get(
209+
(train_accuracy, test_accuracy))
210+
print(f"[Step {step}] Train / Test accuracy: "
211+
f"{train_accuracy:.3f} / {test_accuracy:.3f}.")
212+
213+
# update params
214+
params, opt_state = update(params, opt_state, batch_features, batch_labels)
215+
avg_params = ema_update(params, avg_params)

requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
numpy
22
jax
3-
dm-haiku
3+
dm-haiku
4+
optax

src/orthax/common.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,16 @@ def apply_orthogonal(thetas, inputs, output_size):
3737
thetas_idxs = np.cumsum(slice_sizes // 2)
3838
for start_index, slice_size, thetas_index in zip(start_idxs, slice_sizes,
3939
thetas_idxs):
40-
slice_thetas = lax.dynamic_slice_in_dim(thetas, thetas_index,
40+
# bug: it was not using all thetas, fixed below
41+
slice_thetas = lax.dynamic_slice_in_dim(thetas, thetas_index-slice_size//2,
4142
slice_size // 2, 0)
4243
slice_out = lax.dynamic_slice_in_dim(out, start_index, slice_size, -1)
4344
slice_out = lax.reshape(slice_out,
4445
(slice_out.shape[0], *slice_thetas.shape))
46+
# bug: wrong negative sign, fixed below
4547
slice_mat = jnp.array([
46-
[slice_thetas[:, 0], -slice_thetas[:, 1]],
47-
[slice_thetas[:, 1], slice_thetas[:, 0]],
48+
[slice_thetas[:, 0], slice_thetas[:, 1]],
49+
[-slice_thetas[:, 1], slice_thetas[:, 0]],
4850
]).transpose(2, 0, 1)
4951
slice_res = lax.batch_matmul(slice_out.transpose(1, 0, 2), slice_mat)
5052
slice_res = slice_res.transpose(1, 0, 2)

src/orthax/common_comments.py

+131
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import jax
2+
import numpy as np
3+
from jax import lax
4+
from jax import numpy as jnp
5+
6+
__all__ = ['apply_orthogonal']
7+
8+
9+
def apply_orthogonal(thetas, inputs, output_size):
10+
""" Applies a sequence of orthogonal transformations to a sequence of inputs.
11+
<thetas> is a sequence of scalars, with length (2 * n - d - 1) * d // 2.
12+
<inputs> is a sequence of vectors, each of shape n.
13+
<output_size> is the desired output size, equal to d.
14+
Implementaion of Figure 12 of https://arxiv.org/abs/2106.07198
15+
"""
16+
input_size = inputs.shape[-1]
17+
max_size = max(input_size, output_size)
18+
min_size = min(input_size, output_size)
19+
20+
if max_size == min_size:
21+
min_size -= 1
22+
23+
slice_end_idxs = np.concatenate([
24+
np.arange(1, max_size - 1),
25+
max_size - np.arange(1, min_size + 1)
26+
])
27+
print('debug', 'slice_end_idxs:\n', slice_end_idxs)
28+
29+
30+
slice_start_idxs = np.concatenate([
31+
np.arange(slice_end_idxs.shape[0] + min_size - max_size) % 2, # [0, 1, 0, 1, ...]
32+
np.arange(max_size - min_size)
33+
])
34+
35+
slice_sizes = slice_end_idxs - slice_start_idxs + 1
36+
37+
if input_size < output_size:
38+
slice_start_idxs = slice_start_idxs[::-1]
39+
slice_sizes = slice_sizes[::-1]
40+
41+
# add zeros to inputs
42+
out = jnp.concatenate([
43+
jnp.zeros((*inputs.shape[:-1], output_size - input_size)), inputs
44+
], axis=-1) # (batch_size, output_size)
45+
46+
else:
47+
out = inputs
48+
49+
50+
print('debug', 'slice_start_idxs:\n', slice_start_idxs)
51+
print('debug', 'slice_sizes:\n', slice_sizes)
52+
53+
thetas = jnp.stack([lax.cos(thetas), lax.sin(thetas)], -1) # (len(thetas), 2)
54+
thetas_idxs = np.cumsum(slice_sizes // 2)
55+
56+
print('debug', 'thetas_idxs:\n', thetas_idxs)
57+
print('debug', 'thetas:\n', ' cosθ sinθ')
58+
print(thetas)
59+
print()
60+
61+
cnt = 0
62+
print('debug', cnt, 'out:\n', out)
63+
for start_index, slice_size, thetas_index in zip(slice_start_idxs, slice_sizes, thetas_idxs):
64+
65+
# bug: it was not using all thetas, fixed below
66+
slice_thetas = lax.dynamic_slice_in_dim(thetas, thetas_index-slice_size//2,
67+
slice_size//2, 0) # (n_RBS, 2)
68+
print('debug', cnt, 'slice_thetas:\n', slice_thetas)
69+
70+
slice_out = lax.dynamic_slice_in_dim(out, start_index, slice_size, -1) # (batch_size, slice_size)
71+
72+
print('debug', cnt, 'slice_out:\n', slice_out)
73+
74+
# distribute input features to each RBS
75+
slice_out = lax.reshape(slice_out,
76+
(slice_out.shape[0], *slice_thetas.shape)) # (batch_size, n_RBS, 2)
77+
78+
print('debug', cnt, 'slice_out_after_reshape:\n', slice_out)
79+
slice_out = slice_out.transpose(1, 0, 2) # (n_RBS, batch_size, 2)
80+
81+
# bug: wrong negative sign, fixed below
82+
slice_mat = jnp.array([
83+
[slice_thetas[:, 0], slice_thetas[:, 1]],
84+
[-slice_thetas[:, 1], slice_thetas[:, 0]],
85+
]).transpose(2, 0, 1) # (n_RBS, 2, 2)
86+
print('debug', cnt, 'slice_mat:\n', slice_mat)
87+
print()
88+
89+
slice_res = lax.batch_matmul(slice_out, slice_mat) # (n_RBS, batch_size, 2)
90+
slice_res = slice_res.transpose(1, 0, 2) # (batch_size, n_RBS, 2)
91+
slice_res = lax.reshape(slice_res, (slice_res.shape[0], slice_size)) # (batch_size, slice_size=n_RBS*2)
92+
out = lax.dynamic_update_slice_in_dim(out, slice_res, start_index, -1)
93+
94+
cnt +=1
95+
print('debug', cnt, 'out:\n', out)
96+
97+
98+
if input_size > output_size:
99+
out = out[:, -output_size:]
100+
101+
return out
102+
103+
104+
105+
106+
107+
if __name__ == "__main__":
108+
109+
# # n=d=5, n_params = (2 * n - d - 1) * d // 2 = 10
110+
# thetas = jnp.array([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
111+
# inputs = np.array([[0.1, 0.2, 0.3, 0.4, 0.5]]) # (batch_size=1, 5)
112+
# output_size = 5
113+
114+
# # n=4 d=2, n_params = (2 * n - d - 1) * d // 2 = 5
115+
# thetas = jnp.array([1., 2., 3., 4., 5., 6.])
116+
# inputs = np.array([[0.1, 0.2, 0.3, 0.4]]) # (batch_size=1, 4)
117+
# output_size = 2
118+
119+
# # n=8 d=4, n_params = (2 * n - d - 1) * d // 2 = 22
120+
# thetas = jnp.arange(22)+1.0
121+
# inputs = np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]]) # (batch_size=1, 8)
122+
# output_size = 4
123+
124+
# n=4 d=8, n_params = (2 * d - n - 1) * n // 2 = 22
125+
thetas = jnp.arange(22)+1.0
126+
inputs = np.array([[0.1, 0.2, 0.3, 0.4]]) # (batch_size=1, 4)
127+
output_size = 8
128+
129+
out = apply_orthogonal(thetas, inputs, output_size)
130+
131+

0 commit comments

Comments
 (0)