Skip to content

Commit

Permalink
New Rebase Approach
Browse files Browse the repository at this point in the history
  • Loading branch information
akleeman committed Feb 2, 2024
1 parent 8e2f132 commit 9617994
Showing 1 changed file with 75 additions and 4 deletions.
79 changes: 75 additions & 4 deletions include/albatross/src/models/sparse_gp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -703,15 +703,86 @@ class SparseGaussianProcessRegression

// rebase_inducing_points takes a Sparse GP which was fit using some set of
// inducing points and creates a new fit relative to new inducing points.
//
// Note that this will NOT be the equivalent to having fit the model with
// the new inducing points since some information may have been lost in
// the process.
template <typename ModelType, typename FeatureType, typename NewFeatureType>
//
// For example, consider the extreme case where your first fit
// doesn't have any inducing points at all, all the information from the first
// observations will have been lost, and when you rebase on new inducing points
// you'd have the prior for those new points.
//
// For implementation details see the online documentation.
//
// The summary involves:
// - Compute K_nn = cov(new, new)
// - Compute K_pn = cov(prev, new)
// - Compute A = L_pp^-1 K_pn
// - Solve for Lhat_nn = chol(K_nn - A^T A)
// - Solve for QRP^T = [Lat_nn
// R_p P_p^T L_pp^-T A]
// - Solve for L_nn = chol(K_nn)
// - Solve for v_n = K_nn^-1 K_np v_p
//
template <typename CovFunc, typename MeanFunc, typename GrouperFunction,
typename InducingPointStrategy, typename QRImplementation,
typename FeatureType, typename NewFeatureType>
auto rebase_inducing_points(
const FitModel<ModelType, Fit<SparseGPFit<FeatureType>>> &fit_model,
const FitModel<SparseGaussianProcessRegression<
CovFunc, MeanFunc, GrouperFunction,
InducingPointStrategy, QRImplementation>,
Fit<SparseGPFit<FeatureType>>> &fit_model,
const std::vector<NewFeatureType> &new_inducing_points) {
return fit_model.get_model().fit_from_prediction(
new_inducing_points, fit_model.predict(new_inducing_points).joint());

const auto &cov = fit_model.get_model().get_covariance();
// Compute K_nn = cov(new, new)
const Eigen::MatrixXd K_nn =
cov(new_inducing_points, fit_model.get_model().threads_.get());

// Compute K_pn = cov(prev, new)
const Fit<SparseGPFit<FeatureType>> &prev_fit = fit_model.get_fit();
const auto &prev_inducing_points = prev_fit.train_features;
const Eigen::MatrixXd K_pn = cov(prev_inducing_points, new_inducing_points,
fit_model.get_model().threads_.get());
// A = L_pp^-1 K_pn
const Eigen::MatrixXd A = prev_fit.train_covariance.sqrt_solve(K_pn);
const Eigen::Index p = K_pn.rows();
const Eigen::Index n = K_nn.rows();
Eigen::MatrixXd B = Eigen::MatrixXd::Zero(n + p, n);

// B[upper] = R P^T L_pp^-T A
const auto LTiA = prev_fit.train_covariance.sqrt_transpose_solve(A);
B.topRows(p) = prev_fit.R.template triangularView<Eigen::Upper>() *
(prev_fit.P.transpose() * LTiA);

// B[lower] = chol(K_nn - A^T A)^T
Eigen::MatrixXd S_nn = K_nn - A.transpose() * A;
// This cholesky operation here is the most likely to experience numerical
// instability because of the A^T A subtraction involved, so we add a nugget.
const double nugget =
fit_model.get_model().get_params()[details::inducing_nugget_name()].value;
assert(nugget >= 0);
S_nn.diagonal() += Eigen::VectorXd::Constant(S_nn.rows(), nugget);
B.bottomRows(n) = Eigen::SerializableLDLT(S_nn).sqrt_transpose();

const auto B_qr =
QRImplementation::compute(B, fit_model.get_model().threads_.get());

Fit<SparseGPFit<FeatureType>> new_fit;
new_fit.train_features = new_inducing_points;
new_fit.train_covariance = Eigen::SerializableLDLT(K_nn);
// v_n = K_nn^-1 K_np v_p
new_fit.information = new_fit.train_covariance.solve(
fit_model.predict(new_inducing_points).mean());
new_fit.P = get_P(*B_qr);
new_fit.R = get_R(*B_qr);
new_fit.numerical_rank = B_qr->rank();

return FitModel<
SparseGaussianProcessRegression<CovFunc, MeanFunc, GrouperFunction,
InducingPointStrategy, QRImplementation>,
Fit<SparseGPFit<FeatureType>>>(fit_model.get_model(), std::move(new_fit));
}

template <typename CovFunc, typename MeanFunc, typename GrouperFunction,
Expand Down

0 comments on commit 9617994

Please sign in to comment.