Skip to content

Commit

Permalink
Add wasserstein-2 metric calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
peddie committed Jun 21, 2024
1 parent 544b2de commit bea9640
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions include/albatross/src/evaluation/prediction_metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,30 @@ struct ChiSquaredCdf : public PredictionMetric<JointDistribution> {
ChiSquaredCdf() : PredictionMetric<JointDistribution>(chi_squared_cdf){};
};

namespace distance {

namespace detail {

inline Eigen::MatrixXd principal_sqrt(const Eigen::MatrixXd &input) {
const Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> eigs(input);
return eigs.eigenvectors() *
eigs.eigenvalues().array().sqrt().matrix().asDiagonal() *
eigs.eigenvectors().transpose();
}

} // namespace detail

inline double wasserstein_2(const JointDistribution &a,
const JointDistribution &b) {
auto b_sqrt{detail::principal_sqrt(b.covariance)};
return (a.mean - b.mean).squaredNorm() +
(a.covariance + b.covariance -
2 * detail::principal_sqrt(b_sqrt * a.covariance * b_sqrt))
.trace();
}

} // namespace distance

} // namespace albatross

#endif /* ALBATROSS_EVALUATION_PREDICTION_METRICS_H_ */

0 comments on commit bea9640

Please sign in to comment.