Skip to content

Commit

Permalink
add kfoldnn example
Browse files Browse the repository at this point in the history
  • Loading branch information
jkitchin committed Jan 14, 2025
1 parent c60c217 commit 2629559
Showing 1 changed file with 209 additions and 0 deletions.
209 changes: 209 additions & 0 deletions pycse/sklearn/kfoldnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
"""A K-fold Neural network model in jax.
The idea of the k-fold model is that you train each neuron in the last layer on
a different fold of data. Then, at inference time you get a distribution of
predictions that you can use for uncertainty quantification.
The main hyperparameter that affects the distribution is the fraction of data
used. Empirically I find that a fraction of 0.1 works pretty well. Note that the
neurons before the last layer all end up seeing all the data, it is only the
last layer that sees different parts of the data. If you use a fraction of 1.0,
then each neuron converges to the same result.
There isn't currently an obvious way to choose a fraction that leads to the
"right" UQ distribution. You can try many values and see what works best.
Example usage:
import jax
import numpy as np
import matplotlib.pyplot as plt
key = jax.random.PRNGKey(19)
x = np.linspace(0, 1, 100)[:, None]
y = x**(1/3) + (1 + jax.random.normal(key, x.shape) * 0.05)
from pycse.sklearn.kfoldnn import KfoldNN
model = KfoldNN((1, 15, 25), xtrain=0.1)
model.fit(x, y)
model.report()
print(model.score(x, y))
model.plot(x, y, distribution=True);
"""

import os
import jax


from jax import jit
import jax.numpy as np
from jax import value_and_grad
from jaxopt import LBFGS
import matplotlib.pyplot as plt
from sklearn.base import BaseEstimator, RegressorMixin
from flax import linen as nn

os.environ["JAX_ENABLE_X64"] = "True"
jax.config.update("jax_enable_x64", True)


class _NN(nn.Module):
"""A flax neural network.
layers: a Tuple of integers. Each integer is the number of neurons in that
layer.
"""

layers: tuple

@nn.compact
def __call__(self, x):
for i in self.layers[0:-1]:
x = nn.Dense(i)(x)
x = nn.swish(x)

# Linear last layer
x = nn.Dense(self.layers[-1])(x)
return x


class KfoldNN(BaseEstimator, RegressorMixin):
"""sklearn compatible model for a k-fold neural network."""

def __init__(self, layers, xtrain=0.1, seed=19):
"""Initialize a k-fold nn.
args:
layers : tuple of integers for neurons in each layer
xtrain: fraction of data to use in each fold.
"""
self.layers = layers
self.key = jax.random.PRNGKey(seed)
self.nn = _NN(layers)
self.xtrain = xtrain

def fit(self, X, y, **kwargs):
"""Fit the kfold nn.
Args:
X : a 2d array of x values
y : an array of y values.
kwargs are passed to the LBGF solver.
"""
# This allows retraining.
if not hasattr(self, "optpars"):
params = self.nn.init(self.key, X)
else:
params = self.optpars

last_layer = f"Dense_{len(self.layers) - 1}"
w = params["params"][last_layer]["kernel"].shape
N = w[-1] # number of functions in the last layer

folds = jax.random.permutation(
self.key, np.tile(np.arange(0, len(X))[:, None], N), axis=0, independent=True
).T

# make a smooth, differentiable cutoff
fx = np.arange(0, len(X))

# We use fy to mask out the errors for the dataset we don't want
_y = len(X) / 2 * (fx - len(X) * self.xtrain)
fy = 1 - 0.5 * (np.tanh(_y / 2) + 1)

@jit
def objective(pars):
agge = 0

for i, fold in enumerate(folds):
# predict for a fold
P = self.nn.apply(pars, np.asarray(X)[fold])
errs = (P - y[fold])[:, i] * fy # errors for this fold

mae = np.mean(np.abs(errs)) # MAE for the fold
agge += mae
return agge

if "maxiter" not in kwargs:
kwargs["maxiter"] = 1500

if "tol" not in kwargs:
kwargs["tol"] = 1e-3

solver = LBFGS(fun=value_and_grad(objective), value_and_grad=True, **kwargs)

self.optpars, self.state = solver.run(params)

def report(self):
"""Print the state variables."""
print(f"Iterations: {self.state.iter_num} Value: {self.state.value}")

def predict(self, X, return_std=False):
"""Predict the model for X.
Args:
X: a 2d array of points to predict
return_std: Boolean, if true, return error estimate for each point.
Returns:
if return_std is False, the predictions, else (predictions, errors)
"""
X = np.atleast_2d(X)
P = self.nn.apply(self.optpars, X)

if return_std:
return np.mean(P, axis=1), np.std(P, axis=1)
else:
return np.mean(P, axis=1)

def __call__(self, X, return_std=False, distribution=False):
"""Execute the model.
Args:
X: a 2d array to make predictions for.
return_std: Boolean, if true return errors for each point
distribution: Boolean, if true return the distribution, else the mean.
"""
if not hasattr(self, "optpars"):
raise Exception("You need to fit the model first.")

# get predictions
P = self.nn.apply(self.optpars, X)
se = P.std(axis=1)
if not distribution:
P = P.mean(axis=1)

if return_std:
return (P, se)
else:
return P

def plot(self, X, y, distribution=False):
"""Return a plot.
Args:
X: 2d array of data
y: corresponding y-values
distribution: Boolean, if true, plot the distribution of predictions.
"""
P = self.nn.apply(self.optpars, X)
mp = P.mean(axis=1)
se = P.std(axis=1)

plt.plot(X, y, "b.", label="data")
plt.plot(X, mp, label="mean")
plt.plot(X, mp + 2 * se, "k--")
plt.plot(X, mp - 2 * se, "k--", label="+/- 2sd")
if distribution:
plt.plot(X, P, alpha=0.2)
plt.xlabel("X")
plt.ylabel("y")
plt.legend()
return plt.gcf()

0 comments on commit 2629559

Please sign in to comment.