Skip to content

Commit

Permalink
Merge pull request stan-dev#2800 from stan-dev/feature/issue-2799-rob…
Browse files Browse the repository at this point in the history
…ust_no_u_turn

Feature/issue 2799 robust no u turn
  • Loading branch information
Bob Carpenter authored Aug 30, 2019
2 parents 530a4d4 + 7bab596 commit 6a98b12
Show file tree
Hide file tree
Showing 6 changed files with 508 additions and 187 deletions.
184 changes: 133 additions & 51 deletions src/stan/mcmc/hmc/nuts/base_nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,56 +81,86 @@ namespace stan {
this->hamiltonian_.sample_p(this->z_, this->rand_int_);
this->hamiltonian_.init(this->z_, logger);

ps_point z_plus(this->z_);
ps_point z_minus(z_plus);
ps_point z_fwd(this->z_); // State at forward end of trajectory
ps_point z_bck(z_fwd); // State at backward end of trajectory

ps_point z_sample(z_plus);
ps_point z_propose(z_plus);
ps_point z_sample(z_fwd);
ps_point z_propose(z_fwd);

Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_);
Eigen::VectorXd p_sharp_dummy = p_sharp_plus;
Eigen::VectorXd p_sharp_minus = p_sharp_plus;
Eigen::VectorXd rho = this->z_.p;
// Momentum and sharp momentum at forward end of forward subtree
Eigen::VectorXd p_fwd_fwd = this->z_.p;
Eigen::VectorXd p_sharp_fwd_fwd = this->hamiltonian_.dtau_dp(this->z_);

// Momentum and sharp momentum at backward end of forward subtree
Eigen::VectorXd p_fwd_bck = this->z_.p;
Eigen::VectorXd p_sharp_fwd_bck = p_sharp_fwd_fwd;

// Momentum and sharp momentum at forward end of backward subtree
Eigen::VectorXd p_bck_fwd = this->z_.p;
Eigen::VectorXd p_sharp_bck_fwd = p_sharp_fwd_fwd;

// Momentum and sharp momentum at backward end of backward subtree
Eigen::VectorXd p_bck_bck = this->z_.p;
Eigen::VectorXd p_sharp_bck_bck = p_sharp_fwd_fwd;

// Integrated momenta along trajectory
Eigen::VectorXd rho = this->z_.p.transpose();

// Log sum of state weights (offset by H0) along trajectory
double log_sum_weight = 0; // log(exp(H0 - H0))
double H0 = this->hamiltonian_.H(this->z_);
int n_leapfrog = 0;
double sum_metro_prob = 0;

// Build a trajectory until the NUTS criterion is no longer satisfied
// Build a trajectory until the no-u-turn
// criterion is no longer satisfied
this->depth_ = 0;
this->divergent_ = false;

while (this->depth_ < this->max_depth_) {
// Build a new subtree in a random direction
Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(rho.size());
Eigen::VectorXd rho_fwd = Eigen::VectorXd::Zero(rho.size());
Eigen::VectorXd rho_bck = Eigen::VectorXd::Zero(rho.size());

bool valid_subtree = false;
double log_sum_weight_subtree
= -std::numeric_limits<double>::infinity();

if (this->rand_uniform_() > 0.5) {
this->z_.ps_point::operator=(z_plus);
// Extend the current trajectory forward
this->z_.ps_point::operator=(z_fwd);
rho_bck = rho;
p_bck_fwd = p_fwd_fwd;
p_sharp_bck_fwd = p_sharp_fwd_fwd;

valid_subtree
= build_tree(this->depth_, z_propose,
p_sharp_dummy, p_sharp_plus, rho_subtree,
p_sharp_fwd_bck, p_sharp_fwd_fwd,
rho_fwd, p_fwd_bck, p_fwd_fwd,
H0, 1, n_leapfrog,
log_sum_weight_subtree, sum_metro_prob,
logger);
z_plus.ps_point::operator=(this->z_);
z_fwd.ps_point::operator=(this->z_);
} else {
this->z_.ps_point::operator=(z_minus);
// Extend the current trajectory backwards
this->z_.ps_point::operator=(z_bck);
rho_fwd = rho;
p_fwd_bck = p_bck_bck;
p_sharp_fwd_bck = p_sharp_bck_bck;

valid_subtree
= build_tree(this->depth_, z_propose,
p_sharp_dummy, p_sharp_minus, rho_subtree,
p_sharp_bck_fwd, p_sharp_bck_bck,
rho_bck, p_bck_fwd, p_bck_bck,
H0, -1, n_leapfrog,
log_sum_weight_subtree, sum_metro_prob,
logger);
z_minus.ps_point::operator=(this->z_);
z_bck.ps_point::operator=(this->z_);
}

if (!valid_subtree) break;

// Sample from an accepted subtree
// Sample from accepted subtree
++(this->depth_);

if (log_sum_weight_subtree > log_sum_weight) {
Expand All @@ -145,9 +175,30 @@ namespace stan {
log_sum_weight
= math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);

// Break when NUTS criterion is no longer satisfied
rho += rho_subtree;
if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho))
// Break when no-u-turn criterion is no longer satisfied
rho = rho_bck + rho_fwd;

