Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update error message for different init types #3291

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 40 additions & 7 deletions src/stan/services/util/initialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,17 +102,19 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
model.transform_inits(context, disc_vector, unconstrained, &msg);
}
} catch (std::domain_error& e) {
if (msg.str().length() > 0)
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.warn("Rejecting initial value:");
logger.warn(
" Error evaluating the log probability"
" at the initial value.");
logger.warn(e.what());
continue;
} catch (std::exception& e) {
if (msg.str().length() > 0)
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(
"Unrecoverable error evaluating the log probability"
" at the initial value.");
Expand All @@ -127,8 +129,9 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
// the parameters.
log_prob = model.template log_prob<false, Jacobian>(unconstrained,
disc_vector, &msg);
if (msg.str().length() > 0)
if (msg.str().length() > 0) {
logger.info(msg);
}
} catch (std::domain_error& e) {
if (msg.str().length() > 0)
logger.info(msg);
Expand All @@ -139,8 +142,9 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
logger.warn(e.what());
continue;
} catch (std::exception& e) {
if (msg.str().length() > 0)
if (msg.str().length() > 0) {
logger.info(msg);
}
logger.error(
"Unrecoverable error evaluating the log probability"
" at the initial value.");
Expand All @@ -165,8 +169,9 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
log_prob = stan::model::log_prob_grad<true, Jacobian>(
model, unconstrained, disc_vector, gradient, &log_prob_msg);
} catch (const std::exception& e) {
if (log_prob_msg.str().length() > 0)
if (log_prob_msg.str().length() > 0) {
logger.info(log_prob_msg);
}
logger.error(e.what());
throw;
}
Expand Down Expand Up @@ -210,8 +215,36 @@ std::vector<double> initialize(Model& model, const InitContext& init, RNG& rng,
return unconstrained;
}
}

