diff --git a/include/albatross/src/core/distribution.hpp b/include/albatross/src/core/distribution.hpp index 9a921ecb..bd0a1f61 100644 --- a/include/albatross/src/core/distribution.hpp +++ b/include/albatross/src/core/distribution.hpp @@ -20,6 +20,8 @@ inline bool operator==(const albatross::DiagonalMatrixXd &x, namespace albatross { +struct MarginalDistribution; + constexpr double cDefaultApproximatelyEqualEpsilon = 1e-3; template struct DistributionBase { @@ -48,6 +50,8 @@ template struct DistributionBase { return derived().get_diagonal(i); } + MarginalDistribution operator[](std::size_t index) const; + Eigen::VectorXd mean; std::map metadata; @@ -245,6 +249,13 @@ inline void set_subset(const DistributionBase &from, to->derived().set_subset(from, indices); } +template +MarginalDistribution +DistributionBase::operator[](std::size_t index) const { + return MarginalDistribution(mean[cast::to_index(index)], + get_diagonal(cast::to_index(index))); +} + inline MarginalDistribution concatenate_marginals(const MarginalDistribution &x, const MarginalDistribution &y) { diff --git a/tests/test_core_distribution.cc b/tests/test_core_distribution.cc index 5bf4e5ee..1492c758 100644 --- a/tests/test_core_distribution.cc +++ b/tests/test_core_distribution.cc @@ -236,6 +236,19 @@ TYPED_TEST_P(DistributionTest, test_subtract) { EXPECT_EQ(actual, expected); }; +TYPED_TEST_P(DistributionTest, test_operator_indexing) { + + TypeParam test_case; + const auto dist = test_case.create(); + + for (std::size_t idx = 0; idx < dist.size(); ++idx) { + MarginalDistribution m = dist[idx]; + + EXPECT_EQ(m.mean[0], dist.mean[cast::to_index(idx)]); + EXPECT_EQ(m.get_diagonal(0), dist.get_diagonal(cast::to_index(idx))); + } +}; + REGISTER_TYPED_TEST_SUITE_P(DistributionTest, test_subset, test_multiply_with_matrix_joint, test_multiply_with_matrix_marginal, @@ -243,7 +256,7 @@ REGISTER_TYPED_TEST_SUITE_P(DistributionTest, test_subset, test_multiply_with_sparse_matrix_marginal, test_multiply_with_vector, test_multiply_by_scalar, test_equal, test_approximately_equal, test_add, - test_subtract); + test_subtract, test_operator_indexing); Eigen::VectorXd arange(Eigen::Index k = 5) { Eigen::VectorXd mean(k);