diff --git a/DESCRIPTION b/DESCRIPTION index f53a5069..9aa3f21a 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: PLNmodels Title: Poisson Lognormal Models -Version: 1.0.4-0300 +Version: 1.0.5-0000 Authors@R: c( person("Julien", "Chiquet", role = c("aut", "cre"), email = "julien.chiquet@inrae.fr", comment = c(ORCID = "0000-0002-3629-3429")), diff --git a/NEWS.md b/NEWS.md index 428c9259..a8a3b2e4 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,11 @@ * Update documentation of PLN*_param() functions to include torch optimization parameters * Add (somehow) explicit error message when torch convergence fails * Change initialization in `variance_jackknife()` and `variance_bootstrap()` to prevent estimation recycling, results from those functions are now comparable to doing jackknife / bootstrap "by hand". +* Merge PR #110 from Cole Trapnell to add: + - bootstrap estimation of the variance of model parameter + - improved interface for model initialization / optimisation parameters, which + are now passed on to jackknife / bootstrap post-treatments + - better support of GPU when using torch backend # PLNmodels 1.0.4 (2023-08-24) diff --git a/R/PLN.R b/R/PLN.R index 9a1e321e..fb6a3a61 100644 --- a/R/PLN.R +++ b/R/PLN.R @@ -45,7 +45,7 @@ PLN <- function(formula, data, subset, weights, control = PLN_param()) { ## post-treatment if (control$trace > 0) cat("\n Post-treatments...") - myPLN$postTreatment(args$Y, args$X, args$O, args$w, control$config_post) + myPLN$postTreatment(args$Y, args$X, args$O, args$w, control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPLN @@ -90,7 +90,7 @@ PLN <- function(formula, data, subset, weights, control = PLN_param()) { #' * "etas" pair of multiplicative increase and decrease factors. Default is (0.5, 1.2). Only used in RPROP #' * "centered" if TRUE, compute the centered RMSProp where the gradient is normalized by an estimation of its variance weight_decay (L2 penalty). Default to FALSE. Only used in RMSPROP #' -#' The list of parameters `config_post` controls the post-treatment processing (for PLN and PLNLDA), with the following entries: +#' The list of parameters `config_post` controls the post-treatment processing (for most `PLN*()` functions), with the following entries (defaults may vary depending on the specific function, check `config_post_default_*` for defaults values): #' * jackknife boolean indicating whether jackknife should be performed to evaluate bias and variance of the model parameters. Default is FALSE. #' * bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated). #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. diff --git a/R/PLNLDA.R b/R/PLNLDA.R index a56036cd..80549d8c 100644 --- a/R/PLNLDA.R +++ b/R/PLNLDA.R @@ -64,7 +64,7 @@ PLNLDA <- function(formula, data, subset, weights, grouping, control = PLN_param myLDA$optimize(grouping, args$Y, args$X, args$O, args$w, control$config_optim) ## Post-treatment: prepare LDA visualization - myLDA$postTreatment(grouping, args$Y, args$X, args$O, control$config_post) + myLDA$postTreatment(grouping, args$Y, args$X, args$O, control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myLDA diff --git a/R/PLNLDAfit-class.R b/R/PLNLDAfit-class.R index f4a8762d..f52d1dee 100644 --- a/R/PLNLDAfit-class.R +++ b/R/PLNLDAfit-class.R @@ -84,13 +84,14 @@ PLNLDAfit <- R6Class( ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## Post treatment -------------------- #' @description Update R2, fisher and std_err fields and visualization - #' @param config list controlling the post-treatment - postTreatment = function(grouping, responses, covariates, offsets, config) { + #' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). + #' @param config_optim list controlling the optimization parameters + postTreatment = function(grouping, responses, covariates, offsets, config_post, config_optim) { covariates <- cbind(covariates, model.matrix( ~ grouping + 0)) - super$postTreatment(responses, covariates, offsets, config = config) + super$postTreatment(responses, covariates, offsets, config_post = config_post, config_optim = config_optim) rownames(private$C) <- colnames(private$C) <- colnames(responses) colnames(private$S) <- 1:self$q - if (config$trace > 1) cat("\n\tCompute LD scores for visualization...") + if (config_post$trace > 1) cat("\n\tCompute LD scores for visualization...") self$setVisualization() }, diff --git a/R/PLNPCA.R b/R/PLNPCA.R index c5c26a95..505061d4 100644 --- a/R/PLNPCA.R +++ b/R/PLNPCA.R @@ -51,8 +51,7 @@ PLNPCA <- function(formula, data, subset, weights, ranks = 1:5, control = PLNPCA ## Post-treatments: pseudo-R2, rearrange criteria and prepare PCA visualization if (control$trace > 0) cat("\n Post-treatments") - config_post <- config_post_default_PLNPCA; config_post$trace <- control$trace - myPCA$postTreatment(config_post) + myPCA$postTreatment(control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPCA @@ -65,6 +64,7 @@ PLNPCA <- function(formula, data, subset, weights, ranks = 1:5, control = PLNPCA #' @param backend optimization back used, either "nlopt" or "torch". Default is "nlopt" #' @param trace a integer for verbosity. #' @param config_optim a list for controlling the optimizer (either "nlopt" or "torch" backend). See details +#' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details #' @param inception Set up the parameters initialization: by default, the model is initialized with a multivariate linear model applied on #' log-transformed data, and with the same formula as the one provided by the user. However, the user can provide a PLNfit (typically obtained from a previous fit), #' which sometimes speeds up the inference. @@ -77,11 +77,17 @@ PLNPCA_param <- function( backend = "nlopt", trace = 1 , config_optim = list() , + config_post = list() , inception = NULL # pretrained PLNfit used as initialization ) { if (!is.null(inception)) stopifnot(isPLNfit(inception)) + ## post-treatment config + config_pst <- config_post_default_PLNPCA + config_pst[names(config_post)] <- config_post + config_pst$trace <- trace + ## optimization config backend <- match.arg(backend) stopifnot(backend %in% c("nlopt", "torch")) @@ -100,5 +106,6 @@ PLNPCA_param <- function( backend = backend , trace = trace , config_optim = config_opt, + config_post = config_pst, inception = inception ), class = "PLNmodels_param") } diff --git a/R/PLNPCAfit-class.R b/R/PLNPCAfit-class.R index 85f799ca..583ae05c 100644 --- a/R/PLNPCAfit-class.R +++ b/R/PLNPCAfit-class.R @@ -163,14 +163,16 @@ PLNPCAfit <- R6Class( }, #' @description Update R2, fisher, std_err fields and set up visualization - #' @details The list of parameters `config` controls the post-treatment processing, with the following entries: + #' @param config_optim a list for controlling the optimizer (either "nlopt" or "torch" backend). See details + #' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details + #' @details The list of parameters `config_post` controls the post-treatment processing, with the following entries: #' * jackknife boolean indicating whether jackknife should be performed to evaluate bias and variance of the model parameters. Default is FALSE. #' * bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated). #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. #' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE #' * trace integer for verbosity. should be > 1 to see output in post-treatments - postTreatment = function(responses, covariates, offsets, weights, config, nullModel) { - super$postTreatment(responses, covariates, offsets, weights, config, nullModel) + postTreatment = function(responses, covariates, offsets, weights, config_post, config_optim, nullModel) { + super$postTreatment(responses, covariates, offsets, weights, config_post, config_optim, nullModel) colnames(private$C) <- colnames(private$M) <- 1:self$q rownames(private$C) <- colnames(responses) self$setVisualization() diff --git a/R/PLNfamily-class.R b/R/PLNfamily-class.R index 4c686855..5bb1f8d8 100644 --- a/R/PLNfamily-class.R +++ b/R/PLNfamily-class.R @@ -64,16 +64,22 @@ PLNfamily <- ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## Post treatment -------------------- #' @description Update fields after optimization - #' @param config a list for controlling the post-treatment. - postTreatment = function(config) { - nullModel <- nullModelPoisson(self$responses, self$covariates, self$offsets, self$weights) + #' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). + #' @param config_optim a list for controlling the optimization parameters used during post_treatments + postTreatment = function(config_post, config_optim) { + if (config_post$rsquared) { + nullModel <- nullModelPoisson(self$responses, self$covariates, self$offsets, self$weights) + } else { + nullModel <- NULL + } for (model in self$models) model$postTreatment( self$responses, self$covariates, self$offsets, self$weights, - config, + config_post=config_post, + config_optim=config_optim, nullModel = nullModel ) }, diff --git a/R/PLNfit-class.R b/R/PLNfit-class.R index 9b9f9e9a..98c5fa54 100644 --- a/R/PLNfit-class.R +++ b/R/PLNfit-class.R @@ -66,8 +66,9 @@ PLNfit <- R6Class( torch_elbo = function(data, params, index=torch_tensor(1:self$n)) { S2 <- torch_square(params$S[index]) Z <- data$O[index] + params$M[index] + torch_mm(data$X[index], params$B) + A <- torch_exp(Z + .5 * S2) res <- .5 * sum(data$w[index]) * torch_logdet(private$torch_Sigma(data, params, index)) + - sum(data$w[index,NULL] * (torch_exp(Z + .5 * S2) - data$Y[index] * Z - .5 * torch_log(S2))) + sum(data$w[index,NULL] * (A - data$Y[index] * Z - .5 * torch_log(S2))) res }, @@ -84,21 +85,26 @@ PLNfit <- R6Class( torch_vloglik = function(data, params) { S2 <- torch_square(params$S) - Ji <- .5 * self$p - rowSums(.logfactorial(as.matrix(data$Y))) + as.numeric( - .5 * torch_logdet(params$Omega) + - torch_sum(data$Y * params$Z - params$A + .5 * torch_log(S2), dim = 2) - - .5 * torch_sum(torch_mm(params$M, params$Omega) * params$M + S2 * torch_diag(params$Omega), dim = 2) - ) - attr(Ji, "weights") <- as.numeric(data$w) + + Ji_tmp = .5 * torch_logdet(params$Omega) + + torch_sum(data$Y * params$Z - params$A + .5 * torch_log(S2), dim = 2) - + .5 * torch_sum(torch_mm(params$M, params$Omega) * params$M + S2 * torch_diag(params$Omega), dim = 2) + Ji <- - torch_sum(.logfactorial_torch(data$Y), dim = 2) + Ji_tmp + Ji <- .5 * self$p + as.numeric(Ji$cpu()) + + attr(Ji, "weights") <- as.numeric(data$w$cpu()) Ji }, #' @import torch torch_optimize = function(data, params, config) { + #config$device = "mps" + if (config$trace > 1) + message (paste("optimizing with device: ", config$device)) ## Conversion of data and parameters to torch tensors (pointers) - data <- lapply(data, torch_tensor) # list with Y, X, O, w - params <- lapply(params, torch_tensor, requires_grad = TRUE) # list with B, M, S + data <- lapply(data, torch_tensor, dtype = torch_float32(), device = config$device) # list with Y, X, O, w + params <- lapply(params, torch_tensor, dtype = torch_float32(), requires_grad = TRUE, device = config$device) # list with B, M, S ## Initialize optimizer optimizer <- switch(config$algorithm, @@ -115,11 +121,14 @@ PLNfit <- R6Class( batch_size <- floor(self$n/num_batch) objective <- double(length = config$num_epoch + 1) + #B_old = optimizer$param_groups[[1]]$params$B$clone() for (iterate in 1:num_epoch) { - B_old <- as.numeric(optimizer$param_groups[[1]]$params$B) - + #B_old <- as.numeric(optimizer$param_groups[[1]]$params$B) # rearrange the data each epoch - permute <- torch::torch_randperm(self$n) + 1L + #permute <- torch::torch_randperm(self$n, device = "cpu") + 1L + permute = torch::torch_tensor(sample.int(self$n), dtype = torch_long(), device=config$device) + + #print (paste("num batches", num_batch)) for (batch_idx in 1:num_batch) { # here index is a vector of the indices in the batch index <- permute[(batch_size*(batch_idx - 1) + 1):(batch_idx*batch_size)] @@ -133,9 +142,7 @@ PLNfit <- R6Class( ## assess convergence objective[iterate + 1] <- loss$item() - B_new <- as.numeric(optimizer$param_groups[[1]]$params$B) delta_f <- abs(objective[iterate] - objective[iterate + 1]) / abs(objective[iterate + 1]) - delta_x <- sum(abs(B_old - B_new))/sum(abs(B_new)) ## Error message if objective diverges if (!is.finite(loss$item())) { @@ -144,11 +151,12 @@ PLNfit <- R6Class( } ## display progress - if (config$trace > 1 && (iterate %% 50 == 0)) + if (config$trace > 1 && (iterate %% 50 == 1)) cat('\niteration: ', iterate, 'objective', objective[iterate + 1], - 'delta_f' , round(delta_f, 6), 'delta_x', round(delta_x, 6)) + 'delta_f' , round(delta_f, 6)) ## Check for convergence + #print (delta_f) if (delta_f < config$ftol_rel) status <- 3 if (delta_x < config$xtol_rel) status <- 4 if (status %in% c(3,4)) { @@ -162,7 +170,10 @@ PLNfit <- R6Class( params$Z <- data$O + params$M + torch_matmul(data$X, params$B) params$A <- torch_exp(params$Z + torch_pow(params$S, 2)/2) - out <- lapply(params, as.matrix) + out <- lapply(params, function(x) { + x = x$cpu() + as.matrix(x)} + ) out$Ji <- private$torch_vloglik(data, params) out$monitoring <- list( objective = objective, @@ -178,7 +189,7 @@ PLNfit <- R6Class( ## PRIVATE METHODS FOR VARIANCE OF THE ESTIMATORS ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% - variance_variational = function(X) { + variance_variational = function(X, config = config_default_nlopt) { ## Variance of B for n data points fisher <- Matrix::bdiag(lapply(1:self$p, function(j) { crossprod(X, private$A[, j] * X) # t(X) %*% diag(A[, i]) %*% X @@ -208,6 +219,44 @@ PLNfit <- R6Class( invisible(list(var_B = var_B, var_Omega = var_Omega)) }, + compute_vcov_from_resamples = function(resamples){ + B_list = resamples %>% map("B") + #print (B_list) + vcov_B = lapply(seq(1, ncol(private$B)), function(B_col){ + param_ests_for_col = B_list %>% map(~.x[, B_col]) + param_ests_for_col = do.call(rbind, param_ests_for_col) + #print (param_ests_for_col) + row_vcov = cov(param_ests_for_col) + }) + #print ("vcov blocks") + #print (vcov_B) + + #B_vcov <- resamples %>% map("B") %>% map(~( . )) %>% reduce(cov) + + #var_jack <- jacks %>% map("B") %>% map(~( (. - B_jack)^2)) %>% reduce(`+`) %>% + # `dimnames<-`(dimnames(private$B)) + #B_hat <- private$B[,] ## strips attributes while preserving names + + vcov_B = Matrix::bdiag(vcov_B) %>% as.matrix() + + rownames(vcov_B) <- colnames(vcov_B) <- + expand.grid(covariates = rownames(private$B), + responses = colnames(private$B)) %>% rev() %>% + ## Hack to make sure that species is first and varies slowest + apply(1, paste0, collapse = "_") + + #print (pheatmap::pheatmap(vcov_B, cluster_rows=FALSE, cluster_cols=FALSE)) + + + #names = lapply(bootstrapped_df$cov_mat, function(m){ colnames(m)}) %>% unlist() + #rownames(bootstrapped_vhat) = names + #colnames(bootstrapped_vhat) = names + + vcov_B = methods::as(vcov_B, "dgCMatrix") + + return(vcov_B) + }, + variance_jackknife = function(Y, X, O, w, config = config_default_nlopt) { jacks <- future.apply::future_lapply(seq_len(self$n), function(i) { data <- list(Y = Y[-i, , drop = FALSE], @@ -222,7 +271,7 @@ PLNfit <- R6Class( config = config) optim_out <- do.call(private$optimizer$main, args) optim_out[c("B", "Omega")] - }, future.seed = TRUE) + }) B_jack <- jacks %>% map("B") %>% reduce(`+`) / self$n var_jack <- jacks %>% map("B") %>% map(~( (. - B_jack)^2)) %>% reduce(`+`) %>% @@ -231,6 +280,9 @@ PLNfit <- R6Class( attr(private$B, "bias") <- (self$n - 1) * (B_jack - B_hat) attr(private$B, "variance_jackknife") <- (self$n - 1) / self$n * var_jack + vcov_jacks = private$compute_vcov_from_resamples(jacks) + attr(private$B, "vcov_jackknife") <- vcov_jacks + Omega_jack <- jacks %>% map("Omega") %>% reduce(`+`) / self$n var_jack <- jacks %>% map("Omega") %>% map(~( (. - Omega_jack)^2)) %>% reduce(`+`) %>% `dimnames<-`(dimnames(private$Omega)) @@ -246,19 +298,31 @@ PLNfit <- R6Class( X = X[resample, , drop = FALSE], O = O[resample, , drop = FALSE], w = w[resample]) + if (config$backend == "torch") # Convert data to torch tensors + data <- lapply(data, torch_tensor, device = config$device) # list with Y, X, O, w + + #print (data$Y$device) + args <- list(data = data, # params = list(B = private$B, M = matrix(0,self$n,self$p), S = private$S[resample, ]), params = do.call(compute_PLN_starting_point, data), config = config) + if (config$backend == "torch") # Convert data to torch tensors + args$params <- lapply(args$params, torch_tensor, requires_grad = TRUE, device = config$device) # list with B, M, S + optim_out <- do.call(private$optimizer$main, args) + #print (optim_out) optim_out[c("B", "Omega", "monitoring")] - }, future.seed = TRUE) + }) B_boots <- boots %>% map("B") %>% reduce(`+`) / n_resamples attr(private$B, "variance_bootstrap") <- boots %>% map("B") %>% map(~( (. - B_boots)^2)) %>% reduce(`+`) %>% `dimnames<-`(dimnames(private$B)) / n_resamples + vcov_boots = private$compute_vcov_from_resamples(boots) + attr(private$B, "vcov_bootstrap") <- vcov_boots + Omega_boots <- boots %>% map("Omega") %>% reduce(`+`) / n_resamples attr(private$Omega, "variance_bootstrap") <- boots %>% map("Omega") %>% map(~( (. - Omega_boots)^2)) %>% reduce(`+`) %>% @@ -391,14 +455,15 @@ PLNfit <- R6Class( }, #' @description Update R2, fisher and std_err fields after optimization - #' @param config a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details + #' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details + #' @param config_optim a list for controlling the optimization (optional bootstrap, jackknife, R2, etc.). See details #' @details The list of parameters `config` controls the post-treatment processing, with the following entries: #' * jackknife boolean indicating whether jackknife should be performed to evaluate bias and variance of the model parameters. Default is FALSE. #' * bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated). #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. #' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE #' * trace integer for verbosity. should be > 1 to see output in post-treatments - postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config, nullModel = NULL) { + postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config_post, config_optim, nullModel = NULL) { ## PARAMATERS DIMNAMES ## Set names according to those of the data matrices. If missing, use sensible defaults if (is.null(colnames(responses))) @@ -415,24 +480,27 @@ PLNfit <- R6Class( ## OPTIONAL POST-TREATMENT (potentially costly) ## 1. compute and store approximated R2 with Poisson-based deviance - if (config$rsquared) { - if(config$trace > 1) cat("\n\tComputing bootstrap estimator of the variance...") + if (config_post$rsquared) { + if(config_post$trace > 1) cat("\n\tComputing approximate R^2...") private$approx_r2(responses, covariates, offsets, weights, nullModel) } ## 2. compute and store matrix of standard variances for B and Omega with rough variational approximation - if (config$variational_var) { - if(config$trace > 1) cat("\n\tComputing variational estimator of the variance...") - private$variance_variational(covariates) + if (config_post$variational_var) { + if(config_post$trace > 1) cat("\n\tComputing variational estimator of the variance...") + private$variance_variational(covariates, config = config_optim) } ## 3. Jackknife estimation of bias and variance - if (config$jackknife) { - if(config$trace > 1) cat("\n\tComputing jackknife estimator of the variance...") - private$variance_jackknife(responses, covariates, offsets, weights) + if (config_post$jackknife) { + if(config_post$trace > 1) cat("\n\tComputing jackknife estimator of the variance...") + private$variance_jackknife(responses, covariates, offsets, weights, config = config_optim) } ## 4. Bootstrap estimation of variance - if (config$bootstrap > 0) { - if(config$trace > 1) cat("\n\tComputing bootstrap estimator of the variance...") - private$variance_bootstrap(responses, covariates, offsets, weights, config$bootstrap) + if (config_post$bootstrap > 0) { + if(config_post$trace > 1) { + cat("\n\tComputing bootstrap estimator of the variance...") + #print (str(config_optim)) + } + private$variance_bootstrap(responses, covariates, offsets, weights, n_resamples=config_post$bootstrap, config = config_optim) } }, @@ -820,18 +888,19 @@ PLNfit_fixedcov <- R6Class( }, #' @description Update R2, fisher and std_err fields after optimization - #' @param config a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details + #' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). See details + #' @param config_optim a list for controlling the optimization parameter. See details #' @details The list of parameters `config` controls the post-treatment processing, with the following entries: #' * trace integer for verbosity. should be > 1 to see output in post-treatments #' * jackknife boolean indicating whether jackknife should be performed to evaluate bias and variance of the model parameters. Default is FALSE. #' * bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated). #' * variational_var boolean indicating whether variational Fisher information matrix should be computed to estimate the variance of the model parameters (highly underestimated). Default is FALSE. #' * rsquared boolean indicating whether approximation of R2 based on deviance should be computed. Default is TRUE - postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config, nullModel = NULL) { - super$postTreatment(responses, covariates, offsets, weights, config, nullModel) + postTreatment = function(responses, covariates, offsets, weights = rep(1, nrow(responses)), config_post, config_optim, nullModel = NULL) { + super$postTreatment(responses, covariates, offsets, weights, config_post, config_optim, nullModel) ## 6. compute and store matrix of standard variances for B with sandwich correction approximation - if (config$sandwich_var) { - if(config$trace > 1) cat("\n\tComputing sandwich estimator of the variance...") + if (config_post$sandwich_var) { + if(config_post$trace > 1) cat("\n\tComputing sandwich estimator of the variance...") private$vcov_sandwich_B(responses, covariates) } } diff --git a/R/PLNmixture.R b/R/PLNmixture.R index 7231fe1d..7ecbe61d 100644 --- a/R/PLNmixture.R +++ b/R/PLNmixture.R @@ -59,8 +59,7 @@ PLNmixture <- function(formula, data, subset, clusters = 1:5, control = PLNmixt ## Post-treatments: Compute pseudo-R2, rearrange criteria and the visualization for PCA if (control$trace > 0) cat("\n Post-treatments") - config_post <- config_post_default_PLNmixture; config_post$trace <- control$trace - myPLN$postTreatment(config_post) + myPLN$postTreatment(control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPLN @@ -75,6 +74,7 @@ PLNmixture <- function(formula, data, subset, clusters = 1:5, control = PLNmixt #' @param smoothing The smoothing to apply. Either, 'none', forward', 'backward' or 'both'. Default is 'both'. #' @param init_cl The initial clustering to apply. Either, 'kmeans', CAH' or a user defined clustering given as a list of clusterings, the size of which is equal to the number of clusters considered. Default is 'kmeans'. #' @param config_optim a list for controlling the optimizer (either "nlopt" or "torch" backend). See details +#' @param config_post a list for controlling the post-treatments (optional bootstrap, jackknife, R2, etc.). #' @param trace a integer for verbosity. #' @param inception Set up the parameters initialization: by default, the model is initialized with a multivariate linear model applied on #' log-transformed data, and with the same formula as the one provided by the user. However, the user can provide a PLNfit (typically obtained from a previous fit), @@ -95,10 +95,16 @@ PLNmixture_param <- function( init_cl = "kmeans" , smoothing = "both" , config_optim = list() , + config_post = list() , inception = NULL # pretrained PLNfit used as initialization ) { if (!is.null(inception)) stopifnot(isPLNfit(inception)) + ## post-treatment config + config_pst <- config_post_default_PLNmixture + config_pst[names(config_post)] <- config_post + config_pst$trace <- trace + ## optimization config backend <- match.arg(backend) stopifnot(backend %in% c("nlopt", "torch")) @@ -123,5 +129,6 @@ PLNmixture_param <- function( init_cl = init_cl , smoothing = smoothing , config_optim = config_opt , + config_post = config_pst , inception = inception ), class = "PLNmodels_param") } diff --git a/R/PLNmixturefit-class.R b/R/PLNmixturefit-class.R index 23363380..6f0c65ff 100644 --- a/R/PLNmixturefit-class.R +++ b/R/PLNmixturefit-class.R @@ -280,8 +280,9 @@ PLNmixturefit <- ## %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% ## Post treatment -------------------- #' @description Update fields after optimization - #' @param config a list for controlling the post-treatment - postTreatment = function(responses, covariates, offsets, weights, config, nullModel) { + #' @param config_post a list for controlling the post-treatment + #' @param config_optim a list for controlling the optimization during the post-treatment computations + postTreatment = function(responses, covariates, offsets, weights, config_post, config_optim, nullModel) { ## restoring the full design matrix (group means + covariates) mu_k <- matrix(1, self$n, ncol = 1); colnames(mu_k) <- 'Intercept' @@ -292,7 +293,8 @@ PLNmixturefit <- mu_k, offsets, private$tau[,k_], - config, + config_post, + config_optim, nullModel = nullModel ) }, diff --git a/R/PLNnetwork.R b/R/PLNnetwork.R index 4d644a69..a6f70c0e 100644 --- a/R/PLNnetwork.R +++ b/R/PLNnetwork.R @@ -41,8 +41,7 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control ## Post-treatments if (control$trace > 0) cat("\n Post-treatments") - config_post <- config_post_default_PLNnetwork; config_post$trace <- control$trace - myPLN$postTreatment(config_post) + myPLN$postTreatment(control$config_post, control$config_optim) if (control$trace > 0) cat("\n DONE!\n") myPLN @@ -53,7 +52,9 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control #' Helper to define list of parameters to control the PLN fit. All arguments have defaults. #' #' @param backend optimization back used, either "nlopt" or "torch". Default is "nlopt" +#' @param inception_cov Covariance structure used for the inception model used to initialize the PLNfamily. Defaults to "full" and can be constrained to "diagonal" and "spherical". #' @param config_optim a list for controlling the optimizer (either "nlopt" or "torch" backend). See details +#' @param config_post a list for controlling the post-treatment (optional bootstrap, jackknife, R2, etc). #' @param trace a integer for verbosity. #' @param n_penalties an integer that specifies the number of values for the penalty grid when internally generated. Ignored when penalties is non `NULL` #' @param min_ratio the penalty grid ranges from the minimal value that produces a sparse to this value multiplied by `min_ratio`. Default is 0.1. @@ -72,18 +73,25 @@ PLNnetwork <- function(formula, data, subset, weights, penalties = NULL, control #' @seealso [PLN_param()] #' @export PLNnetwork_param <- function( - backend = "nlopt", + backend = c("nlopt", "torch"), + inception_cov = c("full", "spherical", "diagonal"), trace = 1 , n_penalties = 30 , min_ratio = 0.1 , penalize_diagonal = TRUE , penalty_weights = NULL , - config_optim = list(), + config_post = list(), + config_optim = list(), inception = NULL ) { if (!is.null(inception)) stopifnot(isPLNfit(inception)) + ## post-treatment config + config_pst <- config_post_default_PLNnetwork + config_pst[names(config_post)] <- config_post + config_pst$trace <- trace + ## optimization config backend <- match.arg(backend) stopifnot(backend %in% c("nlopt", "torch")) @@ -95,6 +103,7 @@ PLNnetwork_param <- function( stopifnot(config_optim$algorithm %in% available_algorithms_torch) config_opt <- config_default_torch } + inception_cov <- match.arg(inception_cov) config_opt$trace <- trace config_opt$ftol_out <- 1e-5 config_opt$maxit_out <- 20 @@ -103,6 +112,7 @@ PLNnetwork_param <- function( structure(list( backend = backend , trace = trace , + inception_cov = inception_cov , n_penalties = n_penalties , min_ratio = min_ratio , penalize_diagonal = penalize_diagonal, @@ -110,6 +120,7 @@ PLNnetwork_param <- function( jackknife = FALSE , bootstrap = 0 , variance = TRUE , + config_post = config_pst , config_optim = config_opt , inception = inception ), class = "PLNmodels_param") } diff --git a/R/PLNnetworkfamily-class.R b/R/PLNnetworkfamily-class.R index d62d2053..00f2cec8 100644 --- a/R/PLNnetworkfamily-class.R +++ b/R/PLNnetworkfamily-class.R @@ -45,7 +45,18 @@ PLNnetworkfamily <- R6Class( ## A basic model for inception, useless one is defined by the user ### TODO check if it is useful if (is.null(control$inception)) { - myPLN <- PLNfit$new(responses, covariates, offsets, weights, formula, control) + + # CHECK_ME_TORCH_GPU + # This appears to be in torch_gpu only. The commented out line below is + # in both PLNmodels/master and PLNmodels/dev. + myPLN <- switch( + control$inception_cov, + "spherical" = PLNfit_spherical$new(responses, covariates, offsets, weights, formula, control), + "diagonal" = PLNfit_diagonal$new(responses, covariates, offsets, weights, formula, control), + PLNfit$new(responses, covariates, offsets, weights, formula, control) # defaults to full + ) + ## Allow inception with spherical / diagonal / full PLNfit before switching back to PLNfit_fixedcov + ## for the inner-outer loop of PLNnetwork. myPLN$optimize(responses, covariates, offsets, weights, control$config_optim) control$inception <- myPLN } @@ -68,8 +79,13 @@ PLNnetworkfamily <- R6Class( ## Get an appropriate grid of penalties if (is.null(penalties)) { if (control$trace > 1) cat("\n Recovering an appropriate grid of penalties.") + # CHECK_ME_TORCH_GPU + # This appears to be in torch_gpu only. The commented out line below is + # in both PLNmodels/master and PLNmodels/dev. + # changed it to other one max_pen <- list_penalty_weights %>% - map(~ control$inception$model_par$Sigma / .x) %>% + map(~ as.matrix(myPLN$model_par$Sigma) / .x) %>% + # map(~ control$inception$model_par$Sigma / .x) %>% map_dbl(~ max(abs(.x[upper.tri(.x, diag = control$penalize_diagonal)]))) %>% max() penalties <- 10^seq(log10(max_pen), log10(max_pen*control$min_ratio), len = control$n_penalties) diff --git a/R/utils.R b/R/utils.R index 87e59712..7b5c293d 100644 --- a/R/utils.R +++ b/R/utils.R @@ -4,6 +4,7 @@ available_algorithms_torch <- c("RPROP", "RMSPROP", "ADAM", "ADAGRAD") config_default_nlopt <- list( algorithm = "CCSAQ", + backend = "nlopt", maxeval = 10000 , ftol_rel = 1e-8 , xtol_rel = 1e-6 , @@ -15,6 +16,7 @@ config_default_nlopt <- config_default_torch <- list( algorithm = "RPROP", + backend = "torch", maxeval = 10000 , num_epoch = 1000 , num_batch = 1 , @@ -26,7 +28,8 @@ config_default_torch <- step_sizes = c(1e-3, 50), etas = c(0.5, 1.2), centered = FALSE, - trace = 1 + trace = 1, + device = "cpu" ) config_post_default_PLN <- @@ -107,6 +110,11 @@ trace <- function(x) sum(diag(x)) x } +.logfactorial_torch <- function(n){ + n[n == 0] <- 1 ## 0! = 1! + n*torch_log(n) - n + torch_log(8*torch_pow(n,3) + 4*torch_pow(n,2) + n + 1/30)/6 + torch_log(pi)/2 +} + .logfactorial <- function(n) { # Ramanujan's formula n[n == 0] <- 1 ## 0! = 1! n*log(n) - n + log(8*n^3 + 4*n^2 + n + 1/30)/6 + log(pi)/2 diff --git a/man/PLNLDA_param.Rd b/man/PLNLDA_param.Rd index cc8b0d40..8ab56640 100644 --- a/man/PLNLDA_param.Rd +++ b/man/PLNLDA_param.Rd @@ -62,7 +62,7 @@ When "torch" backend is used (only for PLN and PLNLDA for now), the following en \item "centered" if TRUE, compute the centered RMSProp where the gradient is normalized by an estimation of its variance weight_decay (L2 penalty). Default to FALSE. Only used in RMSPROP } -The list of parameters \code{config_post} controls the post-treatment processing (for PLN and PLNLDA), with the following entries: +The list of parameters \code{config_post} controls the post-treatment processing (for most \verb{PLN*()} functions), with the following entries (defaults may vary depending on the specific function, check \verb{config_post_default_*} for defaults values): \itemize{ \item jackknife boolean indicating whether jackknife should be performed to evaluate bias and variance of the model parameters. Default is FALSE. \item bootstrap integer indicating the number of bootstrap resamples generated to evaluate the variance of the model parameters. Default is 0 (inactivated). diff --git a/man/PLNLDAfit.Rd b/man/PLNLDAfit.Rd index ae7e21af..eb64369c 100644 --- a/man/PLNLDAfit.Rd +++ b/man/PLNLDAfit.Rd @@ -142,7 +142,14 @@ latent space, update corresponding fields \subsection{Method \code{postTreatment()}}{ Update R2, fisher and std_err fields and visualization \subsection{Usage}{ -\if{html}{\out{