Skip to content

Commit

Permalink
Use PermutationMatrix instead of indices (#475)
Browse files Browse the repository at this point in the history
  • Loading branch information
akleeman authored Feb 5, 2024
1 parent c0e3b06 commit 2b99854
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 64 deletions.
23 changes: 23 additions & 0 deletions include/albatross/src/cereal/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,29 @@ inline void load(Archive &archive,
v.indices() = indices;
}

template <class Archive, int SizeAtCompileTime, int MaxSizeAtCompileTime,
typename _StorageIndex>
inline void
save(Archive &archive,
const Eigen::PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime,
_StorageIndex> &v,
const std::uint32_t) {
archive(cereal::make_nvp("indices", v.indices()));
}

template <class Archive, int SizeAtCompileTime, int MaxSizeAtCompileTime,
typename _StorageIndex>
inline void
load(Archive &archive,
Eigen::PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime,
_StorageIndex> &v,
const std::uint32_t) {
typename Eigen::PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime,
_StorageIndex>::IndicesType indices;
archive(cereal::make_nvp("indices", indices));
v.indices() = indices;
}

template <typename Archive, typename _Scalar, int SizeAtCompileTime>
inline void serialize(Archive &archive,
Eigen::DiagonalMatrix<_Scalar, SizeAtCompileTime> &matrix,
Expand Down
16 changes: 8 additions & 8 deletions include/albatross/src/cereal/gp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ inline void serialize(Archive &archive, Fit<SparseGPFit<FeatureType>> &fit,
archive(cereal::make_nvp("information", fit.information));
archive(cereal::make_nvp("train_covariance", fit.train_covariance));
archive(cereal::make_nvp("train_features", fit.train_features));
archive(cereal::make_nvp("sigma_R", fit.sigma_R));
archive(cereal::make_nvp("permutation_indices", fit.permutation_indices));
archive(cereal::make_nvp("R", fit.R));
archive(cereal::make_nvp("P", fit.P));
if (version > 1) {
archive(cereal::make_nvp("numerical_rank", fit.numerical_rank));
} else {
Expand All @@ -53,19 +53,19 @@ inline void serialize(Archive &archive, Fit<SparseGPFit<FeatureType>> &fit,

template <typename Archive, typename CovFunc, typename MeanFunc,
typename ImplType>
void save(Archive &archive,
const GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
const std::uint32_t) {
inline void save(Archive &archive,
const GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
const std::uint32_t) {
archive(cereal::make_nvp("name", gp.get_name()));
archive(cereal::make_nvp("params", gp.get_params()));
archive(cereal::make_nvp("insights", gp.insights));
}

template <typename Archive, typename CovFunc, typename MeanFunc,
typename ImplType>
void load(Archive &archive,
GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
const std::uint32_t version) {
inline void load(Archive &archive,
GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
const std::uint32_t version) {
if (version > 0) {
std::string model_name;
archive(cereal::make_nvp("name", model_name));
Expand Down
7 changes: 7 additions & 0 deletions include/albatross/src/core/declarations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ template <typename... Ts> class variant;

using mapbox::util::variant;

/*
* Permutations
*/
namespace Eigen {
using PermutationMatrixX = PermutationMatrix<Dynamic, Dynamic, Index>;
}

namespace albatross {

/*
Expand Down
22 changes: 12 additions & 10 deletions include/albatross/src/linalg/qr_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,31 @@ get_R(const Eigen::ColPivHouseholderQR<Eigen::MatrixXd> &qr) {
.template triangularView<Eigen::Upper>();
}

inline Eigen::PermutationMatrixX
get_P(const Eigen::ColPivHouseholderQR<Eigen::MatrixXd> &qr) {
return Eigen::PermutationMatrixX(
qr.colsPermutation().indices().template cast<Eigen::Index>());
}

/*
* Computes R^-T P^T rhs given R and P from a QR decomposition.
*/
template <typename MatrixType, typename PermutationIndicesType>
template <typename MatrixType, typename PermutationScalar>
inline Eigen::MatrixXd
sqrt_solve(const Eigen::MatrixXd &R,
const PermutationIndicesType &permutation_indices,
const Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic,
PermutationScalar> &P,
const MatrixType &rhs) {

Eigen::MatrixXd sqrt(rhs.rows(), rhs.cols());
for (Eigen::Index i = 0; i < permutation_indices.size(); ++i) {
sqrt.row(i) = rhs.row(permutation_indices.coeff(i));
}
sqrt = R.template triangularView<Eigen::Upper>().transpose().solve(sqrt);
return sqrt;
return R.template triangularView<Eigen::Upper>().transpose().solve(
P.transpose() * rhs);
}

template <typename MatrixType>
inline Eigen::MatrixXd
sqrt_solve(const Eigen::ColPivHouseholderQR<Eigen::MatrixXd> &qr,
const MatrixType &rhs) {
const Eigen::MatrixXd R = get_R(qr);
return sqrt_solve(R, qr.colsPermutation().indices(), rhs);
return sqrt_solve(R, qr.colsPermutation(), rhs);
}

} // namespace albatross
Expand Down
11 changes: 6 additions & 5 deletions include/albatross/src/linalg/spqr_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@ using SparseMatrix = Eigen::SparseMatrix<double>;

using SPQR = Eigen::SPQR<SparseMatrix>;

using SparsePermutationMatrix =
Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic,
SPQR::StorageIndex>;

inline Eigen::MatrixXd get_R(const SPQR &qr) {
return qr.matrixR()
.topLeftCorner(qr.cols(), qr.cols())
.template triangularView<Eigen::Upper>();
}

inline Eigen::PermutationMatrixX get_P(const SPQR &qr) {
return Eigen::PermutationMatrixX(
qr.colsPermutation().indices().template cast<Eigen::Index>());
}

template <typename MatrixType>
inline Eigen::MatrixXd sqrt_solve(const SPQR &qr, const MatrixType &rhs) {
return sqrt_solve(get_R(qr), qr.colsPermutation().indices(), rhs);
return sqrt_solve(get_R(qr), get_P(qr), rhs);
}

// Matrices with any dimension smaller than this will use a special
Expand Down
59 changes: 22 additions & 37 deletions include/albatross/src/models/sparse_gp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,19 @@ template <typename FeatureType> struct Fit<SparseGPFit<FeatureType>> {

std::vector<FeatureType> train_features;
Eigen::SerializableLDLT train_covariance;
Eigen::MatrixXd sigma_R;
PermutationIndices permutation_indices;
Eigen::MatrixXd R;
Eigen::PermutationMatrixX P;
Eigen::VectorXd information;
Eigen::Index numerical_rank;

Fit(){};

Fit(const std::vector<FeatureType> &features_,
const Eigen::SerializableLDLT &train_covariance_,
const Eigen::MatrixXd &sigma_R_,
PermutationIndices &&permutation_indices_,
const Eigen::MatrixXd &R_, const Eigen::PermutationMatrixX &P_,
const Eigen::VectorXd &information_, Eigen::Index numerical_rank_)
: train_features(features_), train_covariance(train_covariance_),
sigma_R(sigma_R_), permutation_indices(std::move(permutation_indices_)),
information(information_), numerical_rank(numerical_rank_) {}
: train_features(features_), train_covariance(train_covariance_), R(R_),
P(P_), information(information_), numerical_rank(numerical_rank_) {}

void shift_mean(const Eigen::VectorXd &mean_shift) {
ALBATROSS_ASSERT(mean_shift.size() == information.size());
Expand All @@ -120,9 +118,8 @@ template <typename FeatureType> struct Fit<SparseGPFit<FeatureType>> {

bool operator==(const Fit<SparseGPFit<FeatureType>> &other) const {
return (train_features == other.train_features &&
train_covariance == other.train_covariance &&
sigma_R == other.sigma_R &&
permutation_indices == other.permutation_indices &&
train_covariance == other.train_covariance && R == other.R &&
P.indices() == other.P.indices() &&
information == other.information &&
numerical_rank == other.numerical_rank);
}
Expand Down Expand Up @@ -325,20 +322,17 @@ class SparseGaussianProcessRegression
compute_internal_components(old_fit.train_features, features, targets,
&A_ldlt, &K_uu_ldlt, &K_fu, &y);

const Eigen::Index n_old = old_fit.sigma_R.rows();
const Eigen::Index n_old = old_fit.R.rows();
const Eigen::Index n_new = A_ldlt.rows();
const Eigen::Index k = old_fit.sigma_R.cols();
const Eigen::Index k = old_fit.R.cols();
Eigen::MatrixXd B = Eigen::MatrixXd::Zero(n_old + n_new, k);

ALBATROSS_ASSERT(n_old == k);

// Form:
// B = |R_old P_old^T| = |Q_1| R P^T
// |A^{-1/2} K_fu| |Q_2|
for (Eigen::Index i = 0; i < old_fit.permutation_indices.size(); ++i) {
const Eigen::Index &pi = old_fit.permutation_indices.coeff(i);
B.col(pi).topRows(i + 1) = old_fit.sigma_R.col(i).topRows(i + 1);
}
B.topRows(old_fit.P.rows()) = old_fit.R * old_fit.P.transpose();
B.bottomRows(n_new) = A_ldlt.sqrt_solve(K_fu);
const auto B_qr = QRImplementation::compute(B, Base::threads_.get());

Expand All @@ -347,13 +341,9 @@ class SparseGaussianProcessRegression
// |A^{-1/2} y |
ALBATROSS_ASSERT(old_fit.information.size() == n_old);
Eigen::VectorXd y_augmented(n_old + n_new);
for (Eigen::Index i = 0; i < old_fit.permutation_indices.size(); ++i) {
y_augmented[i] =
old_fit.information[old_fit.permutation_indices.coeff(i)];
}
y_augmented.topRows(n_old) =
old_fit.sigma_R.template triangularView<Eigen::Upper>() *
y_augmented.topRows(n_old);
old_fit.R.template triangularView<Eigen::Upper>() *
(old_fit.P.transpose() * old_fit.information);

y_augmented.bottomRows(n_new) = A_ldlt.sqrt_solve(y, Base::threads_.get());
const Eigen::VectorXd v = B_qr->solve(y_augmented);
Expand All @@ -365,10 +355,9 @@ class SparseGaussianProcessRegression
Eigen::VectorXd::Constant(B_qr->cols(), details::cSparseRNugget);
}
using FitType = Fit<SparseGPFit<InducingPointFeatureType>>;
return FitType(
old_fit.train_features, old_fit.train_covariance, R,
B_qr->colsPermutation().indices().template cast<Eigen::Index>(), v,
B_qr->rank());

return FitType(old_fit.train_features, old_fit.train_covariance, R,
get_P(*B_qr), v, B_qr->rank());
}

// Here we create the QR decomposition of:
Expand Down Expand Up @@ -415,10 +404,7 @@ class SparseGaussianProcessRegression
using InducingPointFeatureType = typename std::decay<decltype(u[0])>::type;

using FitType = Fit<SparseGPFit<InducingPointFeatureType>>;
return FitType(
u, K_uu_ldlt, get_R(*B_qr),
B_qr->colsPermutation().indices().template cast<Eigen::Index>(), v,
B_qr->rank());
return FitType(u, K_uu_ldlt, get_R(*B_qr), get_P(*B_qr), v, B_qr->rank());
}

template <typename FeatureType>
Expand Down Expand Up @@ -471,9 +457,8 @@ class SparseGaussianProcessRegression
const Eigen::MatrixXd sigma_inv_sqrt = C_ldlt.sqrt_solve(K_zz);
const auto B_qr = QRImplementation::compute(sigma_inv_sqrt, nullptr);

new_fit.permutation_indices =
B_qr->colsPermutation().indices().template cast<Eigen::Index>();
new_fit.sigma_R = get_R(*B_qr);
new_fit.P = get_P(*B_qr);
new_fit.R = get_R(*B_qr);
new_fit.numerical_rank = B_qr->rank();

return output;
Expand Down Expand Up @@ -519,8 +504,8 @@ class SparseGaussianProcessRegression
Q_sqrt.cwiseProduct(Q_sqrt).array().colwise().sum();
marginal_variance -= Q_diag;

const Eigen::MatrixXd S_sqrt = sqrt_solve(
sparse_gp_fit.sigma_R, sparse_gp_fit.permutation_indices, cross_cov);
const Eigen::MatrixXd S_sqrt =
sqrt_solve(sparse_gp_fit.R, sparse_gp_fit.P, cross_cov);
const Eigen::VectorXd S_diag =
S_sqrt.cwiseProduct(S_sqrt).array().colwise().sum();
marginal_variance += S_diag;
Expand All @@ -537,8 +522,8 @@ class SparseGaussianProcessRegression
this->covariance_function_(sparse_gp_fit.train_features, features);
const Eigen::MatrixXd prior_cov = this->covariance_function_(features);

const Eigen::MatrixXd S_sqrt = sqrt_solve(
sparse_gp_fit.sigma_R, sparse_gp_fit.permutation_indices, cross_cov);
const Eigen::MatrixXd S_sqrt =
sqrt_solve(sparse_gp_fit.R, sparse_gp_fit.P, cross_cov);

const Eigen::MatrixXd Q_sqrt =
sparse_gp_fit.train_covariance.sqrt_solve(cross_cov);
Expand Down
8 changes: 4 additions & 4 deletions tests/test_sparse_gp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,10 @@ TYPED_TEST(SparseGaussianProcessTest, test_update) {
(updated_in_place_pred.covariance - full_pred.covariance).norm();

auto compute_sigma = [](const auto &fit_model) -> Eigen::MatrixXd {
const Eigen::Index n = fit_model.get_fit().sigma_R.cols();
Eigen::MatrixXd sigma = sqrt_solve(fit_model.get_fit().sigma_R,
fit_model.get_fit().permutation_indices,
Eigen::MatrixXd::Identity(n, n));
const Eigen::Index n = fit_model.get_fit().R.cols();
Eigen::MatrixXd sigma =
sqrt_solve(fit_model.get_fit().R, fit_model.get_fit().P,
Eigen::MatrixXd::Identity(n, n));
return sigma.transpose() * sigma;
};

Expand Down

0 comments on commit 2b99854

Please sign in to comment.