diff --git a/include/albatross/src/evaluation/prediction_metrics.hpp b/include/albatross/src/evaluation/prediction_metrics.hpp index 29ee165f..f0543d72 100644 --- a/include/albatross/src/evaluation/prediction_metrics.hpp +++ b/include/albatross/src/evaluation/prediction_metrics.hpp @@ -145,6 +145,30 @@ struct ChiSquaredCdf : public PredictionMetric { ChiSquaredCdf() : PredictionMetric(chi_squared_cdf){}; }; +namespace distance { + +namespace detail { + +inline Eigen::MatrixXd principal_sqrt(const Eigen::MatrixXd &input) { + const Eigen::SelfAdjointEigenSolver eigs(input); + return eigs.eigenvectors() * + eigs.eigenvalues().array().sqrt().matrix().asDiagonal() * + eigs.eigenvectors().transpose(); +} + +} // namespace detail + +inline double wasserstein_2(const JointDistribution &a, + const JointDistribution &b) { + auto b_sqrt{detail::principal_sqrt(b.covariance)}; + return (a.mean - b.mean).squaredNorm() + + (a.covariance + b.covariance - + 2 * detail::principal_sqrt(b_sqrt * a.covariance * b_sqrt)) + .trace(); +} + +} // namespace distance + } // namespace albatross #endif /* ALBATROSS_EVALUATION_PREDICTION_METRICS_H_ */