Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelised optimizeALS #90

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 56 additions & 20 deletions R/liger.R
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,12 @@ removeMissingObs <- function(object, slot.use = "raw.data", use.cols = T) {
#' @param V.init Initial values to use for V matrices (default NULL)
#' @param rand.seed Random seed to allow reproducible results (default 1).
#' @param print.obj Print objective function values after convergence (default FALSE).
#' @param num.cores Number of cores to use for optimizing factorizations in parallel (default 1).
#' @param ... Arguments passed to other methods
#'
#' @return \code{liger} object with H, W, and V slots set.
#' @export
#' @import parallel
#'
#' @examples
#' \dontrun{
Expand Down Expand Up @@ -651,11 +653,13 @@ optimizeALS.list <- function(
V.init = NULL,
rand.seed = 1,
print.obj = FALSE,
num.cores = 1,
...
) {
if (!all(sapply(X = object, FUN = is.matrix))) {
stop("All values in 'object' must be a matrix")
}

E <- object
N <- length(x = E)
ns <- sapply(X = E, FUN = nrow)
Expand Down Expand Up @@ -683,7 +687,43 @@ optimizeALS.list <- function(
tmp <- gc()
best_obj <- Inf
run_stats <- matrix(data = 0, nrow = nrep, ncol = 2)
for (i in 1:nrep) {

### REMOVED THE FUNCTION FROM HERE [for (i in 1:nrep)]
if(num.cores==1){
outputs = lapply(as.list(1:nrep),FUN=liger_mainSubFunc,rand.seed,g,k,N,ns,W.init,V.init,H.init,lambda,thresh,max.iters,best_obj)
}else{
require(parallel)
cl <- makeCluster(num.cores)
outputs = mclapply(as.list(1:nrep),FUN=liger_mainSubFunc,rand.seed,g,k,N,ns,W.init,V.init,H.init,lambda,thresh,max.iters,best_obj,mc.cores=nrep)
stopCluster(cl)
}

for(i in 1:length(outputs)){
if (outputs[[i]]$obj < best_obj) {
W_m <- outputs[[i]]$W
H_m <- outputs[[i]]$H
V_m <- outputs[[i]]$V
best_obj <- outputs[[i]]$obj
best_seed <- outputs[[i]]$seed
}
}


cat("Best results with seed ", best_seed, ".\n", sep = "")
out <- list()
out$H <- H_m
for (i in 1:length(x = object)) {
rownames(x = out$H[[i]]) <- rownames(x = object[[i]])
}
out$V <- V_m
names(x = out$V) <- names(x = out$H) <- names(x = object)
out$W <- W_m

return(out)
}

liger_mainSubFunc <- function(i,rand.seed,g,k,N,ns,W.init,V.init,H.init,lambda,thresh,max.iters,best_obj){
#for (i in 1:nrep)
set.seed(seed = rand.seed + i - 1)
start_time <- Sys.time()
W <- matrix(
Expand Down Expand Up @@ -792,13 +832,7 @@ optimizeALS.list <- function(
print("Warning: failed to converge within the allowed number of iterations.
Re-running with a higher max.iters is recommended.")
}
if (obj < best_obj) {
W_m <- W
H_m <- H
V_m <- V
best_obj <- obj
best_seed <- rand.seed + i - 1
}

end_time <- difftime(time1 = Sys.time(), time2 = start_time, units = "auto")
run_stats[i, 1] <- as.double(x = end_time)
run_stats[i, 2] <- iters
Expand All @@ -815,17 +849,17 @@ optimizeALS.list <- function(
if (print.obj) {
cat("Objective:", obj, "\n")
}
}
cat("Best results with seed ", best_seed, ".\n", sep = "")
out <- list()
out$H <- H_m
for (i in 1:length(x = object)) {
rownames(x = out$H[[i]]) <- rownames(x = object[[i]])
}
out$V <- V_m
names(x = out$V) <- names(x = out$H) <- names(x = object)
out$W <- W_m
return(out)
output=list()
output$obj = obj
output$iters=iters
output$end_time=end_time
output$run_stats=run_stats[i, ]
output$W = W
output$H = H
output$V = V
output$obj = obj
output$seed = rand.seed + i - 1
return(output)
}

#' @importFrom methods slot<-
Expand All @@ -846,6 +880,7 @@ optimizeALS.liger <- function(
V.init = NULL,
rand.seed = 1,
print.obj = FALSE,
num.cores = 1,
...
) {
object <- removeMissingObs(
Expand All @@ -864,7 +899,8 @@ optimizeALS.liger <- function(
W.init = W.init,
V.init = V.init,
rand.seed = rand.seed,
print.obj = print.obj
print.obj = print.obj,
num.cores = num.cores
)
names(x = out$H) <- names(x = out$V) <- names(x = [email protected])
for (i in 1:length(x = [email protected])) {
Expand Down