Skip to content

Commit 455b5cc

Browse files
committed
Noise contrastive priors
1 parent f3ae767 commit 455b5cc

File tree

6 files changed

+1031
-1
lines changed

6 files changed

+1031
-1
lines changed

README.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,11 @@
33
[![DOI](https://zenodo.org/badge/125869131.svg)](https://zenodo.org/badge/latestdoi/125869131)
44

55
This repository is a collection of notebooks related to *Bayesian Machine Learning*. The following links display
6-
the notebooks via [nbviewer](https://nbviewer.jupyter.org/) to ensure a proper rendering of formulas.
6+
some of the notebooks via [nbviewer](https://nbviewer.jupyter.org/) to ensure a proper rendering of formulas.
7+
8+
- [Reliable uncertainty estimates for neural network predictions](https://github.com/krasserm/bayesian-machine-learning/blob/dev/noise-contrastive-priors/ncp.ipynb).
9+
Applies noise contrastive priors to Bayesian neural networks to get more reliable uncertainty estimates for OOD data.
10+
Implemented with Tensorflow 2 and Tensorflow Probability.
711

812
- [Variational inference in Bayesian neural networks](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/dev/bayesian-neural-networks/bayesian_neural_networks.ipynb).
913
Demonstrates how to implement a Bayesian neural network and variational inference of network parameters. Example implementation
Loading
Loading

noise-contrastive-priors/ncp.ipynb

+870
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
scipy==1.4.1
2+
tensorflow==2.3.0
3+
tensorflow-probability==0.11.0
4+
seaborn==0.11.0

noise-contrastive-priors/utils.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
import matplotlib.pyplot as plt
4+
5+
6+
# ------------------------------------------
7+
# Data
8+
# ------------------------------------------
9+
10+
11+
def select_bands(x, y, mask):
12+
assert x.shape[0] == y.shape[0]
13+
14+
num_bands = len(mask)
15+
16+
if x.shape[0] % num_bands != 0:
17+
raise ValueError('size of first dimension must be a multiple of mask length')
18+
19+
data_mask = np.repeat(mask, x.shape[0] // num_bands)
20+
return [arr[data_mask] for arr in (x, y)]
21+
22+
23+
def select_subset(x, y, num, rng=np.random):
24+
assert x.shape[0] == y.shape[0]
25+
26+
choices = rng.choice(range(x.shape[0]), num, replace=False)
27+
return [x[choices] for x in (x, y)]
28+
29+
30+
# ------------------------------------------
31+
# Training
32+
# ------------------------------------------
33+
34+
35+
def data_loader(x, y, batch_size, shuffle=True):
36+
ds = tf.data.Dataset.from_tensor_slices((x, y))
37+
if shuffle:
38+
ds = ds.shuffle(x.shape[0])
39+
return ds.batch(batch_size)
40+
41+
42+
def scheduler(decay_steps, decay_rate=0.5, lr=1e-3):
43+
return tf.keras.optimizers.schedules.ExponentialDecay(
44+
initial_learning_rate=lr,
45+
decay_steps=decay_steps,
46+
decay_rate=decay_rate)
47+
48+
49+
def optimizer(lr):
50+
return tf.optimizers.Adam(learning_rate=lr)
51+
52+
53+
def backprop(model, loss, tape):
54+
trainable_vars = model.trainable_variables
55+
gradients = tape.gradient(loss, trainable_vars)
56+
return zip(gradients, trainable_vars)
57+
58+
59+
def train(model, x, y,
60+
batch_size,
61+
epochs,
62+
step_fn,
63+
optimizer_fn=optimizer,
64+
scheduler_fn=scheduler,
65+
verbose=1,
66+
verbose_every=1000):
67+
steps_per_epoch = int(np.ceil(x.shape[0] / batch_size))
68+
steps = epochs * steps_per_epoch
69+
70+
scheduler = scheduler_fn(steps)
71+
optimizer = optimizer_fn(scheduler)
72+
73+
loss_tracker = tf.keras.metrics.Mean(name='loss')
74+
mse_tracker = tf.keras.metrics.MeanSquaredError(name='mse')
75+
76+
loader = data_loader(x, y, batch_size=batch_size)
77+
78+
for epoch in range(1, epochs + 1):
79+
for x_batch, y_batch in loader:
80+
loss, y_pred = step_fn(model, optimizer, x_batch, y_batch)
81+
82+
loss_tracker.update_state(loss)
83+
mse_tracker.update_state(y_batch, y_pred)
84+
85+
if verbose and epoch % verbose_every == 0:
86+
print(f'epoch {epoch}: loss = {loss_tracker.result():.3f}, mse = {mse_tracker.result():.3f}')
87+
loss_tracker.reset_states()
88+
mse_tracker.reset_states()
89+
90+
91+
# ------------------------------------------
92+
# Visualization
93+
# ------------------------------------------
94+
95+
96+
style = {
97+
'bg_line': {'ls': '--', 'c': 'black', 'lw': 1.0, 'alpha': 0.5},
98+
'fg_data': {'marker': '.', 'c': 'red', 'lw': 1.0, 'alpha': 1.0},
99+
'bg_data': {'marker': '.', 'c': 'gray', 'lw': 0.2, 'alpha': 0.2},
100+
'pred_sample': {'marker': 'x', 'c': 'blue', 'lw': 0.6, 'alpha': 0.5},
101+
'pred_mean': {'ls': '-', 'c': 'blue', 'lw': 1.0},
102+
'a_unc': {'color': 'lightgreen'},
103+
'e_unc': {'color': 'orange'},
104+
}
105+
106+
107+
def plot_data(x_train, y_train, x=None, y=None):
108+
if x is not None and y is not None:
109+
plt.plot(x, y, **style['bg_line'], label='f')
110+
plt.scatter(x_train, y_train, **style['fg_data'], label='Train data')
111+
plt.xlabel('x')
112+
plt.ylabel('y')
113+
114+
115+
def plot_prediction(x, y_mean, y_samples=None, aleatoric_uncertainty=None, epistemic_uncertainty=None):
116+
x, y_mean, y_samples, epistemic_uncertainty, aleatoric_uncertainty = \
117+
flatten(x, y_mean, y_samples, epistemic_uncertainty, aleatoric_uncertainty)
118+
119+
plt.plot(x, y_mean, **style['pred_mean'], label='Expected output')
120+
121+
if y_samples is not None:
122+
plt.scatter(x, y_samples, **style['pred_sample'], label='Predictive samples')
123+
124+
if aleatoric_uncertainty is not None:
125+
plt.fill_between(x,
126+
y_mean + 2 * aleatoric_uncertainty,
127+
y_mean - 2 * aleatoric_uncertainty,
128+
**style['a_unc'], alpha=0.3, label='Aleatoric uncertainty')
129+
130+
if epistemic_uncertainty is not None:
131+
plt.fill_between(x,
132+
y_mean + 2 * epistemic_uncertainty,
133+
y_mean - 2 * epistemic_uncertainty,
134+
**style['e_unc'], alpha=0.3, label='Epistemic uncertainty')
135+
136+
137+
def plot_uncertainty(x, aleatoric_uncertainty, epistemic_uncertainty=None):
138+
plt.plot(x, aleatoric_uncertainty, **style['a_unc'], label='Aleatoric uncertainty')
139+
140+
if epistemic_uncertainty is not None:
141+
plt.plot(x, epistemic_uncertainty, **style['e_unc'], label='Epistemic uncertainty')
142+
143+
plt.xlabel('x')
144+
plt.ylabel('Uncertainty')
145+
146+
147+
def flatten(*ts):
148+
def _flatten(t):
149+
if t is not None:
150+
return tf.reshape(t, -1)
151+
152+
return [_flatten(t) for t in ts]

0 commit comments

Comments
 (0)