From 837331219871c9b1ba4e6cdb8797301a419cb9be Mon Sep 17 00:00:00 2001 From: ecmerkle Date: Fri, 5 Jan 2024 11:47:27 -0600 Subject: [PATCH] two-level models, experiment with different saturated model matrices for within-only variables --- R/lav_export_stanmarg.R | 62 ++++++++++++++++++++++++++++++++++++++--- R/stanmarg_data.R | 10 +++++-- inst/stan/stanmarg.stan | 34 +++++++++++++++++----- 3 files changed, 92 insertions(+), 14 deletions(-) diff --git a/R/lav_export_stanmarg.R b/R/lav_export_stanmarg.R index 479ef9ad..75e63282 100644 --- a/R/lav_export_stanmarg.R +++ b/R/lav_export_stanmarg.R @@ -1250,7 +1250,7 @@ lav2standata <- function(lavobject) { dat$log_lik_x <- array(rep(llx / dat$ncluster_sizes, dat$ncluster_sizes), sum(dat$ncluster_sizes)) ## clusterwise data summaries, for loo and waic and etc - cidx <- lavInspect(lavobject, 'cluster.idx') + cidx <- cid <- lavInspect(lavobject, 'cluster.idx') if (inherits(cidx, "list")) { if (length(cidx) > 1) { for (g in 2:length(cidx)) { @@ -1258,13 +1258,67 @@ lav2standata <- function(lavobject) { } } cidx <- unlist(cidx) + } else { + cid <- list(cid) } mean_d_full <- rowsum.default(as.matrix(dat$YX), cidx) / dat$cluster_size + mean_d_full_sat <- mean_d_full + if (dat$N_within > 0) { + for (i in 1:dat$N_within) { + mean_d_full_sat[, dat$within_idx[i]] <- mean(as.matrix(dat$YX)[, dat$within_idx[i]]) + } + } + + ## cinv for each group, for computing saturated _rep matrices for ppp + dat$gs <- array(unlist(sapply(YLp, function(x) x[[2]]$s, simplify = FALSE))) + + ## computations for "sat" versions of S_PW and cov_b + S_PW_sat <- cov_b_sat <- vector("list", Ng) + nclus <- dat$nclus + + srow <- 1; erow <- nclus[1, 2] + for (g in 1:Ng) { + Y2c <- t( t(mean_d_full_sat[srow:erow, , drop = FALSE]) - colMeans(YX[[g]])) + Y1a <- YX[[g]] - mean_d_full_sat[cid[[g]], , drop = FALSE] + S.w <- crossprod(Y1a) / (nclus[g, 1] - nclus[g, 2]) + + csize <- Lp[[g]]$cluster.size[[2]] + S.b <- crossprod(Y2c * csize, Y2c) / (nclus[g, 2] - 1) + + if (dat$N_between > 0) { + bidx <- dat$between_idx[1:dat$N_between] + + S.w[bidx, ] <- 0 + S.w[, bidx] <- 0 + + S.b[, bidx] <- (dat$gs[g] * nclus[g, 2] / nclus[g, 1]) * S.b[, bidx, drop = FALSE] + S.b[bidx, ] <- (dat$gs[g] * nclus[g, 2] / nclus[g, 1]) * S.b[bidx, , drop = FALSE] + S.b[bidx, bidx] <- dat$gs[g] * crossprod(Y2c[, bidx, drop = FALSE]) / nclus[g, 2] + } + + Sigma.b <- (S.b - S.w)/dat$gs[g] + + if (dat$N_within > 0) { + Sigma.b[dat$within_idx, ] <- 0 + Sigma.b[, dat$within_idx] <- 0 + } + + notbidx <- dat$between_idx[(dat$N_between + 1):dat$p_tilde] + S_PW_sat[[g]] <- S.w[notbidx, notbidx, drop = FALSE] + cov_b_sat[[g]] <- Sigma.b + + srow <- srow + erow + if (g < Ng) erow <- erow + nclus[(g + 1), 2] + } + + dat$S_PW_sat <- S_PW_sat + dat$cov_b_sat <- cov_b_sat tmpYX <- split.data.frame(dat$YX, cidx) dat$YX <- do.call("rbind", tmpYX) dat$log_lik_x_full <- llx_2l(Lp[[1]], dat$YX, mean_d_full, cidx) dat$mean_d_full <- lapply(1:nrow(mean_d_full), function(i) mean_d_full[i, dat$between_idx]) + dat$mean_d_full_sat <- lapply(1:nrow(mean_d_full_sat), function(i) mean_d_full_sat[i, dat$between_idx]) ## cov_d is variability across clusters of same size, so irrelevant for clusterwise ## (just send 0s to satisfy the Stan function) @@ -1278,9 +1332,6 @@ lav2standata <- function(lavobject) { ncol(dat$xbar_b), Ng)) dat$cov_b <- aperm(cov_b, c(3, 1, 2)) - - ## cinv for each group, for computing saturated _rep matrices for ppp - dat$gs <- array(unlist(sapply(YLp, function(x) x[[2]]$s, simplify = FALSE))) } else { dat$nclus <- array(1, c(Ng, 2)) dat$cluster_size <- array(1, Ng) @@ -1300,14 +1351,17 @@ lav2standata <- function(lavobject) { dat$mean_d <- array(0, c(Ng, 0)) dat$cov_w <- array(0, c(Ng, 0, 0)) + dat$S_PW_sat <- array(0, c(Ng, 0, 0)) dat$log_lik_x <- array(0, Ng) dat$log_lik_x_full <- array(0, Ng) dat$cov_d <- array(0, c(Ng, 0, 0)) dat$mean_d_full <- array(0, c(Ng, 0)) + dat$mean_d_full_sat <- array(0, c(Ng, 0)) dat$cov_d_full <- array(0, c(Ng, 0, 0)) dat$xbar_w <- array(0, c(Ng, 0)) dat$xbar_b <- array(0, c(Ng, 0)) dat$cov_b <- array(0, c(Ng, 0, 0)) + dat$cov_b_sat <- array(0, c(Ng, 0, 0)) dat$gs <- array(1, Ng) } # multilevel diff --git a/R/stanmarg_data.R b/R/stanmarg_data.R index 41bbd603..139d232b 100644 --- a/R/stanmarg_data.R +++ b/R/stanmarg_data.R @@ -226,9 +226,10 @@ stanmarg_data <- function(YX = NULL, S = NULL, YXo = NULL, N, Ng, grpnum, # data wigind = NULL, # wiggle indicator pri_only = FALSE, # prior predictive sampling do_reg = FALSE, # regression sampling - multilev, mean_d, cov_w, log_lik_x, cov_d, nclus, cluster_size, # level 2 data - ncluster_sizes, mean_d_full, cov_d_full, log_lik_x_full, xbar_w, xbar_b, - cov_b, gs, cluster_sizes, cluster_size_ns, between_idx, N_between, + multilev, mean_d, cov_w, S_PW_sat, log_lik_x, cov_d, # level 2 data + nclus, cluster_size, ncluster_sizes, mean_d_full, mean_d_full_sat, + cov_d_full, log_lik_x_full, xbar_w, xbar_b, cov_b, cov_b_sat, gs, + cluster_sizes, cluster_size_ns, between_idx, N_between, within_idx, N_within, both_idx, N_both, ov_idx1, ov_idx2, p_tilde, N_lev, Lambda_y_skeleton_c = NULL, # level 2 matrices B_skeleton_c = NULL, Theta_skeleton_c = NULL, Theta_r_skeleton_c = NULL, @@ -336,14 +337,17 @@ stanmarg_data <- function(YX = NULL, S = NULL, YXo = NULL, N, Ng, grpnum, # data ## level 2 data dat$mean_d <- mean_d dat$cov_w <- cov_w + dat$S_PW_sat <- S_PW_sat dat$log_lik_x <- log_lik_x dat$cov_d <- cov_d dat$mean_d_full <- mean_d_full + dat$mean_d_full_sat <- mean_d_full_sat dat$cov_d_full <- cov_d_full dat$log_lik_x_full <- log_lik_x_full dat$xbar_w <- xbar_w dat$xbar_b <- xbar_b dat$cov_b <- cov_b + dat$cov_b_sat <- cov_b_sat dat$gs <- gs dat$nclus <- nclus dat$cluster_size <- cluster_size diff --git a/inst/stan/stanmarg.stan b/inst/stan/stanmarg.stan index aa723def..4c82caf0 100644 --- a/inst/stan/stanmarg.stan +++ b/inst/stan/stanmarg.stan @@ -615,9 +615,11 @@ data { array[Ng] matrix[p_tilde, p_tilde] cov_w; // observed "within" covariance matrix array[sum(nclus[,2])] vector[p_tilde] mean_d_full; // sample means/covs by cluster, for clusterwise log-densities array[sum(nclus[,2])] matrix[p_tilde, p_tilde] cov_d_full; + array[sum(nclus[,2])] vector[p_tilde] mean_d_full_sat; // sample means by cluster, with within-only tweak for saturated computation array[Ng] vector[p_tilde] xbar_w; // data estimates of within/between means/covs (for saturated logl) array[Ng] vector[p_tilde] xbar_b; array[Ng] matrix[p_tilde, p_tilde] cov_b; + array[Ng] matrix[p_tilde, p_tilde] cov_b_sat; // between covariance matrix for ppp computation array[Ng] real gs; // group size constant, for computation of saturated logl int N_within; // number of within variables int N_between; // number of between variables @@ -628,6 +630,7 @@ data { array[N_lev[1]] int ov_idx1; array[N_lev[2]] int ov_idx2; array[N_both] int both_idx; + array[Ng] matrix[N_within + N_both, N_within + N_both] S_PW_sat; // within covariance matrix for ppp computation vector[multilev ? sum(ncluster_sizes) : Ng] log_lik_x; // ll of fixed x variables by unique cluster size vector[multilev ? sum(nclus[,2]) : Ng] log_lik_x_full; // ll of fixed x variables by cluster @@ -1690,14 +1693,18 @@ generated quantities { // these matrices are saved in the output but do not figu array[Ng] real logdetS_rep_sat_grp; matrix[p + q, p + q] zmat; array[sum(nclus[,2])] vector[p_tilde] mean_d_rep; + array[sum(nclus[,2])] vector[p_tilde] mean_d_rep_sat; vector[multilev ? sum(nclus[,2]) : Ng] log_lik_x_rep; array[Ng] matrix[N_both + N_within, N_both + N_within] S_PW_rep; array[Ng] matrix[p_tilde, p_tilde] S_PW_rep_full; + array[Ng] matrix[N_both + N_within, N_both + N_within] S_PW_rep_sat; + array[Ng] matrix[p_tilde, p_tilde] S_PW_rep_full_sat; array[Ng] vector[p_tilde] ov_mean_rep; array[Ng] vector[p_tilde] xbar_b_rep; array[Ng] matrix[N_between, N_between] S2_rep; array[Ng] matrix[p_tilde, p_tilde] S_B_rep; array[Ng] matrix[p_tilde, p_tilde] cov_b_rep; + array[Ng] matrix[p_tilde, p_tilde] cov_b_rep_sat; real ppp; // first deal with sign constraints: @@ -1798,6 +1805,8 @@ generated quantities { // these matrices are saved in the output but do not figu matrix[p + q, p + q] Sigma_chol = cholesky_decompose(Sigma[gg]); S_PW_rep[gg] = rep_matrix(0, N_both + N_within, N_both + N_within); S_PW_rep_full[gg] = rep_matrix(0, p_tilde, p_tilde); + S_PW_rep_sat[gg] = rep_matrix(0, N_both + N_within, N_both + N_within); + S_PW_rep_full_sat[gg] = rep_matrix(0, p_tilde, p_tilde); S_B_rep[gg] = rep_matrix(0, p_tilde, p_tilde); ov_mean_rep[gg] = rep_vector(0, p_tilde); @@ -1823,6 +1832,7 @@ generated quantities { // these matrices are saved in the output but do not figu for (jj in 1:p_tilde) { mean_d_rep[clusidx, jj] = mean(YXstar_rep[r1:(r1 + cluster_size[clusidx] - 1), jj]); } + mean_d_rep_sat[clusidx] = mean_d_rep[clusidx]; r1 += cluster_size[clusidx]; clusidx += 1; @@ -1841,8 +1851,13 @@ generated quantities { // these matrices are saved in the output but do not figu } for (cc in 1:nclus[gg, 2]) { + if (N_within > 0) { + mean_d_rep_sat[clusidx, within_idx] = ov_mean_rep[gg, within_idx]; + } + for (ii in r1:(r1 + cluster_size[clusidx] - 1)) { S_PW_rep_full[gg] += tcrossprod(to_matrix(YXstar_rep[ii] - mean_d_rep[clusidx])); + S_PW_rep_full_sat[gg] += tcrossprod(to_matrix(YXstar_rep[ii] - mean_d_rep_sat[clusidx])); } S_B_rep[gg] += cluster_size[clusidx] * tcrossprod(to_matrix(mean_d_rep[clusidx] - ov_mean_rep[gg])); @@ -1854,6 +1869,7 @@ generated quantities { // these matrices are saved in the output but do not figu clusidx += 1; } S_PW_rep_full[gg] *= pow(nclus[gg, 1] - nclus[gg, 2], -1); + S_PW_rep_full_sat[gg] *= pow(nclus[gg, 1] - nclus[gg, 2], -1); S_B_rep[gg] *= pow(nclus[gg, 2] - 1, -1); S2_rep[gg] *= pow(nclus[gg, 2], -1); // mods to between-only variables: @@ -1877,8 +1893,10 @@ generated quantities { // these matrices are saved in the output but do not figu } cov_b_rep[gg] = pow(gs[gg], -1) * (S_B_rep[gg] - S_PW_rep_full[gg]); + cov_b_rep_sat[gg] = pow(gs[gg], -1) * (S_B_rep[gg] - S_PW_rep_full_sat[gg]); if (N_between > 0) { cov_b_rep[gg, between_idx[1:N_between], between_idx[1:N_between]] = S2_rep[gg]; + cov_b_rep_sat[gg, between_idx[1:N_between], between_idx[1:N_between]] = S2_rep[gg]; } rr1 = r1 - nclus[gg, 1]; @@ -1891,6 +1909,7 @@ generated quantities { // these matrices are saved in the output but do not figu } } S_PW_rep[gg] = S_PW_rep_full[gg, notbidx, notbidx]; + S_PW_rep_sat[gg] = S_PW_rep_full_sat[gg, notbidx, notbidx]; if (Nx[gg] > 0 || Nx_between[gg] > 0) { array[2] vector[p_tilde] mnvecs; @@ -2074,22 +2093,23 @@ generated quantities { // these matrices are saved in the output but do not figu ov_idx1, ov_idx2, within_idx, between_idx, both_idx, p_tilde, N_within, N_between, N_both); - log_lik_sat[rr1:rr2] = twolevel_logdens(mean_d_full[rr1:rr2], cov_d_full[rr1:rr2], - S_PW[grpidx], YX[r3:r4], + log_lik_sat[rr1:rr2] = twolevel_logdens(mean_d_full_sat[rr1:rr2], cov_d_full[rr1:rr2], + S_PW_sat[grpidx], YX[r3:r4], nclus[grpidx,], cluster_size[rr1:rr2], cluster_size[rr1:rr2], nclus[grpidx,2], intone[1:nclus[grpidx,2]], xbar_w[grpidx, ov_idx1], - S_PW[grpidx], xbar_b[grpidx, ov_idx2], cov_b[grpidx, ov_idx2, ov_idx2], + S_PW_sat[grpidx], xbar_b[grpidx, ov_idx2], + cov_b_sat[grpidx, ov_idx2, ov_idx2], ov_idx1, ov_idx2, within_idx, between_idx, both_idx, p_tilde, N_within, N_between, N_both); - log_lik_rep_sat[rr1:rr2] = twolevel_logdens(mean_d_rep[rr1:rr2], cov_d_full[rr1:rr2], - S_PW_rep[grpidx], YXstar_rep[r3:r4], + log_lik_rep_sat[rr1:rr2] = twolevel_logdens(mean_d_rep_sat[rr1:rr2], cov_d_full[rr1:rr2], + S_PW_rep_sat[grpidx], YXstar_rep[r3:r4], nclus[grpidx,], cluster_size[rr1:rr2], cluster_size[rr1:rr2], nclus[grpidx,2], intone[1:nclus[grpidx,2]], Mu_rep_sat[grpidx], - S_PW_rep[grpidx], xbar_b_rep[grpidx, ov_idx2], - cov_b_rep[grpidx, ov_idx2, ov_idx2], + S_PW_rep_sat[grpidx], xbar_b_rep[grpidx, ov_idx2], + cov_b_rep_sat[grpidx, ov_idx2, ov_idx2], ov_idx1, ov_idx2, within_idx, between_idx, both_idx, p_tilde, N_within, N_between, N_both);