Skip to content

Commit

Permalink
Merge pull request #1 from generable/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
jburos authored Nov 18, 2024
2 parents 87b0a58 + a533a93 commit 1011c27
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 5 deletions.
10 changes: 10 additions & 0 deletions R/fits-StanModelFit.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,16 @@ StanModelFit <- R6::R6Class("StanModelFit",
private$stan_data <- stan_data
},

save_object = function(file, ...) {
# replicate internals of `cmdstanr::CmdStanFit$save_object`
private$stan_fit$draws()
try(private$stan_fit$sampler_diagnostics(), silent = TRUE)
try(private$stan_fit$init(), silent = TRUE)
try(private$stan_fit$profiles(), silent = TRUE)
saveRDS(self, file = file, ...)
invisible(self)
},

#' @description Get the underlying 'Stan' fit object.
get_stan_fit = function() {
private$stan_fit
Expand Down
6 changes: 4 additions & 2 deletions R/main-treatment_effect.R
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ treatment_effect <- function(fit_post, fit_prior, time_var = "t",
trajectory_metrics <- function(trajectories, id_var, time_var) {
checkmate::assert_class(trajectories, "FunctionDraws")
checkmate::assert_character(id_var, len = 1)
# use first provided time_var
time_var <- time_var[[1]]
checkmate::assert_character(time_var, len = 1)
df <- trajectories$as_data_frame_long()
df %>%
Expand Down Expand Up @@ -120,12 +122,12 @@ predict_new_subjects <- function(fit_post, fit_prior,
prior_param_names = NULL) {
checkmate::assert_class(fit_post, "TSModelFit")
checkmate::assert_class(fit_prior, "TSModelFit")
t_data <- fit_post$get_data("LON")[[time_var]]
t_data <- fit_post$get_data("LON")[[time_var[[1]]]]
t_pred <- t_pred_auto(t_pred, t_data)

# Pick one subject from each group
newdat <- group_pred_input(fit_post, group_var, t_pred, time_var)

newdat <- add_sff_input(newdat, fit_post$get_model())
# Take draws
if (is.null(prior_param_names)) {
prior_param_names <- id_specific_param_names(fit_post)
Expand Down
6 changes: 4 additions & 2 deletions R/utils-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
#' @export
#' @param df original data frame
#' @param t vector of new time values
#' @param time_var name of time variable
#' @param time_var name of time variables
#' @return new data frame
extend_df <- function(df, t, time_var) {
u <- df_unique_factor_rows(df)
df <- df_replicate_rows(u, length(t))
df[[time_var]] <- rep(t, nrow(u))
for (v in time_var) {
df[[v]] <- rep(t, nrow(u))
}
df
}

Expand Down
5 changes: 4 additions & 1 deletion R/utils-stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,13 @@ stancode_ts_gq <- function(mod, stanname_y, sc_terms_gq) {
" vector[n_LON] f_sum_capped = fmin(f_sum, ", cap_val, ");"
)
def_ycap <- paste0(" vector[n_LON] ", y_var, "_log_pred_capped;")
line_ll <- paste0(
" log_lik[i] = normal_lpdf(", stanname_y, "_log[i] | ", "f_sum[i], sigma);"
)
line_yp <- paste0(
" ", sylp, "[i] = normal_rng(", "f_sum[i], sigma);"
)
loop <- paste0(" for(i in 1:n_LON) {\n", line_yp, "\n }")
loop <- paste0(" for(i in 1:n_LON) {\n", line_yp, "\n", line_ll, "\n }")
line_ycap <- paste0(
" ", y_var, "_log_pred_capped = fmin(", sylp, ", ",
cap_val, ");"
Expand Down

0 comments on commit 1011c27

Please sign in to comment.