Skip to content

Commit

Permalink
two-level models, experiment with different saturated model matrices …
Browse files Browse the repository at this point in the history
…for within-only variables
  • Loading branch information
ecmerkle committed Jan 5, 2024
1 parent 172e84c commit 8373312
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 14 deletions.
62 changes: 58 additions & 4 deletions R/lav_export_stanmarg.R
Original file line number Diff line number Diff line change
Expand Up @@ -1250,21 +1250,75 @@ lav2standata <- function(lavobject) {
dat$log_lik_x <- array(rep(llx / dat$ncluster_sizes, dat$ncluster_sizes), sum(dat$ncluster_sizes))

## clusterwise data summaries, for loo and waic and etc
cidx <- lavInspect(lavobject, 'cluster.idx')
cidx <- cid <- lavInspect(lavobject, 'cluster.idx')
if (inherits(cidx, "list")) {
if (length(cidx) > 1) {
for (g in 2:length(cidx)) {
cidx[[g]] <- cidx[[g]] + max(cidx[[(g - 1)]])
}
}
cidx <- unlist(cidx)
} else {
cid <- list(cid)
}
mean_d_full <- rowsum.default(as.matrix(dat$YX), cidx) / dat$cluster_size
mean_d_full_sat <- mean_d_full
if (dat$N_within > 0) {
for (i in 1:dat$N_within) {
mean_d_full_sat[, dat$within_idx[i]] <- mean(as.matrix(dat$YX)[, dat$within_idx[i]])
}
}

## cinv for each group, for computing saturated _rep matrices for ppp
dat$gs <- array(unlist(sapply(YLp, function(x) x[[2]]$s, simplify = FALSE)))

## computations for "sat" versions of S_PW and cov_b
S_PW_sat <- cov_b_sat <- vector("list", Ng)
nclus <- dat$nclus

srow <- 1; erow <- nclus[1, 2]
for (g in 1:Ng) {
Y2c <- t( t(mean_d_full_sat[srow:erow, , drop = FALSE]) - colMeans(YX[[g]]))
Y1a <- YX[[g]] - mean_d_full_sat[cid[[g]], , drop = FALSE]
S.w <- crossprod(Y1a) / (nclus[g, 1] - nclus[g, 2])

csize <- Lp[[g]]$cluster.size[[2]]
S.b <- crossprod(Y2c * csize, Y2c) / (nclus[g, 2] - 1)

if (dat$N_between > 0) {
bidx <- dat$between_idx[1:dat$N_between]

S.w[bidx, ] <- 0
S.w[, bidx] <- 0

S.b[, bidx] <- (dat$gs[g] * nclus[g, 2] / nclus[g, 1]) * S.b[, bidx, drop = FALSE]
S.b[bidx, ] <- (dat$gs[g] * nclus[g, 2] / nclus[g, 1]) * S.b[bidx, , drop = FALSE]
S.b[bidx, bidx] <- dat$gs[g] * crossprod(Y2c[, bidx, drop = FALSE]) / nclus[g, 2]
}

Sigma.b <- (S.b - S.w)/dat$gs[g]

if (dat$N_within > 0) {
Sigma.b[dat$within_idx, ] <- 0
Sigma.b[, dat$within_idx] <- 0
}

notbidx <- dat$between_idx[(dat$N_between + 1):dat$p_tilde]
S_PW_sat[[g]] <- S.w[notbidx, notbidx, drop = FALSE]
cov_b_sat[[g]] <- Sigma.b

srow <- srow + erow
if (g < Ng) erow <- erow + nclus[(g + 1), 2]
}

dat$S_PW_sat <- S_PW_sat
dat$cov_b_sat <- cov_b_sat

tmpYX <- split.data.frame(dat$YX, cidx)
dat$YX <- do.call("rbind", tmpYX)
dat$log_lik_x_full <- llx_2l(Lp[[1]], dat$YX, mean_d_full, cidx)
dat$mean_d_full <- lapply(1:nrow(mean_d_full), function(i) mean_d_full[i, dat$between_idx])
dat$mean_d_full_sat <- lapply(1:nrow(mean_d_full_sat), function(i) mean_d_full_sat[i, dat$between_idx])

## cov_d is variability across clusters of same size, so irrelevant for clusterwise
## (just send 0s to satisfy the Stan function)
Expand All @@ -1278,9 +1332,6 @@ lav2standata <- function(lavobject) {
ncol(dat$xbar_b),
Ng))
dat$cov_b <- aperm(cov_b, c(3, 1, 2))

