Skip to content

Commit

Permalink
Merge pull request #3303 from stan-dev/bugfix/3071-fix-unit-test
Browse files Browse the repository at this point in the history
Fix and add unit tests for issue 3071 - adaptive sampler runs on models w/ zero params.
  • Loading branch information
mitzimorris authored Jul 26, 2024
2 parents cd8b2e0 + e546442 commit 2f8ae28
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 2 deletions.
7 changes: 7 additions & 0 deletions src/test/test-models/good/services/zero_params.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
transformed data {
int N = 2;
}
generated quantities {
real theta = beta_rng(1, 1);
real eta = beta_rng(10, 10);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
#include <stan/services/sample/hmc_nuts_dense_e.hpp>
#include <gtest/gtest.h>
#include <stan/io/empty_var_context.hpp>
#include <test/test-models/good/services/zero_params.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <iostream>

auto&& blah = stan::math::init_threadpool_tbb();

static constexpr size_t num_chains = 4;
class ServicesSampleHMCNutsDenseENoParams : public testing::Test {
public:
ServicesSampleHMCNutsDenseENoParams() : model(data_context, 0, &model_log) {
for (int i = 0; i < num_chains; ++i) {
init.push_back(stan::test::unit::instrumented_writer{});
parameter.push_back(stan::test::unit::instrumented_writer{});
diagnostic.push_back(stan::test::unit::instrumented_writer{});
context.push_back(std::make_shared<stan::io::empty_var_context>());
}
}
stan::io::empty_var_context data_context;
std::stringstream model_log;
stan::test::unit::instrumented_logger logger;
std::vector<stan::test::unit::instrumented_writer> init;
std::vector<stan::test::unit::instrumented_writer> parameter;
std::vector<stan::test::unit::instrumented_writer> diagnostic;
std::vector<std::shared_ptr<stan::io::empty_var_context>> context;
stan_model model;
};

TEST_F(ServicesSampleHMCNutsDenseENoParams, call_count) {
unsigned int random_seed = 0;
unsigned int chain = 1;
double init_radius = 0;
int num_warmup = 200;
int num_samples = 400;
int num_thin = 5;
bool save_warmup = true;
int refresh = 0;
double stepsize = 0.1;
double stepsize_jitter = 0;
int max_depth = 8;
double delta = .1;
double gamma = .1;
double kappa = .1;
double t0 = .1;
unsigned int init_buffer = 50;
unsigned int term_buffer = 50;
unsigned int window = 100;
stan::test::unit::instrumented_interrupt interrupt;
EXPECT_EQ(interrupt.call_count(), 0);

int return_code = stan::services::sample::hmc_nuts_dense_e(
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
max_depth, interrupt, logger, init, parameter, diagnostic);

EXPECT_EQ(0, return_code);

int num_output_lines = (num_warmup + num_samples) / num_thin;
EXPECT_EQ((num_warmup + num_samples) * num_chains, interrupt.call_count());
for (int i = 0; i < num_chains; ++i) {
EXPECT_EQ(1, parameter[i].call_count("vector_string"));
EXPECT_EQ(num_output_lines, parameter[i].call_count("vector_double"));
EXPECT_EQ(1, diagnostic[i].call_count("vector_string"));
EXPECT_EQ(num_output_lines, diagnostic[i].call_count("vector_double"));
}
}

TEST_F(ServicesSampleHMCNutsDenseENoParams, parameter_checks) {
unsigned int random_seed = 0;
unsigned int chain = 1;
double init_radius = 0;
int num_warmup = 200;
int num_samples = 400;
int num_thin = 5;
bool save_warmup = true;
int refresh = 0;
double stepsize = 0.1;
double stepsize_jitter = 0;
int max_depth = 8;
double delta = .1;
double gamma = .1;
double kappa = .1;
double t0 = .1;
unsigned int init_buffer = 50;
unsigned int term_buffer = 50;
unsigned int window = 100;
stan::test::unit::instrumented_interrupt interrupt;
EXPECT_EQ(interrupt.call_count(), 0);

int return_code = stan::services::sample::hmc_nuts_dense_e(
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
max_depth, interrupt, logger, init, parameter, diagnostic);

for (size_t i = 0; i < num_chains; ++i) {
std::vector<std::vector<std::string>> parameter_names;
parameter_names = parameter[i].vector_string_values();
std::vector<std::vector<double>> parameter_values;
parameter_values = parameter[i].vector_double_values();
std::vector<std::string> strings;
strings = parameter[i].string_values();
std::vector<std::vector<std::string>> diagnostic_names;
diagnostic_names = diagnostic[i].vector_string_values();
std::vector<std::vector<double>> diagnostic_values;
diagnostic_values = diagnostic[i].vector_double_values();

// Expect message at end of warmup
EXPECT_EQ("Adaptation terminated", strings[0]);

// Expectations of sampler and model variables names.
ASSERT_EQ(9, parameter_names[0].size());
EXPECT_EQ("lp__", parameter_names[0][0]);
EXPECT_EQ("accept_stat__", parameter_names[0][1]);
EXPECT_EQ("stepsize__", parameter_names[0][2]);
EXPECT_EQ("treedepth__", parameter_names[0][3]);
EXPECT_EQ("n_leapfrog__", parameter_names[0][4]);
EXPECT_EQ("divergent__", parameter_names[0][5]);
EXPECT_EQ("energy__", parameter_names[0][6]);
EXPECT_EQ("theta", parameter_names[0][7]);
EXPECT_EQ("eta", parameter_names[0][8]);

// Expect one name per parameter value.
EXPECT_EQ(parameter_names[0].size(), parameter_values[0].size());
EXPECT_EQ(diagnostic_names[0].size(), diagnostic_values[0].size());

EXPECT_EQ((num_warmup + num_samples) / num_thin, parameter_values.size());

// Expect one call to set parameter names, and one set of output per
// iteration.
EXPECT_EQ("lp__", diagnostic_names[0][0]);
EXPECT_EQ("accept_stat__", diagnostic_names[0][1]);
}
EXPECT_EQ(return_code, 0);
}

TEST_F(ServicesSampleHMCNutsDenseENoParams, output_regression) {
unsigned int random_seed = 0;
unsigned int chain = 1;
double init_radius = 0;
int num_warmup = 200;
int num_samples = 400;
int num_thin = 5;
bool save_warmup = true;
int refresh = 0;
double stepsize = 0.1;
double stepsize_jitter = 0;
int max_depth = 8;
double delta = .1;
double gamma = .1;
double kappa = .1;
double t0 = .1;
unsigned int init_buffer = 50;
unsigned int term_buffer = 50;
unsigned int window = 100;
stan::test::unit::instrumented_interrupt interrupt;
EXPECT_EQ(interrupt.call_count(), 0);

stan::services::sample::hmc_nuts_dense_e(
model, num_chains, context, random_seed, chain, init_radius, num_warmup,
num_samples, num_thin, save_warmup, refresh, stepsize, stepsize_jitter,
max_depth, interrupt, logger, init, parameter, diagnostic);

for (auto&& init_it : init) {
std::vector<std::string> init_values;
init_values = init_it.string_values();

EXPECT_EQ(0, init_values.size());
}

EXPECT_EQ(num_chains, logger.find_info("Elapsed Time:"));
EXPECT_EQ(num_chains, logger.find_info("seconds (Warm-up)"));
EXPECT_EQ(num_chains, logger.find_info("seconds (Sampling)"));
EXPECT_EQ(num_chains, logger.find_info("seconds (Total)"));
EXPECT_EQ(0, logger.call_count_error());
}
11 changes: 9 additions & 2 deletions src/test/unit/services/util/inv_metric_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ TEST(inv_metric, create_diag_sz100) {
ASSERT_NEAR(1.0, diag_vals[99], 0.0001);
}

TEST(inv_metric, create_dense_sz0) {
auto default_metric = stan::services::util::create_unit_e_dense_inv_metric(0);
stan::io::var_context& inv_inv_metric = default_metric;
std::vector<double> diag_vals = inv_inv_metric.vals_r("inv_metric");
EXPECT_EQ(0, diag_vals.size());
}

TEST(inv_metric, create_dense_sz2) {
auto default_metric = stan::services::util::create_unit_e_dense_inv_metric(2);
stan::io::var_context& inv_inv_metric = default_metric;
Expand Down Expand Up @@ -122,9 +129,9 @@ TEST(inv_metric, read_dense_OK) {

TEST(inv_metric, read_dense_sz0) {
stan::callbacks::logger logger;
stan::io::dump dmp = stan::services::util::create_unit_e_dense_inv_metric(0);
auto zero_metric = stan::services::util::create_unit_e_dense_inv_metric(0);
Eigen::MatrixXd inv_inv_metric
= stan::services::util::read_dense_inv_metric(dmp, 0, logger);
= stan::services::util::read_dense_inv_metric(zero_metric, 0, logger);
EXPECT_EQ(0, inv_inv_metric.size());
EXPECT_EQ(0, inv_inv_metric.rows());
EXPECT_EQ(0, inv_inv_metric.cols());
Expand Down

0 comments on commit 2f8ae28

Please sign in to comment.