From 0e56cd7d159694d5308fa76939c2dbc4fc56606c Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Mon, 6 Feb 2023 08:38:46 -0800 Subject: [PATCH 01/21] Pass optimization configuration options to postTreatment, so that it can be used e.g. in bootstrap/jackknife --- R/PLN.R | 2 +- R/PLNLDA.R | 2 +- R/PLNLDAfit-class.R | 4 +-- R/PLNPCA.R | 2 +- R/PLNPCAfit-class.R | 4 +-- R/PLNfamily-class.R | 11 +++++---- R/PLNfit-class.R | 37 +++++++++++++++------------- R/PLNmixture.R | 2 +- R/PLNmixturefit-class.R | 5 ++-- R/PLNnetwork.R | 14 ++++++++--- tests/testthat/test-standard-error.R | 2 +- 11 files changed, 49 insertions(+), 36 deletions(-) diff --git a/R/PLN.R b/R/PLN.R index df8e6fd5..9a08bcec 100644 --- a/R/PLN.R +++ b/R/PLN.R @@ -45,7 +45,7 @@ PLN <- function(formula, data, subset, weights, control = PLN_param()) { ## post-treatment if (control$trace > 0) cat("\n Post-treatments...") - myPLN$postTreatment(args$Y, args$X, args$O, args$w, control$config_post) + myPLN$postTreatment(args$Y, args$X, args$O, args$w, control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPLN diff --git a/R/PLNLDA.R b/R/PLNLDA.R index 51c8bd63..f40d7b0e 100644 --- a/R/PLNLDA.R +++ b/R/PLNLDA.R @@ -64,7 +64,7 @@ PLNLDA <- function(formula, data, subset, weights, grouping, control = PLN_param myLDA$optimize(grouping, args$Y, args$X, args$O, args$w, control$config_optim) ## Post-treatment: prepare LDA visualization - myLDA$postTreatment(grouping, args$Y, args$X, args$O, control$config_post) + myLDA$postTreatment(grouping, args$Y, args$X, args$O, control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myLDA diff --git a/R/PLNLDAfit-class.R b/R/PLNLDAfit-class.R index 540120bc..dc28f34e 100644 --- a/R/PLNLDAfit-class.R +++ b/R/PLNLDAfit-class.R @@ -85,9 +85,9 @@ PLNLDAfit <- R6Class( ## Post treatment -------------------- #' @description Update R2, fisher and std_err fields and visualization #' @param config list controlling the post-treatment - postTreatment = function(grouping, responses, covariates, offsets, config) { + postTreatment = function(grouping, responses, covariates, offsets, config_post, config_optim) { covariates <- cbind(covariates, model.matrix( ~ grouping + 0)) - super$postTreatment(responses, covariates, offsets, config = config) + super$postTreatment(responses, covariates, offsets, config_post = config_post, config_optim = config_optim) rownames(private$C) <- colnames(private$C) <- colnames(responses) colnames(private$S) <- 1:self$q if (config$trace > 1) cat("\n\tCompute LD scores for visualization...") diff --git a/R/PLNPCA.R b/R/PLNPCA.R index ca72e145..2d2b128e 100644 --- a/R/PLNPCA.R +++ b/R/PLNPCA.R @@ -52,7 +52,7 @@ PLNPCA <- function(formula, data, subset, weights, ranks = 1:5, control = PLNPCA ## Post-treatments: pseudo-R2, rearrange criteria and prepare PCA visualization if (control$trace > 0) cat("\n Post-treatments") config_post <- config_post_default_PLNPCA; config_post$trace <- control$trace - myPCA$postTreatment(config_post) + myPCA$postTreatment(config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPCA diff --git a/R/PLNPCAfit-class.R b/R/PLNPCAfit-class.R index dffb550f..a6ea44ac 100644 --- a/R/PLNPCAfit-class.R +++ b/R/PLNPCAfit-class.R @@ -169,8 +169,8 @@ PLNPCAfit <- R6Class( #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. #' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE #' * trace integer for verbosity. should be > 1 to see output in post-treatments - postTreatment = function(responses, covariates, offsets, weights, config, nullModel) { - super$postTreatment(responses, covariates, offsets, weights, config, nullModel) + postTreatment = function(responses, covariates, offsets, weights, config_post, config_optim, nullModel) { + super$postTreatment(responses, covariates, offsets, weights, config_post, config_optim, nullModel) colnames(private$C) <- colnames(private$M) <- 1:self$q rownames(private$C) <- colnames(responses) self$setVisualization() diff --git a/R/PLNfamily-class.R b/R/PLNfamily-class.R index 4c686855..fda35a9c 100644 --- a/R/PLNfamily-class.R +++ b/R/PLNfamily-class.R @@ -64,17 +64,18 @@ PLNfamily <- ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## Post treatment -------------------- #' @description Update fields after optimization - #' @param config a list for controlling the post-treatment. - postTreatment = function(config) { - nullModel <- nullModelPoisson(self$responses, self$covariates, self$offsets, self$weights) + #' @param config_post a list for controlling the post-treatment. + postTreatment = function(config_post, config_optim) { + #nullModel <- nullModelPoisson(self$responses, self$covariates, self$offsets, self$weights) for (model in self$models) model$postTreatment( self$responses, self$covariates, self$offsets, self$weights, - config, - nullModel = nullModel + config_post=config_post, + config_optim=config_optim, + nullModel = NULL ) }, diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index b65e2b1a..6b76d7f7 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -168,7 +168,7 @@ PLNfit <- R6Class( ## PRIVATE METHODS FOR VARIANCE OF THE ESTIMATORS ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - variance_variational = function(X) { + variance_variational = function(X, config = config_default_nlopt) { ## Variance of B for n data points fisher <- Matrix::bdiag(lapply(1:self$p, function(j) { crossprod(X, private$A[, j] * X) # t(X) %*% diag(A[, i]) %*% X @@ -375,7 +375,7 @@ PLNfit <- R6Class( #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. #' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE #' * trace integer for verbosity. should be > 1 to see output in post-treatments - postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config, nullModel = NULL) { + postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config_post, config_optim, nullModel = NULL) { ## PARAMATERS DIMNAMES ## Set names according to those of the data matrices. If missing, use sensible defaults if (is.null(colnames(responses))) @@ -392,24 +392,27 @@ PLNfit <- R6Class( ## OPTIONAL POST-TREATMENT (potentially costly) ## 1. compute and store approximated R2 with Poisson-based deviance - if (config$rsquared) { - if(config$trace > 1) cat("\n\tComputing bootstrap estimator of the variance...") + if (config_post$rsquared) { + if(config_post$trace > 1) cat("\n\tComputing approximate R^2...") private$approx_r2(responses, covariates, offsets, weights, nullModel) } ## 2. compute and store matrix of standard variances for B and Omega with rough variational approximation - if (config$variational_var) { - if(config$trace > 1) cat("\n\tComputing variational estimator of the variance...") - private$variance_variational(covariates) + if (config_post$variational_var) { + if(config_post$trace > 1) cat("\n\tComputing variational estimator of the variance...") + private$variance_variational(covariates, config = config_optim) } ## 3. Jackknife estimation of bias and variance - if (config$jackknife) { - if(config$trace > 1) cat("\n\tComputing jackknife estimator of the variance...") - private$variance_jackknife(responses, covariates, offsets, weights) + if (config_post$jackknife) { + if(config_post$trace > 1) cat("\n\tComputing jackknife estimator of the variance...") + private$variance_jackknife(responses, covariates, offsets, weights, config = config_optim) } ## 4. Bootstrap estimation of variance - if (config$bootstrap > 0) { - if(config$trace > 1) cat("\n\tComputing bootstrap estimator of the variance...") - private$variance_bootstrap(responses, covariates, offsets, weights, config$bootstrap) + if (config_post$bootstrap > 0) { + if(config_post$trace > 1) { + cat("\n\tComputing bootstrap estimator of the variance...") + print (str(config_optim)) + } + private$variance_bootstrap(responses, covariates, offsets, weights, n_resamples=config_post$bootstrap, config = config_optim) } }, @@ -804,11 +807,11 @@ PLNfit_fixedcov <- R6Class( #' * bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated). #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. #' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE - postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config, nullModel = NULL) { - super$postTreatment(responses, covariates, offsets, weights, config, nullModel) + postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config_post, config_optim, nullModel = NULL) { + super$postTreatment(responses, covariates, offsets, weights, config_post, config_optim, nullModel) ## 6. compute and store matrix of standard variances for B with sandwich correction approximation - if (config$sandwich_var) { - if(config$trace > 1) cat("\n\tComputing sandwich estimator of the variance...") + if (config_post$sandwich_var) { + if(config_post$trace > 1) cat("\n\tComputing sandwich estimator of the variance...") private$vcov_sandwich_B(responses, covariates) } } diff --git a/R/PLNmixture.R b/R/PLNmixture.R index dea6e9af..4e0252bb 100644 --- a/R/PLNmixture.R +++ b/R/PLNmixture.R @@ -60,7 +60,7 @@ PLNmixture <- function(formula, data, subset, clusters = 1:5, control = PLNmixt ## Post-treatments: Compute pseudo-R2, rearrange criteria and the visualization for PCA if (control$trace > 0) cat("\n Post-treatments") config_post <- config_post_default_PLNmixture; config_post$trace <- control$trace - myPLN$postTreatment(config_post) + myPLN$postTreatment(config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPLN diff --git a/R/PLNmixturefit-class.R b/R/PLNmixturefit-class.R index 23363380..12eca6a9 100644 --- a/R/PLNmixturefit-class.R +++ b/R/PLNmixturefit-class.R @@ -281,7 +281,7 @@ PLNmixturefit <- ## Post treatment -------------------- #' @description Update fields after optimization #' @param config a list for controlling the post-treatment - postTreatment = function(responses, covariates, offsets, weights, config, nullModel) { + postTreatment = function(responses, covariates, offsets, weights, config_post, config_optim, nullModel) { ## restoring the full design matrix (group means + covariates) mu_k <- matrix(1, self$n, ncol = 1); colnames(mu_k) <- 'Intercept' @@ -292,7 +292,8 @@ PLNmixturefit <- mu_k, offsets, private$tau[,k_], - config, + config_post, + config_optim, nullModel = nullModel ) }, diff --git a/R/PLNnetwork.R b/R/PLNnetwork.R index 32be12b7..93924f46 100644 --- a/R/PLNnetwork.R +++ b/R/PLNnetwork.R @@ -41,8 +41,9 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control ## Post-treatments if (control$trace > 0) cat("\n Post-treatments") - config_post <- config_post_default_PLNnetwork; config_post$trace <- control$trace - myPLN$postTreatment(config_post) + #config_post <- config_post_default_PLNnetwork; + #config_post$trace <- control$trace + myPLN$postTreatment(control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPLN @@ -85,18 +86,24 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control #' #' @export PLNnetwork_param <- function( - backend = "nlopt", + backend = c("nlopt", "torch"), trace = 1 , n_penalties = 30 , min_ratio = 0.1 , penalize_diagonal = TRUE , penalty_weights = NULL , + config_post = list(), config_optim = list(), inception = NULL ) { if (!is.null(inception)) stopifnot(isPLNfit(inception)) + ## post-treatment config + config_pst <- config_post_default_PLN + config_pst[names(config_post)] <- config_post + config_pst$trace <- trace + ## optimization config backend <- match.arg(backend) stopifnot(backend %in% c("nlopt", "torch")) @@ -123,6 +130,7 @@ PLNnetwork_param <- function( jackknife = FALSE , bootstrap = 0 , variance = TRUE , + config_post = config_pst , config_optim = config_opt , inception = inception ), class = "PLNmodels_param") } diff --git a/tests/testthat/test-standard-error.R b/tests/testthat/test-standard-error.R index 9bd20803..43768350 100644 --- a/tests/testthat/test-standard-error.R +++ b/tests/testthat/test-standard-error.R @@ -95,7 +95,7 @@ test_that("Check that variance estimation are coherent in PLNfit", { trace = 2 ) - myPLN$postTreatment(Y, X, exp(log_O), config = config_post) + myPLN$postTreatment(Y, X, exp(log_O), config_post = config_post) tr_variational <- sum(standard_error(myPLN, "variational")^2) tr_bootstrap <- sum(standard_error(myPLN, "bootstrap")^2) From 4dd745106e5789c95e42267987b7d4ff7f5a4379 Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Mon, 6 Feb 2023 08:40:02 -0800 Subject: [PATCH 02/21] Various improvements for torch optimizer needed to run it on the GPU --- R/PLN.R | 2 +- R/PLNfit-class.R | 68 ++++++++++++++++++++++++++++++++++-------------- R/utils.R | 8 +++++- 3 files changed, 57 insertions(+), 21 deletions(-) diff --git a/R/PLN.R b/R/PLN.R index 9a08bcec..d4148f83 100644 --- a/R/PLN.R +++ b/R/PLN.R @@ -95,7 +95,7 @@ PLN_param <- function( Omega = NULL, config_post = list(), config_optim = list(), - inception = NULL # pretrained PLNfit used as initialization + inception = NULL # pretrained PLNfit used as initialization, ) { covariance <- match.arg(covariance) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 6b76d7f7..2fb320e0 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -80,21 +80,27 @@ PLNfit <- R6Class( torch_vloglik = function(data, params) { S2 <- torch_square(params$S) - Ji <- .5 * self$p - rowSums(.logfactorial(as.matrix(data$Y))) + as.numeric( - .5 * torch_logdet(params$Omega) + - torch_sum(data$Y * params$Z - params$A + .5 * torch_log(S2), dim = 2) - - .5 * torch_sum(torch_mm(params$M, params$Omega) * params$M + S2 * torch_diag(params$Omega), dim = 2) - ) - attr(Ji, "weights") <- as.numeric(data$w) + + Ji_tmp = .5 * torch_logdet(params$Omega) + + torch_sum(data$Y * params$Z - params$A + .5 * torch_log(S2), dim = 2) - + .5 * torch_sum(torch_mm(params$M, params$Omega) * params$M + S2 * torch_diag(params$Omega), dim = 2) + Ji_tmp = Ji_tmp$cpu() + Ji_tmp = as.numeric(Ji_tmp) + Ji <- .5 * self$p - rowSums(.logfactorial(as.matrix(data$Y$cpu()))) + Ji_tmp + + attr(Ji, "weights") <- as.numeric(data$w$cpu()) Ji }, #' @import torch torch_optimize = function(data, params, config) { + #config$device = "mps" + if (config$trace > 1) + message (paste("optimizing with device: ", config$device)) ## Conversion of data and parameters to torch tensors (pointers) - data <- lapply(data, torch_tensor) # list with Y, X, O, w - params <- lapply(params, torch_tensor, requires_grad = TRUE) # list with B, M, S + data <- lapply(data, torch_tensor, dtype = torch_float32(), device = config$device) # list with Y, X, O, w + params <- lapply(params, torch_tensor, dtype = torch_float32(), requires_grad = TRUE, device = config$device) # list with B, M, S ## Initialize optimizer optimizer <- switch(config$algorithm, @@ -111,11 +117,14 @@ PLNfit <- R6Class( batch_size <- floor(self$n/num_batch) objective <- double(length = config$num_epoch + 1) + #B_old = optimizer$param_groups[[1]]$params$B$clone() for (iterate in 1:num_epoch) { - B_old <- as.numeric(optimizer$param_groups[[1]]$params$B) - + #B_old <- as.numeric(optimizer$param_groups[[1]]$params$B) + B_old = optimizer$param_groups[[1]]$params$B$clone() # rearrange the data each epoch - permute <- torch::torch_randperm(self$n) + 1L + #permute <- torch::torch_randperm(self$n, device = "cpu") + 1L + permute = torch::torch_tensor(sample.int(self$n), dtype = torch_long(), device=config$device) + for (batch_idx in 1:num_batch) { # here index is a vector of the indices in the batch index <- permute[(batch_size*(batch_idx - 1) + 1):(batch_idx*batch_size)] @@ -129,14 +138,21 @@ PLNfit <- R6Class( ## assess convergence objective[iterate + 1] <- loss$item() - B_new <- as.numeric(optimizer$param_groups[[1]]$params$B) + B_new <- optimizer$param_groups[[1]]$params$B delta_f <- abs(objective[iterate] - objective[iterate + 1]) / abs(objective[iterate + 1]) - delta_x <- sum(abs(B_old - B_new))/sum(abs(B_new)) + delta_x <- torch::torch_sum(torch::torch_abs(B_old - B_new))/torch::torch_sum(torch::torch_abs(B_new)) + + #print (delta_f) + #print (delta_x) + delta_x = delta_x$cpu() + #print (delta_x) + delta_x = as.matrix(delta_x) + #print (delta_x) ## display progress if (config$trace > 1 && (iterate %% 50 == 0)) cat('\niteration: ', iterate, 'objective', objective[iterate + 1], - 'delta_f' , round(delta_f, 6), 'delta_x', ro% map("B") %>% reduce(`+`) / self$n var_jack <- jacks %>% map("B") %>% map(~( (. - B_jack)^2)) %>% reduce(`+`) %>% @@ -228,17 +247,28 @@ PLNfit <- R6Class( variance_bootstrap = function(Y, X, O, w, n_resamples = 100, config = config_default_nlopt) { resamples <- replicate(n_resamples, sample.int(self$n, replace = TRUE), simplify = FALSE) - boots <- future.apply::future_lapply(resamples, function(resample) { + boots <- lapply(resamples, function(resample) { data <- list(Y = Y[resample, , drop = FALSE], X = X[resample, , drop = FALSE], O = O[resample, , drop = FALSE], w = w[resample]) + #print (config$torch_device) + #print (config) + if (config$algorithm %in% c("RPROP", "RMSPROP", "ADAM", "ADAGRAD")) # hack, to know if we're doing torch or not + data <- lapply(data, torch_tensor, device = config$device) # list with Y, X, O, w + + #print (data$Y$device) + args <- list(data = data, params = list(B = private$B, M = matrix(0,self$n,self$p), S = private$S[resample, ]), config = config) + if (config$algorithm %in% c("RPROP", "RMSPROP", "ADAM", "ADAGRAD")) # hack, to know if we're doing torch or not + args$params <- lapply(args$params, torch_tensor, requires_grad = TRUE, device = config$device) # list with B, M, S + optim_out <- do.call(private$optimizer$main, args) + #print (optim_out) optim_out[c("B", "Omega", "monitoring")] - }, future.seed = TRUE) + }) B_boots <- boots %>% map("B") %>% reduce(`+`) / n_resamples attr(private$B, "variance_bootstrap") <- diff --git a/R/utils.R b/R/utils.R index 2d6e52ad..f6d3f9db 100644 --- a/R/utils.R +++ b/R/utils.R @@ -26,7 +26,8 @@ config_default_torch <- step_sizes = c(1e-3, 50), etas = c(0.5, 1.2), centered = FALSE, - trace = 1 + trace = 1, + device = "cpu" ) config_post_default_PLN <- @@ -107,6 +108,11 @@ trace <- function(x) sum(diag(x)) x } +.logfactorial_torch <- function(n){ + n[n == 0] <- 1 ## 0! = 1! + n*torch_log(n) - n + torch_log(8*torch_pow(n,3) + 4*torch_pow(n,2) + n + 1/30)/6 + log(pi)/2 +} + .logfactorial <- function(n) { # Ramanujan's formula n[n == 0] <- 1 ## 0! = 1! n*log(n) - n + log(8*n^3 + 4*n^2 + n + 1/30)/6 + log(pi)/2 From 635dd223b9a1429611b5b4cd1f9141606c0dd89f Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Fri, 10 Feb 2023 12:48:47 -0800 Subject: [PATCH 03/21] Compute vcov on parameters when using jackknife or bootstrap --- R/PLNfit-class.R | 61 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 59 insertions(+), 2 deletions(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 2fb320e0..0e901507 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -60,6 +60,8 @@ PLNfit <- R6Class( ## PRIVATE TORCH METHODS FOR OPTIMIZATION ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% torch_elbo = function(data, params, index=torch_tensor(1:self$n)) { + #print (index) + #print (params$S) S2 <- torch_square(params$S[index]) Z <- data$O[index] + params$M[index] + torch_mm(data$X[index], params$B) res <- .5 * sum(data$w[index]) * torch_logdet(private$torch_Sigma(data, params, index)) + @@ -140,11 +142,12 @@ PLNfit <- R6Class( objective[iterate + 1] <- loss$item() B_new <- optimizer$param_groups[[1]]$params$B delta_f <- abs(objective[iterate] - objective[iterate + 1]) / abs(objective[iterate + 1]) + #delta_x = 0 delta_x <- torch::torch_sum(torch::torch_abs(B_old - B_new))/torch::torch_sum(torch::torch_abs(B_new)) + delta_x = delta_x$cpu() #print (delta_f) #print (delta_x) - delta_x = delta_x$cpu() #print (delta_x) delta_x = as.matrix(delta_x) #print (delta_x) @@ -156,7 +159,7 @@ PLNfit <- R6Class( ## Check for convergence if (delta_f < config$ftol_rel) status <- 3 - if (delta_x < config$xtol_rel) status <- 4 + #if (delta_x < config$xtol_rel) status <- 4 if (status %in% c(3,4)) { objective <- objective[1:iterate + 1] break @@ -217,6 +220,54 @@ PLNfit <- R6Class( invisible(list(var_B = var_B, var_Omega = var_Omega)) }, + compute_vcov_from_resamples = function(resamples){ + # compute the covariance of the parameters + get_cov_mat = function(data, cell_group) { + + cov_matrix = cov(data) + rownames(cov_matrix) = paste0(cell_group, "_", rownames(cov_matrix)) + colnames(cov_matrix) = paste0(cell_group, "_", colnames(cov_matrix)) + return(cov_matrix) + } + + + B_list = resamples %>% map("B") + #print (B_list) + vcov_B = lapply(seq(1, ncol(private$B)), function(B_col){ + param_ests_for_col = B_list %>% map(~.x[, B_col]) + param_ests_for_col = do.call(rbind, param_ests_for_col) + print (param_ests_for_col) + row_vcov = cov(param_ests_for_col) + }) + #print ("vcov blocks") + #print (vcov_B) + + #B_vcov <- resamples %>% map("B") %>% map(~( . )) %>% reduce(cov) + + #var_jack <- jacks %>% map("B") %>% map(~( (. - B_jack)^2)) %>% reduce(`+`) %>% + # `dimnames<-`(dimnames(private$B)) + #B_hat <- private$B[,] ## strips attributes while preserving names + + vcov_B = Matrix::bdiag(vcov_B) %>% as.matrix() + + rownames(vcov_B) <- colnames(vcov_B) <- + expand.grid(covariates = rownames(private$B), + responses = colnames(private$B)) %>% rev() %>% + ## Hack to make sure that species is first and varies slowest + apply(1, paste0, collapse = "_") + + #print (pheatmap::pheatmap(vcov_B, cluster_rows=FALSE, cluster_cols=FALSE)) + + + #names = lapply(bootstrapped_df$cov_mat, function(m){ colnames(m)}) %>% unlist() + #rownames(bootstrapped_vhat) = names + #colnames(bootstrapped_vhat) = names + + vcov_B = methods::as(vcov_B, "dgCMatrix") + + return(vcov_B) + }, + variance_jackknife = function(Y, X, O, w, config = config_default_nlopt) { jacks <- lapply(seq_len(self$n), function(i) { data <- list(Y = Y[-i, , drop = FALSE], @@ -237,6 +288,9 @@ PLNfit <- R6Class( attr(private$B, "bias") <- (self$n - 1) * (B_jack - B_hat) attr(private$B, "variance_jackknife") <- (self$n - 1) / self$n * var_jack + vcov_boots = private$compute_vcov_from_resamples(boots) + attr(private$B, "vcov_jackknife") <- vcov_boots + Omega_jack <- jacks %>% map("Omega") %>% reduce(`+`) / self$n var_jack <- jacks %>% map("Omega") %>% map(~( (. - Omega_jack)^2)) %>% reduce(`+`) %>% `dimnames<-`(dimnames(private$Omega)) @@ -275,6 +329,9 @@ PLNfit <- R6Class( boots %>% map("B") %>% map(~( (. - B_boots)^2)) %>% reduce(`+`) %>% `dimnames<-`(dimnames(private$B)) / n_resamples + vcov_boots = private$compute_vcov_from_resamples(boots) + attr(private$B, "vcov_bootstrap") <- vcov_boots + Omega_boots <- boots %>% map("Omega") %>% reduce(`+`) / n_resamples attr(private$Omega, "variance_bootstrap") <- boots %>% map("Omega") %>% map(~( (. - Omega_boots)^2)) %>% reduce(`+`) %>% From 0d5b9a75f6dc665234e40ede776adf5b670af286 Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Tue, 14 Feb 2023 11:49:06 -0800 Subject: [PATCH 04/21] Remove stray print statement --- R/PLNfit-class.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 0e901507..ce0c08cf 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -236,7 +236,7 @@ PLNfit <- R6Class( vcov_B = lapply(seq(1, ncol(private$B)), function(B_col){ param_ests_for_col = B_list %>% map(~.x[, B_col]) param_ests_for_col = do.call(rbind, param_ests_for_col) - print (param_ests_for_col) + #print (param_ests_for_col) row_vcov = cov(param_ests_for_col) }) #print ("vcov blocks") @@ -497,7 +497,7 @@ PLNfit <- R6Class( if (config_post$bootstrap > 0) { if(config_post$trace > 1) { cat("\n\tComputing bootstrap estimator of the variance...") - print (str(config_optim)) + #print (str(config_optim)) } private$variance_bootstrap(responses, covariates, offsets, weights, n_resamples=config_post$bootstrap, config = config_optim) } From 6a745c5f28fe7b6d31a94efe6804313b30fc2d3d Mon Sep 17 00:00:00 2001 From: maddyduran Date: Tue, 14 Feb 2023 13:06:38 -0800 Subject: [PATCH 05/21] pass covariance type to PLNnetwork --- R/PLNnetwork.R | 3 +++ R/PLNnetworkfamily-class.R | 7 ++++++- R/PLNnetworkfit-class.R | 2 +- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/R/PLNnetwork.R b/R/PLNnetwork.R index 93924f46..4735b7fc 100644 --- a/R/PLNnetwork.R +++ b/R/PLNnetwork.R @@ -87,6 +87,7 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control #' @export PLNnetwork_param <- function( backend = c("nlopt", "torch"), + covariance = c("fixed", "spherical", "diagonal"), trace = 1 , n_penalties = 30 , min_ratio = 0.1 , @@ -115,6 +116,7 @@ PLNnetwork_param <- function( stopifnot(config_optim$algorithm %in% available_algorithms_torch) config_opt <- config_default_torch } + covariance <- match.arg(covariance) config_opt$trace <- trace config_opt$ftol_out <- 1e-5 config_opt$maxit_out <- 20 @@ -123,6 +125,7 @@ PLNnetwork_param <- function( structure(list( backend = backend , trace = trace , + covariance = covariance , n_penalties = n_penalties , min_ratio = min_ratio , penalize_diagonal = penalize_diagonal, diff --git a/R/PLNnetworkfamily-class.R b/R/PLNnetworkfamily-class.R index 1f9b9f3b..e6ce0a7f 100644 --- a/R/PLNnetworkfamily-class.R +++ b/R/PLNnetworkfamily-class.R @@ -45,7 +45,12 @@ PLNnetworkfamily <- R6Class( ## A basic model for inception, useless one is defined by the user ### TODO check if it is useful if (is.null(control$inception)) { - myPLN <- PLNfit$new(responses, covariates, offsets, weights, formula, control) + + myPLN <- switch(control$covariance, + "spherical" = PLNfit_spherical$new(responses, covariates, offsets, weights, formula, control), + "diagonal" = PLNfit_diagonal$new(responses, covariates, offsets, weights, formula, control), + PLNfit$new(responses, covariates, offsets, weights, formula, control)) # defaults to fixed + # myPLN <- PLNfit$new(responses, covariates, offsets, weights, formula, control) myPLN$optimize(responses, covariates, offsets, weights, control$config_optim) control$inception <- myPLN } diff --git a/R/PLNnetworkfit-class.R b/R/PLNnetworkfit-class.R index 1dff42da..6a24bc32 100644 --- a/R/PLNnetworkfit-class.R +++ b/R/PLNnetworkfit-class.R @@ -33,7 +33,7 @@ #' @seealso The function [PLNnetwork()], the class [`PLNnetworkfamily`] PLNnetworkfit <- R6Class( classname = "PLNnetworkfit", - inherit = PLNfit_fixedcov, + inherit = PLNfit_spherical, #_fixedcov, ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## PUBLIC MEMBERS ---- ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% From 057fd3bdfcba4fa7d971c918a0c07e5c5f663026 Mon Sep 17 00:00:00 2001 From: maddyduran Date: Tue, 14 Feb 2023 14:23:22 -0800 Subject: [PATCH 06/21] fix jackknife bug --- R/PLNfit-class.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index ce0c08cf..7a0b7550 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -288,8 +288,8 @@ PLNfit <- R6Class( attr(private$B, "bias") <- (self$n - 1) * (B_jack - B_hat) attr(private$B, "variance_jackknife") <- (self$n - 1) / self$n * var_jack - vcov_boots = private$compute_vcov_from_resamples(boots) - attr(private$B, "vcov_jackknife") <- vcov_boots + vcov_jacks = private$compute_vcov_from_resamples(jacks) + attr(private$B, "vcov_jackknife") <- vcov_jacks Omega_jack <- jacks %>% map("Omega") %>% reduce(`+`) / self$n var_jack <- jacks %>% map("Omega") %>% map(~( (. - Omega_jack)^2)) %>% reduce(`+`) %>% From 423775bd4f58c383efa2dfa5538d27c6529d943b Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Sat, 18 Feb 2023 10:59:05 -0800 Subject: [PATCH 07/21] cast Sigma to dense when fetching the max penalty in order to avoid warnings about ineffiecient access --- R/PLNnetworkfamily-class.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/PLNnetworkfamily-class.R b/R/PLNnetworkfamily-class.R index e6ce0a7f..315cea40 100644 --- a/R/PLNnetworkfamily-class.R +++ b/R/PLNnetworkfamily-class.R @@ -47,7 +47,7 @@ PLNnetworkfamily <- R6Class( if (is.null(control$inception)) { myPLN <- switch(control$covariance, - "spherical" = PLNfit_spherical$new(responses, covariates, offsets, weights, formula, control), + "spherical" = PLNfit_spherical$new(responses, covariates, offsets, weights, formula, control), "diagonal" = PLNfit_diagonal$new(responses, covariates, offsets, weights, formula, control), PLNfit$new(responses, covariates, offsets, weights, formula, control)) # defaults to fixed # myPLN <- PLNfit$new(responses, covariates, offsets, weights, formula, control) @@ -74,7 +74,7 @@ PLNnetworkfamily <- R6Class( if (is.null(penalties)) { if (control$trace > 1) cat("\n Recovering an appropriate grid of penalties.") max_pen <- list_penalty_weights %>% - map(~ myPLN$model_par$Sigma / .x) %>% + map(~ as.matrix(myPLN$model_par$Sigma) / .x) %>% map_dbl(~ max(abs(.x[upper.tri(.x, diag = control$penalize_diagonal)]))) %>% max() penalties <- 10^seq(log10(max_pen), log10(max_pen*control$min_ratio), len = control$n_penalties) From f801ca6a886988c7f08eec594c57df1177aec948 Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Sat, 18 Feb 2023 10:59:31 -0800 Subject: [PATCH 08/21] Revert back to using fixed covariance instead of spherical as base class for PLNnetworkfit --- R/PLNnetworkfit-class.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/PLNnetworkfit-class.R b/R/PLNnetworkfit-class.R index 6a24bc32..1dff42da 100644 --- a/R/PLNnetworkfit-class.R +++ b/R/PLNnetworkfit-class.R @@ -33,7 +33,7 @@ #' @seealso The function [PLNnetwork()], the class [`PLNnetworkfamily`] PLNnetworkfit <- R6Class( classname = "PLNnetworkfit", - inherit = PLNfit_spherical, #_fixedcov, + inherit = PLNfit_fixedcov, ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## PUBLIC MEMBERS ---- ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% From 31b6153035ebb6074087ed118a2eeff25cb9aa0d Mon Sep 17 00:00:00 2001 From: maddyduran Date: Mon, 14 Aug 2023 16:05:53 -0700 Subject: [PATCH 09/21] changing line 85 to match PLNmodels/master --- R/PLNnetworkfamily-class.R | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/R/PLNnetworkfamily-class.R b/R/PLNnetworkfamily-class.R index ca629e39..e8f0b273 100644 --- a/R/PLNnetworkfamily-class.R +++ b/R/PLNnetworkfamily-class.R @@ -79,9 +79,10 @@ PLNnetworkfamily <- R6Class( # CHECK_ME_TORCH_GPU # This appears to be in torch_gpu only. The commented out line below is # in both PLNmodels/master and PLNmodels/dev. + # changed it to other one max_pen <- list_penalty_weights %>% - map(~ as.matrix(myPLN$model_par$Sigma) / .x) %>% - # map(~ control$inception$model_par$Sigma / .x) %>% + # map(~ as.matrix(myPLN$model_par$Sigma) / .x) %>% + map(~ control$inception$model_par$Sigma / .x) %>% map_dbl(~ max(abs(.x[upper.tri(.x, diag = control$penalize_diagonal)]))) %>% max() penalties <- 10^seq(log10(max_pen), log10(max_pen*control$min_ratio), len = control$n_penalties) From ed7c1811fdbcd62c53a9f7fa6940f094261aedb0 Mon Sep 17 00:00:00 2001 From: maddyduran Date: Fri, 25 Aug 2023 09:20:08 -0700 Subject: [PATCH 10/21] actually putting line 85 back --- R/PLNnetworkfamily-class.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/PLNnetworkfamily-class.R b/R/PLNnetworkfamily-class.R index e8f0b273..efe58309 100644 --- a/R/PLNnetworkfamily-class.R +++ b/R/PLNnetworkfamily-class.R @@ -81,8 +81,8 @@ PLNnetworkfamily <- R6Class( # in both PLNmodels/master and PLNmodels/dev. # changed it to other one max_pen <- list_penalty_weights %>% - # map(~ as.matrix(myPLN$model_par$Sigma) / .x) %>% - map(~ control$inception$model_par$Sigma / .x) %>% + map(~ as.matrix(myPLN$model_par$Sigma) / .x) %>% + # map(~ control$inception$model_par$Sigma / .x) %>% map_dbl(~ max(abs(.x[upper.tri(.x, diag = control$penalize_diagonal)]))) %>% max() penalties <- 10^seq(log10(max_pen), log10(max_pen*control$min_ratio), len = control$n_penalties) From 02e3501047a3a071bf9f7ad1b754f0dd86fca8a5 Mon Sep 17 00:00:00 2001 From: Cole Trapnell Date: Tue, 17 Oct 2023 10:48:24 -0700 Subject: [PATCH 11/21] slight code cleanup in the torch optimizer --- R/PLNfit-class.R | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 476d0d8e..ac1dd363 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -60,12 +60,11 @@ PLNfit <- R6Class( ## PRIVATE TORCH METHODS FOR OPTIMIZATION ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% torch_elbo = function(data, params, index=torch_tensor(1:self$n)) { - #print (index) - #print (params$S) S2 <- torch_square(params$S[index]) Z <- data$O[index] + params$M[index] + torch_mm(data$X[index], params$B) + A <- torch_exp(Z + .5 * S2) res <- .5 * sum(data$w[index]) * torch_logdet(private$torch_Sigma(data, params, index)) + - sum(data$w[index,NULL] * (torch_exp(Z + .5 * S2) - data$Y[index] * Z - .5 * torch_log(S2))) + sum(data$w[index,NULL] * (A - data$Y[index] * Z - .5 * torch_log(S2))) res }, @@ -122,11 +121,11 @@ PLNfit <- R6Class( #B_old = optimizer$param_groups[[1]]$params$B$clone() for (iterate in 1:num_epoch) { #B_old <- as.numeric(optimizer$param_groups[[1]]$params$B) - B_old = optimizer$param_groups[[1]]$params$B$clone() # rearrange the data each epoch #permute <- torch::torch_randperm(self$n, device = "cpu") + 1L permute = torch::torch_tensor(sample.int(self$n), dtype = torch_long(), device=config$device) + #print (paste("num batches", num_batch)) for (batch_idx in 1:num_batch) { # here index is a vector of the indices in the batch index <- permute[(batch_size*(batch_idx - 1) + 1):(batch_idx*batch_size)] @@ -140,24 +139,15 @@ PLNfit <- R6Class( ## assess convergence objective[iterate + 1] <- loss$item() - B_new <- optimizer$param_groups[[1]]$params$B delta_f <- abs(objective[iterate] - objective[iterate + 1]) / abs(objective[iterate + 1]) - #delta_x = 0 - delta_x <- torch::torch_sum(torch::torch_abs(B_old - B_new))/torch::torch_sum(torch::torch_abs(B_new)) - delta_x = delta_x$cpu() - - #print (delta_f) - #print (delta_x) - #print (delta_x) - delta_x = as.matrix(delta_x) - #print (delta_x) ## display progress - if (config$trace > 1 && (iterate %% 50 == 0)) + if (config$trace > 1 && (iterate %% 50 == 1)) cat('\niteration: ', iterate, 'objective', objective[iterate + 1], - 'delta_f' , round(delta_f, 6), 'delta_x', round(delta_x, 6)) + 'delta_f' , round(delta_f, 6)) ## Check for convergence + #print (delta_f) if (delta_f < config$ftol_rel) status <- 3 #if (delta_x < config$xtol_rel) status <- 4 if (status %in% c(3,4)) { From 44a1cbd046173e52ea408ecad2b127dc123316db Mon Sep 17 00:00:00 2001 From: Mahendra Mariadassou Date: Mon, 13 Nov 2023 15:11:57 +0100 Subject: [PATCH 12/21] PLNnetwork: Add `inception_cov` argument to `PLNnetwork_param()` to allow the user to control the covariance structure of the inception model. Defaults to "full" (same as before). --- R/PLNnetwork.R | 11 ++++++----- R/PLNnetworkfamily-class.R | 13 ++++++++----- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/R/PLNnetwork.R b/R/PLNnetwork.R index 197bdf91..88cc5438 100644 --- a/R/PLNnetwork.R +++ b/R/PLNnetwork.R @@ -54,6 +54,7 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control #' Helper to define list of parameters to control the PLN fit. All arguments have defaults. #' #' @param backend optimization back used, either "nlopt" or "torch". Default is "nlopt" +#' @param inception_cov Covariance structure used for the inception model used to initialize the PLNfamily. Defaults to "full" and can be constrained to "diagonal" and "spherical". #' @param config_optim a list for controlling the optimizer (either "nlopt" or "torch" backend). See details #' @param trace a integer for verbosity. #' @param n_penalties an integer that specifies the number of values for the penalty grid when internally generated. Ignored when penalties is non `NULL` @@ -74,14 +75,14 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control #' @export PLNnetwork_param <- function( backend = c("nlopt", "torch"), - covariance = c("fixed", "spherical", "diagonal"), + inception_cov = c("full", "spherical", "diagonal"), trace = 1 , n_penalties = 30 , min_ratio = 0.1 , penalize_diagonal = TRUE , penalty_weights = NULL , - config_post = list(), - config_optim = list(), + config_post = list(), + config_optim = list(), inception = NULL ) { @@ -103,7 +104,7 @@ PLNnetwork_param <- function( stopifnot(config_optim$algorithm %in% available_algorithms_torch) config_opt <- config_default_torch } - covariance <- match.arg(covariance) + inception_cov <- match.arg(inception_cov) config_opt$trace <- trace config_opt$ftol_out <- 1e-5 config_opt$maxit_out <- 20 @@ -112,7 +113,7 @@ PLNnetwork_param <- function( structure(list( backend = backend , trace = trace , - covariance = covariance , + inception_cov = inception_cov , n_penalties = n_penalties , min_ratio = min_ratio , penalize_diagonal = penalize_diagonal, diff --git a/R/PLNnetworkfamily-class.R b/R/PLNnetworkfamily-class.R index efe58309..00f2cec8 100644 --- a/R/PLNnetworkfamily-class.R +++ b/R/PLNnetworkfamily-class.R @@ -49,11 +49,14 @@ PLNnetworkfamily <- R6Class( # CHECK_ME_TORCH_GPU # This appears to be in torch_gpu only. The commented out line below is # in both PLNmodels/master and PLNmodels/dev. - myPLN <- switch(control$covariance, - "spherical" = PLNfit_spherical$new(responses, covariates, offsets, weights, formula, control), - "diagonal" = PLNfit_diagonal$new(responses, covariates, offsets, weights, formula, control), - PLNfit$new(responses, covariates, offsets, weights, formula, control)) # defaults to fixed - # myPLN <- PLNfit$new(responses, covariates, offsets, weights, formula, control) + myPLN <- switch( + control$inception_cov, + "spherical" = PLNfit_spherical$new(responses, covariates, offsets, weights, formula, control), + "diagonal" = PLNfit_diagonal$new(responses, covariates, offsets, weights, formula, control), + PLNfit$new(responses, covariates, offsets, weights, formula, control) # defaults to full + ) + ## Allow inception with spherical / diagonal / full PLNfit before switching back to PLNfit_fixedcov + ## for the inner-outer loop of PLNnetwork. myPLN$optimize(responses, covariates, offsets, weights, control$config_optim) control$inception <- myPLN } From 48b5925d83a03b1c00b234a45f774f9cbe5cdf4c Mon Sep 17 00:00:00 2001 From: Mahendra Mariadassou Date: Mon, 13 Nov 2023 16:07:47 +0100 Subject: [PATCH 13/21] Parameter `config_post`: use the same interface for all `PLN*_param()` functions and use `control` in the same way in all `postTreatment()` functions. --- R/PLN.R | 2 +- R/PLNLDAfit-class.R | 3 ++- R/PLNPCA.R | 11 +++++++++-- R/PLNPCAfit-class.R | 4 +++- R/PLNfamily-class.R | 3 ++- R/PLNfit-class.R | 3 ++- R/PLNmixture.R | 11 +++++++++-- R/PLNnetwork.R | 3 +-- 8 files changed, 29 insertions(+), 11 deletions(-) diff --git a/R/PLN.R b/R/PLN.R index 16131e9b..451850de 100644 --- a/R/PLN.R +++ b/R/PLN.R @@ -90,7 +90,7 @@ PLN <- function(formula, data, subset, weights, control = PLN_param()) { #' * "etas" pair of multiplicative increase and decrease factors. Default is (0.5, 1.2). Only used in RPROP #' * "centered" if TRUE, compute the centered RMSProp where the gradient is normalized by an estimation of its variance weight_decay (L2 penalty). Default to FALSE. Only used in RMSPROP #' -#' The list of parameters `config_post` controls the post-treatment processing (for PLN and PLNLDA), with the following entries: +#' The list of parameters `config_post` controls the post-treatment processing (for most `PLN*()` functions), with the following entries (defaults may vary depending on the specific function, check `config_post_default_*` for defaults values): #' * jackknife boolean indicating whether jackknife should be performed to evaluate bias and variance of the model parameters. Default is FALSE. #' * bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated). #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. diff --git a/R/PLNLDAfit-class.R b/R/PLNLDAfit-class.R index feb6c5c5..9404fda1 100644 --- a/R/PLNLDAfit-class.R +++ b/R/PLNLDAfit-class.R @@ -84,7 +84,8 @@ PLNLDAfit <- R6Class( ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## Post treatment -------------------- #' @description Update R2, fisher and std_err fields and visualization - #' @param config list controlling the post-treatment + #' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). + #' @param config_optim list controlling the optimization parameters postTreatment = function(grouping, responses, covariates, offsets, config_post, config_optim) { covariates <- cbind(covariates, model.matrix( ~ grouping + 0)) super$postTreatment(responses, covariates, offsets, config_post = config_post, config_optim = config_optim) diff --git a/R/PLNPCA.R b/R/PLNPCA.R index 4606550b..505061d4 100644 --- a/R/PLNPCA.R +++ b/R/PLNPCA.R @@ -51,8 +51,7 @@ PLNPCA <- function(formula, data, subset, weights, ranks = 1:5, control = PLNPCA ## Post-treatments: pseudo-R2, rearrange criteria and prepare PCA visualization if (control$trace > 0) cat("\n Post-treatments") - config_post <- config_post_default_PLNPCA; config_post$trace <- control$trace - myPCA$postTreatment(config_post, control$config_optim) + myPCA$postTreatment(control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPCA @@ -65,6 +64,7 @@ PLNPCA <- function(formula, data, subset, weights, ranks = 1:5, control = PLNPCA #' @param backend optimization back used, either "nlopt" or "torch". Default is "nlopt" #' @param trace a integer for verbosity. #' @param config_optim a list for controlling the optimizer (either "nlopt" or "torch" backend). See details +#' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details #' @param inception Set up the parameters initialization: by default, the model is initialized with a multivariate linear model applied on #' log-transformed data, and with the same formula as the one provided by the user. However, the user can provide a PLNfit (typically obtained from a previous fit), #' which sometimes speeds up the inference. @@ -77,11 +77,17 @@ PLNPCA_param <- function( backend = "nlopt", trace = 1 , config_optim = list() , + config_post = list() , inception = NULL # pretrained PLNfit used as initialization ) { if (!is.null(inception)) stopifnot(isPLNfit(inception)) + ## post-treatment config + config_pst <- config_post_default_PLNPCA + config_pst[names(config_post)] <- config_post + config_pst$trace <- trace + ## optimization config backend <- match.arg(backend) stopifnot(backend %in% c("nlopt", "torch")) @@ -100,5 +106,6 @@ PLNPCA_param <- function( backend = backend , trace = trace , config_optim = config_opt, + config_post = config_pst, inception = inception ), class = "PLNmodels_param") } diff --git a/R/PLNPCAfit-class.R b/R/PLNPCAfit-class.R index 9b55e5a9..583ae05c 100644 --- a/R/PLNPCAfit-class.R +++ b/R/PLNPCAfit-class.R @@ -163,7 +163,9 @@ PLNPCAfit <- R6Class( }, #' @description Update R2, fisher, std_err fields and set up visualization - #' @details The list of parameters `config` controls the post-treatment processing, with the following entries: + #' @param config_optim a list for controlling the optimizer (either "nlopt" or "torch" backend). See details + #' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details + #' @details The list of parameters `config_post` controls the post-treatment processing, with the following entries: #' * jackknife boolean indicating whether jackknife should be performed to evaluate bias and variance of the model parameters. Default is FALSE. #' * bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated). #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. diff --git a/R/PLNfamily-class.R b/R/PLNfamily-class.R index fda35a9c..0aaabb6e 100644 --- a/R/PLNfamily-class.R +++ b/R/PLNfamily-class.R @@ -64,7 +64,8 @@ PLNfamily <- ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## Post treatment -------------------- #' @description Update fields after optimization - #' @param config_post a list for controlling the post-treatment. + #' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). + #' @param config_optim a list for controlling the optimization parameters used during post_treatments postTreatment = function(config_post, config_optim) { #nullModel <- nullModelPoisson(self$responses, self$covariates, self$offsets, self$weights) for (model in self$models) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 065137ce..797073e1 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -456,7 +456,8 @@ PLNfit <- R6Class( }, #' @description Update R2, fisher and std_err fields after optimization - #' @param config a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details + #' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details + #' @param config_optim a list for controlling the optimization (optional bootstrap, jackknife, R2, etc.). See details #' @details The list of parameters `config` controls the post-treatment processing, with the following entries: #' * jackknife boolean indicating whether jackknife should be performed to evaluate bias and variance of the model parameters. Default is FALSE. #' * bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated). diff --git a/R/PLNmixture.R b/R/PLNmixture.R index b723c6b2..7ecbe61d 100644 --- a/R/PLNmixture.R +++ b/R/PLNmixture.R @@ -59,8 +59,7 @@ PLNmixture <- function(formula, data, subset, clusters = 1:5, control = PLNmixt ## Post-treatments: Compute pseudo-R2, rearrange criteria and the visualization for PCA if (control$trace > 0) cat("\n Post-treatments") - config_post <- config_post_default_PLNmixture; config_post$trace <- control$trace - myPLN$postTreatment(config_post, control$config_optim) + myPLN$postTreatment(control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPLN @@ -75,6 +74,7 @@ PLNmixture <- function(formula, data, subset, clusters = 1:5, control = PLNmixt #' @param smoothing The smoothing to apply. Either, 'none', forward', 'backward' or 'both'. Default is 'both'. #' @param init_cl The initial clustering to apply. Either, 'kmeans', CAH' or a user defined clustering given as a list of clusterings, the size of which is equal to the number of clusters considered. Default is 'kmeans'. #' @param config_optim a list for controlling the optimizer (either "nlopt" or "torch" backend). See details +#' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). #' @param trace a integer for verbosity. #' @param inception Set up the parameters initialization: by default, the model is initialized with a multivariate linear model applied on #' log-transformed data, and with the same formula as the one provided by the user. However, the user can provide a PLNfit (typically obtained from a previous fit), @@ -95,10 +95,16 @@ PLNmixture_param <- function( init_cl = "kmeans" , smoothing = "both" , config_optim = list() , + config_post = list() , inception = NULL # pretrained PLNfit used as initialization ) { if (!is.null(inception)) stopifnot(isPLNfit(inception)) + ## post-treatment config + config_pst <- config_post_default_PLNmixture + config_pst[names(config_post)] <- config_post + config_pst$trace <- trace + ## optimization config backend <- match.arg(backend) stopifnot(backend %in% c("nlopt", "torch")) @@ -123,5 +129,6 @@ PLNmixture_param <- function( init_cl = init_cl , smoothing = smoothing , config_optim = config_opt , + config_post = config_pst , inception = inception ), class = "PLNmodels_param") } diff --git a/R/PLNnetwork.R b/R/PLNnetwork.R index 88cc5438..6ede31f6 100644 --- a/R/PLNnetwork.R +++ b/R/PLNnetwork.R @@ -41,8 +41,6 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control ## Post-treatments if (control$trace > 0) cat("\n Post-treatments") - #config_post <- config_post_default_PLNnetwork; - #config_post$trace <- control$trace myPLN$postTreatment(control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") @@ -56,6 +54,7 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control #' @param backend optimization back used, either "nlopt" or "torch". Default is "nlopt" #' @param inception_cov Covariance structure used for the inception model used to initialize the PLNfamily. Defaults to "full" and can be constrained to "diagonal" and "spherical". #' @param config_optim a list for controlling the optimizer (either "nlopt" or "torch" backend). See details +#' @param config_post a list for controlling the post-treatment (optional bootstrap, jackknife, R2, etc). #' @param trace a integer for verbosity. #' @param n_penalties an integer that specifies the number of values for the penalty grid when internally generated. Ignored when penalties is non `NULL` #' @param min_ratio the penalty grid ranges from the minimal value that produces a sparse to this value multiplied by `min_ratio`. Default is 0.1. From bd8bfc209dff26d1938c55b4f036f23b50ab80b2 Mon Sep 17 00:00:00 2001 From: Mahendra Mariadassou Date: Mon, 13 Nov 2023 16:26:16 +0100 Subject: [PATCH 14/21] - Use only torch function to compute the logfactorial - Defer moving data back to the CPU until the end when computing the objective function --- R/PLNfit-class.R | 5 ++--- R/utils.R | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 797073e1..c5e76d49 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -85,9 +85,8 @@ PLNfit <- R6Class( Ji_tmp = .5 * torch_logdet(params$Omega) + torch_sum(data$Y * params$Z - params$A + .5 * torch_log(S2), dim = 2) - .5 * torch_sum(torch_mm(params$M, params$Omega) * params$M + S2 * torch_diag(params$Omega), dim = 2) - Ji_tmp = Ji_tmp$cpu() - Ji_tmp = as.numeric(Ji_tmp) - Ji <- .5 * self$p - rowSums(.logfactorial(as.matrix(data$Y$cpu()))) + Ji_tmp + Ji <- - torch_sum(.logfactorial_torch(data$Y), dim = 2) + Ji_tmp + Ji <- .5 * self$p + as.numeric(Ji$cpu()) attr(Ji, "weights") <- as.numeric(data$w$cpu()) Ji diff --git a/R/utils.R b/R/utils.R index 60351772..8e7ec48b 100644 --- a/R/utils.R +++ b/R/utils.R @@ -110,7 +110,7 @@ trace <- function(x) sum(diag(x)) .logfactorial_torch <- function(n){ n[n == 0] <- 1 ## 0! = 1! - n*torch_log(n) - n + torch_log(8*torch_pow(n,3) + 4*torch_pow(n,2) + n + 1/30)/6 + log(pi)/2 + n*torch_log(n) - n + torch_log(8*torch_pow(n,3) + 4*torch_pow(n,2) + n + 1/30)/6 + torch_log(pi)/2 } .logfactorial <- function(n) { # Ramanujan's formula From 6e93d99cf7829dcccf29cffa1a4f01e3d1d4c2bc Mon Sep 17 00:00:00 2001 From: Mahendra Mariadassou Date: Mon, 13 Nov 2023 16:27:53 +0100 Subject: [PATCH 15/21] Reinstate convergence check on update of the parameters values (check can be deactivated by setting xtol_rel to infty). --- R/PLNfit-class.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index c5e76d49..f208e6f0 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -154,7 +154,7 @@ PLNfit <- R6Class( ## Check for convergence #print (delta_f) if (delta_f < config$ftol_rel) status <- 3 - #if (delta_x < config$xtol_rel) status <- 4 + if (delta_x < config$xtol_rel) status <- 4 if (status %in% c(3,4)) { objective <- objective[1:iterate + 1] break From ac79cb68b195d306717f7dcc78cb4bb1076f6809 Mon Sep 17 00:00:00 2001 From: Mahendra Mariadassou Date: Mon, 13 Nov 2023 16:29:16 +0100 Subject: [PATCH 16/21] Comestic changes to clean code --- R/PLN.R | 2 +- R/PLNfit-class.R | 10 ---------- 2 files changed, 1 insertion(+), 11 deletions(-) diff --git a/R/PLN.R b/R/PLN.R index 451850de..fb6a3a61 100644 --- a/R/PLN.R +++ b/R/PLN.R @@ -105,7 +105,7 @@ PLN_param <- function( Omega = NULL, config_post = list(), config_optim = list(), - inception = NULL # pretrained PLNfit used as initialization, + inception = NULL # pretrained PLNfit used as initialization ) { covariance <- match.arg(covariance) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index f208e6f0..6b298349 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -216,16 +216,6 @@ PLNfit <- R6Class( }, compute_vcov_from_resamples = function(resamples){ - # compute the covariance of the parameters - get_cov_mat = function(data, cell_group) { - - cov_matrix = cov(data) - rownames(cov_matrix) = paste0(cell_group, "_", rownames(cov_matrix)) - colnames(cov_matrix) = paste0(cell_group, "_", colnames(cov_matrix)) - return(cov_matrix) - } - - B_list = resamples %>% map("B") #print (B_list) vcov_B = lapply(seq(1, ncol(private$B)), function(B_col){ From 8ccf9afc930f9687b574d894903c66685b9ea810 Mon Sep 17 00:00:00 2001 From: Mahendra Mariadassou Date: Mon, 13 Nov 2023 16:31:51 +0100 Subject: [PATCH 17/21] Add `.$backend` element to `config_optim` and use it when computing bootstrap variances. --- R/PLNfit-class.R | 6 ++---- R/utils.R | 2 ++ 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 6b298349..71ab570e 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -291,9 +291,7 @@ PLNfit <- R6Class( X = X[resample, , drop = FALSE], O = O[resample, , drop = FALSE], w = w[resample]) - #print (config$torch_device) - #print (config) - if (config$algorithm %in% c("RPROP", "RMSPROP", "ADAM", "ADAGRAD")) # hack, to know if we're doing torch or not + if (config$backend == "torch") # Convert data to torch tensors data <- lapply(data, torch_tensor, device = config$device) # list with Y, X, O, w #print (data$Y$device) @@ -301,7 +299,7 @@ PLNfit <- R6Class( args <- list(data = data, params = list(B = private$B, M = matrix(0,self$n,self$p), S = private$S[resample, ]), config = config) - if (config$algorithm %in% c("RPROP", "RMSPROP", "ADAM", "ADAGRAD")) # hack, to know if we're doing torch or not + if (config$backend == "torch") # Convert data to torch tensors args$params <- lapply(args$params, torch_tensor, requires_grad = TRUE, device = config$device) # list with B, M, S optim_out <- do.call(private$optimizer$main, args) diff --git a/R/utils.R b/R/utils.R index 8e7ec48b..a382712c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -4,6 +4,7 @@ available_algorithms_torch <- c("RPROP", "RMSPROP", "ADAM", "ADAGRAD") config_default_nlopt <- list( algorithm = "CCSAQ", + backend = "nlopt", maxeval = 10000 , ftol_rel = 1e-8 , xtol_rel = 1e-6 , @@ -15,6 +16,7 @@ config_default_nlopt <- config_default_torch <- list( algorithm = "RPROP", + backend = "torch", maxeval = 10000 , num_epoch = 1000 , num_batch = 1 , From 772b3ade4d946ede36c71dddc6e124afea499983 Mon Sep 17 00:00:00 2001 From: Mahendra Mariadassou Date: Mon, 13 Nov 2023 16:37:27 +0100 Subject: [PATCH 18/21] Move back to future.apply for lapply calls. --- R/PLNfit-class.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 71ab570e..948c8510 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -254,7 +254,7 @@ PLNfit <- R6Class( }, variance_jackknife = function(Y, X, O, w, config = config_default_nlopt) { - jacks <- lapply(seq_len(self$n), function(i) { + jacks <- future.apply::future_lapply(seq_len(self$n), function(i) { data <- list(Y = Y[-i, , drop = FALSE], X = X[-i, , drop = FALSE], O = O[-i, , drop = FALSE], @@ -286,7 +286,7 @@ PLNfit <- R6Class( variance_bootstrap = function(Y, X, O, w, n_resamples = 100, config = config_default_nlopt) { resamples <- replicate(n_resamples, sample.int(self$n, replace = TRUE), simplify = FALSE) - boots <- lapply(resamples, function(resample) { + boots <- future.apply::future_lapply(resamples, function(resample) { data <- list(Y = Y[resample, , drop = FALSE], X = X[resample, , drop = FALSE], O = O[resample, , drop = FALSE], From 5a2f391edb396c4e9bb435ab070b943a987aae44 Mon Sep 17 00:00:00 2001 From: Mahendra Mariadassou Date: Mon, 13 Nov 2023 16:56:18 +0100 Subject: [PATCH 19/21] Fix failed tests due to change in `control` interface. --- R/PLNLDAfit-class.R | 2 +- tests/testthat/test-standard-error.R | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/R/PLNLDAfit-class.R b/R/PLNLDAfit-class.R index 9404fda1..f52d1dee 100644 --- a/R/PLNLDAfit-class.R +++ b/R/PLNLDAfit-class.R @@ -91,7 +91,7 @@ PLNLDAfit <- R6Class( super$postTreatment(responses, covariates, offsets, config_post = config_post, config_optim = config_optim) rownames(private$C) <- colnames(private$C) <- colnames(responses) colnames(private$S) <- 1:self$q - if (config$trace > 1) cat("\n\tCompute LD scores for visualization...") + if (config_post$trace > 1) cat("\n\tCompute LD scores for visualization...") self$setVisualization() }, diff --git a/tests/testthat/test-standard-error.R b/tests/testthat/test-standard-error.R index 43768350..84edbc12 100644 --- a/tests/testthat/test-standard-error.R +++ b/tests/testthat/test-standard-error.R @@ -95,7 +95,9 @@ test_that("Check that variance estimation are coherent in PLNfit", { trace = 2 ) - myPLN$postTreatment(Y, X, exp(log_O), config_post = config_post) + config_optim <- config_default_nlopt + + myPLN$postTreatment(Y, X, exp(log_O), config_post = config_post, config_optim = config_optim) tr_variational <- sum(standard_error(myPLN, "variational")^2) tr_bootstrap <- sum(standard_error(myPLN, "bootstrap")^2) From 21d592071f831b461b9aa6c069b7abb0e591469d Mon Sep 17 00:00:00 2001 From: Mahendra Mariadassou Date: Tue, 28 Nov 2023 16:50:17 +0100 Subject: [PATCH 20/21] Fix failing tests for R2: - Use correct post_treatment default for PLNnetwork - Instantiate null model for PLNfamily if R2 is requested in post treatments --- R/PLNfamily-class.R | 8 ++++++-- R/PLNnetwork.R | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/R/PLNfamily-class.R b/R/PLNfamily-class.R index 0aaabb6e..5bb1f8d8 100644 --- a/R/PLNfamily-class.R +++ b/R/PLNfamily-class.R @@ -67,7 +67,11 @@ PLNfamily <- #' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). #' @param config_optim a list for controlling the optimization parameters used during post_treatments postTreatment = function(config_post, config_optim) { - #nullModel <- nullModelPoisson(self$responses, self$covariates, self$offsets, self$weights) + if (config_post$rsquared) { + nullModel <- nullModelPoisson(self$responses, self$covariates, self$offsets, self$weights) + } else { + nullModel <- NULL + } for (model in self$models) model$postTreatment( self$responses, @@ -76,7 +80,7 @@ PLNfamily <- self$weights, config_post=config_post, config_optim=config_optim, - nullModel = NULL + nullModel = nullModel ) }, diff --git a/R/PLNnetwork.R b/R/PLNnetwork.R index 6ede31f6..a6f70c0e 100644 --- a/R/PLNnetwork.R +++ b/R/PLNnetwork.R @@ -88,7 +88,7 @@ PLNnetwork_param <- function( if (!is.null(inception)) stopifnot(isPLNfit(inception)) ## post-treatment config - config_pst <- config_post_default_PLN + config_pst <- config_post_default_PLNnetwork config_pst[names(config_post)] <- config_post config_pst$trace <- trace From 2eae352d5688dd09b823224569ffe011c116c548 Mon Sep 17 00:00:00 2001 From: Mahendra Mariadassou Date: Tue, 28 Nov 2023 17:00:21 +0100 Subject: [PATCH 21/21] Update postTreatment argument description - add @param for config_post and config_optim to avoid warnings when building package. --- R/PLNfit-class.R | 3 ++- R/PLNmixturefit-class.R | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 948c8510..bf1e92e3 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -876,7 +876,8 @@ PLNfit_fixedcov <- R6Class( }, #' @description Update R2, fisher and std_err fields after optimization - #' @param config a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details + #' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details + #' @param config_optim a list for controlling the optimization parameter. See details #' @details The list of parameters `config` controls the post-treatment processing, with the following entries: #' * trace integer for verbosity. should be > 1 to see output in post-treatments #' * jackknife boolean indicating whether jackknife should be performed to evaluate bias and variance of the model parameters. Default is FALSE. diff --git a/R/PLNmixturefit-class.R b/R/PLNmixturefit-class.R index 12eca6a9..6f0c65ff 100644 --- a/R/PLNmixturefit-class.R +++ b/R/PLNmixturefit-class.R @@ -280,7 +280,8 @@ PLNmixturefit <- ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## Post treatment -------------------- #' @description Update fields after optimization - #' @param config a list for controlling the post-treatment + #' @param config_post a list for controlling the post-treatment + #' @param config_optim a list for controlling the optimization during the post-treatment computations postTreatment = function(responses, covariates, offsets, weights, config_post, config_optim, nullModel) { ## restoring the full design matrix (group means + covariates)