Skip to content

Commit

Permalink
Add wasserstein-2 metric calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
peddie committed Jun 25, 2024
1 parent 544b2de commit a8d7d1a
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 0 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,7 @@ swift_cc_test(
linkopts = ["-lz"],
local_defines = ["CSV_IO_NO_THREAD"],
type = UNIT,
size = "large",
deps = [
":albatross",
":serialize-testsuite",
Expand Down
24 changes: 24 additions & 0 deletions include/albatross/src/evaluation/prediction_metrics.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,30 @@ struct ChiSquaredCdf : public PredictionMetric<JointDistribution> {
ChiSquaredCdf() : PredictionMetric<JointDistribution>(chi_squared_cdf){};
};

namespace distance {

namespace detail {

inline Eigen::MatrixXd principal_sqrt(const Eigen::MatrixXd &input) {
const Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> eigs(input);
return eigs.eigenvectors() *
eigs.eigenvalues().array().sqrt().matrix().asDiagonal() *
eigs.eigenvectors().transpose();
}

} // namespace detail

inline double wasserstein_2(const JointDistribution &a,
const JointDistribution &b) {
auto b_sqrt{detail::principal_sqrt(b.covariance)};
return (a.mean - b.mean).squaredNorm() +
(a.covariance + b.covariance -
2 * detail::principal_sqrt(b_sqrt * a.covariance * b_sqrt))
.trace();
}

} // namespace distance

} // namespace albatross

#endif /* ALBATROSS_EVALUATION_PREDICTION_METRICS_H_ */
98 changes: 98 additions & 0 deletions tests/test_stats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <albatross/Common>
#include <albatross/Distribution>
#include <albatross/Evaluation>
#include <albatross/Stats>
#include <albatross/utils/RandomUtils>

Expand Down Expand Up @@ -172,4 +173,101 @@ TEST(test_stats, test_chi_squared_cdf_monotonic_1d) {
}
}

template <typename RandomNumberGenerator>
JointDistribution random_distribution(Eigen::Index dimension,
RandomNumberGenerator &gen) {
const auto covariance = random_covariance_matrix(dimension, gen);
Eigen::VectorXd mean(dimension);
gaussian_fill(mean, gen);

return {mean, covariance};
}

static constexpr Eigen::Index cDistributionDimension = 30;
static constexpr std::size_t cNumIterations = 10000;

// The Wasserstein distance between a distribution and itself should
// be zero to within numerical precision.
TEST(test_stats, test_wasserstein_zero) {
std::default_random_engine gen(2222);

for (std::size_t iter = 0; iter < cNumIterations; ++iter) {
const Eigen::Index dimension = std::uniform_int_distribution<Eigen::Index>(
1, cDistributionDimension)(gen);
const auto dist = random_distribution(dimension, gen);

EXPECT_LT(distance::wasserstein_2(dist, dist),
1.e-12 * dist.covariance.trace() +
1.e-12 * dist.mean.squaredNorm());
}
}

// The Wasserstein distance between two distributions should aways be
// nonnegative.
TEST(test_stats, test_wasserstein_nonnegative) {
std::default_random_engine gen(2222);

for (std::size_t iter = 0; iter < cNumIterations; ++iter) {
const Eigen::Index dimension = std::uniform_int_distribution<Eigen::Index>(
1, cDistributionDimension)(gen);
const auto dist_a = random_distribution(dimension, gen);
const auto dist_b = random_distribution(dimension, gen);

EXPECT_GE(distance::wasserstein_2(dist_a, dist_b), 0);
}
}

// If two distributions differ only in their mean, then the
// Wasserstein 2-distance should differ according to the square of the
// distance between means (i.e. the Wasserstein distance has the same
// units as the mean).
TEST(test_stats, test_wasserstein_shift) {
std::default_random_engine gen(2222);

for (std::size_t iter = 0; iter < cNumIterations; ++iter) {
const Eigen::Index dimension = std::uniform_int_distribution<Eigen::Index>(
1, cDistributionDimension)(gen);
auto dist_a = random_distribution(dimension, gen);
auto dist_b = dist_a;
gaussian_fill(dist_b.mean, gen);

const double distance = distance::wasserstein_2(dist_a, dist_b);

const double mean_distance = (dist_a.mean - dist_b.mean).squaredNorm();

EXPECT_LT(distance - mean_distance, 1.e-10);
}
}

// If we inflate the covariance of the distribution, the Wasserstein
// distance to the original distribution should increase.
TEST(test_stats, test_wasserstein_grows_with_covariance) {
std::default_random_engine gen(2222);

for (std::size_t iter = 0; iter < cNumIterations; ++iter) {
const Eigen::Index dimension = std::uniform_int_distribution<Eigen::Index>(
1, cDistributionDimension)(gen);
auto dist_a = random_distribution(dimension, gen);
const Eigen::SelfAdjointEigenSolver<Eigen::MatrixXd> cov_eigs(
dist_a.covariance);

auto dist_b = dist_a;
dist_b.covariance =
cov_eigs.eigenvectors() *
(cov_eigs.eigenvalues().array() * 2).matrix().asDiagonal() *
cov_eigs.eigenvectors().transpose();

auto dist_c = dist_a;
dist_c.covariance =
cov_eigs.eigenvectors() *
(cov_eigs.eigenvalues().array() * 4).matrix().asDiagonal() *
cov_eigs.eigenvectors().transpose();

const double distance_ab = distance::wasserstein_2(dist_a, dist_b);
const double distance_ac = distance::wasserstein_2(dist_a, dist_c);

EXPECT_GT(distance_ac, distance_ab);
}
}

} // namespace albatross

0 comments on commit a8d7d1a

Please sign in to comment.