if (!is_initialized_with_zero) {
if (is_fully_initialized) {
logger.info("");
logger.error("User-specified initialization failed.");
logger.error(
" Try specifying new initial values,"
" using partially specialized initialization,"
" reducing the range of constrained values,"
" or reparameterizing the model.");
} else if (any_initialized) {
logger.info("");
std::stringstream msg;
msg << "Partial user-specified initialization failed. "
"Initialization of non user specified parameters "
"between (-"
<< init_radius << ", " << init_radius << ") failed after"
<< " " << MAX_INIT_TRIES << " attempts. ";
logger.error(msg);
logger.error(
" Try specifying full initial values,"
" reducing the range of constrained values,"
" or reparameterizing the model.");
} else if (is_initialized_with_zero) {
logger.info("");
logger.error("Initial values of 0 failed to initialize.");
logger.error(
" Try specifying new initial values,"
" using partially specialized initialization,"
" reducing the range of constrained values,"
" or reparameterizing the model.");
} else {
logger.info("");
std::stringstream msg;
msg << "Initialization between (-" << init_radius << ", " << init_radius
Expand Down
6 changes: 6 additions & 0 deletions src/test/test-models/good/services/test_fail.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
parameters {
array[2] real<lower=-10, upper=10> y;
}
model {
reject("");
}
17 changes: 17 additions & 0 deletions src/test/unit/services/instrumented_callbacks.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,23 @@ class instrumented_logger : public stan::callbacks::logger {
return count;
}

public:
std::vector<std::string> return_all_logs() {
std::vector<std::string> all_logs;
all_logs.reserve(debug_.size() + info_.size() + warn_.size() + error_.size()
+ fatal_.size() + 5);
all_logs.emplace_back("DEBUG");
all_logs.insert(all_logs.end(), debug_.begin(), debug_.end());
all_logs.emplace_back("INFO");
all_logs.insert(all_logs.end(), info_.begin(), info_.end());
all_logs.emplace_back("WARN");
all_logs.insert(all_logs.end(), warn_.begin(), warn_.end());
all_logs.emplace_back("ERROR");
all_logs.insert(all_logs.end(), error_.begin(), error_.end());
all_logs.emplace_back("FATAL");
all_logs.insert(all_logs.end(), fatal_.begin(), fatal_.end());
return all_logs;
}
std::vector<std::string> debug_;
std::vector<std::string> info_;
std::vector<std::string> warn_;
Expand Down
58 changes: 58 additions & 0 deletions src/test/unit/services/util/fail_init_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include <stan/services/util/initialize.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/io/empty_var_context.hpp>
#include <stan/io/array_var_context.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/callbacks/stream_writer.hpp>
#include <stan/callbacks/stream_logger.hpp>
#include <test/test-models/good/services/test_fail.hpp>
#include <test/unit/util.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <gtest/gtest.h>
#include <sstream>

class ServicesUtilInitialize : public testing::Test {
public:
ServicesUtilInitialize()
: model(empty_context, 12345, &model_ss),
message(message_ss),
rng(stan::services::util::create_rng(0, 1)) {}

stan_model model;
stan::io::empty_var_context empty_context;
std::stringstream model_ss;
std::stringstream message_ss;
stan::callbacks::stream_writer message;
stan::test::unit::instrumented_logger logger;
stan::test::unit::instrumented_writer init;
stan::rng_t rng;
};

TEST_F(ServicesUtilInitialize, model_throws__full_init) {
std::vector<std::string> names_r;
std::vector<double> values_r;
std::vector<std::vector<size_t> > dim_r;
names_r.push_back("y");
values_r.push_back(6.35149); // 1.5 unconstrained: -10 + 20 * inv.logit(1.5)
values_r.push_back(-2.449187); // -0.5 unconstrained
std::vector<size_t> d;
d.push_back(2);
dim_r.push_back(d);
stan::io::array_var_context init_context(names_r, values_r, dim_r);

double init_radius = 2;
bool print_timing = false;
EXPECT_THROW(
stan::services::util::initialize(model, init_context, rng, init_radius,
print_timing, logger, init),
std::domain_error);
/* Uncomment to print all logs
auto logs = logger.return_all_logs();
for (auto&& m : logs) {
std::cout << m << std::endl;
}
*/
EXPECT_EQ(6, logger.call_count());
EXPECT_EQ(3, logger.call_count_warn());
EXPECT_EQ(0, logger.find_warn("throwing within log_prob"));
}
21 changes: 10 additions & 11 deletions src/test/unit/services/util/initialize_test.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
#include <stan/services/util/initialize.hpp>
#include <stan/services/util/create_rng.hpp>
#include <gtest/gtest.h>
#include <test/unit/util.hpp>
#include <stan/callbacks/stream_writer.hpp>
#include <stan/callbacks/stream_logger.hpp>
#include <sstream>
#include <test/test-models/good/services/test_lp.hpp>
#include <stan/io/empty_var_context.hpp>
#include <stan/io/array_var_context.hpp>
#include <stan/services/util/create_rng.hpp>
#include <stan/callbacks/stream_writer.hpp>
#include <stan/callbacks/stream_logger.hpp>
#include <test/test-models/good/services/test_lp.hpp>
#include <test/unit/util.hpp>
#include <test/unit/services/instrumented_callbacks.hpp>
#include <gtest/gtest.h>
#include <sstream>

class ServicesUtilInitialize : public testing::Test {
public:
Expand All @@ -28,7 +28,7 @@ class ServicesUtilInitialize : public testing::Test {
stan::rng_t rng;
};

TEST_F(ServicesUtilInitialize, radius_zero__print_false) {
TEST_F(ServicesUtilInitialize, radius_zero_print_false) {
std::vector<double> params;

double init_radius = 0;
Expand Down Expand Up @@ -250,7 +250,7 @@ class mock_throwing_model : public stan::model::prob_grad {

} // namespace test

TEST_F(ServicesUtilInitialize, model_throws__radius_zero) {
TEST_F(ServicesUtilInitialize, model_throws_radius_zero) {
test::mock_throwing_model throwing_model;

double init_radius = 0;
Expand All @@ -259,8 +259,7 @@ TEST_F(ServicesUtilInitialize, model_throws__radius_zero) {
stan::services::util::initialize(throwing_model, empty_context, rng,
init_radius, print_timing, logger, init),
std::domain_error);

EXPECT_EQ(3, logger.call_count());
EXPECT_EQ(6, logger.call_count());
EXPECT_EQ(3, logger.call_count_warn());
EXPECT_EQ(1, logger.find_warn("throwing within log_prob"));
}
Expand Down Expand Up @@ -533,7 +532,7 @@ TEST_F(ServicesUtilInitialize, model_throws_in_write_array__radius_zero) {
init_radius, print_timing, logger, init),
std::domain_error);

EXPECT_EQ(3, logger.call_count());
EXPECT_EQ(6, logger.call_count());
EXPECT_EQ(3, logger.call_count_warn());
EXPECT_EQ(1, logger.find_warn("throwing within write_array"));
}
Expand Down
Loading