Skip to content

Commit

Permalink
increase test precision, fix split chain logic
Browse files Browse the repository at this point in the history
  • Loading branch information
mitzimorris committed Oct 22, 2024
1 parent 608217c commit bf0d581
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 10 deletions.
14 changes: 8 additions & 6 deletions src/stan/analyze/mcmc/split_chains.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ namespace analyze {
inline Eigen::MatrixXd split_chains(const std::vector<Eigen::MatrixXd>& chains,
const int index) {
size_t num_chains = chains.size();
size_t num_samples = chains[0].rows();
size_t half = std::floor(num_samples / 2.0);
size_t num_draws = chains[0].rows();
size_t half = std::floor(num_draws / 2.0);
size_t tail_start = std::floor((num_draws + 1) / 2.0);

Eigen::MatrixXd split_draws_matrix(half, num_chains * 2);
int split_i = 0;
for (std::size_t i = 0; i < num_chains; ++i) {
Eigen::Map<const Eigen::VectorXd> head_block(chains[i].col(index).data(),
half);
Eigen::Map<const Eigen::VectorXd> tail_block(
chains[i].col(index).data() + half, half);
chains[i].col(index).data() + tail_start, half);

split_draws_matrix.col(split_i) = head_block;
split_draws_matrix.col(split_i + 1) = tail_block;
Expand All @@ -47,14 +48,15 @@ inline Eigen::MatrixXd split_chains(const std::vector<Eigen::MatrixXd>& chains,
*/
inline Eigen::MatrixXd split_chains(const Eigen::MatrixXd& samples) {
size_t num_chains = samples.cols();
size_t num_samples = samples.rows();
size_t half = std::floor(num_samples / 2.0);
size_t num_draws = samples.rows();
size_t half = std::floor(num_draws / 2.0);
size_t tail_start = std::floor((num_draws + 1) / 2.0);

Eigen::MatrixXd split_draws_matrix(half, num_chains * 2);
int split_i = 0;
for (std::size_t i = 0; i < num_chains; ++i) {
Eigen::Map<const Eigen::VectorXd> head_block(samples.col(i).data(), half);
Eigen::Map<const Eigen::VectorXd> tail_block(samples.col(i).data() + half,
Eigen::Map<const Eigen::VectorXd> tail_block(samples.col(i).data() + tail_start,
half);

split_draws_matrix.col(split_i) = head_block;
Expand Down
4 changes: 2 additions & 2 deletions src/test/unit/analyze/mcmc/ess_basic_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,6 @@ TEST_F(EssBasic, test_basic_ess) {
EXPECT_NEAR(ess_lp_expect, ess_basic_lp, 1e-4);
EXPECT_NEAR(ess_theta_expect, ess_basic_theta, 1e-4);

EXPECT_NEAR(old_ess_basic_lp, ess_basic_lp, 1e-9);
EXPECT_NEAR(old_ess_basic_theta, ess_basic_theta, 1e-9);
EXPECT_NEAR(old_ess_basic_lp, ess_basic_lp, 1e-12);
EXPECT_NEAR(old_ess_basic_theta, ess_basic_theta, 1e-12);
}
4 changes: 2 additions & 2 deletions src/test/unit/analyze/mcmc/rhat_basic_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,6 @@ TEST_F(RhatBasic, test_basic_rhat) {
EXPECT_NEAR(rhat_lp_basic_expect, rhat_basic_lp, 1e-5);
EXPECT_NEAR(rhat_theta_basic_expect, rhat_basic_theta, 1e-5);

EXPECT_NEAR(old_rhat_basic_lp, rhat_basic_lp, 1e-9);
EXPECT_NEAR(old_rhat_basic_theta, rhat_basic_theta, 1e-9);
EXPECT_NEAR(old_rhat_basic_lp, rhat_basic_lp, 1e-12);
EXPECT_NEAR(old_rhat_basic_theta, rhat_basic_theta, 1e-12);
}
46 changes: 46 additions & 0 deletions src/test/unit/analyze/mcmc/split_chains_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,3 +109,49 @@ TEST_F(SplitChains, split_chains_convenience) {
}
}
}

TEST_F(SplitChains, split_draws_matrix_odd_rows) {
// When the number of total draws N is odd, the (N+1)/2th draw is ignored.
Eigen::MatrixXd foo(7, 2);
int val = 0;
for (size_t col = 0; col < 2; ++col) {
for (size_t row = 0; row < 7; ++row) {
val += 1;
foo(row, col) = val;
}
}
auto bar = stan::analyze::split_chains(foo);
EXPECT_EQ(4, bar.cols());
EXPECT_EQ(3, bar.rows());
EXPECT_EQ(bar(0,0), 1);
EXPECT_EQ(bar(1,0), 2);
EXPECT_EQ(bar(2,0), 3);
EXPECT_EQ(bar(0,1), 5);
EXPECT_EQ(bar(1,1), 6);
EXPECT_EQ(bar(2,1), 7);
EXPECT_EQ(bar(0,2), 8);
EXPECT_EQ(bar(1,2), 9);
EXPECT_EQ(bar(2,2), 10);
EXPECT_EQ(bar(0,3), 12);
EXPECT_EQ(bar(1,3), 13);
EXPECT_EQ(bar(2,3), 14);

Eigen::MatrixXd baz(4, 2);
for (size_t col = 0; col < 2; ++col) {
for (size_t row = 0; row < 4; ++row) {
val += 1;
baz(row, col) = val;
}
}
auto boz = stan::analyze::split_chains(baz);
EXPECT_EQ(4, boz.cols());
EXPECT_EQ(2, boz.rows());
EXPECT_EQ(boz(0,0), 15);
EXPECT_EQ(boz(1,0), 16);
EXPECT_EQ(boz(0,1), 17);
EXPECT_EQ(boz(1,1), 18);
EXPECT_EQ(boz(0,2), 19);
EXPECT_EQ(boz(1,2), 20);
EXPECT_EQ(boz(0,3), 21);
EXPECT_EQ(boz(1,3), 22);
}

0 comments on commit bf0d581

Please sign in to comment.