Skip to content

Commit

Permalink
Merge pull request #50 from AnqiWang2021/master
Browse files Browse the repository at this point in the history
fix list name
  • Loading branch information
gaow authored Mar 17, 2024
2 parents f999c14 + 5287a73 commit 308b496
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
15 changes: 14 additions & 1 deletion R/mash.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,22 @@ create_mixture_prior = function (mixture_prior, R, null_weight = NULL,
if(!("matrices" %in% names(mixture_prior))){
stop("mixture_prior must contain 'matrices'.")
}
if (is.null(mixture_prior$weights))
if (is.null(mixture_prior$weights)) {
mixture_prior$weights = rep(1/length(mixture_prior$matrices),
length(mixture_prior$matrices))
}
if(!is.null(include_indices)) {
mixture_prior$matrices <- lapply(mixture_prior$matrices, function(x, to_keep) {
x[to_keep, to_keep]
}, include_indices)
}
#Check whether the elements of the each matrix are equal to zero and record the corresponding index
null_matrix_index <- which(sapply(mixture_prior$matrices, function(mat) all(mat == 0)))
#Remove the corresponding indices in matrices
if (length(null_matrix_index) > 0) {
mixture_prior$matrices <- mixture_prior$matrices[-null_matrix_index]
}
mixture_prior$weights <- mixture_prior$weights[-null_matrix_index]/sum(mixture_prior$weights[-null_matrix_index])
return(MashInitializer$new(NULL,NULL,xUlist = mixture_prior$matrices,
prior_weights = mixture_prior$weights,
null_weight = null_weight,
Expand Down
9 changes: 5 additions & 4 deletions R/mash_regression.R
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,9 @@ MashInitializer <- R6Class("MashInitializer",
xUlist = mashr:::expand_cov(Ulist,grid,usepointmass = TRUE)
} else {
if (!all(xUlist[[1]] == 0))
xUlist = c(list(matrix(0,nrow(xUlist[[1]]),ncol(xUlist[[1]]))),
xUlist)
xUlist = c(list(null_model = matrix(0,nrow(xUlist[[1]]),ncol(xUlist[[1]]),
dimnames = list(rownames= rownames(xUlist[[1]]), colnames = colnames(xUlist[[1]])))),
xUlist)
}
if (!is.null(include_conditions)) {
for (l in 1:length(xUlist)) {
Expand Down Expand Up @@ -349,7 +350,7 @@ MashInitializer <- R6Class("MashInitializer",
if (length(unique(u_rows)) > 1)
stop("Ulist contains matrices of different dimensions")
prior_weights = prior_weights/sum(prior_weights)
private$xU = list(pi = c(null_weight,prior_weights * (1 - null_weight)),
private$xU = list(pi = c(null_model = null_weight,prior_weights * (1 - null_weight)),
xUlist = xUlist)

return(invisible(self))
Expand Down Expand Up @@ -461,4 +462,4 @@ MashInitializer <- R6Class("MashInitializer",
xU = NULL,
inv_mats = NULL
)
)
)

0 comments on commit 308b496

Please sign in to comment.