// Demand satisfaction around merged subtrees
bool persist_criterion =
compute_criterion(p_sharp_bck_bck,
p_sharp_fwd_fwd,
rho);

// Demand satisfaction between subtrees
Eigen::VectorXd rho_extended = rho_bck + p_fwd_bck;

persist_criterion &=
compute_criterion(p_sharp_bck_bck,
p_sharp_fwd_bck,
rho_extended);

rho_extended = rho_fwd + p_bck_fwd;
persist_criterion &=
compute_criterion(p_sharp_bck_fwd,
p_sharp_fwd_fwd,
rho_extended);

if (!persist_criterion)
break;
}

Expand Down Expand Up @@ -193,9 +244,11 @@ namespace stan {
*
* @param depth Depth of the desired subtree
* @param z_propose State proposed from subtree
* @param p_sharp_left p_sharp from left boundary of returned tree
* @param p_sharp_right p_sharp from the right boundary of returned tree
* @param p_sharp_beg Sharp momentum at beginning of new tree
* @param p_sharp_end Sharp momentum at end of new tree
* @param rho Summed momentum across trajectory
* @param p_beg Momentum at beginning of returned tree
* @param p_end Momentum at end of returned tree
* @param H0 Hamiltonian of initial state
* @param sign Direction in time to built subtree
* @param n_leapfrog Summed number of leapfrog evaluations
Expand All @@ -204,9 +257,11 @@ namespace stan {
* @param logger Logger for messages
*/
bool build_tree(int depth, ps_point& z_propose,
Eigen::VectorXd& p_sharp_left,
Eigen::VectorXd& p_sharp_right,
Eigen::VectorXd& p_sharp_beg,
Eigen::VectorXd& p_sharp_end,
Eigen::VectorXd& rho,
Eigen::VectorXd& p_beg,
Eigen::VectorXd& p_end,
double H0, double sign, int& n_leapfrog,
double& log_sum_weight, double& sum_metro_prob,
callbacks::logger& logger) {
Expand All @@ -231,63 +286,90 @@ namespace stan {
sum_metro_prob += std::exp(H0 - h);

z_propose = this->z_;
rho += this->z_.p;

p_sharp_left = this->hamiltonian_.dtau_dp(this->z_);
p_sharp_right = p_sharp_left;
p_sharp_beg = this->hamiltonian_.dtau_dp(this->z_);
p_sharp_end = p_sharp_beg;

rho += this->z_.p;
p_beg = this->z_.p;
p_end = p_beg;

return !this->divergent_;
}
// General recursion
Eigen::VectorXd p_sharp_dummy(this->z_.p.size());

// Build the left subtree
double log_sum_weight_left = -std::numeric_limits<double>::infinity();
Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size());
// Build the initial subtree
double log_sum_weight_init = -std::numeric_limits<double>::infinity();

// Momentum and sharp momentum at end of the initial subtree
Eigen::VectorXd p_init_end(this->z_.p.size());
Eigen::VectorXd p_sharp_init_end(this->z_.p.size());

Eigen::VectorXd rho_init = Eigen::VectorXd::Zero(rho.size());

bool valid_left
bool valid_init
= build_tree(depth - 1, z_propose,
p_sharp_left, p_sharp_dummy, rho_left,
p_sharp_beg, p_sharp_init_end,
rho_init, p_beg, p_init_end,
H0, sign, n_leapfrog,
log_sum_weight_left, sum_metro_prob,
log_sum_weight_init, sum_metro_prob,
logger);

if (!valid_left) return false;
if (!valid_init) return false;

// Build the right subtree
ps_point z_propose_right(this->z_);
// Build the final subtree
ps_point z_propose_final(this->z_);

double log_sum_weight_right = -std::numeric_limits<double>::infinity();
Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size());
double log_sum_weight_final = -std::numeric_limits<double>::infinity();

bool valid_right
= build_tree(depth - 1, z_propose_right,
p_sharp_dummy, p_sharp_right, rho_right,
// Momentum and sharp momentum at beginning of the final subtree
Eigen::VectorXd p_final_beg(this->z_.p.size());
Eigen::VectorXd p_sharp_final_beg(this->z_.p.size());

Eigen::VectorXd rho_final = Eigen::VectorXd::Zero(rho.size());

bool valid_final
= build_tree(depth - 1, z_propose_final,
p_sharp_final_beg, p_sharp_end,
rho_final, p_final_beg, p_end,
H0, sign, n_leapfrog,
log_sum_weight_right, sum_metro_prob,
log_sum_weight_final, sum_metro_prob,
logger);

if (!valid_right) return false;
if (!valid_final) return false;

// Multinomial sample from right subtree
double log_sum_weight_subtree
= math::log_sum_exp(log_sum_weight_left, log_sum_weight_right);
= math::log_sum_exp(log_sum_weight_init, log_sum_weight_final);
log_sum_weight
= math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);

if (log_sum_weight_right > log_sum_weight_subtree) {
z_propose = z_propose_right;
if (log_sum_weight_final > log_sum_weight_subtree) {
z_propose = z_propose_final;
} else {
double accept_prob
= std::exp(log_sum_weight_right - log_sum_weight_subtree);
= std::exp(log_sum_weight_final - log_sum_weight_subtree);
if (this->rand_uniform_() < accept_prob)
z_propose = z_propose_right;
z_propose = z_propose_final;
}

Eigen::VectorXd rho_subtree = rho_left + rho_right;
Eigen::VectorXd rho_subtree = rho_init + rho_final;
rho += rho_subtree;

return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree);
// Demand satisfaction around merged subtrees
bool persist_criterion =
compute_criterion(p_sharp_beg, p_sharp_end, rho_subtree);

// Demand satisfaction between subtrees
rho_subtree = rho_init + p_final_beg;
persist_criterion &=
compute_criterion(p_sharp_beg, p_sharp_final_beg, rho_subtree);

rho_subtree = rho_final + p_init_end;
persist_criterion &=
compute_criterion(p_sharp_init_end, p_sharp_end, rho_subtree);

return persist_criterion;
}

