Skip to content

Commit

Permalink
update to fix bug where if PSIS time was 0 then it would not be printed
Browse files Browse the repository at this point in the history
  • Loading branch information
SteveBronder committed Feb 13, 2025
1 parent 7bd2826 commit 069d0e5
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
11 changes: 4 additions & 7 deletions src/stan/services/pathfinder/multi.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ inline int pathfinder_lbfgs_multi(
}
double pathfinders_delta_time = stan::services::util::duration_diff(
start_pathfinders_time, std::chrono::steady_clock::now());
write_times<true>(parameter_writer, pathfinders_delta_time, 0);
write_times<true, false>(parameter_writer, pathfinders_delta_time, 0);
// Writes are done in loop, so just return
return error_codes::OK;
}
Expand Down Expand Up @@ -307,15 +307,13 @@ inline int pathfinder_lbfgs_multi(
++psis_writer_position;
}
}
double psis_delta_time = stan::services::util::duration_diff(
start_psis_time, std::chrono::steady_clock::now());
write_times<false>(single_writer, pathfinders_delta_time, 0);
write_times<false, false>(single_writer, pathfinders_delta_time, 0);
}
});
safe_write.wait();
double psis_delta_time = stan::services::util::duration_diff(
start_psis_time, std::chrono::steady_clock::now());
write_times<true>(parameter_writer, pathfinders_delta_time,
write_times<true, true>(parameter_writer, pathfinders_delta_time,
psis_delta_time);
return error_codes::OK;
}
Expand Down Expand Up @@ -357,10 +355,9 @@ inline int pathfinder_lbfgs_multi(
}
});
safe_write.wait();

double psis_delta_time = stan::services::util::duration_diff(
start_psis_time, std::chrono::steady_clock::now());
write_times<true>(parameter_writer, pathfinders_delta_time, psis_delta_time);
write_times<true, true>(parameter_writer, pathfinders_delta_time, psis_delta_time);
return error_codes::OK;
}
} // namespace pathfinder
Expand Down
9 changes: 5 additions & 4 deletions src/stan/services/pathfinder/single.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,12 +532,13 @@ auto pathfinder_impl(RNG&& rng, LPFun&& lp_fun, AlphaVec&& alpha,
/**
* Write time lines for a pathfinder output file
* @tparam MultiPathfinder If true, output uses (Pathfinders) else (Pathfinder)
* @tparam PSISTime If true, output includes PSIS time
* @tparam ParamWriter Type inheriting from `stan::callbacks::writer`
* @param[in,out] parameter_writer A callback writer for messages
* @param pathfinders_delta_time Time taken for pathfinders
* @param psis_delta_time Time taken for PSIS
*/
template <bool MultiPathfinder, typename ParamWriter>
template <bool MultiPathfinder, bool PSISTime, typename ParamWriter>
inline void write_times(ParamWriter&& parameter_writer,
double pathfinders_delta_time, double psis_delta_time) {
parameter_writer();
Expand All @@ -547,7 +548,7 @@ inline void write_times(ParamWriter&& parameter_writer,
+ std::string(" seconds")
+ (MultiPathfinder ? " (Pathfinders)" : " (Pathfinder)");
parameter_writer(optim_time_str);
if (psis_delta_time != 0) {
if constexpr (PSISTime) {
std::string psis_time_str = std::string(time_header.size(), ' ')
+ std::to_string(psis_delta_time)
+ " seconds (PSIS)";
Expand Down Expand Up @@ -1000,11 +1001,11 @@ inline auto pathfinder_lbfgs_single(
"intend to change this "
"please make it clear why.");
auto&& single_stream = std::get<0>(parameter_writer.get_stream());
internal::write_times<false>(single_stream, pathfinder_delta_time, 0);
internal::write_times<false, false>(single_stream, pathfinder_delta_time, 0);
return internal::ret_pathfinder<ReturnLpSamples>(error_codes::OK,
internal::elbo_est_t{});
} else {
internal::write_times<false>(parameter_writer, pathfinder_delta_time, 0);
internal::write_times<false, false>(parameter_writer, pathfinder_delta_time, 0);
return internal::ret_pathfinder<ReturnLpSamples>(error_codes::OK,
internal::elbo_est_t{});
}
Expand Down

0 comments on commit 069d0e5

Please sign in to comment.