## cinv for each group, for computing saturated _rep matrices for ppp
dat$gs <- array(unlist(sapply(YLp, function(x) x[[2]]$s, simplify = FALSE)))
} else {
dat$nclus <- array(1, c(Ng, 2))
dat$cluster_size <- array(1, Ng)
Expand All @@ -1300,14 +1351,17 @@ lav2standata <- function(lavobject) {

dat$mean_d <- array(0, c(Ng, 0))
dat$cov_w <- array(0, c(Ng, 0, 0))
dat$S_PW_sat <- array(0, c(Ng, 0, 0))
dat$log_lik_x <- array(0, Ng)
dat$log_lik_x_full <- array(0, Ng)
dat$cov_d <- array(0, c(Ng, 0, 0))
dat$mean_d_full <- array(0, c(Ng, 0))
dat$mean_d_full_sat <- array(0, c(Ng, 0))
dat$cov_d_full <- array(0, c(Ng, 0, 0))
dat$xbar_w <- array(0, c(Ng, 0))
dat$xbar_b <- array(0, c(Ng, 0))
dat$cov_b <- array(0, c(Ng, 0, 0))
dat$cov_b_sat <- array(0, c(Ng, 0, 0))
dat$gs <- array(1, Ng)
} # multilevel

Expand Down
10 changes: 7 additions & 3 deletions R/stanmarg_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -226,9 +226,10 @@ stanmarg_data <- function(YX = NULL, S = NULL, YXo = NULL, N, Ng, grpnum, # data
wigind = NULL, # wiggle indicator
pri_only = FALSE, # prior predictive sampling
do_reg = FALSE, # regression sampling
multilev, mean_d, cov_w, log_lik_x, cov_d, nclus, cluster_size, # level 2 data
ncluster_sizes, mean_d_full, cov_d_full, log_lik_x_full, xbar_w, xbar_b,
cov_b, gs, cluster_sizes, cluster_size_ns, between_idx, N_between,
multilev, mean_d, cov_w, S_PW_sat, log_lik_x, cov_d, # level 2 data
nclus, cluster_size, ncluster_sizes, mean_d_full, mean_d_full_sat,
cov_d_full, log_lik_x_full, xbar_w, xbar_b, cov_b, cov_b_sat, gs,
cluster_sizes, cluster_size_ns, between_idx, N_between,
within_idx, N_within, both_idx, N_both, ov_idx1, ov_idx2, p_tilde, N_lev,
Lambda_y_skeleton_c = NULL, # level 2 matrices
B_skeleton_c = NULL, Theta_skeleton_c = NULL, Theta_r_skeleton_c = NULL,
Expand Down Expand Up @@ -336,14 +337,17 @@ stanmarg_data <- function(YX = NULL, S = NULL, YXo = NULL, N, Ng, grpnum, # data
## level 2 data
dat$mean_d <- mean_d
dat$cov_w <- cov_w
dat$S_PW_sat <- S_PW_sat
dat$log_lik_x <- log_lik_x
dat$cov_d <- cov_d
dat$mean_d_full <- mean_d_full
dat$mean_d_full_sat <- mean_d_full_sat
dat$cov_d_full <- cov_d_full
dat$log_lik_x_full <- log_lik_x_full
dat$xbar_w <- xbar_w
dat$xbar_b <- xbar_b
dat$cov_b <- cov_b
dat$cov_b_sat <- cov_b_sat
dat$gs <- gs
dat$nclus <- nclus
dat$cluster_size <- cluster_size
Expand Down
34 changes: 27 additions & 7 deletions inst/stan/stanmarg.stan
Original file line number Diff line number Diff line change
Expand Up @@ -615,9 +615,11 @@ data {
array[Ng] matrix[p_tilde, p_tilde] cov_w; // observed "within" covariance matrix
array[sum(nclus[,2])] vector[p_tilde] mean_d_full; // sample means/covs by cluster, for clusterwise log-densities
array[sum(nclus[,2])] matrix[p_tilde, p_tilde] cov_d_full;
array[sum(nclus[,2])] vector[p_tilde] mean_d_full_sat; // sample means by cluster, with within-only tweak for saturated computation
array[Ng] vector[p_tilde] xbar_w; // data estimates of within/between means/covs (for saturated logl)
array[Ng] vector[p_tilde] xbar_b;
array[Ng] matrix[p_tilde, p_tilde] cov_b;
array[Ng] matrix[p_tilde, p_tilde] cov_b_sat; // between covariance matrix for ppp computation
array[Ng] real gs; // group size constant, for computation of saturated logl
int N_within; // number of within variables
int N_between; // number of between variables
Expand All @@ -628,6 +630,7 @@ data {
array[N_lev[1]] int ov_idx1;
array[N_lev[2]] int ov_idx2;
array[N_both] int both_idx;
array[Ng] matrix[N_within + N_both, N_within + N_both] S_PW_sat; // within covariance matrix for ppp computation
vector[multilev ? sum(ncluster_sizes) : Ng] log_lik_x; // ll of fixed x variables by unique cluster size
vector[multilev ? sum(nclus[,2]) : Ng] log_lik_x_full; // ll of fixed x variables by cluster

Expand Down Expand Up @@ -1690,14 +1693,18 @@ generated quantities { // these matrices are saved in the output but do not figu
array[Ng] real logdetS_rep_sat_grp;
matrix[p + q, p + q] zmat;
array[sum(nclus[,2])] vector[p_tilde] mean_d_rep;
array[sum(nclus[,2])] vector[p_tilde] mean_d_rep_sat;
vector[multilev ? sum(nclus[,2]) : Ng] log_lik_x_rep;
array[Ng] matrix[N_both + N_within, N_both + N_within] S_PW_rep;
array[Ng] matrix[p_tilde, p_tilde] S_PW_rep_full;
array[Ng] matrix[N_both + N_within, N_both + N_within] S_PW_rep_sat;
array[Ng] matrix[p_tilde, p_tilde] S_PW_rep_full_sat;
array[Ng] vector[p_tilde] ov_mean_rep;
array[Ng] vector[p_tilde] xbar_b_rep;
array[Ng] matrix[N_between, N_between] S2_rep;
array[Ng] matrix[p_tilde, p_tilde] S_B_rep;
array[Ng] matrix[p_tilde, p_tilde] cov_b_rep;
array[Ng] matrix[p_tilde, p_tilde] cov_b_rep_sat;
real<lower=0, upper=1> ppp;

// first deal with sign constraints:
Expand Down Expand Up @@ -1798,6 +1805,8 @@ generated quantities { // these matrices are saved in the output but do not figu
matrix[p + q, p + q] Sigma_chol = cholesky_decompose(Sigma[gg]);
S_PW_rep[gg] = rep_matrix(0, N_both + N_within, N_both + N_within);
S_PW_rep_full[gg] = rep_matrix(0, p_tilde, p_tilde);
S_PW_rep_sat[gg] = rep_matrix(0, N_both + N_within, N_both + N_within);
S_PW_rep_full_sat[gg] = rep_matrix(0, p_tilde, p_tilde);
S_B_rep[gg] = rep_matrix(0, p_tilde, p_tilde);
ov_mean_rep[gg] = rep_vector(0, p_tilde);

Expand All @@ -1823,6 +1832,7 @@ generated quantities { // these matrices are saved in the output but do not figu
for (jj in 1:p_tilde) {
mean_d_rep[clusidx, jj] = mean(YXstar_rep[r1:(r1 + cluster_size[clusidx] - 1), jj]);
}
mean_d_rep_sat[clusidx] = mean_d_rep[clusidx];

r1 += cluster_size[clusidx];
clusidx += 1;
Expand All @@ -1841,8 +1851,13 @@ generated quantities { // these matrices are saved in the output but do not figu
}

for (cc in 1:nclus[gg, 2]) {
if (N_within > 0) {
mean_d_rep_sat[clusidx, within_idx] = ov_mean_rep[gg, within_idx];
}

for (ii in r1:(r1 + cluster_size[clusidx] - 1)) {
S_PW_rep_full[gg] += tcrossprod(to_matrix(YXstar_rep[ii] - mean_d_rep[clusidx]));
S_PW_rep_full_sat[gg] += tcrossprod(to_matrix(YXstar_rep[ii] - mean_d_rep_sat[clusidx]));
}

S_B_rep[gg] += cluster_size[clusidx] * tcrossprod(to_matrix(mean_d_rep[clusidx] - ov_mean_rep[gg]));
Expand All @@ -1854,6 +1869,7 @@ generated quantities { // these matrices are saved in the output but do not figu
clusidx += 1;
}
S_PW_rep_full[gg] *= pow(nclus[gg, 1] - nclus[gg, 2], -1);
S_PW_rep_full_sat[gg] *= pow(nclus[gg, 1] - nclus[gg, 2], -1);
S_B_rep[gg] *= pow(nclus[gg, 2] - 1, -1);
S2_rep[gg] *= pow(nclus[gg, 2], -1);
// mods to between-only variables:
Expand All @@ -1877,8 +1893,10 @@ generated quantities { // these matrices are saved in the output but do not figu
}

cov_b_rep[gg] = pow(gs[gg], -1) * (S_B_rep[gg] - S_PW_rep_full[gg]);
cov_b_rep_sat[gg] = pow(gs[gg], -1) * (S_B_rep[gg] - S_PW_rep_full_sat[gg]);
if (N_between > 0) {
cov_b_rep[gg, between_idx[1:N_between], between_idx[1:N_between]] = S2_rep[gg];
cov_b_rep_sat[gg, between_idx[1:N_between], between_idx[1:N_between]] = S2_rep[gg];
}

rr1 = r1 - nclus[gg, 1];
Expand All @@ -1891,6 +1909,7 @@ generated quantities { // these matrices are saved in the output but do not figu
}
}
S_PW_rep[gg] = S_PW_rep_full[gg, notbidx, notbidx];
S_PW_rep_sat[gg] = S_PW_rep_full_sat[gg, notbidx, notbidx];

if (Nx[gg] > 0 || Nx_between[gg] > 0) {
array[2] vector[p_tilde] mnvecs;
Expand Down Expand Up @@ -2074,22 +2093,23 @@ generated quantities { // these matrices are saved in the output but do not figu
ov_idx1, ov_idx2, within_idx, between_idx,
both_idx, p_tilde, N_within, N_between, N_both);

log_lik_sat[rr1:rr2] = twolevel_logdens(mean_d_full[rr1:rr2], cov_d_full[rr1:rr2],
S_PW[grpidx], YX[r3:r4],
log_lik_sat[rr1:rr2] = twolevel_logdens(mean_d_full_sat[rr1:rr2], cov_d_full[rr1:rr2],
S_PW_sat[grpidx], YX[r3:r4],
nclus[grpidx,], cluster_size[rr1:rr2],
cluster_size[rr1:rr2], nclus[grpidx,2],
intone[1:nclus[grpidx,2]], xbar_w[grpidx, ov_idx1],
S_PW[grpidx], xbar_b[grpidx, ov_idx2], cov_b[grpidx, ov_idx2, ov_idx2],
S_PW_sat[grpidx], xbar_b[grpidx, ov_idx2],
cov_b_sat[grpidx, ov_idx2, ov_idx2],
ov_idx1, ov_idx2, within_idx, between_idx,
both_idx, p_tilde, N_within, N_between, N_both);

log_lik_rep_sat[rr1:rr2] = twolevel_logdens(mean_d_rep[rr1:rr2], cov_d_full[rr1:rr2],
S_PW_rep[grpidx], YXstar_rep[r3:r4],
log_lik_rep_sat[rr1:rr2] = twolevel_logdens(mean_d_rep_sat[rr1:rr2], cov_d_full[rr1:rr2],
S_PW_rep_sat[grpidx], YXstar_rep[r3:r4],
nclus[grpidx,], cluster_size[rr1:rr2],
cluster_size[rr1:rr2], nclus[grpidx,2],
intone[1:nclus[grpidx,2]], Mu_rep_sat[grpidx],
S_PW_rep[grpidx], xbar_b_rep[grpidx, ov_idx2],
cov_b_rep[grpidx, ov_idx2, ov_idx2],
S_PW_rep_sat[grpidx], xbar_b_rep[grpidx, ov_idx2],
cov_b_rep_sat[grpidx, ov_idx2, ov_idx2],
ov_idx1, ov_idx2,
within_idx, between_idx, both_idx, p_tilde,
N_within, N_between, N_both);
Expand Down

0 comments on commit 8373312

Please sign in to comment.