int depth_;
Expand Down
14 changes: 7 additions & 7 deletions src/test/performance/logistic_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,31 +108,31 @@ TEST_F(performance, values_from_tagged_version) {
<< "last tagged version, 2.17.0, had " << N_values << " elements";

std::vector<double> first_run = last_draws_per_run[0];
EXPECT_FLOAT_EQ(-65.781998, first_run[0])
EXPECT_FLOAT_EQ(-65.216301, first_run[0])
<< "lp__: index 0";

EXPECT_FLOAT_EQ(1.0, first_run[1])
EXPECT_FLOAT_EQ(0.91851199, first_run[1])
<< "accept_stat__: index 1";

EXPECT_FLOAT_EQ(0.76853198, first_run[2])
EXPECT_FLOAT_EQ(0.76885802, first_run[2])
<< "stepsize__: index 2";

EXPECT_FLOAT_EQ(2, first_run[3])
<< "treedepth__: index 3";

EXPECT_FLOAT_EQ(7, first_run[4])
EXPECT_FLOAT_EQ(3, first_run[4])
<< "n_leapfrog__: index 4";

EXPECT_FLOAT_EQ(0, first_run[5])
<< "divergent__: index 5";

EXPECT_FLOAT_EQ(66.6695, first_run[6])
EXPECT_FLOAT_EQ(66.696503, first_run[6])
<< "energy__: index 6";

EXPECT_FLOAT_EQ(1.55186, first_run[7])
EXPECT_FLOAT_EQ(1.3577, first_run[7])
<< "beta.1: index 7";

EXPECT_FLOAT_EQ(-0.52400702, first_run[8])
EXPECT_FLOAT_EQ(-0.51189202, first_run[8])
<< "beta.2: index 8";

matches_tagged_version = !HasNonfatalFailure();
Expand Down
2 changes: 1 addition & 1 deletion src/test/unit/mcmc/hmc/mock_hmc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ namespace stan {

// Ensures that NUTS non-termination criterion is always true
Eigen::VectorXd dtau_dp(ps_point& z) {
return Eigen::VectorXd::Ones(this->model_.num_params_r());
return z.q;
}

Eigen::VectorXd dphi_dq(ps_point& z,
Expand Down
Loading

0 comments on commit 6a98b12

Please sign in to comment.