From f85e82c72e01734f4fbf43a6374a5b2deda24b29 Mon Sep 17 00:00:00 2001 From: nskene Date: Mon, 22 Jul 2019 12:34:16 +0100 Subject: [PATCH] Parallelised optimizeALS --- R/liger.R | 76 ++++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 56 insertions(+), 20 deletions(-) diff --git a/R/liger.R b/R/liger.R index 9095b910..c7220eb6 100644 --- a/R/liger.R +++ b/R/liger.R @@ -610,10 +610,12 @@ removeMissingObs <- function(object, slot.use = "raw.data", use.cols = T) { #' @param V.init Initial values to use for V matrices (default NULL) #' @param rand.seed Random seed to allow reproducible results (default 1). #' @param print.obj Print objective function values after convergence (default FALSE). +#' @param num.cores Number of cores to use for optimizing factorizations in parallel (default 1). #' @param ... Arguments passed to other methods #' #' @return \code{liger} object with H, W, and V slots set. #' @export +#' @import parallel #' #' @examples #' \dontrun{ @@ -651,11 +653,13 @@ optimizeALS.list <- function( V.init = NULL, rand.seed = 1, print.obj = FALSE, + num.cores = 1, ... ) { if (!all(sapply(X = object, FUN = is.matrix))) { stop("All values in 'object' must be a matrix") } + E <- object N <- length(x = E) ns <- sapply(X = E, FUN = nrow) @@ -683,7 +687,43 @@ optimizeALS.list <- function( tmp <- gc() best_obj <- Inf run_stats <- matrix(data = 0, nrow = nrep, ncol = 2) - for (i in 1:nrep) { + + ### REMOVED THE FUNCTION FROM HERE [for (i in 1:nrep)] + if(num.cores==1){ + outputs = lapply(as.list(1:nrep),FUN=liger_mainSubFunc,rand.seed,g,k,N,ns,W.init,V.init,H.init,lambda,thresh,max.iters,best_obj) + }else{ + require(parallel) + cl <- makeCluster(num.cores) + outputs = mclapply(as.list(1:nrep),FUN=liger_mainSubFunc,rand.seed,g,k,N,ns,W.init,V.init,H.init,lambda,thresh,max.iters,best_obj,mc.cores=nrep) + stopCluster(cl) + } + + for(i in 1:length(outputs)){ + if (outputs[[i]]$obj < best_obj) { + W_m <- outputs[[i]]$W + H_m <- outputs[[i]]$H + V_m <- outputs[[i]]$V + best_obj <- outputs[[i]]$obj + best_seed <- outputs[[i]]$seed + } + } + + + cat("Best results with seed ", best_seed, ".\n", sep = "") + out <- list() + out$H <- H_m + for (i in 1:length(x = object)) { + rownames(x = out$H[[i]]) <- rownames(x = object[[i]]) + } + out$V <- V_m + names(x = out$V) <- names(x = out$H) <- names(x = object) + out$W <- W_m + + return(out) +} + +liger_mainSubFunc <- function(i,rand.seed,g,k,N,ns,W.init,V.init,H.init,lambda,thresh,max.iters,best_obj){ + #for (i in 1:nrep) set.seed(seed = rand.seed + i - 1) start_time <- Sys.time() W <- matrix( @@ -792,13 +832,7 @@ optimizeALS.list <- function( print("Warning: failed to converge within the allowed number of iterations. Re-running with a higher max.iters is recommended.") } - if (obj < best_obj) { - W_m <- W - H_m <- H - V_m <- V - best_obj <- obj - best_seed <- rand.seed + i - 1 - } + end_time <- difftime(time1 = Sys.time(), time2 = start_time, units = "auto") run_stats[i, 1] <- as.double(x = end_time) run_stats[i, 2] <- iters @@ -815,17 +849,17 @@ optimizeALS.list <- function( if (print.obj) { cat("Objective:", obj, "\n") } - } - cat("Best results with seed ", best_seed, ".\n", sep = "") - out <- list() - out$H <- H_m - for (i in 1:length(x = object)) { - rownames(x = out$H[[i]]) <- rownames(x = object[[i]]) - } - out$V <- V_m - names(x = out$V) <- names(x = out$H) <- names(x = object) - out$W <- W_m - return(out) + output=list() + output$obj = obj + output$iters=iters + output$end_time=end_time + output$run_stats=run_stats[i, ] + output$W = W + output$H = H + output$V = V + output$obj = obj + output$seed = rand.seed + i - 1 + return(output) } #' @importFrom methods slot<- @@ -846,6 +880,7 @@ optimizeALS.liger <- function( V.init = NULL, rand.seed = 1, print.obj = FALSE, + num.cores = 1, ... ) { object <- removeMissingObs( @@ -864,7 +899,8 @@ optimizeALS.liger <- function( W.init = W.init, V.init = V.init, rand.seed = rand.seed, - print.obj = print.obj + print.obj = print.obj, + num.cores = num.cores ) names(x = out$H) <- names(x = out$V) <- names(x = object@raw.data) for (i in 1:length(x = object@scale.data)) {