diff --git a/R/calculate.R b/R/calculate.R index cfda156b..82292124 100644 --- a/R/calculate.R +++ b/R/calculate.R @@ -158,10 +158,8 @@ calculate <- function( trace_batch_size = 100, compute_options = cpu_only() ) { - # set device to be CPU/GPU for the entire run with(tf$device(compute_options), { - # message users about random seeds and GPU usage if they are using GPU message_if_using_gpu(compute_options) @@ -200,7 +198,6 @@ calculate <- function( # checks and RNG seed setting if we're sampling # REFACTOR: check_rng_seed(nim, seed, compute_option) if (!is.null(nsim)) { - # check nsim is valid nsim <- check_positive_integer(nsim, "nsim") @@ -210,27 +207,25 @@ calculate <- function( x = ".Random.seed", envir = .GlobalEnv, inherits = FALSE - ) + ) if (no_global_random_seed) { runif(1) } - r_seed <- get(".Random.seed", envir = .GlobalEnv) on.exit(assign(".Random.seed", r_seed, envir = .GlobalEnv)) tensorflow::set_random_seed( seed = seed, disable_gpu = is_using_cpu(compute_options) - ) + ) } - if (is.null(seed)){ + if (is.null(seed)) { tensorflow::set_random_seed( seed = get_seed(), disable_gpu = is_using_cpu(compute_options) ) } - } # set precision @@ -259,7 +254,6 @@ calculate <- function( } if (!is.greta_mcmc_list(result)) { - # if it's not mcmc samples, make sure the results are in the right order # (tensorflow order seems to be platform specific?!?) order <- match(names(result), names(target)) @@ -278,12 +272,13 @@ calculate <- function( #' @importFrom coda thin #' @importFrom stats start end -calculate_greta_mcmc_list <- function(target, - values, - nsim, - tf_float, - trace_batch_size) { - +calculate_greta_mcmc_list <- function( + target, + values, + nsim, + tf_float, + trace_batch_size +) { # assign the free state stochastic <- !is.null(nsim) @@ -321,13 +316,11 @@ calculate_greta_mcmc_list <- function(target, # if they didn't specify nsim, check we can deterministically compute the # targets from the draws if (!stochastic) { - # see if the new dag introduces any new variables check_dag_introduces_new_variables(dag, mcmc_dag) # see if any of the targets are stochastic and not sampled in the mcmc check_targets_stochastic_and_not_sampled(target, mcmc_dag_variables) - } dag$target_nodes <- lapply(target, get_node) @@ -356,19 +349,20 @@ calculate_greta_mcmc_list <- function(target, # add the batch size to the data list # assign # pass these values in as the free state - trace <- dag$trace_values(draws, + trace <- dag$trace_values( + draws, trace_batch_size = trace_batch_size, flatten = FALSE ) # hopefully values is already a list of the correct dimensions... } else { - # for deterministic posterior prediction, just trace the target for each # chain - values <- lapply(draws, - #double check the trace value part - can pronbanly just do that here + values <- lapply( + draws, + #double check the trace value part - can pronbanly just do that here dag$trace_values, trace_batch_size = trace_batch_size ) @@ -394,7 +388,6 @@ calculate_list <- function(target, values, nsim, tf_float, env) { values_exist <- !identical(values, list()) if (values_exist) { - # check the list of values makes sense, and return these and the # corresponding greta arrays (looked up by name in environment env) values_list <- check_values_list(values, env) @@ -408,11 +401,8 @@ calculate_list <- function(target, values, nsim, tf_float, env) { stochastic <- !is.null(nsim) if (stochastic) { - check_if_unsampleable_and_unfixed(fixed_greta_arrays, dag) - } else { - # check there are no unspecified variables on which the target depends lapply(target, check_dependencies_satisfied, fixed_greta_arrays, dag, env) } @@ -420,7 +410,7 @@ calculate_list <- function(target, values, nsim, tf_float, env) { # TF1/2 check todo # need to wrap this in tf_function I think? if (Sys.getenv("GRETA_DEBUG") == "true") { - browser() + browser() } values <- calculate_target_tensor_list( dag = dag, @@ -438,12 +428,12 @@ calculate_list <- function(target, values, nsim, tf_float, env) { # simultaneously # assign("calculate_target_tensor_list", target_tensor_list, envir = tfe) -# # add values or data not specified by the user -# data_list <- dag$get_tf_data_list() -# missing <- !names(data_list) %in% names(values) -# -# # send list to tf environment and roll into a dict -# + # # add values or data not specified by the user + # data_list <- dag$get_tf_data_list() + # missing <- !names(data_list) %in% names(values) + # + # # send list to tf environment and roll into a dict + # } diff --git a/R/callbacks.R b/R/callbacks.R index 0e3978fe..b78a9539 100644 --- a/R/callbacks.R +++ b/R/callbacks.R @@ -26,7 +26,6 @@ render_progress <- function(reads) { reads[is.na(reads)] <- "" some_results <- any(nchar(reads) > 0) if (some_results) { - # optionally add blanks to put lines at the edges if (length(reads) > 1) { reads <- c("", reads, "") @@ -50,7 +49,8 @@ percentages <- function() { } progress_bars <- function() { - reads <- lapply(greta_stash$progress_bar_log_files, + reads <- lapply( + greta_stash$progress_bar_log_files, read_progress_log_file, skip = 1 ) @@ -58,16 +58,14 @@ progress_bars <- function() { } # determine the type of progress information -set_progress_bar_type <- function(n_chain){ - -if (bar_width(n_chain) < 42) { - progress_callback <- percentages -} else { - progress_callback <- progress_bars -} - -greta_stash$callbacks$parallel_progress <- progress_callback +set_progress_bar_type <- function(n_chain) { + if (bar_width(n_chain) < 42) { + progress_callback <- percentages + } else { + progress_callback <- progress_bars + } + greta_stash$callbacks$parallel_progress <- progress_callback } # register some diff --git a/R/checkers.R b/R/checkers.R index dd5346fe..53574f82 100644 --- a/R/checkers.R +++ b/R/checkers.R @@ -7,12 +7,9 @@ #' @importFrom cli cli_process_start #' @importFrom cli cli_process_done #' @importFrom cli cli_process_failed -check_tf_version <- function(alert = c("none", - "error", - "warn", - "message", - "startup")) { - +check_tf_version <- function( + alert = c("none", "error", "warn", "message", "startup") +) { # temporarily turn off the reticulate autoconfigure functionality ac_flag <- Sys.getenv("RETICULATE_AUTOCONFIGURE") on.exit( @@ -25,7 +22,6 @@ check_tf_version <- function(alert = c("none", alert <- match.arg(alert) if (!greta_stash$python_has_been_initialised) { - cli_process_start( msg = "Initialising python and checking dependencies, this may take a \\ moment." @@ -41,16 +37,14 @@ check_tf_version <- function(alert = c("none", py_not_init <- !greta_stash$python_has_been_initialised requirements_valid_py_not_init <- all(requirements_valid) && py_not_init if (requirements_valid_py_not_init) { - cli_process_done( - msg_done = "Initialising python and checking dependencies ... done!") + msg_done = "Initialising python and checking dependencies ... done!" + ) cat("\n") greta_stash$python_has_been_initialised <- TRUE - } if (!all(requirements_valid)) { - cli_process_failed() cli_msg <- c( @@ -80,15 +74,11 @@ check_tf_version <- function(alert = c("none", } invisible(all(requirements_valid)) - } # check dimensions of arguments to ops, and return the maximum dimension -check_dims <- function(..., - target_dim = NULL, - call = rlang::caller_env()) { - +check_dims <- function(..., target_dim = NULL, call = rlang::caller_env()) { # coerce args to greta arrays elem_list <- list(...) elem_list <- lapply(elem_list, as.greta_array) @@ -110,19 +100,16 @@ check_dims <- function(..., # if they're non-scalar, but have the same dimensions, that's fine too if (!all(match_first)) { - # otherwise it's not fine cli::cli_abort( message = "incompatible dimensions: {dims_text}", call = call ) - } } # if there's a target dimension, make sure they all match it if (!is.null(target_dim)) { - # make sure it's 2D is_1d <- length(target_dim) == 1 if (is_1d) { @@ -133,13 +120,11 @@ check_dims <- function(..., # if they are all scalars, that's fine too if (!all(scalars)) { - # check all arguments against this matches_target <- are_identical(dim_list[!scalars], target_dim) # error if not if (!all(matches_target)) { - cli::cli_abort( c( "incorrect array dimensions", @@ -148,14 +133,11 @@ check_dims <- function(..., "but input dimensions were {dims_text}." ) ) - } - } output_dim <- target_dim } else { - # otherwise, find the correct output dimension dim_lengths <- lengths(dim_list) dim_list <- lapply(dim_list, pad_vector, to_length = max(dim_lengths)) @@ -166,8 +148,7 @@ check_dims <- function(..., } # make sure a greta array is 2D -check_2d_multivariate <- function(x, - call = rlang::caller_env()) { +check_2d_multivariate <- function(x, call = rlang::caller_env()) { if (!is_2d(x)) { cli::cli_abort( message = c( @@ -181,10 +162,7 @@ check_2d_multivariate <- function(x, } } -check_square <- function(x = NULL, - dim = NULL, - call = rlang::caller_env()) { - +check_square <- function(x = NULL, dim = NULL, call = rlang::caller_env()) { # allows for specifying x or named dim = dim dim <- dim %||% dim(x) ndim <- length(dim) @@ -201,8 +179,10 @@ check_square <- function(x = NULL, } } -check_sigma_square_2d_greta_array <- function(sigma, - call = rlang::caller_env()){ +check_sigma_square_2d_greta_array <- function( + sigma, + call = rlang::caller_env() +) { # check dimensions of Sigma not_square <- nrow(sigma) != ncol(sigma) not_2d <- n_dim(sigma) != 2 @@ -219,9 +199,11 @@ check_sigma_square_2d_greta_array <- function(sigma, } } -check_mean_sigma_have_same_dimensions <- function(mean, - sigma, - call = rlang::caller_env()) { +check_mean_sigma_have_same_dimensions <- function( + mean, + sigma, + call = rlang::caller_env() +) { dim_mean <- ncol(mean) dim_sigma <- nrow(sigma) @@ -237,8 +219,8 @@ check_mean_sigma_have_same_dimensions <- function(mean, } check_chol2symm_square_symmetric_upper_tri_matrix <- function( - x, - call = rlang::caller_env() + x, + call = rlang::caller_env() ) { dim <- dim(x) is_square <- dim[1] == dim[2] @@ -255,8 +237,8 @@ check_chol2symm_square_symmetric_upper_tri_matrix <- function( } check_chol2symm_2d_square_upper_tri_greta_array <- function( - x, - call = rlang::caller_env() + x, + call = rlang::caller_env() ) { dim <- dim(x) is_square <- dim[1] == dim[2] @@ -276,11 +258,12 @@ check_chol2symm_2d_square_upper_tri_greta_array <- function( # matrices and column vectors, respectively, where number of rows implies the # number of realisations) and an optional target number of realisations, error # if there's a mismatch, and otherwise return the output number of realisations -check_n_realisations <- function(vectors = list(), - scalars = list(), - target = NULL, - call = rlang::caller_env()) { - +check_n_realisations <- function( + vectors = list(), + scalars = list(), + target = NULL, + call = rlang::caller_env() +) { # get the number of rows in the vector and scalar objects nrows <- lapply(c(vectors, scalars), nrow) @@ -293,7 +276,6 @@ check_n_realisations <- function(vectors = list(), # if they're non-scalar, but have the same dimensions, that's fine too if (!all(match_first)) { - # otherwise it's not fine cli::cli_abort( message = c( @@ -308,7 +290,6 @@ check_n_realisations <- function(vectors = list(), # if there's a target number of realisations, check it's valid and make sure # they all match it if (!is.null(target)) { - # make sure it's a scalar not_scalar <- length(target) != 1 || target < 1 if (not_scalar) { @@ -329,7 +310,6 @@ check_n_realisations <- function(vectors = list(), # if they are all scalars, that's fine too if (!all(single_rows)) { - # check all arguments against this matches_target <- are_identical(nrows[!single_rows], target) @@ -347,7 +327,6 @@ check_n_realisations <- function(vectors = list(), n_realisations <- target } else { - # otherwise, find the correct output dimension n_realisations <- max(unlist(nrows)) } @@ -358,18 +337,18 @@ check_n_realisations <- function(vectors = list(), # check the dimension of maultivariate parameters matches, and matches the # optional target dimension -check_dimension <- function(vectors = list(), - squares = list(), - target = NULL, - min_dimension = 2L, - call = rlang::caller_env()) { - +check_dimension <- function( + vectors = list(), + squares = list(), + target = NULL, + min_dimension = 2L, + call = rlang::caller_env() +) { # get the number of columns in the vector and scalar objects ncols <- lapply(c(vectors, squares), ncol) # if there's a target dimension, check then use that: if (!is.null(target)) { - # make sure it's a scalar positive_scalar <- length(target) != 1 || target < 1 || !is.finite(target) if (positive_scalar) { @@ -385,7 +364,6 @@ check_dimension <- function(vectors = list(), dimension <- as.integer(target) } else { - # otherwise, get it from the first parameter dimension <- ncols[[1]] } @@ -430,13 +408,14 @@ check_dimension <- function(vectors = list(), # the objects passed in can either be vector-like (like 'mean'), # scalar-like (like 'size'), or square (like 'Sigma'). -check_multivariate_dims <- function(vectors = list(), - scalars = list(), - squares = list(), - n_realisations = NULL, - dimension = NULL, - min_dimension = 2L) { - +check_multivariate_dims <- function( + vectors = list(), + scalars = list(), + squares = list(), + n_realisations = NULL, + dimension = NULL, + min_dimension = 2L +) { # coerce args to greta arrays vectors <- lapply(vectors, as.greta_array) scalars <- lapply(scalars, as.greta_array) @@ -467,8 +446,7 @@ check_multivariate_dims <- function(vectors = list(), # check truncation for different distributions -check_positive <- function(truncation, - call = rlang::caller_env()) { +check_positive <- function(truncation, call = rlang::caller_env()) { bound_is_negative <- truncation[1] < 0 if (bound_is_negative) { cli::cli_abort( @@ -481,8 +459,7 @@ check_positive <- function(truncation, } } -check_unit <- function(truncation, - call = rlang::caller_env()) { +check_unit <- function(truncation, call = rlang::caller_env()) { bounds_not_btn_0_1 <- truncation[1] < 0 | truncation[2] > 1 if (bounds_not_btn_0_1) { cli::cli_abort( @@ -498,9 +475,7 @@ check_unit <- function(truncation, # check whether the function calling this is being used as the 'family' argument # of another modelling function -check_in_family <- function(function_name, - arg, - call = rlang::caller_env()) { +check_in_family <- function(function_name, arg, call = rlang::caller_env()) { if (missing(arg)) { # if the first argument is missing, the user might be doing # `family = binomial()` or similar @@ -509,11 +484,17 @@ check_in_family <- function(function_name, # if the first argument is one of these text strings, the user might be # doing `family = binomial("logit")` or similar links <- c( - "logit", "probit", "cloglog", "cauchit", - "log", "identity", "sqrt" + "logit", + "probit", + "cloglog", + "cauchit", + "log", + "identity", + "sqrt" ) arg_is_link <- inherits(arg, "character") && - length(arg) == 1 && arg %in% links + length(arg) == 1 && + arg %in% links } # if it's being executed in an environment where it's named 'family', the user @@ -541,7 +522,6 @@ check_in_family <- function(function_name, #' @importFrom future plan future check_future_plan <- function(call = rlang::caller_env()) { - plan_info <- future::plan() plan_is <- list( @@ -553,10 +533,8 @@ check_future_plan <- function(call = rlang::caller_env()) { # if running in parallel if (plan_is$parallel) { - # if it's a cluster, check there's no forking if (plan_is$cluster) { - test_if_forked_cluster() f <- future::future(NULL, lazy = FALSE) @@ -571,7 +549,6 @@ check_future_plan <- function(call = rlang::caller_env()) { } } } else { - # if multi*, check it's multisession if (!plan_is$multisession) { cli::cli_abort( @@ -587,11 +564,12 @@ check_future_plan <- function(call = rlang::caller_env()) { } # check a list of greta arrays and return a list with names scraped from call -check_greta_arrays <- function(greta_array_list, - fun_name, - hint = NULL, - call = rlang::caller_env()) { - +check_greta_arrays <- function( + greta_array_list, + fun_name, + hint = NULL, + call = rlang::caller_env() +) { # check they are greta arrays are_greta_arrays <- are_greta_array(greta_array_list) @@ -632,10 +610,7 @@ check_greta_arrays <- function(greta_array_list, # check the provided list of greta array fixed values (as used in calculate and # simulate) is valid -check_values_list <- function(values, - env, - call = rlang::caller_env()) { - +check_values_list <- function(values, env, call = rlang::caller_env()) { # get the values and their names names <- names(values) stopifnot(length(names) == length(values)) @@ -668,11 +643,7 @@ check_values_list <- function(values, } # make sure the values have the correct dimensions - values <- mapply(assign_dim, - values, - fixed_greta_arrays, - SIMPLIFY = FALSE - ) + values <- mapply(assign_dim, values, fixed_greta_arrays, SIMPLIFY = FALSE) list( fixed_greta_arrays = fixed_greta_arrays, @@ -682,11 +653,13 @@ check_values_list <- function(values, # check that all the variable greta arrays on which the target greta array # depends are in the list fixed_greta_arrays (for use in calculate_list) -check_dependencies_satisfied <- function(target, - fixed_greta_arrays, - dag, - env, - call = rlang::caller_env()) { +check_dependencies_satisfied <- function( + target, + fixed_greta_arrays, + dag, + env, + call = rlang::caller_env() +) { dependency_names <- function(x) { get_node(x)$parent_names(recursive = TRUE) } @@ -713,14 +686,14 @@ check_dependencies_satisfied <- function(target, # if there are any undefined variables if (any(is_variable)) { - # try to find the associated greta arrays to provide a more informative # error message greta_arrays <- all_greta_arrays(env, include_data = FALSE) - greta_array_node_names <- vapply(greta_arrays, - function(x) get_node(x)$unique_name, - FUN.VALUE = "" + greta_array_node_names <- vapply( + greta_arrays, + function(x) get_node(x)$unique_name, + FUN.VALUE = "" ) unmet_variables <- unmet_nodes[is_variable] @@ -760,12 +733,10 @@ check_dependencies_satisfied <- function(target, message = msg, call = call ) - } } -check_cum_op <- function(x, - call = rlang::caller_env()) { +check_cum_op <- function(x, call = rlang::caller_env()) { dims <- dim(x) x_not_column_vector <- length(dims) > 2 | dims[2] != 1 if (x_not_column_vector) { @@ -781,10 +752,7 @@ check_cum_op <- function(x, #' @importFrom future availableCores -check_n_cores <- function(n_cores, - samplers, - plan_is) { - +check_n_cores <- function(n_cores, samplers, plan_is) { # if the plan is remote, and the user hasn't specificed the number of cores, # leave it as all of them if (is.null(n_cores) & !plan_is$local) { @@ -815,9 +783,7 @@ check_n_cores <- function(n_cores, as.integer(n_cores) } -check_positive_integer <- function(x, - name = "", - call = rlang::caller_env()) { +check_positive_integer <- function(x, name = "", call = rlang::caller_env()) { suppressWarnings(x <- as.integer(x)) not_positive_integer <- length(x) != 1 | is.na(x) | x < 1 @@ -835,8 +801,7 @@ check_positive_integer <- function(x, } # batch sizes must be positive numerics, rounded off to integers -check_trace_batch_size <- function(x, - call = rlang::caller_env()) { +check_trace_batch_size <- function(x, call = rlang::caller_env()) { valid <- is.numeric(x) && length(x) == 1 && x >= 1 if (!valid) { cli::cli_abort( @@ -848,8 +813,7 @@ check_trace_batch_size <- function(x, x } -check_if_greta_array_in_mcmc <- function(x, - call = rlang::caller_env()){ +check_if_greta_array_in_mcmc <- function(x, call = rlang::caller_env()) { if (!is.greta_model(x) && is.greta_array(x)) { cli::cli_abort( message = c( @@ -863,8 +827,7 @@ check_if_greta_array_in_mcmc <- function(x, } } -check_if_greta_model <- function(x, - call = rlang::caller_env()) { +check_if_greta_model <- function(x, call = rlang::caller_env()) { if (!is.greta_model(x)) { cli::cli_abort( message = c( @@ -877,8 +840,6 @@ check_if_greta_model <- function(x, } - - complex_error <- function(z) { cli::cli_abort( "{.pkg greta} does not yet support complex numbers" @@ -900,9 +861,11 @@ Conj.greta_array <- complex_error #' @export Mod.greta_array <- complex_error -check_if_unsampleable_and_unfixed <- function(fixed_greta_arrays, - dag, - call = rlang::caller_env()) { +check_if_unsampleable_and_unfixed <- function( + fixed_greta_arrays, + dag, + call = rlang::caller_env() +) { # check there are no variables without distributions (or whose children have # distributions - for lkj & wishart) that aren't given fixed values variables <- dag$node_list[dag$node_types == "variable"] @@ -935,7 +898,7 @@ check_if_unsampleable_and_unfixed <- function(fixed_greta_arrays, } } -check_if_array_is_empty_list <- function(target, call = rlang::caller_env()){ +check_if_array_is_empty_list <- function(target, call = rlang::caller_env()) { no_greta_arrays_provided <- identical(target, list()) if (no_greta_arrays_provided) { cli::cli_abort( @@ -948,9 +911,11 @@ check_if_array_is_empty_list <- function(target, call = rlang::caller_env()){ } } -check_if_lower_upper_numeric <- function(lower, - upper, - call = rlang::caller_env()) { +check_if_lower_upper_numeric <- function( + lower, + upper, + call = rlang::caller_env() +) { if (!is.numeric(lower) | !is.numeric(upper)) { cli::cli_abort( message = c( @@ -965,8 +930,10 @@ check_if_lower_upper_numeric <- function(lower, } } -check_if_lower_upper_has_bad_limits <- function(bad_limits, - call = rlang::caller_env()) { +check_if_lower_upper_has_bad_limits <- function( + bad_limits, + call = rlang::caller_env() +) { if (bad_limits) { cli::cli_abort( message = "lower and upper must either be -Inf (lower only), \\ @@ -976,9 +943,7 @@ check_if_lower_upper_has_bad_limits <- function(bad_limits, } } -check_if_upper_gt_lower <- function(lower, - upper, - call = rlang::caller_env()) { +check_if_upper_gt_lower <- function(lower, upper, call = rlang::caller_env()) { if (any(lower >= upper)) { cli::cli_abort( message = c( @@ -993,9 +958,9 @@ check_if_upper_gt_lower <- function(lower, check_targets_stochastic_and_not_sampled <- function( - target, - mcmc_dag_variables, - call = rlang::caller_env() + target, + mcmc_dag_variables, + call = rlang::caller_env() ) { target_nodes <- lapply(target, get_node) target_node_names <- extract_unique_names(target_nodes) @@ -1021,9 +986,11 @@ check_targets_stochastic_and_not_sampled <- function( } # see if the new dag introduces any new variables -check_dag_introduces_new_variables <- function(dag, - mcmc_dag, - call = rlang::caller_env()) { +check_dag_introduces_new_variables <- function( + dag, + mcmc_dag, + call = rlang::caller_env() +) { new_types <- dag$node_types[!connected_to_draws(dag, mcmc_dag)] any_new_variables <- any(new_types == "variable") if (any_new_variables) { @@ -1042,9 +1009,11 @@ check_dag_introduces_new_variables <- function(dag, } } -check_commanality_btn_dags <- function(dag, - mcmc_dag, - call = rlang::caller_env()) { +check_commanality_btn_dags <- function( + dag, + mcmc_dag, + call = rlang::caller_env() +) { target_not_connected_to_mcmc <- !any(connected_to_draws(dag, mcmc_dag)) if (target_not_connected_to_mcmc) { cli::cli_abort( @@ -1055,16 +1024,17 @@ check_commanality_btn_dags <- function(dag, } } -check_finite_positive_scalar_integer <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ +check_finite_positive_scalar_integer <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { is_not_finite_positive_scalar_integer <- !is.numeric(x) || length(x) != 1 || !is.finite(x) || x <= 0 if (is_not_finite_positive_scalar_integer) { - cli::cli_abort( message = c( "{.arg {arg}} must be a finite, positive, scalar integer", @@ -1075,12 +1045,13 @@ check_finite_positive_scalar_integer <- function(x, call = call ) } - } -check_if_greta_mcmc_list <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ +check_if_greta_mcmc_list <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { if (!is.greta_mcmc_list(x)) { cli::cli_abort( message = c( @@ -1093,9 +1064,11 @@ check_if_greta_mcmc_list <- function(x, } } -check_if_model_info <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ +check_if_model_info <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { valid <- !is.null(x) if (!valid) { @@ -1109,13 +1082,14 @@ check_if_model_info <- function(x, call = call ) } - } -check_2d <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ - if (!is_2d(x)){ +check_2d <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { + if (!is_2d(x)) { cli::cli_abort( message = c( "{.arg {arg} must be two dimensional}", @@ -1126,12 +1100,13 @@ check_2d <- function(x, } } -check_positive_scalar <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ - +check_positive_scalar <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { not_positive_scalar <- !is.numeric(x) || !length(x) == 1 || x <= 0 - if (not_positive_scalar){ + if (not_positive_scalar) { cli::cli_abort( message = c( "{.arg {arg}} must be a positive scalar value, or a scalar \\ @@ -1142,11 +1117,13 @@ check_positive_scalar <- function(x, } } -check_scalar <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ +check_scalar <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { scalar <- is_scalar(x) - if (!scalar){ + if (!scalar) { cli::cli_abort( message = c( "{.arg {arg}} must be a scalar", @@ -1158,13 +1135,15 @@ check_scalar <- function(x, } } -check_finite <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ +check_finite <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { not_finite <- !is.finite(x) - if (not_finite){ + if (not_finite) { cli::cli_abort( - message = c( + message = c( "{.arg {x}} must be a finite scalar", "But their values are:", "{.arg {x}}: {x}" @@ -1174,12 +1153,13 @@ check_finite <- function(x, } } -check_x_gte_y <- function(x, - y, - x_arg = rlang::caller_arg(x), - y_arg = rlang::caller_arg(y), - call = rlang::caller_env()){ - +check_x_gte_y <- function( + x, + y, + x_arg = rlang::caller_arg(x), + y_arg = rlang::caller_arg(y), + call = rlang::caller_env() +) { x_gte_y <- x >= y if (x_gte_y) { @@ -1193,13 +1173,14 @@ check_x_gte_y <- function(x, call = call ) } - } -check_numeric_length_1 <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ +check_numeric_length_1 <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { good_type <- is.numeric(x) && length(x) == 1 if (!good_type) { @@ -1214,14 +1195,15 @@ check_numeric_length_1 <- function(x, call = call ) } - } -check_both_2d <- function(x, - y, - x_arg = rlang::caller_arg(x), - y_arg = rlang::caller_arg(y), - call = rlang::caller_env()){ +check_both_2d <- function( + x, + y, + x_arg = rlang::caller_arg(x), + y_arg = rlang::caller_arg(y), + call = rlang::caller_env() +) { if (!is_2d(x) | !is_2d(y)) { cli::cli_abort( message = c( @@ -1235,10 +1217,7 @@ check_both_2d <- function(x, } } -check_compatible_dimensions <- function(x, - y, - call = rlang::caller_env()){ - +check_compatible_dimensions <- function(x, y, call = rlang::caller_env()) { incompatible_dimensions <- dim(x)[2] != dim(y)[1] if (incompatible_dimensions) { cli::cli_abort( @@ -1252,9 +1231,11 @@ check_compatible_dimensions <- function(x, } } -check_distribution_support <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ +check_distribution_support <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { n_supports <- length(unique(x)) if (n_supports != 1) { supports_text <- vapply( @@ -1273,12 +1254,13 @@ check_distribution_support <- function(x, call = call ) } - } -check_not_multivariate_univariate <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ +check_not_multivariate_univariate <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { is_multivariate_and_univariate <- !all(x) & !all(!x) if (is_multivariate_and_univariate) { cli::cli_abort( @@ -1291,11 +1273,12 @@ check_not_multivariate_univariate <- function(x, } } -check_not_discrete_continuous <- function(x, - name, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ - +check_not_discrete_continuous <- function( + x, + name, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { is_discrete_and_continuous <- !all(x) & !all(!x) if (is_discrete_and_continuous) { cli::cli_abort( @@ -1308,10 +1291,12 @@ check_not_discrete_continuous <- function(x, } } -check_num_distributions <- function(n_distributions, - at_least, - name, - call = rlang::caller_env()){ +check_num_distributions <- function( + n_distributions, + at_least, + name, + call = rlang::caller_env() +) { if (n_distributions < at_least) { cli::cli_abort( message = c( @@ -1322,15 +1307,15 @@ check_num_distributions <- function(n_distributions, call = call ) } - } -check_weights_dim <- function(weights_dim, - dim, - n_distributions, - arg = rlang::caller_arg(weights_dim), - call = rlang::caller_env()){ - +check_weights_dim <- function( + weights_dim, + dim, + n_distributions, + arg = rlang::caller_arg(weights_dim), + call = rlang::caller_env() +) { # weights should have n_distributions as the first dimension if (weights_dim[1] != n_distributions) { cli::cli_abort( @@ -1366,11 +1351,9 @@ check_weights_dim <- function(weights_dim, call = call ) } - } -check_initials_are_named <- function(values, - call = rlang::caller_env()){ +check_initials_are_named <- function(values, call = rlang::caller_env()) { names <- names(values) initials_not_all_named <- length(names) != length(values) if (initials_not_all_named) { @@ -1381,8 +1364,7 @@ check_initials_are_named <- function(values, } } -check_initials_are_numeric <- function(values, - call = rlang::caller_env()){ +check_initials_are_numeric <- function(values, call = rlang::caller_env()) { are_numeric <- vapply(values, is.numeric, FUN.VALUE = FALSE) if (!all(are_numeric)) { cli::cli_abort( @@ -1392,10 +1374,11 @@ check_initials_are_numeric <- function(values, } } -check_initial_values_match_chains <- function(initial_values, - n_chains, - call = rlang::caller_env()){ - +check_initial_values_match_chains <- function( + initial_values, + n_chains, + call = rlang::caller_env() +) { initials <- initial_values not_initials_but_list <- !is.initials(initials) && is.list(initials) if (not_initials_but_list) { @@ -1417,13 +1400,13 @@ check_initial_values_match_chains <- function(initial_values, ) } } - } -check_initial_values_correct_dim <- function(target_dims, - replacement_dims, - call = rlang::caller_env()){ - +check_initial_values_correct_dim <- function( + target_dims, + replacement_dims, + call = rlang::caller_env() +) { same_dims <- mapply(identical, target_dims, replacement_dims) if (!all(same_dims)) { @@ -1433,12 +1416,12 @@ check_initial_values_correct_dim <- function(target_dims, call = call ) } - } -check_initial_values_correct_class <- function(initial_values, - call = rlang::caller_env()){ - +check_initial_values_correct_class <- function( + initial_values, + call = rlang::caller_env() +) { initials <- initial_values not_initials_but_list <- !is.initials(initials) && is.list(initials) not_initials_not_list <- !is.initials(initials) && !is.list(initials) @@ -1456,11 +1439,9 @@ check_initial_values_correct_class <- function(initial_values, call = call ) } - } -check_nodes_all_variable <- function(nodes, - call = rlang::caller_env()){ +check_nodes_all_variable <- function(nodes, call = rlang::caller_env()) { types <- lapply(nodes, node_type) are_variables <- are_identical(types, "variable") @@ -1469,11 +1450,12 @@ check_nodes_all_variable <- function(nodes, "Initial values can only be set for variable {.cls greta_array}s" ) } - } -check_greta_arrays_associated_with_model <- function(tf_names, - call = rlang::caller_env()){ +check_greta_arrays_associated_with_model <- function( + tf_names, + call = rlang::caller_env() +) { missing_names <- is.na(tf_names) if (any(missing_names)) { bad <- names(tf_names)[missing_names] @@ -1487,9 +1469,7 @@ check_greta_arrays_associated_with_model <- function(tf_names, } } -check_not_data_greta_arrays <- function(model, - call = rlang::caller_env()){ - +check_not_data_greta_arrays <- function(model, call = rlang::caller_env()) { # find variable names to label samples target_greta_arrays <- model$target_greta_arrays names <- names(target_greta_arrays) @@ -1513,7 +1493,7 @@ check_not_data_greta_arrays <- function(model, } } -check_diagrammer_installed <- function(call = rlang::caller_env()){ +check_diagrammer_installed <- function(call = rlang::caller_env()) { if (!is_DiagrammeR_installed()) { cli::cli_abort( message = c( @@ -1527,9 +1507,10 @@ check_diagrammer_installed <- function(call = rlang::caller_env()){ } } -check_unfixed_discrete_distributions <- function(dag, - call = rlang::caller_env()){ - +check_unfixed_discrete_distributions <- function( + dag, + call = rlang::caller_env() +) { # check for unfixed discrete distributions distributions <- dag$node_list[dag$node_types == "distribution"] bad_nodes <- vapply( @@ -1549,11 +1530,8 @@ check_unfixed_discrete_distributions <- function(dag, } } -check_greta_array_type <- function(x, - optional, - call = rlang::caller_env()){ - - if (!is.numeric(x) && !is.logical(x) && !optional){ +check_greta_array_type <- function(x, optional, call = rlang::caller_env()) { + if (!is.numeric(x) && !is.logical(x) && !optional) { cli::cli_abort( message = c( "{.cls greta_array} must contain the same type", @@ -1566,10 +1544,12 @@ check_greta_array_type <- function(x, } } -check_greta_data_frame <- function(x, - optional, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ +check_greta_data_frame <- function( + x, + optional, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { classes <- vapply(x, class, "") valid <- classes %in% c("numeric", "integer", "logical") @@ -1588,9 +1568,7 @@ check_greta_data_frame <- function(x, } } -check_ncols_match <- function(x1, - x2, - call = rlang::caller_env()){ +check_ncols_match <- function(x1, x2, call = rlang::caller_env()) { if (ncol(x1) != ncol(x2)) { cli::cli_abort( message = c( @@ -1603,7 +1581,7 @@ check_ncols_match <- function(x1, } } -check_fields_installed <- function(){ +check_fields_installed <- function() { fields_installed <- requireNamespace("fields", quietly = TRUE) if (!fields_installed) { cli::cli_abort( @@ -1617,8 +1595,7 @@ check_fields_installed <- function(){ } } -check_2_by_1 <- function(x, - call = rlang::caller_env()){ +check_2_by_1 <- function(x, call = rlang::caller_env()) { dim_x <- dim(x) is_2_by_1 <- is_2d(x) && dim_x[2] == 1L if (!is_2_by_1) { @@ -1633,8 +1610,7 @@ check_2_by_1 <- function(x, } -check_transpose <- function(x, - call = rlang::caller_env()){ +check_transpose <- function(x, call = rlang::caller_env()) { if (x) { cli::cli_abort( message = "{.arg transpose} must be FALSE for {.cls greta_array}s", @@ -1643,12 +1619,13 @@ check_transpose <- function(x, } } -check_x_matches_ncol <- function(x, - ncol_of, - x_arg = rlang::caller_arg(x), - ncol_of_arg = rlang::caller_arg(ncol_of), - call = rlang::caller_env()){ - +check_x_matches_ncol <- function( + x, + ncol_of, + x_arg = rlang::caller_arg(x), + ncol_of_arg = rlang::caller_arg(ncol_of), + call = rlang::caller_env() +) { if (x != ncol(ncol_of)) { cli::cli_abort( message = "{.arg {x}} must equal {.code ncol({ncol_of_arg})} for \\ @@ -1658,10 +1635,12 @@ check_x_matches_ncol <- function(x, } } -check_stats_dim_matches_x_dim <- function(x, - margin, - stats, - call = rlang::caller_env()){ +check_stats_dim_matches_x_dim <- function( + x, + margin, + stats, + call = rlang::caller_env() +) { stats_dim_matches_x_dim <- dim(x)[margin] == dim(stats)[1] if (!stats_dim_matches_x_dim) { cli::cli_abort( @@ -1675,10 +1654,11 @@ check_stats_dim_matches_x_dim <- function(x, } # STATS must be a column array -check_is_column_array <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ - +check_is_column_array <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { is_column_array <- is_2d(x) && dim(x)[2] == 1 if (!is_column_array) { cli::cli_abort( @@ -1693,12 +1673,13 @@ check_is_column_array <- function(x, } } -check_rows_equal <- function(a, - b, - a_arg = rlang::caller_arg(a), - b_arg = rlang::caller_arg(b), - call = rlang::caller_env()){ - +check_rows_equal <- function( + a, + b, + a_arg = rlang::caller_arg(a), + b_arg = rlang::caller_arg(b), + call = rlang::caller_env() +) { # b must have the right number of rows rows_not_equal <- dim(b)[1] != dim(a)[1] if (rows_not_equal) { @@ -1714,9 +1695,7 @@ check_rows_equal <- function(a, } } -check_final_dim <- function(dim, - thing, - call = rlang::caller_env()){ +check_final_dim <- function(dim, thing, call = rlang::caller_env()) { # dimension of the free state version n_dim <- length(dim) last_dim <- dim[n_dim] @@ -1732,12 +1711,13 @@ check_final_dim <- function(dim, call = call ) } - } -check_param_greta_array <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ +check_param_greta_array <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { if (is.greta_array(x)) { cli::cli_abort( message = "{.arg {arg}} must be fixed, they cannot be another \\ @@ -1747,9 +1727,11 @@ check_param_greta_array <- function(x, } } -check_not_greta_array <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ +check_not_greta_array <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { if (is.greta_array(x)) { cli::cli_abort( "{.arg {arg}} cannot be a {.cls greta_array}" @@ -1758,17 +1740,16 @@ check_not_greta_array <- function(x, } # if it errored -check_for_errors <- function(res, - call = rlang::caller_env()){ - +check_for_errors <- function(res, call = rlang::caller_env()) { if (inherits(res, "error")) { - # check for known numerical errors - numerical_errors <- vapply(greta_stash$numerical_messages, - grepl, - res$message, - FUN.VALUE = 0 - ) == 1 + numerical_errors <- vapply( + greta_stash$numerical_messages, + grepl, + res$message, + FUN.VALUE = 0 + ) == + 1 # if it was just a numerical error, quietly return a bad value if (!any(numerical_errors)) { @@ -1781,12 +1762,9 @@ check_for_errors <- function(res, ) } } - } -check_dim_length <- function(dim, - call = rlang::caller_env()){ - +check_dim_length <- function(dim, call = rlang::caller_env()) { ndim <- length(dim) ndim_gt2 <- ndim > 2 if (ndim_gt2) { @@ -1801,20 +1779,19 @@ check_dim_length <- function(dim, } } -check_is_distribution_node <- function(distribution, - call = rlang::caller_env()){ +check_is_distribution_node <- function( + distribution, + call = rlang::caller_env() +) { if (!is.distribution_node(distribution)) { cli::cli_abort( message = c("Invalid distribution"), call = call ) } - } -check_values_dim <- function(value, - dim, - call = rlang::caller_env()){ +check_values_dim <- function(value, dim, call = rlang::caller_env()) { values_have_wrong_dim <- !is.null(value) && !all.equal(dim(value), dim) if (values_have_wrong_dim) { cli::cli_abort( @@ -1822,12 +1799,10 @@ check_values_dim <- function(value, call = call ) } - } # check they are all scalar -check_dot_nodes_scalar <- function(dot_nodes, - call = rlang::caller_env()){ +check_dot_nodes_scalar <- function(dot_nodes, call = rlang::caller_env()) { are_scalar <- vapply(dot_nodes, is_scalar, logical(1)) if (!all(are_scalar)) { cli::cli_abort( @@ -1836,13 +1811,13 @@ check_dot_nodes_scalar <- function(dot_nodes, call = call ) } - } -inform_if_one_set_of_initials <- function(initial_values, - n_chains, - call = rlang::caller_env()){ - +inform_if_one_set_of_initials <- function( + initial_values, + n_chains, + call = rlang::caller_env() +) { is_blank <- identical(initial_values, initials()) one_set_of_initials <- !is_blank & n_chains > 1 @@ -1856,8 +1831,7 @@ inform_if_one_set_of_initials <- function(initial_values, # the user might pass greta arrays with groups of nodes that are unconnected # to one another. Need to check there are densities in each graph -check_subgraphs <- function(dag, - call = rlang::caller_env()){ +check_subgraphs <- function(dag, call = rlang::caller_env()) { # get and check the types types <- dag$node_types @@ -1924,13 +1898,14 @@ check_subgraphs <- function(dag, ) } } - } -check_has_representation <- function(repr, - name, - error, - call = rlang::caller_env()){ +check_has_representation <- function( + repr, + name, + error, + call = rlang::caller_env() +) { not_represented <- error && is.null(repr) if (not_represented) { cli::cli_abort( @@ -1940,10 +1915,12 @@ check_has_representation <- function(repr, } } -check_has_anti_representation <- function(repr, - name, - error, - call = rlang::caller_env()){ +check_has_anti_representation <- function( + repr, + name, + error, + call = rlang::caller_env() +) { not_anti_represented <- error && is.null(repr) if (not_anti_represented) { cli::cli_abort( @@ -1953,9 +1930,11 @@ check_has_anti_representation <- function(repr, } } -check_is_greta_array <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ +check_is_greta_array <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { if (!is.greta_array(x)) { cli::cli_abort( message = c( @@ -1967,9 +1946,11 @@ check_is_greta_array <- function(x, } } -check_missing_infinite_values <- function(x, - optional, - call = rlang::caller_env()){ +check_missing_infinite_values <- function( + x, + optional, + call = rlang::caller_env() +) { contains_missing_or_inf <- !optional & !all(is.finite(x)) if (contains_missing_or_inf) { cli::cli_abort( @@ -1981,10 +1962,11 @@ check_missing_infinite_values <- function(x, } } -check_truncation_implemented <- function(tfp_distribution, - distribution_node, - call = rlang::caller_env()){ - +check_truncation_implemented <- function( + tfp_distribution, + distribution_node, + call = rlang::caller_env() +) { cdf <- tfp_distribution$cdf quantile <- tfp_distribution$quantile @@ -1998,12 +1980,13 @@ check_truncation_implemented <- function(tfp_distribution, call = call ) } - } -check_sampling_implemented <- function(sample, - distribution_node, - call = rlang::caller_env()){ +check_sampling_implemented <- function( + sample, + distribution_node, + call = rlang::caller_env() +) { if (is.null(sample)) { cli::cli_abort( "Sampling is not yet implemented for \\ @@ -2012,9 +1995,7 @@ check_sampling_implemented <- function(sample, } } -check_timeout <- function(it, - maxit, - call = rlang::caller_env()){ +check_timeout <- function(it, maxit, call = rlang::caller_env()) { # check we didn't time out if (it == maxit) { cli::cli_abort( @@ -2027,36 +2008,34 @@ check_timeout <- function(it, call = call ) } - } -inform_if_remote_machine <- function(plan_is, - samplers){ +inform_if_remote_machine <- function(plan_is, samplers) { is_remote_machine <- plan_is$parallel & !plan_is$local if (is_remote_machine) { - cli::cli_inform( message = c( "running {length(samplers)} \\ {?sampler on a remote machine/samplers on remote machines}" ) ) - } } -inform_if_local_parallel_multiple_samplers <- function(plan_is, - samplers, - n_cores, - compute_options){ - +inform_if_local_parallel_multiple_samplers <- function( + plan_is, + samplers, + n_cores, + compute_options +) { local_parallel_multiple_samplers <- plan_is$parallel & plan_is$local & length(samplers) > 1 if (local_parallel_multiple_samplers) { cores_text <- compute_text(n_cores, compute_options) cli::cli_inform( - message = c(" + message = c( + " running {length(samplers)} samplers in parallel, {cores_text} \n\n" ) @@ -2065,7 +2044,6 @@ inform_if_local_parallel_multiple_samplers <- function(plan_is, } - checks_module <- module( check_tf_version, check_dims, @@ -2089,4 +2067,3 @@ checks_module <- module( check_if_greta_mcmc_list, check_2d_multivariate ) - diff --git a/R/chol2symm.R b/R/chol2symm.R index 3285cdfa..3919035d 100644 --- a/R/chol2symm.R +++ b/R/chol2symm.R @@ -28,7 +28,6 @@ chol2symm <- function(x) { #' @export chol2symm.default <- function(x) { - check_chol2symm_square_symmetric_upper_tri_matrix(x) t(x) %*% x @@ -44,7 +43,9 @@ chol2symm.greta_array <- function(x) { check_chol2symm_2d_square_upper_tri_greta_array(x) # sum the elements - op("chol2symm", x, + op( + "chol2symm", + x, tf_operation = "tf_chol2symm", representations = list(cholesky = x) ) diff --git a/R/conda_greta_env.R b/R/conda_greta_env.R index adfbeb21..d2d92e54 100644 --- a/R/conda_greta_env.R +++ b/R/conda_greta_env.R @@ -10,7 +10,7 @@ using_greta_conda_env <- function() { grepl("greta-env-tf2", config$python) } -have_greta_conda_env <- function(){ +have_greta_conda_env <- function() { tryCatch( expr = "greta-env-tf2" %in% reticulate::conda_list()$name, error = function(e) FALSE diff --git a/R/dag_class.R b/R/dag_class.R index 69c84a44..12e89669 100644 --- a/R/dag_class.R +++ b/R/dag_class.R @@ -18,9 +18,11 @@ dag_class <- R6Class( trace_names = NULL, # create a dag from some target nodes - initialize = function(target_greta_arrays, - tf_float = "float32", - compile = FALSE) { + initialize = function( + target_greta_arrays, + tf_float = "float32", + compile = FALSE + ) { # build the dag self$build_dag(target_greta_arrays) @@ -39,13 +41,13 @@ dag_class <- R6Class( self$define_tf_log_prob_function() }, - define_tf_trace_values_batch = function(){ + define_tf_trace_values_batch = function() { self$tf_trace_values_batch <- tensorflow::tf_function( f = self$define_trace_values_batch ) }, - define_tf_log_prob_function = function(){ + define_tf_log_prob_function = function() { self$tf_log_prob_function <- tensorflow::tf_function( # TF1/2 check # need to check in on all cases of `tensorflow::tf_function()` @@ -56,11 +58,11 @@ dag_class <- R6Class( tf_log_prob_function = NULL, - tf_log_prob_function_adjusted = function(free_state){ + tf_log_prob_function_adjusted = function(free_state) { self$tf_log_prob_function(free_state)$adjusted }, - tf_log_prob_function_unadjusted = function(free_state){ + tf_log_prob_function_unadjusted = function(free_state) { self$tf_log_prob_function(free_state)$unadjusted }, @@ -88,7 +90,6 @@ dag_class <- R6Class( # get the TF names for different node types get_tf_names = function(types = NULL) { - # get tf basenames names <- self$node_tf_names if (!is.null(types)) { @@ -103,7 +104,6 @@ dag_class <- R6Class( # look up the TF name for a single node tf_name = function(node) { - # get tf basename from node name name <- self$node_tf_names[node$unique_name] if (length(name) == 0) { @@ -185,10 +185,11 @@ dag_class <- R6Class( # how to define the node if we're sampling everything (no free state) how_to_define_all_sampling = function(node) { - switch(node_type(node), - data = ifelse(has_distribution(node), "sampling", "forward"), - operation = ifelse(has_distribution(node), "sampling", "forward"), - "sampling" + switch( + node_type(node), + data = ifelse(has_distribution(node), "sampling", "forward"), + operation = ifelse(has_distribution(node), "sampling", "forward"), + "sampling" ) }, @@ -196,16 +197,17 @@ dag_class <- R6Class( # from an existing free state), or in sampling mode (generate a random # version of itself) how_to_define = function(node) { - switch(self$mode, + switch( + self$mode, - # if doing inference, everything is push-forward - all_forward = "forward", + # if doing inference, everything is push-forward + all_forward = "forward", - # sampling from prior most nodes are in sampling mode - all_sampling = self$how_to_define_all_sampling(node), + # sampling from prior most nodes are in sampling mode + all_sampling = self$how_to_define_all_sampling(node), - # sampling from posterior some nodes defined forward, others sampled - hybrid = self$how_to_define_hybrid(node) + # sampling from posterior some nodes defined forward, others sampled + hybrid = self$how_to_define_hybrid(node) ) }, define_batch_size = function() { @@ -227,11 +229,12 @@ dag_class <- R6Class( # put this in the greta stash, so it can be accessed by other (sub-)dags # if needed, e.g. when using as_tf_function() assign("batch_size", self$tf_environment$batch_size, envir = greta_stash) - }, - define_free_state = function(type = c("variable", "placeholder"), - name = "free_state") { + define_free_state = function( + type = c("variable", "placeholder"), + name = "free_state" + ) { type <- match.arg(type) tfe <- self$tf_environment @@ -239,7 +242,6 @@ dag_class <- R6Class( vals <- unlist_tf(vals) if (type == "variable") { - # TF1/2 check # tf$Variable seems to have trouble assigning values, if created with # numeric (rather than logical) NAs @@ -260,10 +262,7 @@ dag_class <- R6Class( # ) } - assign(name, - free_state, - envir = tfe - ) + assign(name, free_state, envir = tfe) }, # split the overall free state vector into free versions of variables @@ -292,14 +291,13 @@ dag_class <- R6Class( # define the body of the tensorflow graph in the environment env; without # defining the free_state, or the densities etc. define_tf_body = function(target_nodes = self$node_list) { - # if in forward or hybrid mode, split up the free state if (self$mode %in% c("all_forward", "hybrid")) { self$split_free_state() } # define all nodes in the environment and on the graph - lapply(target_nodes, function(x){ + lapply(target_nodes, function(x) { x$define_tf(self) }) @@ -324,7 +322,6 @@ dag_class <- R6Class( } self$define_tf_body(target_nodes = target_nodes) - }, # define tensor for overall log density and gradients @@ -339,10 +336,11 @@ dag_class <- R6Class( target_nodes <- target_nodes[has_target] # get the densities, evaluated at these targets - densities <- mapply(self$evaluate_density, - distribution_nodes, - target_nodes, - SIMPLIFY = FALSE + densities <- mapply( + self$evaluate_density, + distribution_nodes, + target_nodes, + SIMPLIFY = FALSE ) # reduce_sum each of them (skipping the batch dimension) @@ -353,10 +351,7 @@ dag_class <- R6Class( joint_density <- tf$add_n(summed_densities) # assign overall density to environment - assign("joint_density", - joint_density, - envir = self$tf_environment - ) + assign("joint_density", joint_density, envir = self$tf_environment) # define adjusted joint density @@ -373,9 +368,10 @@ dag_class <- R6Class( total_adj <- tf$add_n(adj) # assign overall density to environment - assign("joint_density_adj", - joint_density + total_adj, - envir = self$tf_environment + assign( + "joint_density_adj", + joint_density + total_adj, + envir = self$tf_environment ) }, @@ -401,11 +397,12 @@ dag_class <- R6Class( bounds = distribution_node$bounds ) }, - tf_evaluate_density = function(tfp_distribution, - tf_target, - truncation = NULL, - bounds = NULL) { - + tf_evaluate_density = function( + tfp_distribution, + tf_target, + truncation = NULL, + bounds = NULL + ) { # get the uncorrected log density ld <- tfp_distribution$log_prob(tf_target) @@ -420,24 +417,22 @@ dag_class <- R6Class( ## TODO add explaining variables if (all(lower == bounds[1])) { - # if only upper is constrained, just need the cdf at the upper offset <- tfp_distribution$log_cdf(fl(upper)) } else if (all(upper == bounds[2])) { - # if only lower is constrained, get the log of the integral above it offset <- tf$math$log(fl(1) - tfp_distribution$cdf(fl(lower))) } else { - # if both are constrained, get the log of the integral between them - offset <- tf$math$log(tfp_distribution$cdf(fl(upper)) - - tfp_distribution$cdf(fl(lower))) + offset <- tf$math$log( + tfp_distribution$cdf(fl(upper)) - + tfp_distribution$cdf(fl(lower)) + ) } ld <- ld - offset } - ld }, @@ -454,29 +449,31 @@ dag_class <- R6Class( "adjusted", "unadjusted" ) - ){ + ) { which_objective <- match.arg(which_objective) ga_names <- names(nodes) ## TF1/2 retracing ## This is a location where retracting happens in `opt` - hessian_list <- lapply(X = nodes, - self$calculate_one_hessian, - free_state = free_state, - which_objective = which_objective) + hessian_list <- lapply( + X = nodes, + self$calculate_one_hessian, + free_state = free_state, + which_objective = which_objective + ) # assign names and return names(hessian_list) <- ga_names hessian_list }, calculate_one_hessian = function( - node, - free_state, - which_objective = c( - "adjusted", - "unadjusted" - ) + node, + free_state, + which_objective = c( + "adjusted", + "unadjusted" + ) ) { which_objective <- match.arg(which_objective) @@ -507,9 +504,10 @@ dag_class <- R6Class( ) # return either of the densities, or a list of both - y <- switch(which_objective, - adjusted = objectives$adjusted, - unadjusted = objectives$unadjusted + y <- switch( + which_objective, + adjusted = objectives$adjusted, + unadjusted = objectives$unadjusted ) }) g <- tape_2$gradient(y, xs) @@ -520,18 +518,19 @@ dag_class <- R6Class( hessian <- array(h$numpy(), dim = hessian_dims(ga_dim)) hessian - }, ###<<< # return a function to obtain the model log probability from a tensor for # the free state - generate_log_prob_function = function(which = c( - "both", - "adjusted", - "unadjusted" - )) { + generate_log_prob_function = function( + which = c( + "both", + "adjusted", + "unadjusted" + ) + ) { which <- match.arg(which) # we can only pass the free_state parameter through @@ -557,10 +556,11 @@ dag_class <- R6Class( ) # return either of the densities, or a list of both - result <- switch(which, - adjusted = objectives$adjusted, - unadjusted = objectives$unadjusted, - both = objectives + result <- switch( + which, + adjusted = objectives$adjusted, + unadjusted = objectives$unadjusted, + both = objectives ) result @@ -570,7 +570,6 @@ dag_class <- R6Class( # return the expected parameter format either in free state vector form, or # list of transformed parameters example_parameters = function(free = TRUE) { - # find all variable nodes in the graph nodes <- self$node_list[self$node_types == "variable"] names(nodes) <- self$get_tf_names(types = "variable") @@ -583,9 +582,10 @@ dag_class <- R6Class( } # remove any of these that don't need a free state here (for calculate()) - stateless_names <- vapply(self$variables_without_free_state, - self$tf_name, - FUN.VALUE = character(1) + stateless_names <- vapply( + self$variables_without_free_state, + self$tf_name, + FUN.VALUE = character(1) ) keep <- !names(parameters) %in% stateless_names parameters <- parameters[keep] @@ -613,7 +613,7 @@ dag_class <- R6Class( }, tf_trace_values_batch = NULL, - trace_values_batch = function(free_state_batch){ + trace_values_batch = function(free_state_batch) { lapply( X = self$tf_trace_values_batch(free_state_batch), FUN = as.array @@ -621,7 +621,6 @@ dag_class <- R6Class( }, define_trace_values_batch = function(free_state_batch) { - # update the parameters & build the feed dict target_tf_names <- lapply( self$target_nodes, @@ -640,18 +639,17 @@ dag_class <- R6Class( # we now make all of the operations define themselves now self$define_tf() - target_tensors <- lapply(target_tf_names, - get, - envir = tfe - ) + target_tensors <- lapply(target_tf_names, get, envir = tfe) return(target_tensors) }, # return the current values of the traced nodes, as a named vector - trace_values = function(free_state, - flatten = TRUE, - trace_batch_size = Inf) { + trace_values = function( + free_state, + flatten = TRUE, + trace_batch_size = Inf + ) { # get the number of samples to trace n_samples <- nrow(free_state) indices <- seq_len(n_samples) @@ -690,13 +688,11 @@ dag_class <- R6Class( out <- trace_list } - out }, # for all the nodes in this dag, return a vector of membership to sub-graphs subgraph_membership = function() { - # convert adjacency matrix into absolute connectedness matrix using matrix # powers. Inspired by Method 2 here: # http://raphael.candelier.fr/?blog=Adj2cluster @@ -726,10 +722,12 @@ dag_class <- R6Class( # find the cluster IDs n <- nrow(r) neighbours <- lapply(seq_len(n), function(i) which(r[i, ])) - cluster_names <- vapply(X = neighbours, - FUN = paste, - FUN.VALUE = character(1), - collapse = "_") + cluster_names <- vapply( + X = neighbours, + FUN = paste, + FUN.VALUE = character(1), + collapse = "_" + ) cluster_id <- match(cluster_names, unique(cluster_names)) # name them @@ -739,7 +737,6 @@ dag_class <- R6Class( # get the tfp distribution object for a distribution node get_tfp_distribution = function(distrib_node) { - # build the tfp distribution object for the distribution, and use it # to get the tensor for the sample distrib_constructor <- self$get_tf_object(distrib_node) @@ -763,11 +760,9 @@ dag_class <- R6Class( truncation <- distribution_node$truncation if (is.null(truncation)) { - # if we're not dealing with truncation, sample directly tensor <- sample(seed = get_seed()) } else { - # if we're dealing with truncation (therefore univariate and continuous) # sample a random uniform (tensor), and pass through the truncated # quantile (inverse cdf) function @@ -808,7 +803,6 @@ dag_class <- R6Class( types }, adjacency_matrix = function(value) { - # make dag matrix n_node <- length(self$node_list) node_names <- names(self$node_list) @@ -835,7 +829,6 @@ dag_class <- R6Class( target_name <- self$node_list[[i]]$target$unique_name if (!is.null(target_name)) { - # switch the target from child to parent of the distribution parents[[i]] <- parents[[i]][parents[[i]] != target_name] children[[i]] <- c(children[[i]], target_name) diff --git a/R/distribution.R b/R/distribution.R index beafd402..990ef245 100644 --- a/R/distribution.R +++ b/R/distribution.R @@ -33,7 +33,8 @@ #' # get the distribution over y #' distribution(y) #' } -`distribution<-` <- function(greta_array, value) { # nolint +`distribution<-` <- function(greta_array, value) { + # nolint # stash the old greta array to return greta_array_tmp <- greta_array @@ -122,7 +123,6 @@ #' @rdname distribution #' @export distribution <- function(greta_array) { - # only for greta arrays check_is_greta_array(greta_array) @@ -130,7 +130,6 @@ distribution <- function(greta_array) { if (is.distribution_node(get_node(greta_array)$distribution)) { distrib <- greta_array } else { - # otherwise return NULL distrib <- NULL } diff --git a/R/extract_replace_combine.R b/R/extract_replace_combine.R index 9ce34d86..20cda1c1 100644 --- a/R/extract_replace_combine.R +++ b/R/extract_replace_combine.R @@ -79,7 +79,6 @@ NULL # extract syntax for greta_array objects #' @export `[.greta_array` <- function(x, ...) { - # store the full call to mimic on a dummy array, plus the array's dimensions call <- sys.call() dims_in <- dim(x) @@ -148,7 +147,8 @@ NULL index <- flatten_rowwise(dummy_out) # create operation node, passing call and dims as additional arguments - op("extract", + op( + "extract", x, dim = dims_out, operation_args = list( @@ -163,7 +163,8 @@ NULL # replace syntax for greta array objects #' @export -`[<-.greta_array` <- function(x, ..., value) { # nolint +`[<-.greta_array` <- function(x, ..., value) { + # nolint node <- get_node(x) @@ -230,7 +231,8 @@ NULL } # create operation node, passing call and dims as additional arguments - op("replace", + op( + "replace", x, replacement, dim = dims, @@ -272,10 +274,7 @@ cbind.greta_array <- function(...) { # output dimensions dims <- c(rows[1], sum(cols)) - op("cbind", ..., - dim = dims, - tf_operation = "tf_cbind" - ) + op("cbind", ..., dim = dims, tf_operation = "tf_cbind") } #' @export @@ -307,31 +306,42 @@ rbind.greta_array <- function(...) { # output dimensions dims <- c(sum(rows), cols[1]) - op("rbind", ..., - dim = dims, - tf_operation = "tf_rbind" - ) + op("rbind", ..., dim = dims, tf_operation = "tf_rbind") } # nolint start #' @rdname overloaded #' @export -abind <- function(..., - along = N, rev.along = NULL, new.names = NULL, - force.array = TRUE, make.names = use.anon.names, - use.anon.names = FALSE, use.first.dimnames = FALSE, - hier.names = FALSE, use.dnns = FALSE) { +abind <- function( + ..., + along = N, + rev.along = NULL, + new.names = NULL, + force.array = TRUE, + make.names = use.anon.names, + use.anon.names = FALSE, + use.first.dimnames = FALSE, + hier.names = FALSE, + use.dnns = FALSE +) { UseMethod("abind") } # nolint end # nolint start #' @export -abind.default <- function(..., - along = N, rev.along = NULL, new.names = NULL, - force.array = TRUE, make.names = use.anon.names, - use.anon.names = FALSE, use.first.dimnames = FALSE, - hier.names = FALSE, use.dnns = FALSE) { +abind.default <- function( + ..., + along = N, + rev.along = NULL, + new.names = NULL, + force.array = TRUE, + make.names = use.anon.names, + use.anon.names = FALSE, + use.first.dimnames = FALSE, + hier.names = FALSE, + use.dnns = FALSE +) { # nolint end # error nicely if they don't have abind installed @@ -354,12 +364,18 @@ abind.default <- function(..., # nolint start #' @export -abind.greta_array <- function(..., - along = N, rev.along = NULL, new.names = NULL, - force.array = TRUE, make.names = use.anon.names, - use.anon.names = FALSE, - use.first.dimnames = FALSE, hier.names = FALSE, - use.dnns = FALSE) { +abind.greta_array <- function( + ..., + along = N, + rev.along = NULL, + new.names = NULL, + force.array = TRUE, + make.names = use.anon.names, + use.anon.names = FALSE, + use.first.dimnames = FALSE, + hier.names = FALSE, + use.dnns = FALSE +) { # nolint end # warn if any of the arguments have been changed @@ -399,8 +415,12 @@ abind.greta_array <- function(..., # rationalise along, and pad N if we're prepending/appending a dimension ## TODO add explaining variable here - if (along < 1 || along > n || (along > floor(along) && - along < ceiling(along))) { + if ( + along < 1 || + along > n || + (along > floor(along) && + along < ceiling(along)) + ) { n <- n + 1 along <- max(1, min(n + 1, ceiling(along))) } @@ -504,7 +524,6 @@ c.greta_array <- function(...) { #' @export rep.greta_array <- function(x, ...) { - # get the index idx <- rep(seq_along(x), ...) @@ -525,7 +544,8 @@ length.greta_array <- function(x) { # reshape greta arrays #' @export -`dim<-.greta_array` <- function(x, value) { # nolint +`dim<-.greta_array` <- function(x, value) { + # nolint dims <- value @@ -567,7 +587,7 @@ length.greta_array <- function(x) { n_elem_not_match <- prod_dims != len if (!is_scalar && n_elem_not_match) { cli::cli_abort( - "dims [product {prod_dims}] do not match the length of object [{len}]" + "dims [product {prod_dims}] do not match the length of object [{len}]" ) } @@ -578,10 +598,10 @@ length.greta_array <- function(x) { unmatch_dim <- !identical(dim(x), dims) if (unmatch_dim && is_scalar) { - # if the dims don't match, but x is a scalar, expand it to the required # dimension - op("expand_dim", + op( + "expand_dim", x, operation_args = list(dims = dims), tf_operation = "tf_expand_dim", @@ -591,7 +611,8 @@ length.greta_array <- function(x) { } else { # otherwise, if the dimensions don't match, but the number of elements do, # just change the dimensions - op("set_dim", + op( + "set_dim", x, operation_args = list(dims = dims), tf_operation = "tf_set_dim", @@ -600,7 +621,6 @@ length.greta_array <- function(x) { ) } - # otherwise just reorder them } @@ -609,27 +629,27 @@ length.greta_array <- function(x) { #' @export #' @importFrom utils head #' @importFrom utils head.matrix -head.greta_array <- function(x, n = 6L, ...) { # nolint +head.greta_array <- function(x, n = 6L, ...) { + # nolint stopifnot(length(n) == 1L) ans <- head.matrix(x, n, ...) ans - } #' @export #' @importFrom utils tail #' @importFrom utils tail.matrix -tail.greta_array <- function(x, n = 6L, ...) { # nolint +tail.greta_array <- function(x, n = 6L, ...) { + # nolint stopifnot(length(n) == 1L) ans <- tail.matrix(x, n, ...) ans - } #' @rdname overloaded @@ -667,8 +687,5 @@ diag.greta_array <- function(x = 1, nrow, ncol) { dims <- c(dim[1], 1) # return the extraction op - op("diag", x, - dim = dims, - tf_operation = "tf$linalg$diag_part" - ) + op("diag", x, dim = dims, tf_operation = "tf$linalg$diag_part") } diff --git a/R/functions.R b/R/functions.R index fcbaa584..02c86190 100644 --- a/R/functions.R +++ b/R/functions.R @@ -122,7 +122,6 @@ NULL #' @export log.greta_array <- function(x, base = lifecycle::deprecated()) { - if (lifecycle::is_present(base)) { lifecycle::deprecate_warn( when = "0.5.1", @@ -136,7 +135,9 @@ log.greta_array <- function(x, base = lifecycle::deprecated()) { if (has_representation(x, "log")) { result <- copy_representation(x, "log") } else { - result <- op("log", x, + result <- op( + "log", + x, tf_operation = "tf$math$log", representations = list(exp = x) ) @@ -151,7 +152,9 @@ exp.greta_array <- function(x) { result <- copy_representation(x, "exp") } else { # otherwise exponentiate it, and store the log representation - result <- op("exp", x, + result <- op( + "exp", + x, tf_operation = "tf$math$exp", representations = list(log = x) ) @@ -321,11 +324,7 @@ t.greta_array <- function(x) { dims <- rev(dim(x)) - op("transpose", - x, - dim = dims, - tf_operation = "tf_transpose" - ) + op("transpose", x, dim = dims, tf_operation = "tf_transpose") } #' @export @@ -345,7 +344,9 @@ aperm.greta_array <- function(a, perm = NULL, ...) { ) } - op("aperm", a, + op( + "aperm", + a, dim = dim(a)[perm], tf_operation = "tf$transpose", operation_args = list(perm = c(0L, perm)) @@ -387,18 +388,11 @@ chol.greta_array <- function(x, ..., force_cholesky = FALSE) { ) } - result <- op("chol", x, - dim = dim, - tf_operation = "tf_chol" - ) - + result <- op("chol", x, dim = dim, tf_operation = "tf_chol") } - if (force_cholesky){ - result <- op("chol", x, - dim = dim(x), - tf_operation = "tf_chol" - ) + if (force_cholesky) { + result <- op("chol", x, dim = dim(x), tf_operation = "tf_chol") } result @@ -406,7 +400,6 @@ chol.greta_array <- function(x, ..., force_cholesky = FALSE) { #' @export solve.greta_array <- function(a, b, ...) { - check_2d(a) # check the matrix is square @@ -418,21 +411,15 @@ solve.greta_array <- function(a, b, ...) { u <- representation(a, "cholesky") result <- chol2inv(u) } else { - result <- op("solve", a, - tf_operation = "tf$linalg$inv" - ) + result <- op("solve", a, tf_operation = "tf$linalg$inv") } } else { - check_2d(b) # b must have the right number of rows check_rows_equal(a, b) # ... and solve the linear equations - result <- op("solve", a, b, - dim = dim(b), - tf_operation = "tf$linalg$solve" - ) + result <- op("solve", a, b, dim = dim(b), tf_operation = "tf$linalg$solve") } result @@ -466,92 +453,81 @@ chol2inv.greta_array <- function(x, size = NCOL(x), LINPACK = FALSE) { ) } - op("chol2inv", x, - tf_operation = "tf_chol2inv" - ) + op("chol2inv", x, tf_operation = "tf_chol2inv") } # nolint end #' @rdname overloaded #' @export -cov2cor <- function(V) { # nolint +cov2cor <- function(V) { + # nolint UseMethod("cov2cor", V) } #' @export -cov2cor.default <- function(V) { # nolint +cov2cor.default <- function(V) { + # nolint stats::cov2cor(V) } #' @export -cov2cor.greta_array <- function(V) { # nolint - op("cov2cor", V, - tf_operation = "tf_cov2cor" - ) +cov2cor.greta_array <- function(V) { + # nolint + op("cov2cor", V, tf_operation = "tf_cov2cor") } # sum, prod, min, mean, max #' @export -sum.greta_array <- function(..., na.rm = TRUE) { # nolint +sum.greta_array <- function(..., na.rm = TRUE) { + # nolint # combine all elements into a column vector vec <- c(...) # sum the elements - op("sum", vec, - dim = c(1, 1), - tf_operation = "tf_sum" - ) + op("sum", vec, dim = c(1, 1), tf_operation = "tf_sum") } #' @export -prod.greta_array <- function(..., na.rm = TRUE) { # nolint +prod.greta_array <- function(..., na.rm = TRUE) { + # nolint # combine all elements into a column vector vec <- c(...) # sum the elements - op("prod", vec, - dim = c(1, 1), - tf_operation = "tf_prod" - ) + op("prod", vec, dim = c(1, 1), tf_operation = "tf_prod") } #' @export -min.greta_array <- function(..., na.rm = TRUE) { # nolint +min.greta_array <- function(..., na.rm = TRUE) { + # nolint # combine all elements into a column vector vec <- c(...) # sum the elements - op("min", vec, - dim = c(1, 1), - tf_operation = "tf_min" - ) + op("min", vec, dim = c(1, 1), tf_operation = "tf_min") } #' @export -mean.greta_array <- function(x, trim = 0, na.rm = TRUE, ...) { # nolint +mean.greta_array <- function(x, trim = 0, na.rm = TRUE, ...) { + # nolint # sum the elements - op("mean", x, - dim = c(1, 1), - tf_operation = "tf_mean" - ) + op("mean", x, dim = c(1, 1), tf_operation = "tf_mean") } #' @export -max.greta_array <- function(..., na.rm = TRUE) { # nolint +max.greta_array <- function(..., na.rm = TRUE) { + # nolint # combine all elements into a column vector vec <- c(...) # sum the elements - op("max", vec, - dim = c(1, 1), - tf_operation = "tf_max" - ) + op("max", vec, dim = c(1, 1), tf_operation = "tf_max") } #' @export @@ -591,10 +567,7 @@ rowcol_idx <- function(x, dims, which = c("col", "row")) { ) } - switch(which, - row = (dims + 1):n_dim(x), - col = seq_len(dims) - ) + switch(which, row = (dims + 1):n_dim(x), col = seq_len(dims)) } # get output dimension for colSums, rowSums, colMeans, rowMeans @@ -639,7 +612,9 @@ colMeans.default <- function(x, na.rm = FALSE, dims = 1L) { #' @export colMeans.greta_array <- function(x, na.rm = FALSE, dims = 1L) { - op("colMeans", x, + op( + "colMeans", + x, operation_args = list(dims = dims), tf_operation = "tf_colmeans", dim = rowcol_dim(x, dims, "col") @@ -659,7 +634,9 @@ rowMeans.default <- function(x, na.rm = FALSE, dims = 1L) { #' @export rowMeans.greta_array <- function(x, na.rm = FALSE, dims = 1L) { - op("rowMeans", x, + op( + "rowMeans", + x, operation_args = list(dims = dims), tf_operation = "tf_rowmeans", dim = rowcol_dim(x, dims, "row") @@ -679,7 +656,9 @@ colSums.default <- function(x, na.rm = FALSE, dims = 1L) { #' @export colSums.greta_array <- function(x, na.rm = FALSE, dims = 1L) { - op("colSums", x, + op( + "colSums", + x, operation_args = list(dims = dims), tf_operation = "tf_colsums", dim = rowcol_dim(x, dims, "col") @@ -699,7 +678,9 @@ rowSums.default <- function(x, na.rm = FALSE, dims = 1L) { #' @export rowSums.greta_array <- function(x, na.rm = FALSE, dims = 1L) { - op("rowSums", x, + op( + "rowSums", + x, operation_args = list(dims = dims), tf_operation = "tf_rowsums", dim = rowcol_dim(x, dims, "row") @@ -725,11 +706,14 @@ sweep.default <- base::sweep # nolint start #' @export -sweep.greta_array <- function(x, - MARGIN, - STATS, - FUN = c("-", "+", "/", "*"), - check.margin = TRUE, ...) { +sweep.greta_array <- function( + x, + MARGIN, + STATS, + FUN = c("-", "+", "/", "*"), + check.margin = TRUE, + ... +) { # nolint end # only allow these four functions @@ -748,7 +732,10 @@ sweep.greta_array <- function(x, # STATS must have the same dimension as the correct dim of x check_stats_dim_matches_x_dim(x, margin, stats) - op("sweep", x, stats, + op( + "sweep", + x, + stats, operation_args = list(margin = margin, fun = fun), tf_operation = "tf_sweep", dim = dim(x) @@ -760,23 +747,27 @@ sweep.greta_array <- function(x, #' @importFrom tensorflow %as% setClass("greta_array") setMethod( - "kronecker", signature(X = "greta_array", Y = "greta_array"), - function(X, Y, FUN = c("*", "/", "+", "-"), make.dimnames = FALSE, - ...) { + "kronecker", + signature(X = "greta_array", Y = "greta_array"), + function(X, Y, FUN = c("*", "/", "+", "-"), make.dimnames = FALSE, ...) { # nolint end fun <- match.arg(FUN) check_2d(X) check_2d(Y) - tf_fun_name <- switch(fun, + tf_fun_name <- switch( + fun, `*` = "multiply", `/` = "truediv", `+` = "add", `-` = "subtract" ) - op("kronecker", X, Y, + op( + "kronecker", + X, + Y, tf_operation = "tf_kronecker", operation_args = list(tf_fun_name = tf_fun_name), dim = dim(X) * dim(Y) @@ -787,9 +778,9 @@ setMethod( # nolint start #' @import methods setMethod( - kronecker, signature(X = "array", Y = "greta_array"), - function(X, Y, FUN = c("*", "/", "+", "-"), make.dimnames = FALSE, - ...) { + kronecker, + signature(X = "array", Y = "greta_array"), + function(X, Y, FUN = c("*", "/", "+", "-"), make.dimnames = FALSE, ...) { # nolint end kronecker(as.greta_array(X), Y, FUN, make.dimnames = FALSE) } @@ -798,9 +789,9 @@ setMethod( # nolint start #' @import methods setMethod( - kronecker, signature(X = "greta_array", Y = "array"), - function(X, Y, FUN = c("*", "/", "+", "-"), make.dimnames = FALSE, - ...) { + kronecker, + signature(X = "greta_array", Y = "array"), + function(X, Y, FUN = c("*", "/", "+", "-"), make.dimnames = FALSE, ...) { # nolint end kronecker(X, as.greta_array(Y), FUN, make.dimnames = FALSE) } @@ -809,38 +800,42 @@ setMethod( # nolint start #' @rdname overloaded #' @export -backsolve <- function(r, x, k = ncol(r), - upper.tri = TRUE, - transpose = FALSE) { +backsolve <- function(r, x, k = ncol(r), upper.tri = TRUE, transpose = FALSE) { # nolint end UseMethod("backsolve", x) } # nolint start #' @export -backsolve.default <- function(r, x, k = ncol(r), - upper.tri = TRUE, - transpose = FALSE) { +backsolve.default <- function( + r, + x, + k = ncol(r), + upper.tri = TRUE, + transpose = FALSE +) { # nolint end - base::backsolve(r, x, - k = ncol(r), - upper.tri = TRUE, - transpose = FALSE - ) + base::backsolve(r, x, k = ncol(r), upper.tri = TRUE, transpose = FALSE) } # define this explicitly so CRAN doesn't think we're using .Internal # nolint start #' @export -backsolve.greta_array <- function(r, x, - k = ncol(r), - upper.tri = TRUE, - transpose = FALSE) { +backsolve.greta_array <- function( + r, + x, + k = ncol(r), + upper.tri = TRUE, + transpose = FALSE +) { # nolint end check_x_matches_ncol(x = k, ncol_of = r) check_transpose(transpose) - op("backsolve", r, x, + op( + "backsolve", + r, + x, operation_args = list(lower = !upper.tri), tf_operation = "tf$linalg$triangular_solve", dim = dim(x) @@ -850,9 +845,13 @@ backsolve.greta_array <- function(r, x, # nolint start #' @rdname overloaded #' @export -forwardsolve <- function(l, x, k = ncol(l), - upper.tri = FALSE, - transpose = FALSE) { +forwardsolve <- function( + l, + x, + k = ncol(l), + upper.tri = FALSE, + transpose = FALSE +) { # nolint end UseMethod("forwardsolve", x) } @@ -860,28 +859,34 @@ forwardsolve <- function(l, x, k = ncol(l), # define this explicitly so CRAN doesn't think we're using .Internal # nolint start #' @export -forwardsolve.default <- function(l, x, k = ncol(l), - upper.tri = FALSE, - transpose = FALSE) { +forwardsolve.default <- function( + l, + x, + k = ncol(l), + upper.tri = FALSE, + transpose = FALSE +) { # nolint end - base::forwardsolve(l, x, - k = ncol(l), - upper.tri = FALSE, - transpose = FALSE - ) + base::forwardsolve(l, x, k = ncol(l), upper.tri = FALSE, transpose = FALSE) } # nolint start #' @export -forwardsolve.greta_array <- function(l, x, - k = ncol(l), - upper.tri = FALSE, - transpose = FALSE) { +forwardsolve.greta_array <- function( + l, + x, + k = ncol(l), + upper.tri = FALSE, + transpose = FALSE +) { # nolint end check_x_matches_ncol(x = k, ncol_of = l) check_transpose(transpose) - op("forwardsolve", l, x, + op( + "forwardsolve", + l, + x, operation_args = list(lower = !upper.tri), tf_operation = "tf$linalg$triangular_solve", dim = dim(x) @@ -891,12 +896,14 @@ forwardsolve.greta_array <- function(l, x, #' @rdname overloaded #' @export -apply <- function(X, MARGIN, FUN, ...) { # nolint +apply <- function(X, MARGIN, FUN, ...) { + # nolint UseMethod("apply", X) } #' @export -apply.default <- function(X, MARGIN, FUN, ...) { # nolint +apply.default <- function(X, MARGIN, FUN, ...) { + # nolint base::apply( X = X, MARGIN = MARGIN, @@ -907,12 +914,20 @@ apply.default <- function(X, MARGIN, FUN, ...) { # nolint # nolint start #' @export -apply.greta_array <- function(X, MARGIN, - FUN = c( - "sum", "max", "mean", "min", "prod", - "cumsum", "cumprod" - ), - ...) { +apply.greta_array <- function( + X, + MARGIN, + FUN = c( + "sum", + "max", + "mean", + "min", + "prod", + "cumsum", + "cumprod" + ), + ... +) { # nolint end fun <- match.arg(FUN) @@ -955,7 +970,9 @@ apply.greta_array <- function(X, MARGIN, tf_fun_name <- paste("reduce", tf_fun_name, sep = "_") } - out <- op("apply", new_x, + out <- op( + "apply", + new_x, operation_args = list( axis = -2L, tf_fun_name = tf_fun_name @@ -975,14 +992,21 @@ apply.greta_array <- function(X, MARGIN, #' @rdname overloaded #' @export -tapply <- function(X, INDEX, FUN, ...) { # nolint +tapply <- function(X, INDEX, FUN, ...) { + # nolint UseMethod("tapply", X) } # nolint start #' @export -tapply.default <- function(X, INDEX, FUN = NULL, ..., - default = NA, simplify = TRUE) { +tapply.default <- function( + X, + INDEX, + FUN = NULL, + ..., + default = NA, + simplify = TRUE +) { # nolint end base::tapply( X = X, @@ -996,9 +1020,12 @@ tapply.default <- function(X, INDEX, FUN = NULL, ..., # nolint start #' @export -tapply.greta_array <- function(X, INDEX, - FUN = c("sum", "max", "mean", "min", "prod"), - ...) { +tapply.greta_array <- function( + X, + INDEX, + FUN = c("sum", "max", "mean", "min", "prod"), + ... +) { # nolint end x <- X @@ -1014,7 +1041,9 @@ tapply.greta_array <- function(X, INDEX, check_2_by_1(x) - op("tapply", x, + op( + "tapply", + x, operation_args = list( segment_ids = id, num_segments = len, @@ -1027,14 +1056,14 @@ tapply.greta_array <- function(X, INDEX, #' @rdname overloaded #' @export -eigen <- function(x, symmetric, only.values, EISPACK) { # nolint +eigen <- function(x, symmetric, only.values, EISPACK) { + # nolint UseMethod("eigen") } # nolint start #' @export -eigen.default <- function(x, symmetric, - only.values = FALSE, EISPACK = FALSE) { +eigen.default <- function(x, symmetric, only.values = FALSE, EISPACK = FALSE) { # nolint end base::eigen( x = x, @@ -1046,8 +1075,12 @@ eigen.default <- function(x, symmetric, # nolint start #' @export -eigen.greta_array <- function(x, symmetric, - only.values = FALSE, EISPACK = FALSE) { +eigen.greta_array <- function( + x, + symmetric, + only.values = FALSE, + EISPACK = FALSE +) { # nolint end x <- as.greta_array(x) @@ -1069,32 +1102,31 @@ eigen.greta_array <- function(x, symmetric, # they just want the eigenvalues, use that tf method if (only.values) { - values <- op("eigenvalues", x, + values <- op( + "eigenvalues", + x, dim = nrow(x), tf_operation = "tf_only_eigenvalues" ) vectors <- NULL } else { - # if we're doing the whole eigendecomposition, do it in three operations # a wacky greta array which apparently has the same dimension as x; but in # fact is a list of the two elements. But that's OK so long as the user # never sees it - eig <- op("eigen", x, - tf_operation = "tf$linalg$eigh" - ) + eig <- op("eigen", x, tf_operation = "tf$linalg$eigh") # get the eigenvalues and vectors as actual, sane greta arrays - values <- op("values", eig, + values <- op( + "values", + eig, dim = c(nrow(eig), 1L), tf_operation = "tf_extract_eigenvalues" ) - vectors <- op("vectors", eig, - tf_operation = "tf_extract_eigenvectors" - ) + vectors <- op("vectors", eig, tf_operation = "tf_extract_eigenvectors") } list( @@ -1141,12 +1173,8 @@ rdist.greta_array <- function(x1, x2 = NULL, compact = FALSE) { # square self-distance matrix if (is.null(x2)) { - op("rdist", x1, - tf_operation = "tf_self_distance", - dim = c(n1, n1) - ) + op("rdist", x1, tf_operation = "tf_self_distance", dim = c(n1, n1)) } else { - # possibly non-square pairwise distance matrix x2 <- as.greta_array(x2) @@ -1162,9 +1190,6 @@ rdist.greta_array <- function(x1, x2 = NULL, compact = FALSE) { n2 <- nrow(x2) - op("rdist", x1, x2, - tf_operation = "tf_distance", - dim = c(n1, n2) - ) + op("rdist", x1, x2, tf_operation = "tf_distance", dim = c(n1, n2)) } } diff --git a/R/greta-sitrep.R b/R/greta-sitrep.R index d85d650c..46f63abd 100644 --- a/R/greta-sitrep.R +++ b/R/greta-sitrep.R @@ -19,8 +19,7 @@ #' \dontrun{ #' greta_sitrep() #' } -greta_sitrep <- function(verbosity = c("minimal", "detailed", "quiet")){ - +greta_sitrep <- function(verbosity = c("minimal", "detailed", "quiet")) { verbosity <- rlang::arg_match( arg = verbosity, values = c("minimal", "detailed", "quiet"), @@ -33,23 +32,19 @@ greta_sitrep <- function(verbosity = c("minimal", "detailed", "quiet")){ detailed = detailed_sitrep(), quiet = quiet_sitrep() ) - } -minimal_sitrep <- function(){ - +minimal_sitrep <- function() { check_if_python_available() check_if_tf_available() check_if_tfp_available() check_if_greta_conda_env_available() check_greta_ready_to_use() - } -detailed_sitrep <- function(){ - +detailed_sitrep <- function() { config_info <- reticulate::py_config() cli::cli_h1("R") @@ -70,13 +65,17 @@ detailed_sitrep <- function(){ cli::cli_ul("path: {.path {conda_env_path}}") conda_modules <- conda_list_env_modules() - tf_in_conda <- nzchar(grep("^(tensorflow)(\\s|$)", - conda_modules, - value = TRUE)) + tf_in_conda <- nzchar(grep( + "^(tensorflow)(\\s|$)", + conda_modules, + value = TRUE + )) - tfp_in_conda <- nzchar(grep("^(tensorflow-probability)(\\s|$)", - conda_modules, - value = TRUE)) + tfp_in_conda <- nzchar(grep( + "^(tensorflow-probability)(\\s|$)", + conda_modules, + value = TRUE + )) cli::cli_h1("{.pkg TensorFlow}") check_if_tf_available() @@ -96,51 +95,49 @@ detailed_sitrep <- function(){ "{.code system(paste('conda list -n', 'greta-env-tf2'), intern = TRUE)}" ) ) - } -quiet_sitrep <- function(){ - +quiet_sitrep <- function() { suppressMessages(check_greta_ready_to_use()) - } -conda_list_env_modules <- function(){ +conda_list_env_modules <- function() { system(paste("conda list -n", "greta-env-tf2"), intern = TRUE) } -check_if_python_available <- function(min_version = "3.3"){ +check_if_python_available <- function(min_version = "3.3") { check_if_software_available( software_available = have_python(), version = py_version(), software_name = "python" - ) + ) } -check_if_tf_available <- function(){ +check_if_tf_available <- function() { check_if_software_available( software_available = have_tf(), version = version_tf(), software_name = "TensorFlow" - ) + ) } -check_if_tfp_available <- function(){ +check_if_tfp_available <- function() { check_if_software_available( software_available = have_tfp(), version = version_tfp(), software_name = "TensorFlow Probability" - ) + ) } -check_if_greta_conda_env_available <- function(){ - check_if_software_available(software_available = have_greta_conda_env(), - software_name = "greta conda environment") - +check_if_greta_conda_env_available <- function() { + check_if_software_available( + software_available = have_greta_conda_env(), + software_name = "greta conda environment" + ) } -software_availability <- function(){ +software_availability <- function() { software_available <- c( python = have_python(), tf = have_tf(), @@ -151,8 +148,7 @@ software_availability <- function(){ } -get_current_ideal_deps <- function(){ - +get_current_ideal_deps <- function() { software_version <- data.frame( software = c( "python", @@ -179,17 +175,15 @@ get_current_ideal_deps <- function(){ ) software_version - } -check_greta_ready_to_use <- function(software_available){ - +check_greta_ready_to_use <- function(software_available) { software_available <- software_availability() greta_env_not_available <- !software_available["greta_env"] - other_software_ready <- all(software_available[1:3]) + other_software_ready <- all(software_available[1:3]) deps_avail_not_greta_env <- greta_env_not_available && other_software_ready - if (deps_avail_not_greta_env){ + if (deps_avail_not_greta_env) { check_tf_version("none") cli::cli_alert_info( c( @@ -197,31 +191,29 @@ check_greta_ready_to_use <- function(software_available){ "i" = "{.pkg greta} is ready to use!" ), wrap = TRUE - ) + ) } if (!all(software_available)) { check_tf_version("warn") } else if (all(software_available)) { software_version <- get_current_ideal_deps() - if (all(software_version$match)){ + if (all(software_version$match)) { check_tf_version("none") - cli::cli_alert_info("{.pkg greta} is ready to use!", - wrap = TRUE) + cli::cli_alert_info("{.pkg greta} is ready to use!", wrap = TRUE) } else { check_tf_version("warn") } - } - } -check_if_software_available <- function(software_available, - version = NULL, - ideal_version = NULL, - software_name){ - +check_if_software_available <- function( + software_available, + version = NULL, + ideal_version = NULL, + software_name +) { cli::cli_process_start("checking if {.pkg {software_name}} available") # if the software is detected @@ -232,8 +224,7 @@ check_if_software_available <- function(software_available, } if (software_available) { - - if (is.null(ideal_version) & !is.null(version)){ + if (is.null(ideal_version) & !is.null(version)) { cli::cli_process_done( msg_done = "{.pkg {software_name}} (v{version}) available" ) @@ -241,16 +232,16 @@ check_if_software_available <- function(software_available, # if it has a version and ideal version has_ideal_version <- !is.null(version) & !is.null(ideal_version) - if (has_ideal_version){ + if (has_ideal_version) { version_chr <- paste0(version) version_match <- compareVersion(version_chr, ideal_version) == 0 - if (version_match){ + if (version_match) { cli::cli_process_done( msg_done = "{.pkg {software_name}} (v{version}) available" ) } - if (!version_match){ + if (!version_match) { cli::cli_process_failed( msg_failed = "{.pkg {software_name}} available, \\ however {.strong {ideal_version}} is needed and \\ @@ -258,7 +249,7 @@ check_if_software_available <- function(software_available, ) } # if there is no version for the software - } else if (is.null(version)){ + } else if (is.null(version)) { cli::cli_process_done( msg_done = "{.pkg {software_name}} available" ) @@ -275,8 +266,9 @@ compare_version_vec <- Vectorize( # find out whether the usr has conda installed and visible #' @importFrom reticulate conda_binary have_conda <- function() { - conda_bin <- tryCatch(reticulate::conda_binary("auto"), - error = function(e) NULL + conda_bin <- tryCatch( + reticulate::conda_binary("auto"), + error = function(e) NULL ) !is.null(conda_bin) } @@ -294,32 +286,26 @@ have_tfp <- function() { is_tfp_available <- py_module_available("tensorflow_probability") if (is_tfp_available) { - pkg <- reticulate::import("pkg_resources") tfp_version <- pkg$get_distribution("tensorflow_probability")$version is_tfp_available <- utils::compareVersion("0.15.0", tfp_version) <= 0 - } return(is_tfp_available) - } have_tf <- function() { is_tf_available <- py_module_available("tensorflow") if (is_tf_available) { - tf_version <- suppressMessages(tf$`__version__`) is_tf_available <- utils::compareVersion("2.9.0", tf_version) <= 0 - } return(is_tf_available) - } -version_tf <- function(){ +version_tf <- function() { if (have_tf()) { tf$`__version__` } else { @@ -327,7 +313,7 @@ version_tf <- function(){ } } -version_tfp <- function(){ +version_tfp <- function() { if (have_tfp()) { tfp$`__version__` } else { diff --git a/R/greta_array_class.R b/R/greta_array_class.R index d98ad329..f82ed490 100644 --- a/R/greta_array_class.R +++ b/R/greta_array_class.R @@ -10,7 +10,12 @@ as.greta_array <- function(x, optional = FALSE, original_x = x, ...) { # safely handle self-coercion #' @export -as.greta_array.greta_array <- function(x, optional = FALSE, original_x = x, ...) { +as.greta_array.greta_array <- function( + x, + optional = FALSE, + original_x = x, + ... +) { x } @@ -18,21 +23,22 @@ as.greta_array.greta_array <- function(x, optional = FALSE, original_x = x, ...) #' @export as.greta_array.logical <- function(x, optional = FALSE, original_x = x, ...) { x[] <- as.numeric(x[]) - as.greta_array.numeric(x, - optional = optional, - original_x = original_x, - ... - ) + as.greta_array.numeric(x, optional = optional, original_x = original_x, ...) } # coerce dataframes if all columns can safely be converted to numeric, error # otherwise #' @export -as.greta_array.data.frame <- function(x, optional = FALSE, - original_x = x, ...) { +as.greta_array.data.frame <- function( + x, + optional = FALSE, + original_x = x, + ... +) { check_greta_data_frame(x, optional) - as.greta_array.numeric(as.matrix(x), + as.greta_array.numeric( + as.matrix(x), optional = optional, original_x = original_x, ... @@ -43,36 +49,26 @@ as.greta_array.data.frame <- function(x, optional = FALSE, # or numeric #' @export as.greta_array.matrix <- function(x, optional = FALSE, original_x = x, ...) { - check_greta_array_type(x, optional) if (!is.numeric(x) && is.logical(x)) { - x[] <- as.numeric(x[]) - } + x[] <- as.numeric(x[]) + } - as.greta_array.numeric(x, - optional = optional, - original_x = original_x, - ... - ) + as.greta_array.numeric(x, optional = optional, original_x = original_x, ...) } # coerce logical arrays to numeric arrays, and error if they aren't logical # or numeric #' @export as.greta_array.array <- function(x, optional = FALSE, original_x = x, ...) { - check_greta_array_type(x, optional) if (!optional && !is.numeric(x) && is.logical(x)) { - x[] <- as.numeric(x[]) + x[] <- as.numeric(x[]) } - as.greta_array.numeric(x, - optional = optional, - original_x = original_x, - ... - ) + as.greta_array.numeric(x, optional = optional, original_x = original_x, ...) } # finally, reject if there are any missing values, or set up the greta_array @@ -80,7 +76,8 @@ as.greta_array.array <- function(x, optional = FALSE, original_x = x, ...) { as.greta_array.numeric <- function(x, optional = FALSE, original_x = x, ...) { check_missing_infinite_values(x, optional) - as.greta_array.node(data_node$new(x), + as.greta_array.node( + data_node$new(x), optional = optional, original_x = original_x, ... @@ -123,7 +120,7 @@ print.greta_array <- function(x, ..., n = 10) { cli::cli_text("{.pkg greta} array {.cls {node_desc}}") cli::cli_text("\n") - if (is.unknowns(node$value())){ + if (is.unknowns(node$value())) { return(print(node$value(), ..., n = n)) } @@ -143,7 +140,7 @@ print.greta_array <- function(x, ..., n = 10) { return(invisible(x_val)) } - if (remaining_vals > 0 ) { + if (remaining_vals > 0) { cli::cli_alert_info( text = c( "i" = "{remaining_vals} more values\n", @@ -151,8 +148,6 @@ print.greta_array <- function(x, ..., n = 10) { ) ) } - - } @@ -177,7 +172,6 @@ summary.greta_array <- function(object, ...) { #' @export #' @method print summary.greta_array print.summary.greta_array <- function(x, ...) { - # array type type_text <- glue::glue( "'{x$type}' greta array" @@ -269,7 +263,7 @@ anti_representation <- function(x, name, error = TRUE) { } -has_anti_representation <- function(x, name){ +has_anti_representation <- function(x, name) { repr <- anti_representation(x, name, error = FALSE) !is.null(repr) } @@ -280,7 +274,8 @@ copy_representation <- function(x, name) { identity(repr) } -greta_array_module <- module(as.greta_array, +greta_array_module <- module( + as.greta_array, get_node, has_representation, representation, diff --git a/R/greta_create_conda_env.R b/R/greta_create_conda_env.R index 8e374313..60473104 100644 --- a/R/greta_create_conda_env.R +++ b/R/greta_create_conda_env.R @@ -16,9 +16,7 @@ #' #' @return nothing - creates a conda environment for a specific python version #' @export -greta_create_conda_env <- function(timeout = 5, - deps = greta_deps_spec()) { - +greta_create_conda_env <- function(timeout = 5, deps = greta_deps_spec()) { check_greta_deps_spec(deps) stdout_file <- create_temp_file("out-greta-conda") @@ -42,7 +40,7 @@ greta_create_conda_env <- function(timeout = 5, stderr_file = stderr_file, timeout = timeout, cli_start_msg = glue::glue( - "Creating 'greta-env-tf2' conda environment using python \\ + "Creating 'greta-env-tf2' conda environment using python \\ v{deps$python_version}, this may take a minute" ), cli_end_msg = "greta-env-tf2 environment created!" @@ -50,5 +48,4 @@ greta_create_conda_env <- function(timeout = 5, greta_stash$conda_create_notes <- install_conda_create$output_notes greta_stash$conda_create_error <- install_conda_create$output_error - } diff --git a/R/greta_install_miniconda.R b/R/greta_install_miniconda.R index ed523d6c..7a158710 100644 --- a/R/greta_install_miniconda.R +++ b/R/greta_install_miniconda.R @@ -10,7 +10,6 @@ #' @return nothing - installs miniconda. #' @export greta_install_miniconda <- function(timeout = 5) { - stdout_file <- create_temp_file("out-miniconda") stderr_file <- create_temp_file("err-miniconda") @@ -35,5 +34,4 @@ greta_install_miniconda <- function(timeout = 5) { greta_stash$miniconda_notes <- install_miniconda_process$output_notes greta_stash$miniconda_error <- install_miniconda_process$output_error - } diff --git a/R/greta_install_python_deps.R b/R/greta_install_python_deps.R index 2c50def3..a514e87f 100644 --- a/R/greta_install_python_deps.R +++ b/R/greta_install_python_deps.R @@ -1,6 +1,4 @@ -greta_install_python_deps <- function(timeout = 5, - deps = greta_deps_spec()) { - +greta_install_python_deps <- function(timeout = 5, deps = greta_deps_spec()) { stdout_file <- create_temp_file("out-python-deps") stderr_file <- create_temp_file("err-python-deps") @@ -10,7 +8,7 @@ greta_install_python_deps <- function(timeout = 5, msg = "Installing TF (v{deps$tf_version})", msg_done = "Installed TF (v{deps$tf_version})!", msg_failed = "Error installing TF (v{deps$tf_version})" - ) + ) tensorflow::install_tensorflow( version = deps$tf_version, envname = "greta-env-tf2", @@ -27,12 +25,12 @@ greta_install_python_deps <- function(timeout = 5, pip = TRUE, envname = "greta-env-tf2", method = "conda" - ) - }, + ) + }, args = list(deps = deps), stdout = stdout_file, stderr = stderr_file - ) + ) install_python_modules <- new_install_process( callr_process = callr_conda_install, @@ -40,7 +38,7 @@ greta_install_python_deps <- function(timeout = 5, stdout_file = stdout_file, stderr_file = stderr_file, cli_start_msg = glue::glue( - "Installing python modules into 'greta-env-tf2' conda environment, \\ + "Installing python modules into 'greta-env-tf2' conda environment, \\ this may take a few minutes" ), cli_end_msg = "Python modules installed!" @@ -48,5 +46,4 @@ greta_install_python_deps <- function(timeout = 5, greta_stash$conda_install_notes <- install_python_modules$output_notes greta_stash$conda_install_error <- install_python_modules$output_error - } diff --git a/R/greta_mcmc_list.R b/R/greta_mcmc_list.R index 88980c87..9524760a 100644 --- a/R/greta_mcmc_list.R +++ b/R/greta_mcmc_list.R @@ -2,7 +2,6 @@ # mcmc.list as_greta_mcmc_list <- function(x, model_info) { - # add the raw draws as an attribute attr(x, "model_info") <- model_info class(x) <- c("greta_mcmc_list", class(x)) @@ -28,7 +27,7 @@ as.mcmc.list.greta_mcmc_list <- function(x, ...) { #' @returns logical TRUE/FALSE #' #' @export -is.greta_mcmc_list <- function(x, ...){ +is.greta_mcmc_list <- function(x, ...) { inherits(x, "greta_mcmc_list") } @@ -60,8 +59,7 @@ window.greta_mcmc_list <- function(x, start, end, thin, ...) { #' #' @return printed MCMC output #' @export -print.greta_mcmc_list <- function(x, ..., n = 5){ - +print.greta_mcmc_list <- function(x, ..., n = 5) { n_warmup <- n_warmup(x) n_chain <- coda::nchain(x) n_iter <- coda::niter(x) @@ -95,19 +93,20 @@ print.greta_mcmc_list <- function(x, ..., n = 5){ print(draws_head) - if (more_draws_than_can_print){ - cli::cli_alert_info( - text = c( - "i" = "{remaining_draws} more draws\n", - "i" = "Use {.code print(n = ...)} to see more draws" + if (more_draws_than_can_print) { + cli::cli_alert_info( + text = c( + "i" = "{remaining_draws} more draws\n", + "i" = "Use {.code print(n = ...)} to see more draws" + ) ) - ) } cli::cli_rule() cli::cli_alert_info( - c("View {.pkg greta} draw chain {.param i} with:\n", + c( + "View {.pkg greta} draw chain {.param i} with:\n", "{.code greta_draws_object[[i]]}. \n", "E.g., view chain {.param 1} with: \n", "{.code greta_draws_object[[1]]}." @@ -115,8 +114,9 @@ print.greta_mcmc_list <- function(x, ..., n = 5){ ) cli::cli_alert_info( - c("To see a summary of draws, run:\n", - "{.code summary(greta_draws_object)}") + c( + "To see a summary of draws, run:\n", + "{.code summary(greta_draws_object)}" + ) ) - } diff --git a/R/greta_model_class.R b/R/greta_model_class.R index db01889a..5b089be8 100644 --- a/R/greta_model_class.R +++ b/R/greta_model_class.R @@ -49,16 +49,15 @@ NULL #' #' plot(m) #' } -model <- function(..., - precision = c("double", "single"), - compile = TRUE) { +model <- function(..., precision = c("double", "single"), compile = TRUE) { check_tf_version("error") # get the floating point precision # TODO # what does it choose as default if both double and single are listed # as default? - tf_float <- switch(match.arg(precision), + tf_float <- switch( + match.arg(precision), double = "float64", single = "float32" ) @@ -69,11 +68,11 @@ model <- function(..., # if no arrays were specified, find all of the non-data arrays no_arrays_specified <- identical(target_greta_arrays, list()) if (no_arrays_specified) { - target_greta_arrays <- all_greta_arrays(parent.frame(), + target_greta_arrays <- all_greta_arrays( + parent.frame(), include_data = FALSE ) } else { - # otherwise, find variable names for the provided nodes names <- substitute(list(...))[-1] names <- vapply(names, deparse, "") @@ -86,7 +85,8 @@ model <- function(..., # TF1/2 check # I don't think we need to use the `compile` flag in TF2 anymore # Well, it will be passed onto the tf_function creation step - dag <- dag_class$new(target_greta_arrays, + dag <- dag_class$new( + target_greta_arrays, tf_float = tf_float, compile = compile ) @@ -113,12 +113,14 @@ model <- function(..., #' @param ... extra arguments - not used. #' #' @export -as.greta_model <- function(x, ...) { # nolint +as.greta_model <- function(x, ...) { + # nolint UseMethod("as.greta_model", x) } #' @export -as.greta_model.dag_class <- function(x, ...) { # nolint +as.greta_model.dag_class <- function(x, ...) { + # nolint ans <- list(dag = x) class(ans) <- "greta_model" ans @@ -148,16 +150,14 @@ print.greta_model <- function(x, ...) { #' create it as an attribute `"dgr_graph"`. #' #' @export -plot.greta_model <- function(x, - y, - colour = "#996bc7", - ...) { +plot.greta_model <- function(x, y, colour = "#996bc7", ...) { check_diagrammer_installed() # set up graph dag_mat <- x$dag$adjacency_matrix - gr <- DiagrammeR::from_adj_matrix(dag_mat, + gr <- DiagrammeR::from_adj_matrix( + dag_mat, mode = "directed", use_diag = FALSE ) @@ -189,7 +189,8 @@ plot.greta_model <- function(x, node_size[types == "operation"] <- 0.2 # get node labels - node_labels <- vapply(x$dag$node_list, + node_labels <- vapply( + x$dag$node_list, member, "plotting_label()", FUN.VALUE = "" @@ -201,14 +202,16 @@ plot.greta_model <- function(x, known_nodes <- known_nodes[known_nodes %in% names] known_idx <- match(known_nodes, names) - node_labels[known_idx] <- paste(names(known_nodes), + node_labels[known_idx] <- paste( + names(known_nodes), node_labels[known_idx], sep = "\n" ) # for the operation nodes, add the operation to the edges op_idx <- which(types == "operation") - op_names <- vapply(x$dag$node_list[op_idx], + op_names <- vapply( + x$dag$node_list[op_idx], member, "operation_name", FUN.VALUE = "" @@ -266,8 +269,8 @@ plot.greta_model <- function(x, keep <- !are_null(targets) distrib_idx <- distrib_idx[keep] - - target_names <- vapply(x$dag$node_list[distrib_idx], + target_names <- vapply( + x$dag$node_list[distrib_idx], member, "target$unique_name", FUN.VALUE = "" @@ -317,7 +320,6 @@ plot.greta_model <- function(x, ) ) - widget <- DiagrammeR::render_graph(gr) attr(widget, "dgr_graph") <- gr widget diff --git a/R/greta_stash.R b/R/greta_stash.R index adff7fcf..d85e5786 100644 --- a/R/greta_stash.R +++ b/R/greta_stash.R @@ -4,7 +4,7 @@ greta_note_msg <- cli::format_message( have been wiped. This likely means that installation has not happened, or \\ it has happened and you've restarted R. See `?install_greta_deps()` for \\ more information." - ) + ) ) greta_stash$install_miniconda_notes <- greta_note_msg diff --git a/R/inference.R b/R/inference.R index 7854f1d0..ba174c76 100644 --- a/R/inference.R +++ b/R/inference.R @@ -207,24 +207,22 @@ greta_stash$numerical_messages <- c( #' #' } mcmc <- function( - model, - sampler = hmc(), - n_samples = 1000, - thin = 1, - warmup = 1000, - chains = 4, - n_cores = NULL, - verbose = TRUE, - pb_update = 50, - one_by_one = FALSE, - initial_values = initials(), - trace_batch_size = 100, - compute_options = cpu_only() + model, + sampler = hmc(), + n_samples = 1000, + thin = 1, + warmup = 1000, + chains = 4, + n_cores = NULL, + verbose = TRUE, + pb_update = 50, + one_by_one = FALSE, + initial_values = initials(), + trace_batch_size = 100, + compute_options = cpu_only() ) { - # set device to be CPU/GPU for the entire run with(tf$device(compute_options), { - # message users about random seeds and GPU usage if they are using GPU message_if_using_gpu(compute_options) @@ -279,18 +277,19 @@ mcmc <- function( } #' @importFrom future future resolved value -run_samplers <- function(samplers, - n_samples, - thin, - warmup, - verbose, - pb_update, - one_by_one, - n_cores, - from_scratch, - trace_batch_size, - compute_options) { - +run_samplers <- function( + samplers, + n_samples, + thin, + warmup, + verbose, + pb_update, + one_by_one, + n_cores, + from_scratch, + trace_batch_size, + compute_options +) { # check the future plan is valid, and get information about it plan_is <- check_future_plan() @@ -308,8 +307,11 @@ run_samplers <- function(samplers, greta_stash$samplers <- samplers inform_if_local_parallel_multiple_samplers( - plan_is, samplers, n_cores, compute_options - ) + plan_is, + samplers, + n_cores, + compute_options + ) inform_if_remote_machine(plan_is, samplers) @@ -328,8 +330,12 @@ run_samplers <- function(samplers, # give the samplers somewhere to write their progress if (parallel_reporting) { sampler <- sampler_parallel_reporting( - n_chain, samplers, chains, n_samples, warmup - ) + n_chain, + samplers, + chains, + n_samples, + warmup + ) } if (plan_is$parallel) { @@ -364,7 +370,6 @@ run_samplers <- function(samplers, # loop until they are resolved, executing the callbacks if (parallel_reporting) { while (!all(vapply(samplers, resolved, FALSE))) { - # loop through callbacks executing them for (callback in greta_stash$callbacks) { callback() @@ -377,7 +382,6 @@ run_samplers <- function(samplers, cat("\n") } - # if we were running in parallel, retrieve the samplers and put them back in # the stash to return if (plan_is$parallel) { @@ -430,15 +434,17 @@ stashed_samples <- function() { thins <- lapply(samplers, member, "thin") # convert to mcmc objects, passing on thinning - free_state_draws <- mapply(prepare_draws, - draws = free_state_draws, - thin = thins, - SIMPLIFY = FALSE + free_state_draws <- mapply( + prepare_draws, + draws = free_state_draws, + thin = thins, + SIMPLIFY = FALSE ) - values_draws <- mapply(prepare_draws, - draws = values_draws, - thin = thins, - SIMPLIFY = FALSE + values_draws <- mapply( + prepare_draws, + draws = values_draws, + thin = thins, + SIMPLIFY = FALSE ) # convert to mcmc.list objects @@ -475,15 +481,17 @@ stashed_samples <- function() { #' used to generate the previous samples. It is not possible to change the #' sampler or extend the warmup period. #' -extra_samples <- function(draws, - n_samples = 1000, - thin = 1, - n_cores = NULL, - verbose = TRUE, - pb_update = 50, - one_by_one = FALSE, - trace_batch_size = 100, - compute_options = cpu_only()) { +extra_samples <- function( + draws, + n_samples = 1000, + thin = 1, + n_cores = NULL, + verbose = TRUE, + pb_update = 50, + one_by_one = FALSE, + trace_batch_size = 100, + compute_options = cpu_only() +) { model_info <- get_model_info(draws) samplers <- model_info$samplers @@ -552,11 +560,12 @@ to_free <- function(node, data) { stats::qlogis((x - lower) / (upper - lower)) } - fun <- switch(node$constraint, - scalar_all_none = identity, - scalar_all_high = high, - scalar_all_low = low, - scalar_all_both = both + fun <- switch( + node$constraint, + scalar_all_none = identity, + scalar_all_high = high, + scalar_all_low = low, + scalar_all_both = both ) fun(data) @@ -565,7 +574,6 @@ to_free <- function(node, data) { # convert a named list of initial values into the corresponding vector of values # on the free state parse_initial_values <- function(initials, dag) { - # skip if no inits provided if (identical(initials, initials())) { free_parameters <- dag$example_parameters(free = TRUE) @@ -574,14 +582,15 @@ parse_initial_values <- function(initials, dag) { } # find the elements we have been given initial values for - tf_names <- vapply(names(initials), - function(name, env) { - ga <- get(name, envir = env) - node <- get_node(ga) - dag$tf_name(node) - }, - env = parent.frame(4), - FUN.VALUE = "" + tf_names <- vapply( + names(initials), + function(name, env) { + ga <- get(name, envir = env) + node <- get_node(ga) + dag$tf_name(node) + }, + env = parent.frame(4), + FUN.VALUE = "" ) check_greta_arrays_associated_with_model(tf_names) @@ -620,16 +629,12 @@ parse_initial_values <- function(initials, dag) { # convert (possibly NULL) user-specified initial values into a list of the # correct length, with nice error messages prep_initials <- function(initial_values, n_chains, dag) { - # if the user passed a single set of initial values, repeat them for all # chains if (is.initials(initial_values)) { inform_if_one_set_of_initials(initial_values, n_chains) - initial_values <- replicate(n_chains, - initial_values, - simplify = FALSE - ) + initial_values <- replicate(n_chains, initial_values, simplify = FALSE) } # TODO: revisit logic here for errors and messages @@ -723,18 +728,18 @@ print.initials <- function(x, ...) { #' matrices/arrays for the parameters (w.r.t. `value`) #' } #' -opt <- function(model, - optimiser = bfgs(), - max_iterations = 100, - tolerance = 1e-6, - initial_values = initials(), - adjust = TRUE, - hessian = FALSE, - compute_options = cpu_only()) { - +opt <- function( + model, + optimiser = bfgs(), + max_iterations = 100, + tolerance = 1e-6, + initial_values = initials(), + adjust = TRUE, + hessian = FALSE, + compute_options = cpu_only() +) { # set device to be CPU/GPU for the entire run with(tf$device(compute_options), { - # message users about random seeds and GPU usage if they are using GPU message_if_using_gpu(compute_options) @@ -780,6 +785,4 @@ opt <- function(model, }) } -inference_module <- module(dag_class, - progress_bar = progress_bar_module -) +inference_module <- module(dag_class, progress_bar = progress_bar_module) diff --git a/R/inference_class.R b/R/inference_class.R index c986efe1..44f28e13 100644 --- a/R/inference_class.R +++ b/R/inference_class.R @@ -30,11 +30,12 @@ inference <- R6Class( traced_free_state = list(), # all recorded greta array values traced_values = list(), - initialize = function(initial_values, - model, - parameters = list(), - seed = get_seed()) { - + initialize = function( + initial_values, + model, + parameters = list(), + seed = get_seed() + ) { self$parameters <- parameters self$model <- model free_parameters <- model$dag$example_parameters(free = TRUE) @@ -52,15 +53,21 @@ inference <- R6Class( write_trace_to_log_file = function(last_burst_values) { if (file.exists(self$trace_log_file)) { # Append - write.table(last_burst_values, self$trace_log_file, - append = TRUE, - row.names = FALSE, col.names = FALSE + write.table( + last_burst_values, + self$trace_log_file, + append = TRUE, + row.names = FALSE, + col.names = FALSE ) } else { # Create file - write.table(last_burst_values, self$trace_log_file, - append = FALSE, - row.names = FALSE, col.names = TRUE + write.table( + last_burst_values, + self$trace_log_file, + append = FALSE, + row.names = FALSE, + col.names = TRUE ) } }, @@ -85,8 +92,7 @@ inference <- R6Class( # check and try to autofill a single set of initial values (single vector on # free state scale) - check_initial_values = function(inits, - call = rlang::caller_env()) { + check_initial_values = function(inits, call = rlang::caller_env()) { undefined <- is.na(inits) # try to fill in any that weren't specified @@ -104,19 +110,16 @@ inference <- R6Class( } self$check_reasonable_starting_values(valid, attempts) - } else { - # if they were all provided, check they can be be used valid <- self$valid_parameters(inits) self$check_valid_parameters(valid) - } inits }, - check_reasonable_starting_values = function(valid, attempts){ + check_reasonable_starting_values = function(valid, attempts) { if (!valid) { cli::cli_abort( message = c( @@ -129,7 +132,7 @@ inference <- R6Class( } }, - check_valid_parameters = function(valid){ + check_valid_parameters = function(valid) { if (!valid) { cli::cli_abort( c( @@ -144,7 +147,6 @@ inference <- R6Class( # check and set a list of initial values set_initial_values = function(init_list) { - # check/autofill them init_list <- lapply(init_list, self$check_initial_values) @@ -182,12 +184,12 @@ inference <- R6Class( # arrays for the latest batch of raw draws trace = function(free_state = TRUE, values = FALSE) { if (free_state) { - # append the free state trace for each chain - self$traced_free_state <- mapply(rbind, - self$traced_free_state, - self$last_burst_free_states, - SIMPLIFY = FALSE + self$traced_free_state <- mapply( + rbind, + self$traced_free_state, + self$last_burst_free_states, + SIMPLIFY = FALSE ) } @@ -197,10 +199,11 @@ inference <- R6Class( if (!is.null(self$trace_log_file)) { self$write_trace_to_log_file(last_burst_values) } - self$traced_values <- mapply(rbind, - self$traced_values, - last_burst_values, - SIMPLIFY = FALSE + self$traced_values <- mapply( + rbind, + self$traced_values, + last_burst_values, + SIMPLIFY = FALSE ) } }, @@ -208,7 +211,6 @@ inference <- R6Class( # given a matrix of free state values, get a matrix of values of the target # greta arrays trace_burst_values = function(free_states = self$last_burst_free_states) { - # can't use apply directly, as it will drop the variable name if there's # only one parameter being traced values_trace <- lapply( @@ -231,4 +233,3 @@ inference <- R6Class( } ) ) - diff --git a/R/install_greta_deps.R b/R/install_greta_deps.R index fe8c79c2..b563971a 100644 --- a/R/install_greta_deps.R +++ b/R/install_greta_deps.R @@ -71,11 +71,12 @@ #' @importFrom callr r_process #' @importFrom cli cli_alert_success #' @importFrom cli cli_ul -install_greta_deps <- function(deps = greta_deps_spec(), - timeout = 5, - restart = c("ask", "force", "no"), - ...) { - +install_greta_deps <- function( + deps = greta_deps_spec(), + timeout = 5, + restart = c("ask", "force", "no"), + ... +) { check_greta_deps_spec(deps) restart <- rlang::arg_match( @@ -118,12 +119,13 @@ install_greta_deps <- function(deps = greta_deps_spec(), write_greta_install_log(path = greta_logfile) - cli::cli_alert_success("Installation of {.pkg greta} dependencies \\ + cli::cli_alert_success( + "Installation of {.pkg greta} dependencies \\ is complete!", - wrap = TRUE) + wrap = TRUE + ) restart_or_not(restart) - } get_pkg_user_dir <- function() { @@ -134,13 +136,13 @@ get_pkg_user_dir <- function() { pkg_user_dir } -greta_default_logfile <- function(){ +greta_default_logfile <- function() { greta_user_dir <- get_pkg_user_dir() file.path(greta_user_dir, "greta-installation-logfile.html") } -restart_or_not <- function(restart){ +restart_or_not <- function(restart) { # Managing how to restart R # requires RStudio and also an interactive session has_rstudioapi_pkg <- requireNamespace("rstudioapi", quietly = TRUE) && @@ -156,7 +158,7 @@ restart_or_not <- function(restart){ # Where there is no rstudio/not interactive, suggest restarting. suggest_restart <- (restart == "force" | restart == "no") && - (!interactive() | !has_rstudioapi_pkg) + (!interactive() | !has_rstudioapi_pkg) if (suggest_restart) { cli::cli_inform( @@ -181,7 +183,6 @@ restart_or_not <- function(restart){ clean = TRUE ) } - } ## TODO @@ -238,10 +239,11 @@ restart_or_not <- function(restart){ #' python_version = "3.10" #' ) #' } -greta_deps_spec <- function(tf_version = "2.15.0", - tfp_version = "0.23.0", - python_version = "3.10"){ - +greta_deps_spec <- function( + tf_version = "2.15.0", + tfp_version = "0.23.0", + python_version = "3.10" +) { deps_list <- data.frame( tf_version = tf_version, tfp_version = tfp_version, @@ -261,11 +263,9 @@ greta_deps_spec <- function(tf_version = "2.15.0", check_greta_deps_config(deps_obj) deps_obj - } -check_greta_deps_spec <- function(deps, - call = rlang::caller_env()) { +check_greta_deps_spec <- function(deps, call = rlang::caller_env()) { if (!inherits(deps, "greta_deps_spec")) { cli::cli_abort( message = "{.arg deps} must be created by {.fun greta_deps_spec}.", @@ -279,7 +279,7 @@ check_greta_deps_spec <- function(deps, #' @param x greta python deps #' @param ... extra args, not used #' @export -print.greta_deps_spec <- function(x, ...){ +print.greta_deps_spec <- function(x, ...) { print.data.frame(x) } @@ -295,34 +295,29 @@ print.greta_deps_spec <- function(x, ...){ #' \dontrun{ #' my_deps <- greta_deps_receipt() #' } -greta_deps_receipt <- function(){ - +greta_deps_receipt <- function() { greta_deps_spec( tf_version = version_tf(), tfp_version = version_tfp(), python_version = as.character(py_version()) ) - } -check_greta_deps_range <- function(deps, - module, - call = rlang::caller_env()){ - +check_greta_deps_range <- function(deps, module, call = rlang::caller_env()) { greta_tf_tfp <- greta_deps_tf_tfp[[module]] version_provided <- numeric_version(deps[[module]]) - version_name <- switch(module, - tf_version = "TF", - tfp_version = "TFP") + version_name <- switch(module, tf_version = "TF", tfp_version = "TFP") - latest_version <- switch(module, - tf_version = numeric_version("2.15.0"), - tfp_version = numeric_version("0.23.0")) + latest_version <- switch( + module, + tf_version = numeric_version("2.15.0"), + tfp_version = numeric_version("0.23.0") + ) later_tf_tfp <- version_provided > latest_version - if (later_tf_tfp){ + if (later_tf_tfp) { gh_issue <- "https://github.com/greta-dev/greta/issues/675" cli::cli_abort( message = c( @@ -335,49 +330,51 @@ check_greta_deps_range <- function(deps, "i" = "Valid versions of TF, TFP, and Python are in \\ {.code greta_deps_tf_tfp}", "i" = "Inspect with:", - "{.run View(greta_deps_tf_tfp)}"), + "{.run View(greta_deps_tf_tfp)}" + ), call = call ) } valid <- version_provided %in% greta_tf_tfp if (!valid) { - closest_value <- closest_version(version_provided, greta_deps_tf_tfp[[module]]) + closest_value <- closest_version( + version_provided, + greta_deps_tf_tfp[[module]] + ) } - if (!valid){ - + if (!valid) { cli::cli_abort( - message = c("{.val {version_name}} version provided does not match \\ + message = c( + "{.val {version_name}} version provided does not match \\ supported versions", - "The version {.val {version_provided}} was not in \\ + "The version {.val {version_provided}} was not in \\ {.val {greta_deps_tf_tfp[[module]]}}", - "i" = "The nearest valid version that is supported by \\ + "i" = "The nearest valid version that is supported by \\ {.pkg greta} is: {.val {closest_value}}", - "i" = "Valid versions of TF, TFP, and Python are in \\ + "i" = "Valid versions of TF, TFP, and Python are in \\ {.code greta_deps_tf_tfp}", - "i" = "Inspect with:", - "{.run View(greta_deps_tf_tfp)}"), + "i" = "Inspect with:", + "{.run View(greta_deps_tf_tfp)}" + ), call = call ) } } check_greta_tf_range <- function(deps, call = rlang::caller_env()) { - check_greta_deps_range(deps = deps, - module = "tf_version", - call = call) + check_greta_deps_range(deps = deps, module = "tf_version", call = call) } check_greta_tfp_range <- function(deps, call = rlang::caller_env()) { - check_greta_deps_range(deps = deps, - module = "tfp_version", - call = call) + check_greta_deps_range(deps = deps, module = "tfp_version", call = call) } -check_greta_python_range <- function(version_provided, - call = rlang::caller_env()) { - +check_greta_python_range <- function( + version_provided, + call = rlang::caller_env() +) { py_version_min <- unique(greta_deps_tf_tfp$python_version_min) py_version_max <- unique(greta_deps_tf_tfp$python_version_max) py_versions <- sort(unique(c(py_version_min, py_version_max))) @@ -388,23 +385,21 @@ check_greta_python_range <- function(version_provided, outside_range <- outside_version_range(version_provided, py_versions) if (outside_range) { - closest_value <- paste0(closest_version(version_provided, c(py_versions))) cli::cli_abort( - message = c("Python version must be between \\ + message = c( + "Python version must be between \\ {.val {min_py}}-{.val {max_py}}", - "x" = "The version provided was {.val {version_provided}}.", - "i" = "Try: {.val {closest_value}}"), + "x" = "The version provided was {.val {version_provided}}.", + "i" = "Try: {.val {closest_value}}" + ), call = call ) } - } -check_greta_deps_config <- function(deps, - call = rlang::caller_env()){ - +check_greta_deps_config <- function(deps, call = rlang::caller_env()) { check_greta_deps_spec(deps) deps <- deps |> @@ -418,11 +413,13 @@ check_greta_deps_config <- function(deps, if (no_os_matches) { valid_os <- unique(greta_deps_tf_tfp$os) cli::cli_abort( - message = c("The os provided does not match one of {.val {valid_os}}", - "i" = "Valid versions of TF, TFP, and Python are in \\ + message = c( + "The os provided does not match one of {.val {valid_os}}", + "i" = "Valid versions of TF, TFP, and Python are in \\ {.code greta_deps_tf_tfp}", - "i" = "Inspect with:", - "{.run View(greta_deps_tf_tfp)}"), + "i" = "Inspect with:", + "{.run View(greta_deps_tf_tfp)}" + ), call = call ) } @@ -436,8 +433,7 @@ check_greta_deps_config <- function(deps, no_matches <- nrow(config_matches) == 0 # Build logic to prioritise valid TFP over others - if (no_matches){ - + if (no_matches) { tfp_matches <- subset(os_matches, tfp_version == deps$tfp_version) tf_matches <- subset(os_matches, tf_version == deps$tf_version) py_matches <- os_matches |> @@ -459,13 +455,15 @@ check_greta_deps_config <- function(deps, suggest_tf <- !all_valid && any_valid && tf_valid suggest_py <- !tfp_valid && !tf_valid && py_valid - if (!any_valid){ + if (!any_valid) { cli::cli_abort( - message = c("Config does not match any installation combinations.", - "i" = "Valid versions of TF, TFP, and Python are in \\ + message = c( + "Config does not match any installation combinations.", + "i" = "Valid versions of TF, TFP, and Python are in \\ {.code greta_deps_tf_tfp}", - "i" = "Inspect with:", - "{.run View(greta_deps_tf_tfp)}"), + "i" = "Inspect with:", + "{.run View(greta_deps_tf_tfp)}" + ), call = call ) } @@ -492,56 +490,55 @@ check_greta_deps_config <- function(deps, suggested_py <- as.character(max(suggested_match$python_version_max)) cli::cli_abort( - message = c("Provided {.code greta_deps_spec} does not match valid \\ + message = c( + "Provided {.code greta_deps_spec} does not match valid \\ installation combinations.", - "See below for a suggested config to use:", - "{.code greta_deps_spec(\\ + "See below for a suggested config to use:", + "{.code greta_deps_spec(\\ tf_version = {.val {suggested_tf}}, \\ tfp_version = {.val {suggested_tfp}}, \\ python_version = {.val {suggested_py}}\\ )}", - "i" = "Valid versions of TF, TFP, and Python are in \\ + "i" = "Valid versions of TF, TFP, and Python are in \\ {.code greta_deps_tf_tfp}", - "i" = "Inspect with:", - "{.run View(greta_deps_tf_tfp)}" + "i" = "Inspect with:", + "{.run View(greta_deps_tf_tfp)}" ), call = call ) - } - } -check_tfp_tf_semantic <- function(deps_obj, - call = rlang::caller_env()){ +check_tfp_tf_semantic <- function(deps_obj, call = rlang::caller_env()) { check_semantic(deps_obj$tf_version) check_semantic(deps_obj$tfp_version) } -split_dots <- function(x){ - strsplit(x = x, - split = ".", - fixed = TRUE)[[1]] +split_dots <- function(x) { + strsplit(x = x, split = ".", fixed = TRUE)[[1]] } -is_semantic <- function(x){ +is_semantic <- function(x) { separated <- split_dots(x) is_sem <- length(separated) == 3 is_sem } -check_semantic <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ - +check_semantic <- function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() +) { not_semantic <- !is_semantic(x) - if (not_semantic){ + if (not_semantic) { cli::cli_abort( - message = c("{.arg {arg}} must be semantic.", - "We saw {.val {x}}, but we require three separating dots:", - "i" = "{.val 1.1.1}", - "x" = "{.val 1.1}"), + message = c( + "{.arg {arg}} must be semantic.", + "We saw {.val {x}}, but we require three separating dots:", + "i" = "{.val 1.1.1}", + "x" = "{.val 1.1}" + ), call = call ) } diff --git a/R/joint.R b/R/joint.R index 888bd310..a01372f2 100644 --- a/R/joint.R +++ b/R/joint.R @@ -101,7 +101,8 @@ joint_distribution <- R6Class( super$initialize("joint", dim, discrete = discrete[1]) for (i in seq_len(n_distributions)) { - self$add_parameter(distribs[[i]], + self$add_parameter( + distribs[[i]], glue::glue("distribution {i}"), shape_matches_output = FALSE ) @@ -111,7 +112,6 @@ joint_distribution <- R6Class( vble(self$bounds, dim = self$dim) }, tf_distrib = function(parameters, dag) { - # get information from the *nodes* for component distributions, not the tf # objects passed in here @@ -123,7 +123,6 @@ joint_distribution <- R6Class( names(tfp_distributions) <- NULL log_prob <- function(x) { - # split x on the joint dimension, and loop through computing the # densities last_dim <- n_dim(x) - 1L diff --git a/R/mixture.R b/R/mixture.R index 55dd4e73..629197c1 100644 --- a/R/mixture.R +++ b/R/mixture.R @@ -117,7 +117,8 @@ mixture_distribution <- R6Class( check_not_discrete_continuous(discrete, name = "mixture") # check the distributions are all either multivariate or univariate - multivariate <- vapply(distribs, + multivariate <- vapply( + distribs, member, "multivariate", FUN.VALUE = logical(1) @@ -152,14 +153,16 @@ mixture_distribution <- R6Class( self$bounds <- support # for any discrete ones, tell them they are fixed - super$initialize("mixture", + super$initialize( + "mixture", dim, discrete = discrete[1], multivariate = multivariate[1] ) for (i in seq_len(n_distributions)) { - self$add_parameter(distribs[[i]], + self$add_parameter( + distribs[[i]], glue::glue("distribution {i}"), shape_matches_output = FALSE ) @@ -171,7 +174,6 @@ mixture_distribution <- R6Class( vble(self$bounds, dim = self$dim) }, tf_distrib = function(parameters, dag) { - # get information from the *nodes* for component distributions, not the tf # objects passed in here @@ -200,7 +202,6 @@ mixture_distribution <- R6Class( log_weights <- log_weights - log_weights_sum log_prob <- function(x) { - # get component densities in an array log_probs <- mapply( dag$tf_evaluate_density, @@ -231,7 +232,6 @@ mixture_distribution <- R6Class( } sample <- function(seed) { - # draw samples from each component samples <- lapply(distribution_nodes, dag$draw_sample) names(samples) <- NULL @@ -264,7 +264,8 @@ mixture_distribution <- R6Class( # extract the relevant component indices <- tf$expand_dims(indices, n_batches) - draws <- tf$gather(samples_array, + draws <- tf$gather( + samples_array, indices, axis = collapse_axis, batch_dims = n_batches diff --git a/R/new_install_process.R b/R/new_install_process.R index 278d3622..01bcb7ed 100644 --- a/R/new_install_process.R +++ b/R/new_install_process.R @@ -1,9 +1,11 @@ -new_install_process <- function(callr_process, - timeout, - stdout_file = NULL, - stderr_file = NULL, - cli_start_msg = NULL, - cli_end_msg = NULL){ +new_install_process <- function( + callr_process, + timeout, + stdout_file = NULL, + stderr_file = NULL, + cli_start_msg = NULL, + cli_end_msg = NULL +) { cli::cli_process_start(cli_start_msg) # convert max timeout from milliseconds into minutes timeout_minutes <- timeout * 1000 * 60 @@ -34,10 +36,11 @@ new_install_process <- function(callr_process, cli_process_done(msg_done = cli_end_msg) return( - list(output_notes = output_notes, - status = status, - no_output = no_output, - output_error = output_error) + list( + output_notes = output_notes, + status = status, + no_output = no_output, + output_error = output_error + ) ) - } diff --git a/R/node_class.R b/R/node_class.R index 386fe048..49474c1b 100644 --- a/R/node_class.R +++ b/R/node_class.R @@ -15,7 +15,7 @@ node <- R6Class( dim = NA, distribution = NULL, initialize = function(dim = NULL, value = NULL) { - dim <- dim %||% c(1,1) + dim <- dim %||% c(1, 1) # coerce dim to integer dim <- as.integer(dim) @@ -37,7 +37,6 @@ node <- R6Class( register_family = function(dag) { ## TODO add explaining variable if (!(self$unique_name %in% names(dag$node_list))) { - # add self to list self$register(dag) @@ -59,13 +58,11 @@ node <- R6Class( } }, add_parent = function(node) { - # add to list of parents self$parents <- c(self$parents, node) node$add_child(self) }, remove_parent = function(node) { - # remove node from list of parents rem_idx <- which(self$parent_names(recursive = FALSE) == node$unique_name) self$parents <- self$parents[-rem_idx] @@ -81,7 +78,7 @@ node <- R6Class( parents <- c(parents, list(self$distribution)) } - if (mode == "sampling" & has_representation(self, "cholesky")){ + if (mode == "sampling" & has_representation(self, "cholesky")) { # remove cholesky representation node from parents parent_names <- extract_unique_names(parents) antirep_name <- get_node(self$representations$cholesky)$unique_name @@ -89,7 +86,7 @@ node <- R6Class( parents <- parents[match(parent_names_keep, parent_names)] } - if (mode == "sampling" & has_anti_representation(self, "chol2symm")){ + if (mode == "sampling" & has_anti_representation(self, "chol2symm")) { chol2symm_node <- get_node(self$anti_representations$chol2symm) parents <- c(parents, list(chol2symm_node)) } @@ -97,12 +94,10 @@ node <- R6Class( parents }, add_child = function(node) { - # add to list of children self$children <- c(self$children, node) }, remove_child = function(node) { - # remove node from list of parents rem_idx <- which(self$child_names() == node$unique_name) self$children <- self$children[-rem_idx] @@ -183,9 +178,9 @@ node <- R6Class( } # if defined already, skip if (!self$defined(dag)) { - # make sure parents are defined - parents_defined <- vapply(self$list_parents(dag), + parents_defined <- vapply( + self$list_parents(dag), function(x) x$defined(dag), FUN.VALUE = FALSE ) @@ -193,14 +188,14 @@ node <- R6Class( parents <- self$list_parents(dag) lapply( parents[which(!parents_defined)], - function(x){ + function(x) { x$define_tf(dag) } ) } # then define self - # stop("hi from the future ... parents are of class:", str(parents)) + # stop("hi from the future ... parents are of class:", str(parents)) self$tf(dag) } }, @@ -210,7 +205,6 @@ node <- R6Class( if (is.null(new_value)) { self$.value } else { - # get the dimension of the new value dim <- dim(new_value) @@ -227,7 +221,6 @@ node <- R6Class( } }, set_distribution = function(distribution) { - check_is_distribution_node(distribution) # add it @@ -251,7 +244,9 @@ node <- R6Class( text <- node_type(self) text <- node_type_colour(text) - dist_txt <- glue::glue("{self$distribution$distribution_name} distribution") + dist_txt <- glue::glue( + "{self$distribution$distribution_name} distribution" + ) if (has_distribution(self)) { text <- cli::cli_fmt( cli::cli_text( @@ -285,29 +280,30 @@ node <- R6Class( label }, - make_antirepresentations = function(representations){ + make_antirepresentations = function(representations) { mapply( FUN = self$make_one_anti_representation, representations, names(representations) - ) + ) }, - make_one_anti_representation = function(ga, name){ + make_one_anti_representation = function(ga, name) { node <- get_node(ga) anti_name <- self$find_anti_name(name) node$anti_representations[[anti_name]] <- as.greta_array(self) node }, - find_anti_name = function(name){ - switch(name, - cholesky = "chol2symm", - chol2symm = "chol", - exp = "log", - log = "exp", - probit = "iprobit", - iprobit = "probit", - logit = "ilogit", - ilogit = "logit" + find_anti_name = function(name) { + switch( + name, + cholesky = "chol2symm", + chol2symm = "chol", + exp = "log", + log = "exp", + probit = "iprobit", + iprobit = "probit", + logit = "ilogit", + ilogit = "logit" ) } ) diff --git a/R/node_types.R b/R/node_types.R index 1a68e4d9..4c166c5a 100644 --- a/R/node_types.R +++ b/R/node_types.R @@ -83,16 +83,17 @@ operation_node <- R6Class( operation_args = NA, arguments = list(), tf_function_env = NA, - initialize = function(operation, - ..., - dim = NULL, - operation_args = list(), - tf_operation = NULL, - value = NULL, - representations = list(), - tf_function_env = parent.frame(3), - expand_scalars = FALSE) { - + initialize = function( + operation, + ..., + dim = NULL, + operation_args = list(), + tf_operation = NULL, + value = NULL, + representations = list(), + tf_function_env = parent.frame(3), + expand_scalars = FALSE + ) { # coerce all arguments to nodes, and remember the operation dots <- lapply(list(...), as.greta_array) @@ -132,7 +133,6 @@ operation_node <- R6Class( super$initialize(dim, value) }, add_argument = function(argument) { - # guess at a name, coerce to a node, and add as a parent parameter <- to_node(argument) self$add_parent(parameter) @@ -153,7 +153,6 @@ operation_node <- R6Class( } if (mode == "forward") { - # fetch the tensors from the environment arg_tf_names <- lapply(self$list_parents(dag), dag$tf_name) tf_args <- lapply(arg_tf_names, get, envir = tfe) @@ -165,7 +164,8 @@ operation_node <- R6Class( } # get the tensorflow function and apply it to the args - operation <- eval(parse(text = self$operation), + operation <- eval( + parse(text = self$operation), envir = self$tf_function_env ) tensor <- do.call(operation, tf_args) @@ -186,10 +186,12 @@ variable_node <- R6Class( lower = -Inf, upper = Inf, free_value = NULL, - initialize = function(lower = -Inf, - upper = Inf, - dim = NULL, - free_dim = prod(dim)) { + initialize = function( + lower = -Inf, + upper = Inf, + dim = NULL, + free_dim = prod(dim) + ) { check_if_lower_upper_numeric(lower, upper) # replace values of lower and upper with finite values for dimension @@ -222,7 +224,8 @@ variable_node <- R6Class( self$constraint <- "scalar_mixed" } - bad_limits <- switch(self$constraint, + bad_limits <- switch( + self$constraint, scalar_all_low = !all(is.finite(upper)), scalar_all_high = !all(is.finite(lower)), scalar_all_both = !all(is.finite(lower)) | !all(is.finite(upper)), @@ -264,20 +267,17 @@ variable_node <- R6Class( if (is.null(distrib_node)) { # does it have an anti-representation where it is the cholesky? - # the antirepresentation of cholesky is chol2symm - # if yes, we take antirep and get it to `tf`, then get the tf_name + # the antirepresentation of cholesky is chol2symm + # if yes, we take antirep and get it to `tf`, then get the tf_name chol2symm_ga <- self$anti_representations$chol2symm chol2symm_existing <- !is.null(chol2symm_ga) if (chol2symm_existing) { - chol2symm_node <- get_node(chol2symm_ga) chol2symm_name <- dag$tf_name(chol2symm_node) chol2symm_tensor <- get(chol2symm_name, envir = dag$tf_environment) tensor <- tf_chol(chol2symm_tensor) - } - } else { tensor <- dag$draw_sample(self$distribution) } @@ -290,10 +290,7 @@ variable_node <- R6Class( # create the log jacobian adjustment for the free state tf_adj <- self$tf_adjustment(dag) adj_name <- glue::glue("{tf_name}_adj") - assign(adj_name, - tf_adj, - envir = dag$tf_environment - ) + assign(adj_name, tf_adj, envir = dag$tf_environment) # map from the free to constrained state in a new tensor tf_free <- get(free_name, envir = dag$tf_environment) @@ -301,10 +298,7 @@ variable_node <- R6Class( } # assign to environment variable - assign(tf_name, - tensor, - envir = dag$tf_environment - ) + assign(tf_name, tensor, envir = dag$tf_environment) }, create_tf_bijector = function() { dim <- self$dim @@ -312,15 +306,18 @@ variable_node <- R6Class( upper <- flatten_rowwise(self$upper) constraints <- flatten_rowwise(self$constraint_array) - switch(self$constraint, + switch( + self$constraint, scalar_all_none = tf_scalar_bijector(dim), scalar_all_low = tf_scalar_neg_bijector(dim, upper = upper), scalar_all_high = tf_scalar_pos_bijector(dim, lower = lower), - scalar_all_both = tf_scalar_neg_pos_bijector(dim, + scalar_all_both = tf_scalar_neg_pos_bijector( + dim, lower = lower, upper = upper ), - scalar_mixed = tf_scalar_mixed_bijector(dim, + scalar_mixed = tf_scalar_mixed_bijector( + dim, lower = lower, upper = upper, constraints = constraints @@ -360,7 +357,8 @@ variable_node <- R6Class( ljd <- tf$expand_dims(ljd, 0L) tiling <- tf$stack( list(tf$shape(free)[0]), - axis = 0L) + axis = 0L + ) ljd <- tf$tile(ljd, tiling) } @@ -369,7 +367,6 @@ variable_node <- R6Class( # create a tensor giving the log jacobian adjustment for this variable tf_adjustment = function(dag) { - # find free version of node free_tensor_name <- glue::glue("{dag$tf_name(self)}_free") free_tensor <- get(free_tensor_name, envir = dag$tf_environment) @@ -394,12 +391,14 @@ distribution_node <- R6Class( truncation = NULL, parameters = list(), parameter_shape_matches_output = logical(), - initialize = function(name = "no distribution", - dim = NULL, - truncation = NULL, - discrete = FALSE, - multivariate = FALSE, - truncatable = TRUE) { + initialize = function( + name = "no distribution", + dim = NULL, + truncation = NULL, + discrete = FALSE, + multivariate = FALSE, + truncatable = TRUE + ) { super$initialize(dim) # for all distributions, set name, store dims, and set whether discrete @@ -416,9 +415,11 @@ distribution_node <- R6Class( # distributions) set the truncation can_be_truncated <- !self$multivariate & !self$discrete & self$truncatable - if (!is.null(truncation) & - !identical(truncation, self$bounds) & - can_be_truncated) { + if ( + !is.null(truncation) & + !identical(truncation, self$bounds) & + can_be_truncated + ) { self$truncation <- truncation } @@ -476,23 +477,17 @@ distribution_node <- R6Class( # optional function to reset the flags for target representations whenever a # target is changed reset_target_flags = function() { - }, # replace the existing target node with a new one remove_target = function() { - # remove x from parents self$remove_parent(self$target) self$target <- NULL }, tf = function(dag) { - # assign the distribution object constructor function to the environment - assign(dag$tf_name(self), - self$tf_distrib, - envir = dag$tf_environment - ) + assign(dag$tf_name(self), self$tf_distrib, envir = dag$tf_environment) }, # which node to use as the *tf* target (overwritten by some distributions) @@ -504,11 +499,12 @@ distribution_node <- R6Class( # have the same shape as the output (e.g. this is true for binomial's prob # parameter, but not for size) by default, assume a scalar (row) parameter # can be expanded up to the distribution size - add_parameter = function(parameter, - name, - shape_matches_output = TRUE, - expand_now = TRUE) { - + add_parameter = function( + parameter, + name, + shape_matches_output = TRUE, + expand_now = TRUE + ) { # record whether this parameter can be scaled up self$parameter_shape_matches_output[[name]] <- shape_matches_output @@ -525,15 +521,16 @@ distribution_node <- R6Class( # try to expand a greta array for a parameter up to the required dimension expand_parameter = function(parameter, dim) { - # can this realisation of the parameter be expanded? - expandable_shape <- ifelse(self$multivariate, + expandable_shape <- ifelse( + self$multivariate, is_row(parameter), is_scalar(parameter) ) # should we expand it now? - expanded_target <- ifelse(self$multivariate, + expanded_target <- ifelse( + self$multivariate, !identical(dim[1], 1L), !identical(dim, c(1L, 1L)) ) @@ -562,7 +559,8 @@ distribution_node <- R6Class( parameter <- as.greta_array(self$parameters[[name]]) expanded <- self$expand_parameter(parameter, dim) - self$add_parameter(expanded, + self$add_parameter( + expanded, name, self$parameter_shape_matches_output[[name]], expand_now = FALSE diff --git a/R/operators.R b/R/operators.R index a0dca594..2420b5e4 100644 --- a/R/operators.R +++ b/R/operators.R @@ -64,68 +64,52 @@ NULL #' @export `+.greta_array` <- function(e1, e2) { check_dims(e1, e2) - op("add", e1, e2, - tf_operation = "tf$add", - expand_scalars = TRUE - ) + op("add", e1, e2, tf_operation = "tf$add", expand_scalars = TRUE) } #' @export `-.greta_array` <- function(e1, e2) { # handle unary minus if (missing(e2)) { - op("minus", e1, - tf_operation = "tf$negative" - ) + op("minus", e1, tf_operation = "tf$negative") } else { check_dims(e1, e2) - op("subtract", e1, e2, - tf_operation = "tf$subtract", - expand_scalars = TRUE - ) + op("subtract", e1, e2, tf_operation = "tf$subtract", expand_scalars = TRUE) } } #' @export `*.greta_array` <- function(e1, e2) { check_dims(e1, e2) - op("multiply", e1, e2, - tf_operation = "tf$multiply", - expand_scalars = TRUE - ) + op("multiply", e1, e2, tf_operation = "tf$multiply", expand_scalars = TRUE) } #' @export `/.greta_array` <- function(e1, e2) { check_dims(e1, e2) - op("divide", e1, e2, - tf_operation = "tf$truediv", - expand_scalars = TRUE - ) + op("divide", e1, e2, tf_operation = "tf$truediv", expand_scalars = TRUE) } #' @export `^.greta_array` <- function(e1, e2) { check_dims(e1, e2) - op("power", e1, e2, - tf_operation = "tf$pow", - expand_scalars = TRUE - ) + op("power", e1, e2, tf_operation = "tf$pow", expand_scalars = TRUE) } #' @export `%%.greta_array` <- function(e1, e2) { check_dims(e1, e2) - op("`modulo`", e1, e2, - tf_operation = "tf$math$mod", - expand_scalars = TRUE - ) + op("`modulo`", e1, e2, tf_operation = "tf$math$mod", expand_scalars = TRUE) } #' @export -`%/%.greta_array` <- function(e1, e2) { # nolint +`%/%.greta_array` <- function(e1, e2) { + # nolint check_dims(e1, e2) - op("`integer divide`", e1, e2, + op( + "`integer divide`", + e1, + e2, tf_operation = "tf$math$floordiv", expand_scalars = TRUE ) @@ -135,19 +119,21 @@ NULL # would rather get S4 version working properly, but uuurgh S4. #' @export -`%*%.default` <- function(x, y) { # nolint +`%*%.default` <- function(x, y) { + # nolint .Primitive("%*%")(x, y) } #' @rdname overloaded #' @export -`%*%` <- function(x, y) { # nolint +`%*%` <- function(x, y) { + # nolint # if y is a greta array, coerce x before dispatch if (is.greta_array(y) & !is.greta_array(x)) { as_data(x) %*% y - # if y is not a greta array and x is, coerce y before dispatch - } else if (!is.greta_array(y) & is.greta_array(x)){ + # if y is not a greta array and x is, coerce y before dispatch + } else if (!is.greta_array(y) & is.greta_array(x)) { x %*% as_data(y) } else { UseMethod("%*%", x) @@ -155,12 +141,16 @@ NULL } #' @export -`%*%.greta_array` <- function(x, y) { # nolint +`%*%.greta_array` <- function(x, y) { + # nolint - check_both_2d(x,y) + check_both_2d(x, y) check_compatible_dimensions(x, y) - op("matrix multiply", x, y, + op( + "matrix multiply", + x, + y, dim = c(nrow(x), ncol(y)), tf_operation = "tf$matmul" ) @@ -169,81 +159,63 @@ NULL # logical operators #' @export `!.greta_array` <- function(e1) { - op("not", e1, - tf_operation = "tf_not" - ) + op("not", e1, tf_operation = "tf_not") } #' @export -`&.greta_array` <- function(e1, e2) { # nolint +`&.greta_array` <- function(e1, e2) { + # nolint check_dims(e1, e2) - op("and", e1, e2, - tf_operation = "tf_and", - expand_scalars = TRUE - ) + op("and", e1, e2, tf_operation = "tf_and", expand_scalars = TRUE) } #' @export -`|.greta_array` <- function(e1, e2) { # nolint +`|.greta_array` <- function(e1, e2) { + # nolint check_dims(e1, e2) - op("or", e1, e2, - tf_operation = "tf_or", - expand_scalars = TRUE - ) + op("or", e1, e2, tf_operation = "tf_or", expand_scalars = TRUE) } # relational operators #' @export -`<.greta_array` <- function(e1, e2) { # nolint +`<.greta_array` <- function(e1, e2) { + # nolint check_dims(e1, e2) - op("less", e1, e2, - tf_operation = "tf_lt", - expand_scalars = TRUE - ) + op("less", e1, e2, tf_operation = "tf_lt", expand_scalars = TRUE) } #' @export -`>.greta_array` <- function(e1, e2) { # nolint +`>.greta_array` <- function(e1, e2) { + # nolint check_dims(e1, e2) - op("greater", e1, e2, - tf_operation = "tf_gt", - expand_scalars = TRUE - ) + op("greater", e1, e2, tf_operation = "tf_gt", expand_scalars = TRUE) } #' @export -`<=.greta_array` <- function(e1, e2) { # nolint +`<=.greta_array` <- function(e1, e2) { + # nolint check_dims(e1, e2) - op("less/equal", e1, e2, - tf_operation = "tf_lte", - expand_scalars = TRUE - ) + op("less/equal", e1, e2, tf_operation = "tf_lte", expand_scalars = TRUE) } #' @export -`>=.greta_array` <- function(e1, e2) { # nolint +`>=.greta_array` <- function(e1, e2) { + # nolint check_dims(e1, e2) - op("greater/equal", e1, e2, - tf_operation = "tf_gte", - expand_scalars = TRUE - ) + op("greater/equal", e1, e2, tf_operation = "tf_gte", expand_scalars = TRUE) } #' @export -`==.greta_array` <- function(e1, e2) { # nolint +`==.greta_array` <- function(e1, e2) { + # nolint check_dims(e1, e2) - op("equal", e1, e2, - tf_operation = "tf_eq", - expand_scalars = TRUE - ) + op("equal", e1, e2, tf_operation = "tf_eq", expand_scalars = TRUE) } #' @export -`!=.greta_array` <- function(e1, e2) { # nolint +`!=.greta_array` <- function(e1, e2) { + # nolint check_dims(e1, e2) - op("not equal", e1, e2, - tf_operation = "tf_neq", - expand_scalars = TRUE - ) + op("not equal", e1, e2, tf_operation = "tf_neq", expand_scalars = TRUE) } diff --git a/R/optimiser_class.R b/R/optimiser_class.R index 38aa1c42..6df044d8 100644 --- a/R/optimiser_class.R +++ b/R/optimiser_class.R @@ -20,15 +20,15 @@ optimiser <- R6Class( # set up the model initialize = function( - initial_values, - model, - name, - method, - parameters, - other_args, - max_iterations, - tolerance, - adjust + initial_values, + model, + name, + method, + parameters, + other_args, + max_iterations, + tolerance, + adjust ) { super$initialize( initial_values, @@ -112,7 +112,6 @@ tf_optimiser <- R6Class( "tf_optimiser", inherit = optimiser, public = list( - # create an op to minimise the objective run_tf_minimiser = function() { dag <- self$model$dag @@ -138,8 +137,10 @@ tf_optimiser <- R6Class( # TF1/2 todo # get this to work inside TF with TF while loop - while (self$it < self$max_iterations & - all(self$diff > self$tolerance)) { + while ( + self$it < self$max_iterations & + all(self$diff > self$tolerance) + ) { # add 1 because python indexing self$it <- as.numeric(tfe$tf_optimiser$iterations) + 1 ## TF1/2 For Keras 3.0, this is the new syntax @@ -170,10 +171,12 @@ tf_optimiser <- R6Class( } }, - check_numerical_overflow = function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ - if (!is.finite(x)){ + check_numerical_overflow = function( + x, + arg = rlang::caller_arg(x), + call = rlang::caller_env() + ) { + if (!is.finite(x)) { cli::cli_abort( message = c( "Detected numerical overflow during optimisation", @@ -186,8 +189,6 @@ tf_optimiser <- R6Class( ) } } - - ) ) @@ -195,7 +196,6 @@ tfp_optimiser <- R6Class( "tfp_optimiser", inherit = optimiser, public = list( - run_tfp_minimiser = function() { dag <- self$model$dag tfe <- dag$tf_environment @@ -213,7 +213,7 @@ tfp_optimiser <- R6Class( } # bfgs uses value_and_gradient - value_and_gradient <- function(x){ + value_and_gradient <- function(x) { tfp$math$value_and_gradient( function(x) objective(x), x @@ -221,7 +221,6 @@ tfp_optimiser <- R6Class( } self$run_minimiser <- function(inits) { - self$parameters$max_iterations <- self$max_iterations # TF1/2 todo # will be better in the long run to have some kind of @@ -232,12 +231,12 @@ tfp_optimiser <- R6Class( } else if (self$name == "nelder_mead") { # nelder_mead uses different args, so we must change the ags in place self$parameters$batch_evaluate_objective <- FALSE - self$parameters$objective_function <- function(x){ + self$parameters$objective_function <- function(x) { x_expand <- tf$expand_dims(x, axis = 0L) val <- objective(x_expand) tf$squeeze(val) } - self$parameters$initial_vertex <- fl(inits[1,]) + self$parameters$initial_vertex <- fl(inits[1, ]) } tfe$tf_optimiser <- do.call( @@ -260,7 +259,6 @@ tf_compat_optimiser <- R6Class( "tf_compat_optimiser", inherit = optimiser, public = list( - # some of the optimisers are very fussy about dtypes, so convert them now sanitise_dtypes = function() { self$set_dtype("global_step", tf$int64) @@ -308,8 +306,10 @@ tf_compat_optimiser <- R6Class( # TF1/2 todo # get this to work inside TF with TF while loop - while (self$it < self$max_iterations & - all(self$diff > self$tolerance)) { + while ( + self$it < self$max_iterations & + all(self$diff > self$tolerance) + ) { # add 1 because python indexing self$it <- self$it + 1 diff --git a/R/optimisers.R b/R/optimisers.R index 0baa015b..ee52f5e2 100644 --- a/R/optimisers.R +++ b/R/optimisers.R @@ -44,7 +44,7 @@ optimiser_defunct_error <- function(optimiser) { in {.pkg greta} 0.5.0.", "Please use a different optimiser.", "See {.code ?optimisers} for detail on which optimizers are removed." - ) + ) ) } @@ -59,11 +59,7 @@ optimiser_deprecation_warning <- function(version = "0.4.0") { ) } -new_optimiser <- function(name, - method, - parameters, - class, - other_args){ +new_optimiser <- function(name, method, parameters, class, other_args) { obj <- list( name = name, method = method, @@ -77,10 +73,12 @@ new_optimiser <- function(name, obj } -define_tf_optimiser <- function(name, - method, - parameters = list(), - other_args = list()) { +define_tf_optimiser <- function( + name, + method, + parameters = list(), + other_args = list() +) { new_optimiser( name = name, method = method, @@ -90,10 +88,12 @@ define_tf_optimiser <- function(name, ) } -define_tf_compat_optimiser <- function(name, - method, - parameters = list(), - other_args = list()) { +define_tf_compat_optimiser <- function( + name, + method, + parameters = list(), + other_args = list() +) { new_optimiser( name = name, method = method, @@ -103,10 +103,12 @@ define_tf_compat_optimiser <- function(name, ) } -define_tfp_optimiser <- function(name, - method, - parameters = list(), - other_args = list()) { +define_tfp_optimiser <- function( + name, + method, + parameters = list(), + other_args = list() +) { new_optimiser( name = name, method = method, @@ -167,16 +169,16 @@ define_tfp_optimiser <- function(name, #' @export #' nelder_mead <- function( - objective_function = NULL, - initial_vertex = NULL, - step_sizes = NULL, - func_tolerance = 1e-08, - position_tolerance = 1e-08, - reflection = NULL, - expansion = NULL, - contraction = NULL, - shrinkage = NULL) { - + objective_function = NULL, + initial_vertex = NULL, + step_sizes = NULL, + func_tolerance = 1e-08, + position_tolerance = 1e-08, + reflection = NULL, + expansion = NULL, + contraction = NULL, + shrinkage = NULL +) { define_tfp_optimiser( name = "nelder_mead", method = "tfp$optimizer$nelder_mead_minimize", @@ -249,16 +251,18 @@ nelder_mead <- function( #' smaller than this value, the algorithm is stopped. #' #' @export -bfgs <- function(value_and_gradients_function = NULL, - initial_position = NULL, - tolerance = 1e-08, - x_tolerance = 0L, - f_relative_tolerance = 0L, - initial_inverse_hessian_estimate = NULL, - stopping_condition = NULL, - validate_args = TRUE, - max_line_search_iterations = 50L, - f_absolute_tolerance = 0L) { +bfgs <- function( + value_and_gradients_function = NULL, + initial_position = NULL, + tolerance = 1e-08, + x_tolerance = 0L, + f_relative_tolerance = 0L, + initial_inverse_hessian_estimate = NULL, + stopping_condition = NULL, + validate_args = TRUE, + max_line_search_iterations = 50L, + f_absolute_tolerance = 0L +) { define_tfp_optimiser( name = "bfgs", method = "tfp$optimizer$bfgs_minimize", @@ -345,9 +349,11 @@ slsqp <- function() { #' relevant direction and dampens oscillations. Defaults to 0, which is #' vanilla gradient descent. #' @param nesterov Whether to apply Nesterov momentum. Defaults to FALSE. -gradient_descent <- function(learning_rate = 0.01, - momentum = 0, - nesterov = FALSE) { +gradient_descent <- function( + learning_rate = 0.01, + momentum = 0, + nesterov = FALSE +) { define_tf_optimiser( name = "gradient_descent", method = "tf$keras$optimizers$legacy$SGD", @@ -384,9 +390,11 @@ adadelta <- function(learning_rate = 0.001, rho = 1, epsilon = 1e-08) { #' @param initial_accumulator_value initial value of the 'accumulator' used to #' tune the algorithm #' -adagrad <- function(learning_rate = 0.8, - initial_accumulator_value = 0.1, - epsilon = 1e-08) { +adagrad <- function( + learning_rate = 0.8, + initial_accumulator_value = 0.1, + epsilon = 1e-08 +) { define_tf_optimiser( name = "adagrad", # method = "tf$keras$optimizers$Adagrad", @@ -414,12 +422,13 @@ adagrad <- function(learning_rate = 0.8, #' @note This optimizer isn't supported in TF2, so proceed with caution. See #' the [TF docs on AdagradDAOptimiser](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/AdagradDAOptimizer) for more detail. #' -adagrad_da <- function(learning_rate = 0.8, - global_step = 1L, - initial_gradient_squared_accumulator_value = 0.1, - l1_regularization_strength = 0, - l2_regularization_strength = 0) { - +adagrad_da <- function( + learning_rate = 0.8, + global_step = 1L, + initial_gradient_squared_accumulator_value = 0.1, + l1_regularization_strength = 0, + l2_regularization_strength = 0 +) { optimiser_deprecation_warning(version = "0.6.0") define_tf_compat_optimiser( @@ -428,8 +437,7 @@ adagrad_da <- function(learning_rate = 0.8, parameters = list( learning_rate = learning_rate, global_step = global_step, - initial_gradient_squared_accumulator_value = - initial_gradient_squared_accumulator_value, + initial_gradient_squared_accumulator_value = initial_gradient_squared_accumulator_value, l1_regularization_strength = l1_regularization_strength, l2_regularization_strength = l2_regularization_strength ) @@ -445,11 +453,13 @@ adagrad_da <- function(learning_rate = 0.8, #' @param amsgrad Boolean. Whether to apply AMSGrad variant of this algorithm #' from the paper "On the Convergence of Adam and beyond". Defaults to FALSE. #' -adam <- function(learning_rate = 0.1, - beta_1 = 0.9, - beta_2 = 0.999, - amsgrad = FALSE, - epsilon = 1e-08) { +adam <- function( + learning_rate = 0.1, + beta_1 = 0.9, + beta_2 = 0.999, + amsgrad = FALSE, + epsilon = 1e-08 +) { define_tf_optimiser( name = "adam", # method = "tf$keras$optimizers$Adam", @@ -467,10 +477,12 @@ adam <- function(learning_rate = 0.1, #' @rdname optimisers #' @export #' -adamax <- function(learning_rate = 0.001, - beta_1 = 0.9, - beta_2 = 0.999, - epsilon = 1e-07){ +adamax <- function( + learning_rate = 0.001, + beta_1 = 0.9, + beta_2 = 0.999, + epsilon = 1e-07 +) { define_tf_optimiser( name = "adamax", # method = "tf$keras$optimizers$Adamax", @@ -495,13 +507,15 @@ adamax <- function(learning_rate = 0.001, #' @param beta A float value, representing the beta value from the paper by #' [McMahan et al 2013](https://static.googleusercontent.com/media/research.google.com/en//pubs/archive/41159.pdf). Defaults to 0 #' -ftrl <- function(learning_rate = 1, - learning_rate_power = -0.5, - initial_accumulator_value = 0.1, - l1_regularization_strength = 0, - l2_regularization_strength = 0, - l2_shrinkage_regularization_strength = 0, - beta = 0) { +ftrl <- function( + learning_rate = 1, + learning_rate_power = -0.5, + initial_accumulator_value = 0.1, + l1_regularization_strength = 0, + l2_regularization_strength = 0, + l2_shrinkage_regularization_strength = 0, + beta = 0 +) { define_tf_optimiser( name = "ftrl", # method = "tf$keras$optimizers$Ftrl", @@ -524,10 +538,11 @@ ftrl <- function(learning_rate = 1, #' @note This optimizer isn't supported in TF2, so proceed with caution. See #' the [TF docs on ProximalGradientDescentOptimizer](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/ProximalGradientDescentOptimizer) for more detail. #' -proximal_gradient_descent <- function(learning_rate = 0.01, - l1_regularization_strength = 0, - l2_regularization_strength = 0) { - +proximal_gradient_descent <- function( + learning_rate = 0.01, + l1_regularization_strength = 0, + l2_regularization_strength = 0 +) { optimiser_deprecation_warning(version = "0.6.0") define_tf_compat_optimiser( @@ -548,11 +563,12 @@ proximal_gradient_descent <- function(learning_rate = 0.01, #' the [TF docs on ProximalAdagradOptimizer](https://www.tensorflow.org/api_docs/python/tf/compat/v1/train/ProximalAdagradOptimizer) for more detail. #' -proximal_adagrad <- function(learning_rate = 1, - initial_accumulator_value = 0.1, - l1_regularization_strength = 0, - l2_regularization_strength = 0) { - +proximal_adagrad <- function( + learning_rate = 1, + initial_accumulator_value = 0.1, + l1_regularization_strength = 0, + l2_regularization_strength = 0 +) { optimiser_deprecation_warning(version = "0.6.0") define_tf_compat_optimiser( @@ -570,11 +586,12 @@ proximal_adagrad <- function(learning_rate = 1, #' @rdname optimisers #' @export #' -nadam <- function(learning_rate = 0.001, - beta_1 = 0.9, - beta_2 = 0.999, - epsilon = 1e-07){ - +nadam <- function( + learning_rate = 0.001, + beta_1 = 0.9, + beta_2 = 0.999, + epsilon = 1e-07 +) { define_tf_optimiser( name = "nadam", # method = "tf$keras$optimizers$Nadam", @@ -586,7 +603,6 @@ nadam <- function(learning_rate = 0.001, epsilon = epsilon ) ) - } #' @rdname optimisers @@ -596,11 +612,13 @@ nadam <- function(learning_rate = 0.001, #' variance of the gradient; if FALSE, by the uncentered second moment. #' Setting this to TRUE may help with training, but is slightly more #' expensive in terms of computation and memory. Defaults to FALSE. -rms_prop <- function(learning_rate = 0.1, - rho = 0.9, - momentum = 0, - epsilon = 1e-10, - centered = FALSE) { +rms_prop <- function( + learning_rate = 0.1, + rho = 0.9, + momentum = 0, + epsilon = 1e-10, + centered = FALSE +) { define_tf_optimiser( name = "rms_prop", # method = "tf$keras$optimizers$RMSprop", diff --git a/R/package.R b/R/package.R index 94b741be..3b6df155 100644 --- a/R/package.R +++ b/R/package.R @@ -46,13 +46,15 @@ reticulate::py_version # clear CRAN checks spotting floating global variables #' @importFrom utils globalVariables utils::globalVariables( - c("N", - "greta_deps_tf_tfp", - "greta_logfile", - "os", - "python_version_max", - "python_version_min", - "tf_version", - "tfp_version", - "greta") + c( + "N", + "greta_deps_tf_tfp", + "greta_logfile", + "os", + "python_version_max", + "python_version_min", + "tf_version", + "tfp_version", + "greta" + ) ) diff --git a/R/probability_distributions.R b/R/probability_distributions.R index 1f1c35ba..9488db1f 100644 --- a/R/probability_distributions.R +++ b/R/probability_distributions.R @@ -124,7 +124,6 @@ bernoulli_distribution <- R6Class( if (self$prob_is_logit) { tfp$distributions$Bernoulli(logits = parameters$prob) } else if (self$prob_is_probit) { - # in the probit case, get the log probability of success and compute the # log prob directly probit <- parameters$prob @@ -176,7 +175,6 @@ binomial_distribution <- R6Class( logits = parameters$prob ) } else if (self$prob_is_probit) { - # in the probit case, get the log probability of success and compute the # log prob directly size <- parameters$size @@ -430,7 +428,6 @@ weibull_distribution <- R6Class( } sample <- function(seed) { - # sample by pushing standard uniforms through the inverse cdf u <- tf_randu(self$dim, dag) quantile(u) @@ -484,7 +481,6 @@ pareto_distribution <- R6Class( self$add_parameter(b, "b") }, tf_distrib = function(parameters, dag) { - # a is shape, b is scale tfp$distributions$Pareto( concentration = parameters$a, @@ -680,7 +676,6 @@ f_distribution <- R6Class( } sample <- function(seed) { - # sample as the ratio of two scaled chi squared distributions d1 <- tfp$distributions$Chi2(df = df1) d2 <- tfp$distributions$Chi2(df = df2) @@ -718,7 +713,9 @@ dirichlet_distribution <- R6Class( # coerce the parameter arguments to nodes and add as parents and # parameters self$bounds <- c(0, Inf) - super$initialize("dirichlet", dim, + super$initialize( + "dirichlet", + dim, truncation = c(0, Inf), multivariate = TRUE ) @@ -743,7 +740,6 @@ dirichlet_multinomial_distribution <- R6Class( inherit = distribution_node, public = list( initialize = function(size, alpha, n_realisations, dimension) { - # coerce to greta arrays size <- as.greta_array(size) alpha <- as.greta_array(alpha) @@ -755,12 +751,12 @@ dirichlet_multinomial_distribution <- R6Class( dimension = dimension ) - # need to handle size as a vector! # coerce the parameter arguments to nodes and add as parents and # parameters - super$initialize("dirichlet_multinomial", + super$initialize( + "dirichlet_multinomial", dim = dim, discrete = TRUE, multivariate = TRUE @@ -787,7 +783,6 @@ multinomial_distribution <- R6Class( inherit = distribution_node, public = list( initialize = function(size, prob, n_realisations, dimension) { - # coerce to greta arrays size <- as.greta_array(size) prob <- as.greta_array(prob) @@ -803,7 +798,8 @@ multinomial_distribution <- R6Class( # coerce the parameter arguments to nodes and add as parents and # parameters - super$initialize("multinomial", + super$initialize( + "multinomial", dim = dim, discrete = TRUE, multivariate = TRUE @@ -815,8 +811,8 @@ multinomial_distribution <- R6Class( parameters$size <- tf_flatten(parameters$size) # scale probs to get absolute density correct # parameters$prob <- parameters$prob / tf_sum(parameters$prob) - parameters$prob <- parameters$prob / tf_rowsums(parameters$prob, - dims = 1L) + parameters$prob <- parameters$prob / + tf_rowsums(parameters$prob, dims = 1L) tfp$distributions$Multinomial( total_count = parameters$size, @@ -831,7 +827,6 @@ categorical_distribution <- R6Class( inherit = distribution_node, public = list( initialize = function(prob, n_realisations, dimension) { - # coerce to greta arrays prob <- as.greta_array(prob) @@ -843,7 +838,8 @@ categorical_distribution <- R6Class( # coerce the parameter arguments to nodes and add as parents and # parameters - super$initialize("categorical", + super$initialize( + "categorical", dim = dim, discrete = TRUE, multivariate = TRUE @@ -897,7 +893,6 @@ multivariate_normal_distribution <- R6Class( self$add_parameter(sigma, "sigma") }, tf_distrib = function(parameters, dag) { - # if Sigma is a cholesky factor transpose it to tensorflow expoectation, # otherwise decompose it @@ -926,7 +921,6 @@ wishart_distribution <- R6Class( "wishart_distribution", inherit = distribution_node, public = list( - # TF1/2 - consider setting this as NULL for debugging purposes # set when defining the distribution sigma_is_cholesky = FALSE, @@ -934,7 +928,8 @@ wishart_distribution <- R6Class( # TF1/2 - consider setting this as NULL for debugging purposes # set when defining the graph target_is_cholesky = FALSE, - initialize = function(df, Sigma) { # nolint + initialize = function(df, Sigma) { + # nolint # add the nodes as parents and parameters df <- as.greta_array(df) sigma <- as.greta_array(Sigma) @@ -1029,7 +1024,6 @@ wishart_distribution <- R6Class( log_prob <- log_prob_raw + adjustment log_prob - } sample <- function(seed) { @@ -1069,7 +1063,6 @@ lkj_correlation_distribution <- R6Class( "lkj_correlation_distribution", inherit = distribution_node, public = list( - # set when defining the graph target_is_cholesky = FALSE, eta_is_cholesky = FALSE, @@ -1097,7 +1090,6 @@ lkj_correlation_distribution <- R6Class( # default (cholesky factor, ignores truncation) create_target = function(truncation) { - # create (correlation matrix) cholesky factor variable greta array chol_greta_array <- cholesky_variable(self$dim[1], correlation = TRUE) @@ -1130,7 +1122,7 @@ lkj_correlation_distribution <- R6Class( eta <- tf$squeeze(parameters$eta, 1:2) dim <- self$dim[1] - log_prob <- function(x){ + log_prob <- function(x) { if (self$target_is_cholesky) { x_chol <- tf$linalg$matrix_transpose(x) } else { @@ -1158,7 +1150,6 @@ lkj_correlation_distribution <- R6Class( log_prob <- log_prob_raw + adjustment log_prob - } # tfp's lkj sampling can't detect the size of the output from eta, for @@ -1179,7 +1170,6 @@ lkj_correlation_distribution <- R6Class( # and R uses upper triangular (non zeroes are in top right) draws <- tf$matmul(chol_draws, chol_draws, adjoint_b = TRUE) draws - } tf$map_fn(sample_once, eta) @@ -1521,18 +1511,26 @@ f <- function(df1, df2, dim = NULL, truncation = c(0, Inf)) { # nolint start #' @rdname distributions #' @export -multivariate_normal <- function(mean, Sigma, - n_realisations = NULL, dimension = NULL) { +multivariate_normal <- function( + mean, + Sigma, + n_realisations = NULL, + dimension = NULL +) { # nolint end distrib( - "multivariate_normal", mean, Sigma, - n_realisations, dimension + "multivariate_normal", + mean, + Sigma, + n_realisations, + dimension ) } #' @rdname distributions #' @export -wishart <- function(df, Sigma) { # nolint +wishart <- function(df, Sigma) { + # nolint distrib("wishart", df, Sigma) } @@ -1562,10 +1560,17 @@ dirichlet <- function(alpha, n_realisations = NULL, dimension = NULL) { #' @rdname distributions #' @export -dirichlet_multinomial <- function(size, alpha, - n_realisations = NULL, dimension = NULL) { +dirichlet_multinomial <- function( + size, + alpha, + n_realisations = NULL, + dimension = NULL +) { distrib( "dirichlet_multinomial", - size, alpha, n_realisations, dimension + size, + alpha, + n_realisations, + dimension ) } diff --git a/R/progress_bar.R b/R/progress_bar.R index e28a9c24..ab9ac91f 100644 --- a/R/progress_bar.R +++ b/R/progress_bar.R @@ -8,18 +8,11 @@ # iterations respectively # 'pb_update' gives the number of iterations between updates of the progress bar create_progress_bar <- function(phase, iter, pb_update, width, ...) { - # name for formatting - name <- switch(phase, - warmup = " warmup", - sampling = "sampling" - ) + name <- switch(phase, warmup = " warmup", sampling = "sampling") # total iterations for bat - iter_this <- switch(phase, - warmup = iter[1], - sampling = iter[2] - ) + iter_this <- switch(phase, warmup = iter[1], sampling = iter[2]) # pad the frmat so that the width iterations counter is the same for both # warmup and sampling @@ -87,13 +80,15 @@ iterate_progress_bar <- function(pb, it, rejects, chains, file = NULL) { # tick the progess bar and record the output message # (or print it if file = NULL) - record(pb$tick(amount, - tokens = list( - iter = iter_pretty, - rejection = reject_text - ) - ), - file = file + record( + pb$tick( + amount, + tokens = list( + iter = iter_pretty, + rejection = reject_text + ) + ), + file = file ) } } diff --git a/R/reinstallers.R b/R/reinstallers.R index 9444d98f..dc527395 100644 --- a/R/reinstallers.R +++ b/R/reinstallers.R @@ -22,50 +22,52 @@ #' reinstall_greta_env() #' reinstall_miniconda() #' } -remove_greta_env <- function(){ - cli::cli_alert_info("removing 'greta-env-tf2' conda environment", - wrap = TRUE) +remove_greta_env <- function() { + cli::cli_alert_info("removing 'greta-env-tf2' conda environment", wrap = TRUE) reticulate::conda_remove( envname = "greta-env-tf2" ) - cli::cli_alert_success("greta-env-tf2 environment removed!", - wrap = TRUE) + cli::cli_alert_success("greta-env-tf2 environment removed!", wrap = TRUE) } #' @export #' @param timeout time in minutes to wait until timeout (default is 5 minutes) #' @rdname reinstallers -reinstall_greta_env <- function(timeout = 5){ +reinstall_greta_env <- function(timeout = 5) { remove_greta_env() greta_create_conda_env(timeout = timeout) } #' @export #' @rdname reinstallers -remove_miniconda <- function(){ +remove_miniconda <- function() { path_to_miniconda <- reticulate::miniconda_path() - if (!file.exists(path_to_miniconda)){ - cli::cli_alert_info("No miniconda files found at {path_to_miniconda}", - wrap = TRUE) + if (!file.exists(path_to_miniconda)) { + cli::cli_alert_info( + "No miniconda files found at {path_to_miniconda}", + wrap = TRUE + ) return(invisible()) } - if (yesno::yesno("Are you sure you want to delete miniconda from ", - path_to_miniconda,"?") ){ - cli::cli_alert_info("removing 'miniconda' installation", - wrap = TRUE) + if ( + yesno::yesno( + "Are you sure you want to delete miniconda from ", + path_to_miniconda, + "?" + ) + ) { + cli::cli_alert_info("removing 'miniconda' installation", wrap = TRUE) unlink(path_to_miniconda, recursive = TRUE) - cli::cli_alert_success("'miniconda' successfully removed!", - wrap = TRUE) + cli::cli_alert_success("'miniconda' successfully removed!", wrap = TRUE) } else { return(invisible()) } - } #' @param timeout time in minutes to wait until timeout (default is 5 minutes) #' @rdname reinstallers #' @export -reinstall_miniconda <- function(timeout = 5){ +reinstall_miniconda <- function(timeout = 5) { remove_miniconda() greta_install_miniconda(timeout) } @@ -78,16 +80,18 @@ reinstall_miniconda <- function(timeout = 5){ #' # issues with installing greta dependencies #' reinstall_greta_deps() #' } -reinstall_greta_deps <- function(deps = greta_deps_spec(), - timeout = 5, - restart = c("ask", "force", "no")){ +reinstall_greta_deps <- function( + deps = greta_deps_spec(), + timeout = 5, + restart = c("ask", "force", "no") +) { remove_greta_env() remove_miniconda() install_greta_deps( deps = deps, timeout = timeout, restart = restart - ) + ) } #' Remove greta dependencies and remove miniconda @@ -99,7 +103,7 @@ reinstall_greta_deps <- function(deps = greta_deps_spec(), #' #' @return nothing #' @export -destroy_greta_deps <- function(){ +destroy_greta_deps <- function() { cli::cli_progress_step( msg = "You are removing greta env and miniconda", msg_done = c("You have successfully removed greta env and miniconda") diff --git a/R/sampler_class.R b/R/sampler_class.R index e598f28d..988228e3 100644 --- a/R/sampler_class.R +++ b/R/sampler_class.R @@ -3,7 +3,6 @@ sampler <- R6Class( "sampler", inherit = inference, public = list( - # sampler information sampler_number = 1, n_samplers = 1, @@ -40,11 +39,13 @@ sampler <- R6Class( # batch sizes for tracing trace_batch_size = 100, - initialize = function(initial_values, - model, - parameters = list(), - seed, - compute_options) { + initialize = function( + initial_values, + model, + parameters = list(), + seed, + compute_options + ) { # initialize the inference method super$initialize( initial_values = initial_values, @@ -68,47 +69,46 @@ sampler <- R6Class( # define the draws tensor on the tf graph # define_tf_draws is now used in place of of run_burst self$define_tf_evaluate_sample_batch() - }, - define_tf_evaluate_sample_batch = function(){ + define_tf_evaluate_sample_batch = function() { self$tf_evaluate_sample_batch <- tensorflow::tf_function( f = self$define_tf_draws, input_signature = list( # free state - tf$TensorSpec(shape = list(NULL, self$n_free), - dtype = tf_float()), + tf$TensorSpec(shape = list(NULL, self$n_free), dtype = tf_float()), # sampler_burst_length - tf$TensorSpec(shape = list(), - dtype = tf$int32), + tf$TensorSpec(shape = list(), dtype = tf$int32), # sampler_thin - tf$TensorSpec(shape = list(), - dtype = tf$int32), + tf$TensorSpec(shape = list(), dtype = tf$int32), # sampler_param_vec - tf$TensorSpec(shape = list( - length( - unlist( - self$sampler_parameter_values() + tf$TensorSpec( + shape = list( + length( + unlist( + self$sampler_parameter_values() + ) ) - ) - ), - dtype = tf_float() + ), + dtype = tf_float() ) ) ) }, - run_chain = function(n_samples, - thin, - warmup, - verbose, - pb_update, - one_by_one, - plan_is, - n_cores, - float_type, - trace_batch_size, - from_scratch = TRUE) { + run_chain = function( + n_samples, + thin, + warmup, + verbose, + pb_update, + one_by_one, + plan_is, + n_cores, + float_type, + trace_batch_size, + from_scratch = TRUE + ) { self$warmup <- warmup self$thin <- thin dag <- self$model$dag @@ -120,22 +120,24 @@ sampler <- R6Class( self$print_sampler_number() } if (plan_is$parallel) { - dag$define_tf_trace_values_batch() dag$define_tf_log_prob_function() self$define_tf_evaluate_sample_batch() - } # create these objects if needed if (from_scratch) { - self$traced_free_state <- self$empty_matrices(n = self$n_chains, - ncol = self$n_free) + self$traced_free_state <- self$empty_matrices( + n = self$n_chains, + ncol = self$n_free + ) - self$traced_values <- self$empty_matrices(n = self$n_chains, - ncol = self$n_traced) + self$traced_values <- self$empty_matrices( + n = self$n_chains, + ncol = self$n_traced + ) } # how big would we like the bursts to be @@ -162,10 +164,10 @@ sampler <- R6Class( }, run_warmup = function( - n_samples, - pb_update, - ideal_burst_size, - verbose + n_samples, + pb_update, + ideal_burst_size, + verbose ) { perform_warmup <- self$warmup > 0 if (perform_warmup) { @@ -189,9 +191,11 @@ sampler <- R6Class( } # split up warmup iterations into bursts of sampling - burst_lengths <- self$burst_lengths(self$warmup, - ideal_burst_size, - warmup = TRUE) + burst_lengths <- self$burst_lengths( + self$warmup, + ideal_burst_size, + warmup = TRUE + ) completed_iterations <- cumsum(burst_lengths) @@ -213,7 +217,6 @@ sampler <- R6Class( self$tune(completed_iterations[burst], self$warmup) if (verbose) { - # update the progress bar/percentage log iterate_progress_bar( pb = pb_warmup, @@ -232,24 +235,25 @@ sampler <- R6Class( } # scrub the free state trace and numerical rejections - self$traced_free_state <- self$empty_matrices(n = self$n_chains, - ncol = self$n_free) + self$traced_free_state <- self$empty_matrices( + n = self$n_chains, + ncol = self$n_free + ) self$numerical_rejections <- 0 } # end warmup }, - run_sampling = function ( + run_sampling = function( n_samples, pb_update, ideal_burst_size, trace_batch_size, thin, verbose - ){ + ) { perform_sampling <- n_samples > 0 if (perform_sampling) { - # on exiting during the main sampling period (even if killed by the # user) trace the free state values @@ -269,7 +273,7 @@ sampler <- R6Class( rejects = 0, chains = self$n_chains, file = self$pb_file - ) + ) } else { pb_sampling <- NULL } @@ -283,13 +287,11 @@ sampler <- R6Class( # and how often to return them # TF1/2 check todo # replace with define_tf_draws - self$run_burst(n_samples = burst_lengths[burst], - thin = thin) + self$run_burst(n_samples = burst_lengths[burst], thin = thin) # trace is it receiving the python self$trace() if (verbose) { - # update the progress bar/percentage log iterate_progress_bar( pb = pb_sampling, @@ -307,13 +309,11 @@ sampler <- R6Class( } } } # end sampling - }, # update the welford accumulator for summary statistics of the posterior, # used for tuning update_welford = function() { - # unlist the states into a matrix trace_matrix <- do.call(rbind, self$last_burst_free_states) @@ -346,9 +346,10 @@ sampler <- R6Class( # convert traced free state to the traced values, accounting for # chain dimension trace_values = function(trace_batch_size) { - self$traced_values <- lapply(self$traced_free_state, - self$model$dag$trace_values, - trace_batch_size = trace_batch_size + self$traced_values <- lapply( + self$traced_free_state, + self$model$dag$trace_values, + trace_batch_size = trace_batch_size ) }, @@ -383,12 +384,10 @@ sampler <- R6Class( # considering the progress bar update frequency and the parameter tuning # schedule during warmup burst_lengths = function(n_samples, pb_update, warmup = FALSE) { - # when to stop for progress bar updates changepoints <- c(seq(0, n_samples, by = pb_update), n_samples) if (warmup) { - # when to break to update tuning tuning_points <- seq(0, n_samples, by = self$tuning_interval) @@ -410,7 +409,6 @@ sampler <- R6Class( self$tune_diag_sd(iterations_completed, total_iterations) }, tune_epsilon = function(iter, total) { - # tuning periods for the tunable parameters (first 10%, last 60%) tuning_periods <- list(c(0, 0.1), c(0.4, 1)) @@ -422,7 +420,6 @@ sampler <- R6Class( ) if (tuning_now) { - # epsilon & tuning parameters kappa <- 0.75 gamma <- 0.1 @@ -451,7 +448,6 @@ sampler <- R6Class( } }, tune_diag_sd = function(iterations_completed, total_iterations) { - # when, during warmup, to tune this parameter (after epsilon, but stopping # before halfway through) tuning_periods <- list(c(0.1, 0.4)) @@ -467,7 +463,6 @@ sampler <- R6Class( # provided there have been at least 5 acceptances in the warmup so far if (n_accepted > 5) { - # get the sample posterior variance and shrink it sample_var <- self$sample_variance() shrinkage <- 1 / (n_accepted + 5) @@ -478,13 +473,13 @@ sampler <- R6Class( }, # TF1/2 check todo # need to convert this into a TF function - define_tf_draws = function(free_state, - sampler_burst_length, - sampler_thin, - sampler_param_vec - # pass values through + define_tf_draws = function( + free_state, + sampler_burst_length, + sampler_thin, + sampler_param_vec + # pass values through ) { - dag <- self$model$dag tfe <- dag$tf_environment @@ -531,8 +526,7 @@ sampler <- R6Class( # this will be removed in favour of the tf_function decorated # define_tf_draws() function that takes in argument values # sampler_burst_length and sampler_thin - run_burst = function(n_samples, - thin = 1L) { + run_burst = function(n_samples, thin = 1L) { dag <- self$model$dag tfe <- dag$tf_environment @@ -551,7 +545,7 @@ sampler <- R6Class( ) # get trace of free state and drop the null dimension - if (is.null(batch_results$all_states)){ + if (is.null(batch_results$all_states)) { browser() } free_state_draws <- as.array(batch_results$all_states) @@ -562,7 +556,6 @@ sampler <- R6Class( dim(free_state_draws) <- c(1, dim(free_state_draws)) } - self$last_burst_free_states <- split_chains(free_state_draws) n_draws <- nrow(free_state_draws) @@ -573,7 +566,6 @@ sampler <- R6Class( } if (self$uses_metropolis) { - # log acceptance probability log_accept_stats <- as.array(batch_results$trace$log_accept_ratio) is_accepted <- as.array(batch_results$trace$is_accepted) @@ -589,11 +581,12 @@ sampler <- R6Class( tf_evaluate_sample_batch = NULL, - sample_carefully = function(free_state, - sampler_burst_length, - sampler_thin, - sampler_param_vec) { - + sample_carefully = function( + free_state, + sampler_burst_length, + sampler_thin, + sampler_param_vec + ) { # tryCatch handling for numerical errors dag <- self$model$dag tfe <- dag$tf_environment @@ -624,12 +617,11 @@ sampler <- R6Class( result }, - check_for_free_state_error = function(result, n_samples){ + check_for_free_state_error = function(result, n_samples) { # if it's fine, batch_results is the output # if it's a non-numerical error, it will error # if it's a numerical error, batch_results will be an error object if (inherits(result, "error")) { - # simple case that this is a single bad sample. Mock up a result and # pass it back if (n_samples == 1L) { @@ -641,7 +633,6 @@ sampler <- R6Class( ) ) } else { - greta_stash$tf_num_error <- result # otherwise, *one* of these multiple samples was bad. The sampler @@ -657,21 +648,20 @@ sampler <- R6Class( "{.code greta_notes_tf_num_error()}" ) ) - } } }, sampler_parameter_values = function() { - # random number of integration steps self$parameters }, - empty_matrices = function(n, - ncol){ - replicate(n = n, - matrix(data = NA, nrow = 0, ncol = ncol), - simplify = FALSE) - } + empty_matrices = function(n, ncol) { + replicate( + n = n, + matrix(data = NA, nrow = 0, ncol = ncol), + simplify = FALSE + ) + } ) ) diff --git a/R/samplers.R b/R/samplers.R index 2eacdaaf..b1591e82 100644 --- a/R/samplers.R +++ b/R/samplers.R @@ -26,10 +26,7 @@ NULL #' selected uniformly at random from between `Lmin` and `Lmax`. #' `diag_sd` is used to rescale the parameter space to make it more #' uniform, and make sampling more efficient. -hmc <- function(Lmin = 5, - Lmax = 10, - epsilon = 0.1, - diag_sd = 1) { +hmc <- function(Lmin = 5, Lmax = 10, epsilon = 0.1, diag_sd = 1) { # nolint end obj <- list( parameters = list( @@ -56,9 +53,11 @@ hmc <- function(Lmin = 5, #' @param proposal the probability distribution used to generate proposal states #' #' @export -rwmh <- function(proposal = c("normal", "uniform"), - epsilon = 0.1, - diag_sd = 1) { +rwmh <- function( + proposal = c("normal", "uniform"), + epsilon = 0.1, + diag_sd = 1 +) { proposal <- match.arg(proposal) obj <- list( @@ -87,7 +86,7 @@ slice <- function(max_doublings = 5) { obj <- list( parameters = list( max_doublings = as.integer(max_doublings)[1] - ), + ), name = "slice", class = slice_sampler ) @@ -98,7 +97,8 @@ slice <- function(max_doublings = 5) { #' @noRd #' @export print.sampler <- function(x, ...) { - values_text <- paste(names(x$parameters), + values_text <- paste( + names(x$parameters), prettyNum(x$parameters), sep = " = ", collapse = ", " @@ -106,10 +106,12 @@ print.sampler <- function(x, ...) { if (!nzchar(values_text)) values_text <- "None" - parameters_text <- glue::glue(" + parameters_text <- glue::glue( + " parameters: {values_text} - ") + " + ) msg <- glue::glue( "{class(x)[1]} object with {parameters_text}" @@ -132,7 +134,6 @@ hmc_sampler <- R6Class( accept_target = 0.651, define_tf_kernel = function(sampler_param_vec) { - dag <- self$model$dag tfe <- dag$tf_environment @@ -143,7 +144,7 @@ hmc_sampler <- R6Class( hmc_l <- sampler_param_vec[0] hmc_epsilon <- sampler_param_vec[1] - hmc_diag_sd <- sampler_param_vec[2:(1+free_state_size)] + hmc_diag_sd <- sampler_param_vec[2:(1 + free_state_size)] hmc_step_sizes <- tf$cast( x = tf$reshape( @@ -170,7 +171,6 @@ hmc_sampler <- R6Class( ) }, sampler_parameter_values = function() { - # random number of integration steps l_min <- self$parameters$Lmin l_max <- self$parameters$Lmax @@ -201,19 +201,19 @@ rwmh_sampler <- R6Class( accept_target = 0.44, define_tf_kernel = function(sampler_param_vec) { - # wrap this up into a function to extract these out free_state_size <- length(sampler_param_vec) - 1 # get it from dag object # e.g., length(dag$free_state) rwmh_epsilon <- sampler_param_vec[0] - rwmh_diag_sd <- sampler_param_vec[1:(1+free_state_size)] + rwmh_diag_sd <- sampler_param_vec[1:(1 + free_state_size)] dag <- self$model$dag tfe <- dag$tf_environment - tfe$rwmh_proposal <- switch(self$parameters$proposal, - normal = tfp$mcmc$random_walk_normal_fn, - uniform = tfp$mcmc$random_walk_uniform_fn + tfe$rwmh_proposal <- switch( + self$parameters$proposal, + normal = tfp$mcmc$random_walk_normal_fn, + uniform = tfp$mcmc$random_walk_uniform_fn ) # TF1/2 check @@ -306,7 +306,6 @@ slice_sampler <- R6Class( # no additional here tuning tune = function(iterations_completed, total_iterations) { - } ) ) diff --git a/R/simulate.R b/R/simulate.R index 70392a5d..4cab20a2 100644 --- a/R/simulate.R +++ b/R/simulate.R @@ -48,11 +48,13 @@ #' sims <- simulate(m, nsim = 100) #' } #' # nolint start -simulate.greta_model <- function(object, - nsim = 1, - seed = NULL, - precision = c("double", "single"), - ...) { +simulate.greta_model <- function( + object, + nsim = 1, + seed = NULL, + precision = c("double", "single"), + ... +) { # nolint end # find all the greta arrays in the calling environment target_greta_arrays <- all_greta_arrays(parent.frame()) diff --git a/R/structures.R b/R/structures.R index b271de6d..76f34d21 100644 --- a/R/structures.R +++ b/R/structures.R @@ -65,7 +65,6 @@ greta_array <- function(data = 0, dim = length(data)) { # safely handle self-coersion, possibly with reshaping #' @export greta_array.greta_array <- function(data = 0, dim = length(data)) { - # reshape if necessary (apparently users expect this functionality) dim <- as.integer(dim) if (length(dim) == 1) { diff --git a/R/test_if_forked_cluster.R b/R/test_if_forked_cluster.R index 5aa2a56f..212a2a36 100644 --- a/R/test_if_forked_cluster.R +++ b/R/test_if_forked_cluster.R @@ -1,4 +1,4 @@ -test_if_forked_cluster <- function(){ +test_if_forked_cluster <- function() { is_forked <- value(future(parallelly::isForkedChild())) if (is_forked) { cli::cli_abort( diff --git a/R/testthat-helpers.R b/R/testthat-helpers.R index 854f49e9..b760545e 100644 --- a/R/testthat-helpers.R +++ b/R/testthat-helpers.R @@ -20,11 +20,16 @@ as_variable <- function(x) { # check a greta operation and the equivalent R operation give the same output # e.g. check_op(sum, randn(100, 3)) -check_op <- function(op, a, b, greta_op = NULL, - other_args = list(), - tolerance = 1e-3, - only = c("data", "variable", "batched"), - relative_error = FALSE) { +check_op <- function( + op, + a, + b, + greta_op = NULL, + other_args = list(), + tolerance = 1e-3, + only = c("data", "variable", "batched"), + relative_error = FALSE +) { greta_op <- greta_op %||% op r_out <- run_r_op(op, a, b, other_args) @@ -36,10 +41,15 @@ check_op <- function(op, a, b, greta_op = NULL, } } -compare_op <- function(r_out, greta_out, tolerance = 1e-4, relative_error = FALSE) { - if (relative_error){ +compare_op <- function( + r_out, + greta_out, + tolerance = 1e-4, + relative_error = FALSE +) { + if (relative_error) { difference <- as.vector(abs(r_out - greta_out) / abs(r_out)) - } else if (!relative_error){ + } else if (!relative_error) { difference <- as.vector(abs(r_out - greta_out)) } difference_lt_tolerance <- difference < tolerance @@ -57,14 +67,20 @@ run_r_op <- function(op, a, b, other_args) { do.call(op, arg_list) } -run_greta_op <- function(greta_op, a, b, other_args, - type = c("data", "variable", "batched")) { +run_greta_op <- function( + greta_op, + a, + b, + other_args, + type = c("data", "variable", "batched") +) { type <- match.arg(type) - converter <- switch(type, - data = as_data, - variable = as_variable, - batched = as_variable + converter <- switch( + type, + data = as_data, + variable = as_variable, + batched = as_variable ) g_a <- converter(a) diff --git a/R/tf_functions.R b/R/tf_functions.R index 3ff689d4..8fffa435 100644 --- a/R/tf_functions.R +++ b/R/tf_functions.R @@ -94,7 +94,6 @@ tf_cumprod <- function(x) { # set the dimensions of a tensor, reshaping in the same way (column-major) as R tf_set_dim <- function(x, dims) { - # transpose to do work in row-major order perm_old <- c(0L, rev(seq_along(dim(x)[-1]))) x <- tf$transpose(x, perm_old) @@ -111,7 +110,6 @@ tf_set_dim <- function(x, dims) { # expand the dimensions of a scalar tensor, reshaping in the same way # (column-major) as R tf_expand_dim <- function(x, dims) { - # prepend a batch dimension to dims (a 1 so we can tile with it) dims <- c(1L, dims) @@ -154,9 +152,11 @@ tf_tapply <- function(x, segment_ids, num_segments, op_name) { op_name <- glue::glue("unsorted_segment_{op_name}") x <- tf$transpose(x, perm = c(1:2, 0L)) - x <- tf$math[[op_name]](x, + x <- tf$math[[op_name]]( + x, segment_ids = segment_ids, - num_segments = num_segments) + num_segments = num_segments + ) x <- tf$transpose(x, perm = c(2L, 0:1)) x } @@ -222,16 +222,9 @@ tf_corrmat_row <- function(z, which = c("values", "ljac")) { ) # nolint end - body <- switch(which, - values = body_values, - ljac = body_ljac - ) + body <- switch(which, values = body_values, ljac = body_ljac) - out <- tf$while_loop(cond, - body, - values, - shape_invariants = shapes - ) + out <- tf$while_loop(cond, body, values, shape_invariants = shapes) if (which == "values") { x <- out[[2]] @@ -308,8 +301,14 @@ tf_kronecker <- function(x, y, tf_fun_name) { # expand dimensions of tensors to allow direct multiplication for kronecker # prod - x_rsh <- tf$reshape(x, tensorflow::as_tensor(shape(-1, dims[1], 1L, dims[2], 1L))) - y_rsh <- tf$reshape(y, tensorflow::as_tensor(shape(-1, 1L, dims[3], 1L, dims[4]))) + x_rsh <- tf$reshape( + x, + tensorflow::as_tensor(shape(-1, dims[1], 1L, dims[2], 1L)) + ) + y_rsh <- tf$reshape( + y, + tensorflow::as_tensor(shape(-1, 1L, dims[3], 1L, dims[4])) + ) # multiply tensors and reshape with appropriate dimensions z <- tf_function(x_rsh, y_rsh) @@ -321,7 +320,6 @@ tf_kronecker <- function(x, y, tf_fun_name) { # tensorflow version of sweep, based on broadcasting of tf ops tf_sweep <- function(x, stats, margin, fun) { - # if the second margin, transpose before and after if (margin == 2) { x <- tf_transpose(x) @@ -431,7 +429,6 @@ tf_imultilogit <- function(x) { # input tensor # dims_out - dimension of output array tf_extract <- function(x, nelem, index, dims_out) { - # flatten tensor, gather using index, reshape to output dimension tensor_in_flat <- tf$reshape(x, tensorflow::as_tensor(shape(-1, nelem))) tf_index <- tf$constant(as.integer(index), dtype = tf$int32) @@ -450,7 +447,6 @@ tf_extract <- function(x, nelem, index, dims_out) { # values, a tensor `updates` at the elements given by the R vector `index` (in # 0-indexing) tf_recombine <- function(ref, index, updates) { - # vector denoting whether an element is being updated nelem <- dim(ref)[[2]] replaced <- rep(0, nelem) @@ -498,7 +494,6 @@ tf_flatten <- function(x, extra_ones = 0) { # replace elements in a tensor with another tensor tf_replace <- function(x, replacement, index, dims) { - # flatten original tensor and new values x_flat <- tf_flatten(x, 1) replacement_flat <- tf_flatten(replacement, 1) @@ -590,7 +585,7 @@ tf_scalar_bijector <- function(dim, lower, upper) { ) } -tfb_shift_scale <- function(shift, scale){ +tfb_shift_scale <- function(shift, scale) { tfb_shift <- tfp$bijectors$Shift(shift) tfb_scale <- tfp$bijectors$Scale(scale) tfb_shift_scale <- tfb_shift(tfb_scale) @@ -651,7 +646,8 @@ tf_scalar_mixed_bijector <- function(dim, lower, upper, constraints) { # create bijectors for each block names(block_constructors) <- NULL - bijectors <- mapply(do.call, + bijectors <- mapply( + do.call, block_constructors, block_parameters, SIMPLIFY = FALSE @@ -671,11 +667,9 @@ tf_correlation_cholesky_bijector <- function() { ) tfp$bijectors$Chain(steps) - } tf_covariance_cholesky_bijector <- function() { - steps <- list( tfp$bijectors$Transpose(perm = 1:0), tfp$bijectors$FillScaleTriL(diag_shift = fl(1e-5)) diff --git a/R/transforms.R b/R/transforms.R index ac434219..1960aff8 100644 --- a/R/transforms.R +++ b/R/transforms.R @@ -48,7 +48,9 @@ NULL #' @rdname transforms #' @export iprobit <- function(x) { - op("iprobit", x, + op( + "iprobit", + x, tf_operation = "tf_iprobit", representations = list(probit = x) ) @@ -57,7 +59,9 @@ iprobit <- function(x) { #' @rdname transforms #' @export ilogit <- function(x) { - op("ilogit", x, + op( + "ilogit", + x, tf_operation = "tf$nn$sigmoid", representations = list(logit = x) ) @@ -88,8 +92,5 @@ imultilogit <- function(x) { check_2d(x) - op("imultilogit", x, - dim = dim + c(0, 1), - tf_operation = "tf_imultilogit" - ) + op("imultilogit", x, dim = dim + c(0, 1), tf_operation = "tf_imultilogit") } diff --git a/R/unknowns_class.R b/R/unknowns_class.R index db9aedaa..4a0860a9 100644 --- a/R/unknowns_class.R +++ b/R/unknowns_class.R @@ -1,23 +1,27 @@ #' @title Create objects of class 'unknowns' to nicely print ? valued arrays #' @param x object to convert to "unknowns" class #' @export -as.unknowns <- function(x) { # nolint +as.unknowns <- function(x) { + # nolint UseMethod("as.unknowns") } #' @export -as.unknowns.unknowns <- function(x) { # nolint +as.unknowns.unknowns <- function(x) { + # nolint x } #' @export -as.unknowns.array <- function(x) { # nolint +as.unknowns.array <- function(x) { + # nolint class(x) <- c("unknowns", class(x)) x } #' @export -as.unknowns.matrix <- function(x) { # nolint +as.unknowns.matrix <- function(x) { + # nolint as.unknowns.array(x) } @@ -44,7 +48,7 @@ print.unknowns <- function(x, ..., n = 10) { return(invisible(x)) } - if (remaining_vals > 0 ) { + if (remaining_vals > 0) { cli::cli_alert_info( text = c( "i" = "{remaining_vals} more values\n", @@ -52,7 +56,6 @@ print.unknowns <- function(x, ..., n = 10) { ) ) } - } # create an unknowns array from some dimensions @@ -65,7 +68,8 @@ unknowns <- function(dims = c(1, 1), data = NA_real_) { #' @param x matrix/array to set values to #' @param value values that are being set set #' @export -`dim<-.unknowns` <- function(x, value) { # nolint +`dim<-.unknowns` <- function(x, value) { + # nolint x <- unclass(x) dim(x) <- value as.unknowns(x) diff --git a/R/utils.R b/R/utils.R index 1458f8af..679a479c 100644 --- a/R/utils.R +++ b/R/utils.R @@ -136,7 +136,6 @@ record <- function(expr, file) { # convert an assumed numeric to an array with at least 2 dimensions as_2d_array <- function(x) { - # coerce data from common formats to an array here x <- as.array(x) @@ -365,15 +364,16 @@ palettize <- function(base_colour) { # colour scheme for plotting #' @importFrom grDevices col2rgb -greta_col <- function(which = c( - "main", - "dark", - "light", - "lighter", - "super_light" -), -colour = "#996bc7") { - +greta_col <- function( + which = c( + "main", + "dark", + "light", + "lighter", + "super_light" + ), + colour = "#996bc7" +) { # tests if a color encoded as string can be converted to RGB tryCatch( is.matrix(grDevices::col2rgb(colour)), @@ -386,12 +386,13 @@ colour = "#996bc7") { which <- match.arg(which) pal <- palettize(colour) - switch(which, - dark = pal(0.45), # 45% - main = pal(0.55), # 55% - light = pal(0.65), # 65%ish - lighter = pal(0.85), # 85%ish - super_light = pal(0.95) + switch( + which, + dark = pal(0.45), # 45% + main = pal(0.55), # 55% + light = pal(0.65), # 65%ish + lighter = pal(0.85), # 85%ish + super_light = pal(0.95) ) # 95%ish } @@ -402,17 +403,16 @@ colour_module <- module( # look in the environment specified by env, and return a named list of all greta # arrays in that environment -all_greta_arrays <- function(env = parent.frame(), - include_data = TRUE) { - +all_greta_arrays <- function(env = parent.frame(), include_data = TRUE) { # all objects in that environment as a named list all_object_names <- ls(envir = env) # loop carefully in case there are unfulfilled promises all_objects <- list() for (name in all_object_names) { - all_objects[[name]] <- tryCatch(get(name, envir = env), - error = function(e) NULL + all_objects[[name]] <- tryCatch( + get(name, envir = env), + error = function(e) NULL ) } @@ -423,9 +423,10 @@ all_greta_arrays <- function(env = parent.frame(), # optionally strip out the data arrays if (!include_data) { - is_data <- vapply(all_arrays, - function(x) is.data_node(get_node(x)), - FUN.VALUE = FALSE + is_data <- vapply( + all_arrays, + function(x) is.data_node(get_node(x)), + FUN.VALUE = FALSE ) all_arrays <- all_arrays[!is_data] } @@ -461,25 +462,32 @@ prepare_draws <- function(draws, thin = 1) { coda::mcmc(draws_df, thin = thin) } -build_sampler <- function(initial_values, sampler, model, seed = get_seed(), - compute_options) { +build_sampler <- function( + initial_values, + sampler, + model, + seed = get_seed(), + compute_options +) { ## TF1/2 retracing ## This is where a retracing warning happens ## in mcmc - sampler$class$new(initial_values, - model, - sampler$parameters, - seed = seed, - compute_options = compute_options + sampler$class$new( + initial_values, + model, + sampler$parameters, + seed = seed, + compute_options = compute_options ) } -build_samplers <- function(sampler, - initial_values, - chains, - model, - compute_options){ - +build_samplers <- function( + sampler, + initial_values, + chains, + model, + compute_options +) { # determine number of separate samplers to spin up, based on future plan max_samplers <- future::nbrOfWorkers() @@ -513,15 +521,15 @@ build_samplers <- function(sampler, } samplers - } -sampler_parallel_reporting <- function(n_chain, - samplers, - chains, - n_samples, - warmup){ - +sampler_parallel_reporting <- function( + n_chain, + samplers, + chains, + n_samples, + warmup +) { trace_log_files <- replicate(n_chain, create_log_file()) percentage_log_files <- replicate(n_chain, create_log_file(TRUE)) progress_bar_log_files <- replicate(n_chain, create_log_file(TRUE)) @@ -529,7 +537,6 @@ sampler_parallel_reporting <- function(n_chain, pb_width <- bar_width(n_chain) for (chain in chains) { - # set the log files sampler <- samplers[[chain]] sampler$trace_log_file <- trace_log_files[[chain]] @@ -549,7 +556,6 @@ sampler_parallel_reporting <- function(n_chain, ) sampler - } # unlist and flatten a list of arrays to a vector row-wise unlist_tf <- function(x) { @@ -592,7 +598,6 @@ flatten_trace <- function(i, trace_list) { # extract the model information object from mcmc samples returned by # stashed_samples, and error nicely if there's something fishy get_model_info <- function(draws) { - check_if_greta_mcmc_list(draws) model_info <- attr(draws, "model_info") @@ -624,7 +629,6 @@ sampler_utils_module <- module( # we could use this as a way of returning a function that TF recognises # as a function tensorflow function that returns tensors as_tf_function <- function(r_fun, ...) { - # run the operation on isolated greta arrays, so nothing gets attached to the # model real greta arrays in dots # creating a fake greta array @@ -690,7 +694,6 @@ as_tf_function <- function(r_fun, ...) { tf_out <- list() for (i in seq_along(ga_out)) { - # define the output nodes node_out <- get_node(ga_out[[i]]) node_out$define_tf(sub_dag) @@ -724,12 +727,12 @@ utilities_module <- module( ) # remove empty strings -base_remove_empty_string <- function(string){ +base_remove_empty_string <- function(string) { string[string != ""] } -other_install_fail_msg <- function(error_passed){ +other_install_fail_msg <- function(error_passed) { # drop "" error_passed <- base_remove_empty_string(error_passed) cli::format_error( @@ -758,7 +761,7 @@ other_install_fail_msg <- function(error_passed){ ) } -timeout_install_msg <- function(timeout = 5, py_error = NULL){ +timeout_install_msg <- function(timeout = 5, py_error = NULL) { msg <- c( "Stopping as installation of {.pkg greta} dependencies took longer than \\ {timeout} minutes", @@ -787,7 +790,7 @@ timeout_install_msg <- function(timeout = 5, py_error = NULL){ py_error <- NULL } - if (is.null(py_error)){ + if (is.null(py_error)) { cli::format_error( message = msg ) @@ -803,12 +806,12 @@ timeout_install_msg <- function(timeout = 5, py_error = NULL){ } } -is_DiagrammeR_installed <- function(){ +is_DiagrammeR_installed <- function() { requireNamespace("DiagrammeR", quietly = TRUE) } -greta_conda_env_path <- function(){ - if (!have_greta_conda_env()){ +greta_conda_env_path <- function() { + if (!have_greta_conda_env()) { cli::cli_ul("path: no conda env found for {.var greta-env-tf2}") } @@ -816,7 +819,6 @@ greta_conda_env_path <- function(){ which_greta_env <- which(py_cl$name == "greta-env-tf2") greta_env_path <- py_cl$python[which_greta_env] greta_env_path - } # adapted from https://github.com/rstudio/tensorflow/blob/main/R/utils.R @@ -830,11 +832,11 @@ is_mac_arm64 <- function() { is_darwin && is_arm64 } -read_char <- function(path){ +read_char <- function(path) { trimws(readChar(path, nchars = file.info(path)$size)) } -create_temp_file <- function(path){ +create_temp_file <- function(path) { file_path <- tempfile(path, fileext = ".txt") file.create(file_path) return(file_path) @@ -847,17 +849,17 @@ create_temp_file <- function(path){ #' complexity. These functions are passed to `compute_options` inside of a few #' functions: [mcmc()], [opt()], and [calculate()]. #' @export -gpu_only <- function(){ +gpu_only <- function() { "GPU" } #' @rdname gpu_cpu #' @export -cpu_only <- function(){ +cpu_only <- function() { "CPU" } -compute_text <- function(n_cores, compute_options){ +compute_text <- function(n_cores, compute_options) { ifelse( test = n_cores == 1, yes = "each on 1 core", @@ -874,17 +876,17 @@ connected_to_draws <- function(dag, mcmc_dag) { names(dag$node_list) %in% names(mcmc_dag$node_list) } -is_using_gpu <- function(x){ +is_using_gpu <- function(x) { x == "GPU" } -is_using_cpu <- function(x){ +is_using_cpu <- function(x) { x == "CPU" } `%||%` <- function(x, y) if (is.null(x)) y else x -message_if_using_gpu <- function(compute_options){ +message_if_using_gpu <- function(compute_options) { gpu_used <- is_using_gpu(compute_options) greta_gpu_message <- getOption("greta_gpu_message") %||% TRUE gpu_used_and_message <- gpu_used && greta_gpu_message @@ -905,36 +907,35 @@ message_if_using_gpu <- function(compute_options){ n_dim <- function(x) length(dim(x)) is_2d <- function(x) n_dim(x) == 2 -is.node <- function(x, ...){ +is.node <- function(x, ...) { inherits(x, "node") } -is.data_node <- function(x, ...){ +is.data_node <- function(x, ...) { inherits(x, "data_node") } -is.distribution_node <- function(x, ...){ +is.distribution_node <- function(x, ...) { inherits(x, "distribution_node") } -is.variable_node <- function(x, ...){ +is.variable_node <- function(x, ...) { inherits(x, "variable_node") } -is.greta_model <- function(x, ...){ +is.greta_model <- function(x, ...) { inherits(x, "greta_model") } -is.unknowns <- function(x, ...){ +is.unknowns <- function(x, ...) { inherits(x, "unknowns") } -is.initials <- function(x, ...){ +is.initials <- function(x, ...) { inherits(x, "initials") } -node_type_colour <- function(type){ - +node_type_colour <- function(type) { switch_cols <- switch( type, variable = cli::col_red(type), @@ -946,7 +947,7 @@ node_type_colour <- function(type){ switch_cols } -extract_unique_names <- function(x){ +extract_unique_names <- function(x) { vapply( X = x, FUN = member, @@ -955,7 +956,7 @@ extract_unique_names <- function(x){ ) } -are_identical <- function(x, y){ +are_identical <- function(x, y) { vapply( X = x, FUN = identical, @@ -977,7 +978,7 @@ are_identical <- function(x, y){ #' are_null(list(NULL, NULL, NULL)) #' are_null(list(1, 2, 3)) #' is.null(list(1, 2, 3)) -are_null <- function(x){ +are_null <- function(x) { vapply( x, is.null, @@ -985,7 +986,7 @@ are_null <- function(x){ ) } -are_greta_array <- function(x){ +are_greta_array <- function(x) { vapply( x, is.greta_array, @@ -993,7 +994,7 @@ are_greta_array <- function(x){ ) } -have_distribution <- function(x){ +have_distribution <- function(x) { vapply( x, has_distribution, @@ -1013,7 +1014,7 @@ is_linux <- function() { identical(tolower(Sys.info()[["sysname"]]), "linux") } -os_name <- function(){ +os_name <- function() { os <- c( windows = is_windows(), mac = is_mac(), @@ -1023,8 +1024,7 @@ os_name <- function(){ } # semantic version finder -closest_version <- function(current, available){ - +closest_version <- function(current, available) { available <- sort(available) not_available <- !(current %in% available) @@ -1043,13 +1043,12 @@ closest_version <- function(current, available){ closest <- min(available) } - if (current_btn_available){ + if (current_btn_available) { version_gt <- current > available closest <- max(available[version_gt]) } return(closest) - } outside_version_range <- function(provided, range) { @@ -1060,7 +1059,7 @@ outside_version_range <- function(provided, range) { outside_range } -pretty_dim <- function(x){ +pretty_dim <- function(x) { x_dim <- dim(x) print_dim_x <- x_dim %||% x @@ -1068,7 +1067,7 @@ pretty_dim <- function(x){ prettied_dim } -are_initials <- function(x){ +are_initials <- function(x) { vapply( X = x, FUN = is.initials, @@ -1076,7 +1075,7 @@ are_initials <- function(x){ ) } -n_warmup <- function(x){ +n_warmup <- function(x) { x_info <- attr(x, "model_info") x_info$warmup } diff --git a/R/variable.R b/R/variable.R index 1be29925..a00e974d 100644 --- a/R/variable.R +++ b/R/variable.R @@ -80,10 +80,7 @@ cholesky_variable <- function(dim, correlation = FALSE) { k <- dim[1] # dimension of the free state version - free_dim <- ifelse(correlation, - k * (k - 1) / 2, - k + k * (k - 1) / 2 - ) + free_dim <- ifelse(correlation, k * (k - 1) / 2, k + k * (k - 1) / 2) # create variable node node <- vble( @@ -93,9 +90,10 @@ cholesky_variable <- function(dim, correlation = FALSE) { ) # set the constraint, to enable transformation - node$constraint <- ifelse(correlation, - "correlation_matrix", - "covariance_matrix" + node$constraint <- ifelse( + correlation, + "correlation_matrix", + "covariance_matrix" ) # set the printed value to be nicer @@ -117,7 +115,6 @@ cholesky_variable <- function(dim, correlation = FALSE) { #' # a 4D simplex on the final dimension #' g <- simplex_variable(dim = c(2, 3, 4)) simplex_variable <- function(dim) { - # for scalar dims, return a row vector if (length(dim) == 1) { dim <- c(1, dim) @@ -157,7 +154,6 @@ simplex_variable <- function(dim) { #' # ordered positive variable #' i <- exp(ordered_variable(5)) ordered_variable <- function(dim) { - # for scalar dims, return a row vector if (length(dim) == 1) { dim <- c(1, dim) diff --git a/R/write-logfiles.R b/R/write-logfiles.R index 65d772c2..d51cb26a 100644 --- a/R/write-logfiles.R +++ b/R/write-logfiles.R @@ -9,8 +9,8 @@ #' @return nothing - sets an environment variable for use with #' [install_greta_deps()]. #' @export -greta_set_install_logfile <- function(path){ - Sys.setenv("GRETA_INSTALLATION_LOG"=path) +greta_set_install_logfile <- function(path) { + Sys.setenv("GRETA_INSTALLATION_LOG" = path) } #' Write greta dependency installation log file @@ -23,11 +23,10 @@ greta_set_install_logfile <- function(path){ #' @return nothing - writes to file #' @export write_greta_install_log <- function(path = greta_logfile) { - cli::cli_progress_step( msg = "Writing logfile to {.path {path}}", msg_done = "Logfile written to {.path {path}}" - ) + ) cli::cli_progress_step( msg = "Open with: {.run open_greta_install_log()}" @@ -122,16 +121,14 @@ write_greta_install_log <- function(path = greta_logfile) { conda_install_error = greta_stash$conda_install_error ) - writeLines(whisker::whisker.render(template, greta_install_data), - path) - + writeLines(whisker::whisker.render(template, greta_install_data), path) } # returns NULL if no envvar -sys_get_env <- function(envvar){ +sys_get_env <- function(envvar) { retrieved_envvar <- Sys.getenv(envvar) env_exists <- nzchar(retrieved_envvar) - if (env_exists){ + if (env_exists) { envvar } else { envvar <- NULL @@ -152,12 +149,10 @@ sys_get_env <- function(envvar){ #' #' @return opens a URL in your default HTML browser. #' @export -open_greta_install_log <- function(){ - +open_greta_install_log <- function() { greta_logfile <- sys_get_env("GRETA_INSTALLATION_LOG") greta_logfile <- greta_logfile %||% greta_default_logfile() utils::browseURL(greta_logfile) - } diff --git a/R/zzz.R b/R/zzz.R index 9f5116fe..00f6e3f1 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -3,7 +3,8 @@ tfp <- reticulate::import("tensorflow_probability", delay_load = TRUE) tf <- reticulate::import("tensorflow", delay_load = TRUE) # crate the node list object whenever the package is loaded -.onLoad <- function(libname, pkgname) { # nolint +.onLoad <- function(libname, pkgname) { + # nolint # unset reticulate python environment, for more details, see: # https://github.com/greta-dev/greta/issues/444 diff --git a/data-raw/valid-greta-deps.R b/data-raw/valid-greta-deps.R index 93a95a02..6324d47f 100644 --- a/data-raw/valid-greta-deps.R +++ b/data-raw/valid-greta-deps.R @@ -12,13 +12,13 @@ windows_deps <- "https://www.tensorflow.org/install/source_windows" |> bow() |> scrape() -get_tidy_html_tables <- function(html_raw){ +get_tidy_html_tables <- function(html_raw) { html_raw |> html_table() |> map(clean_names) } -bind_html_tables <- function(tidied_tables, os_hardware_names){ +bind_html_tables <- function(tidied_tables, os_hardware_names) { tidied_tables |> setNames(os_hardware_names) |> bind_rows( @@ -31,7 +31,7 @@ bind_html_tables <- function(tidied_tables, os_hardware_names){ ) } -tidy_tf_dep_tables <- function(html_raw, os_hardware_names){ +tidy_tf_dep_tables <- function(html_raw, os_hardware_names) { html_raw |> get_tidy_html_tables() |> bind_html_tables(os_hardware_names) @@ -70,9 +70,7 @@ tf_cpu_deps <- tf_deps |> filter( hardware == "cpu" ) |> - select(os, - version, - python_version) |> + select(os, version, python_version) |> mutate( version = str_remove_all(version, "tensorflow-") ) |> @@ -99,33 +97,56 @@ tf_cpu_deps <- tf_deps |> # tfp_to_tf_compatability <- tibble::tribble( - ~tfp_version, ~tf_version, - "tfp==0.24.0", "tf==2.16.1", - "tfp==0.23.0", "tf==2.15.0", - "tfp==0.22.1", "tf==2.14.0", - "tfp==0.22.0", "tf==2.14.0", - "tfp==0.21.0", "tf==2.13.0", - "tfp==0.20.0", "tf==2.12.0", - "tfp==0.19.0", "tf==2.11.0", - "tfp==0.18.0", "tf==2.10.0", - "tfp==0.17.0", "tf==2.9.1", - "tfp==0.16.0", "tf==2.8.0", - "tfp==0.15.0", "tf==2.7.0", - "tfp==0.14.1", "tf==2.6.0", - "tfp==0.14.0", "tf==2.6.0", - "tfp==0.13.0", "tf==2.5.0", - "tfp==0.12.2", "tf==2.4.0", - "tfp==0.12.1", "tf==2.4.0", - "tfp==0.12.0", "tf==2.4.0", - "tfp==0.11.1", "tf==2.3.0", - "tfp==0.11.0", "tf==2.3.0", - "tfp==0.10.1", "tf==2.2.0", - "tfp==0.9.0", "tf==2.1.0", - "tfp==0.8.0", "tf==2.0.0" + ~tfp_version, + ~tf_version, + "tfp==0.24.0", + "tf==2.16.1", + "tfp==0.23.0", + "tf==2.15.0", + "tfp==0.22.1", + "tf==2.14.0", + "tfp==0.22.0", + "tf==2.14.0", + "tfp==0.21.0", + "tf==2.13.0", + "tfp==0.20.0", + "tf==2.12.0", + "tfp==0.19.0", + "tf==2.11.0", + "tfp==0.18.0", + "tf==2.10.0", + "tfp==0.17.0", + "tf==2.9.1", + "tfp==0.16.0", + "tf==2.8.0", + "tfp==0.15.0", + "tf==2.7.0", + "tfp==0.14.1", + "tf==2.6.0", + "tfp==0.14.0", + "tf==2.6.0", + "tfp==0.13.0", + "tf==2.5.0", + "tfp==0.12.2", + "tf==2.4.0", + "tfp==0.12.1", + "tf==2.4.0", + "tfp==0.12.0", + "tf==2.4.0", + "tfp==0.11.1", + "tf==2.3.0", + "tfp==0.11.0", + "tf==2.3.0", + "tfp==0.10.1", + "tf==2.2.0", + "tfp==0.9.0", + "tf==2.1.0", + "tfp==0.8.0", + "tf==2.0.0" ) |> mutate( - tfp_version = str_remove_all(tfp_version,"tfp=="), - tf_version = str_remove_all(tf_version,"tf==") + tfp_version = str_remove_all(tfp_version, "tfp=="), + tf_version = str_remove_all(tf_version, "tf==") ) tfp_to_tf_compatability @@ -138,7 +159,6 @@ extra_rows <- tibble( ) - numeric_version(tf_cpu_deps$tf_version) .deps_tf <- bind_rows(tf_cpu_deps, extra_rows) |> @@ -147,18 +167,15 @@ numeric_version(tf_cpu_deps$tf_version) mutate(tf_version = as.character(tf_version)) .deps_tfp <- tfp_to_tf_compatability -remove_before_comma <- function(x){ +remove_before_comma <- function(x) { trimws(str_remove_all(x, ".*?,")) } greta_deps_tf_tfp <- .deps_tf |> - left_join(.deps_tfp, - by = "tf_version", - relationship = "many-to-many") |> - relocate(tfp_version, - .after = tf_version) |> + left_join(.deps_tfp, by = "tf_version", relationship = "many-to-many") |> + relocate(tfp_version, .after = tf_version) |> mutate( - python_version_min = remove_before_comma(python_version_min) + python_version_min = remove_before_comma(python_version_min) ) |> mutate( across( @@ -170,22 +187,20 @@ greta_deps_tf_tfp <- .deps_tf |> tfp_version, .after = os ) |> - arrange(os, - desc(tfp_version)) |> + arrange(os, desc(tfp_version)) |> drop_na() |> filter( tfp_version < "0.24.0" ) - usethis::use_data( - .deps_tf, - .deps_tfp, - internal = TRUE, - overwrite = TRUE - ) - - usethis::use_data( - greta_deps_tf_tfp, - overwrite = TRUE - ) +usethis::use_data( + .deps_tf, + .deps_tfp, + internal = TRUE, + overwrite = TRUE +) +usethis::use_data( + greta_deps_tf_tfp, + overwrite = TRUE +) diff --git a/docs/reference/figures/plot_greta_legend.R b/docs/reference/figures/plot_greta_legend.R index 639d37b7..a02a5db6 100644 --- a/docs/reference/figures/plot_greta_legend.R +++ b/docs/reference/figures/plot_greta_legend.R @@ -1,74 +1,75 @@ # plot legend -library (DiagrammeR) -library (raster) +library(DiagrammeR) +library(raster) ns <- 0.3 # bespoke set of nodes for legend # pad operation with some invisible nodes to get in the right position -nodes1 <- create_node_df(n = 6, - label = c("data", - "variable", - "distribution", - "", - "operation", - ""), - type = "lower", - style = "filled", - fontcolor = greta_col('dark'), - fontname = 'Helvetica', - fontsize = 12, - fillcolor = c('white', - greta_col('super_light'), - greta_col('lighter'), - 'white', - 'lightgray', - 'white'), - color = c(greta_col('lighter'), - greta_col('lighter'), - greta_col('light'), - 'white', - 'lightgray', - 'white'), - penwidth = 2, - shape = c("square", "circle", "diamond", "circle", "circle", "circle"), - width = c(0.5, 0.6, 1, 0.01, 0.2, 0.01), - height = c(0.5, 0.6, 0.8, 0.01, 0.2, 0.01)) +nodes1 <- create_node_df( + n = 6, + label = c("data", "variable", "distribution", "", "operation", ""), + type = "lower", + style = "filled", + fontcolor = greta_col('dark'), + fontname = 'Helvetica', + fontsize = 12, + fillcolor = c( + 'white', + greta_col('super_light'), + greta_col('lighter'), + 'white', + 'lightgray', + 'white' + ), + color = c( + greta_col('lighter'), + greta_col('lighter'), + greta_col('light'), + 'white', + 'lightgray', + 'white' + ), + penwidth = 2, + shape = c("square", "circle", "diamond", "circle", "circle", "circle"), + width = c(0.5, 0.6, 1, 0.01, 0.2, 0.01), + height = c(0.5, 0.6, 0.8, 0.01, 0.2, 0.01) +) gr1 <- create_graph(nodes1) gr1$global_attrs[1, 'value'] <- 'dot' f_nodes <- tempfile(fileext = '.png') -export_graph(gr1, file_name = f_nodes, - width = 1005, - height = 249) +export_graph(gr1, file_name = f_nodes, width = 1005, height = 249) ns <- 0.01 -nodes2 <- create_node_df(n = 4, - label = '', - type = "lower", - alpha = 0, - style = "filled", - fillcolor = rep('#ffffff00', 4), - color = rep('#ffffff00', 4), - shape = rep("circle", 4), - width = rep(ns, 4)) - -edges2 <- create_edge_df(from = c(1, 3), - to = c(2, 4), - label = c('deterministic', 'stochastic'), - color = rep('Gainsboro', 2), - fontname = 'Helvetica', - fontcolor = greta_col('dark'), - fontsize = 12, - penwidth = 3, - style = c('solid', 'dashed')) +nodes2 <- create_node_df( + n = 4, + label = '', + type = "lower", + alpha = 0, + style = "filled", + fillcolor = rep('#ffffff00', 4), + color = rep('#ffffff00', 4), + shape = rep("circle", 4), + width = rep(ns, 4) +) + +edges2 <- create_edge_df( + from = c(1, 3), + to = c(2, 4), + label = c('deterministic', 'stochastic'), + color = rep('Gainsboro', 2), + fontname = 'Helvetica', + fontcolor = greta_col('dark'), + fontsize = 12, + penwidth = 3, + style = c('solid', 'dashed') +) gr2 <- create_graph(nodes2, edges2) gr2$global_attrs[1, 'value'] <- 'dot' f_edges <- tempfile(fileext = '.png') -export_graph(gr2, file_name = f_edges, - width = 631, - height = 249) +export_graph(gr2, file_name = f_edges, width = 631, height = 249) # combine the two panels into one @@ -80,18 +81,18 @@ edges <- brick(f_edges) nodes <- crop(nodes, extent(nodes) - 4) edges <- crop(edges, extent(edges) - 4) -dim_pixels <- c(width = round(1005 + 631 + 1005/16), - height = 249) +dim_pixels <- c(width = round(1005 + 631 + 1005 / 16), height = 249) dim_inches <- dim_pixels / 10 # layout with a gap inbetween -mat <- matrix(rep(c(1, 2, 3), c(16, 1, 10)), - nrow = 1) +mat <- matrix(rep(c(1, 2, 3), c(16, 1, 10)), nrow = 1) -png('plotlegend.png', - width = dim_pixels['width'] * 2, - height = dim_pixels['height'] * 2) +png( + 'plotlegend.png', + width = dim_pixels['width'] * 2, + height = dim_pixels['height'] * 2 +) layout(mat) raster::plotRGB(nodes, maxpixels = Inf) @@ -101,9 +102,11 @@ raster::plotRGB(edges, maxpixels = Inf) dev.off() -pdf('plotlegend.pdf', - width = dim_inches['width'], - height = dim_inches['height']) +pdf( + 'plotlegend.pdf', + width = dim_inches['width'], + height = dim_inches['height'] +) layout(mat) raster::plotRGB(nodes, maxpixels = Inf) diff --git a/logos/logo_functions.R b/logos/logo_functions.R index 1dd6098e..4091cfac 100644 --- a/logos/logo_functions.R +++ b/logos/logo_functions.R @@ -2,83 +2,80 @@ # get the coordinates and links to tile the logo 'repeats' times. # x_start and y_start give the position of the first node -logo_shape <- function (x_start = 0, y_range = c(0, 1)) { - +logo_shape <- function(x_start = 0, y_range = c(0, 1)) { # coordinates and links of the base shape - coords <- data.frame(x = c(0, 0, 1, 1, 2, 2, 3), - y = c(0, 1, 0.5, 1.5, 0, 1, 0.5)) + coords <- data.frame( + x = c(0, 0, 1, 1, 2, 2, 3), + y = c(0, 1, 0.5, 1.5, 0, 1, 0.5) + ) - links <- rbind(c(1, 3), - c(2, 3), - c(3, 6), - c(4, 6), - c(5, 7), - c(6, 7)) + links <- rbind(c(1, 3), c(2, 3), c(3, 6), c(4, 6), c(5, 7), c(6, 7)) scale <- abs(diff(y_range)) / 1.6 coords$x <- coords$x * scale + x_start coords$y <- coords$y * scale + y_range[1] - list(coords = coords, links =links) - + list(coords = coords, links = links) } -plot_logo <- function (background = c('white', 'purple', 'light', 'lighter'), - pointsize = 4.5, - add = FALSE, - edge_width = 1, - ...) { - +plot_logo <- function( + background = c('white', 'purple', 'light', 'lighter'), + pointsize = 4.5, + add = FALSE, + edge_width = 1, + ... +) { background <- match.arg(background) data <- logo_shape(...) - bg_col <- switch (background, - white = 'white', - light = greta:::greta_col('light'), - lighter = greta:::greta_col('lighter'), - purple = greta:::greta_col('main')) - - link_col <- switch (background, - white = greta:::greta_col('light'), - light = greta:::greta_col('dark'), - lighter = greta:::greta_col('dark'), - purple = greta:::greta_col('dark')) - - node_col <- switch (background, - white = greta:::greta_col('dark'), - light = greta:::greta_col('dark'), - lighter = greta:::greta_col('dark'), - purple = greta:::greta_col('dark')) + bg_col <- switch( + background, + white = 'white', + light = greta:::greta_col('light'), + lighter = greta:::greta_col('lighter'), + purple = greta:::greta_col('main') + ) + + link_col <- switch( + background, + white = greta:::greta_col('light'), + light = greta:::greta_col('dark'), + lighter = greta:::greta_col('dark'), + purple = greta:::greta_col('dark') + ) + + node_col <- switch( + background, + white = greta:::greta_col('dark'), + light = greta:::greta_col('dark'), + lighter = greta:::greta_col('dark'), + purple = greta:::greta_col('dark') + ) if (!add) { - old_mar <- par()$mar old_xpd <- par()$xpd old_bg <- par()$bg - on.exit( par(mar = old_mar, xpd = old_xpd, bg = old_bg) ) + on.exit(par(mar = old_mar, xpd = old_xpd, bg = old_bg)) - par(mar = rep(2, 4), - xpd = NA, - bg = bg_col) + par(mar = rep(2, 4), xpd = NA, bg = bg_col) plot.new() - plot.window(xlim = range(data$coords$x), - ylim = range(data$coords$y), - asp = 1) - - + plot.window( + xlim = range(data$coords$x), + ylim = range(data$coords$y), + asp = 1 + ) } # loop though from right to left, plotting points and edges to ensure # gaps on either side distances - x_loc <- sort(unique(data$coords$x), - decreasing = TRUE) + x_loc <- sort(unique(data$coords$x), decreasing = TRUE) for (loc in x_loc) { - # find relevant nodes and edges idx_nodes <- which(data$coords$x %in% loc) idx_edges <- which(data$links[, 2] %in% idx_nodes) @@ -86,34 +83,38 @@ plot_logo <- function (background = c('white', 'purple', 'light', 'lighter'), edges <- data$links[idx_edges, , drop = FALSE] # plot nodes with fat edge - points(y ~ x, - data = nodes, - pch = 21, - bg = node_col, - col = bg_col, - cex = pointsize - 1.5, - lwd = pointsize * 4 * edge_width) + points( + y ~ x, + data = nodes, + pch = 21, + bg = node_col, + col = bg_col, + cex = pointsize - 1.5, + lwd = pointsize * 4 * edge_width + ) # plot lines for (i in seq_len(nrow(edges))) { link <- edges[i, ] - lines(x = data$coords$x[link], - y = data$coords$y[link], - lwd = pointsize * 2.3 * edge_width, - col = link_col) + lines( + x = data$coords$x[link], + y = data$coords$y[link], + lwd = pointsize * 2.3 * edge_width, + col = link_col + ) } # plot nodes with thin edge - points(y ~ x, - data = nodes, - pch = 21, - bg = node_col, - col = bg_col, - cex = pointsize - 1.5, - lwd = 0) - + points( + y ~ x, + data = nodes, + pch = 21, + bg = node_col, + col = bg_col, + cex = pointsize - 1.5, + lwd = 0 + ) } - } # greta logo generation @@ -124,35 +125,44 @@ plot_logo <- function (background = c('white', 'purple', 'light', 'lighter'), # 'margin' gives the proportion of the vertical height to use a border on each side # the text is scaled to never exceed that border #' @importFrom graphics par plot.new plot.window strheight strwidth text -banner <- function (background = c('purple', 'white', 'light', 'lighter'), - transparent_bg = FALSE, - width = 8, margin = 0.2, - font = c('Muli', 'sans'), - add_logo = TRUE, ...) { - +banner <- function( + background = c('purple', 'white', 'light', 'lighter'), + transparent_bg = FALSE, + width = 8, + margin = 0.2, + font = c('Muli', 'sans'), + add_logo = TRUE, + ... +) { font <- match.arg(font) background <- match.arg(background) # warn if the banner isn't height-filled min_width <- 3.184175 + 2 * margin * (1 - 3.184175) if (width < min_width) { - warning ('with a margin of ', - margin, - ' the minimum width to ensure the banner is height-filled is ', - min_width) + warning( + 'with a margin of ', + margin, + ' the minimum width to ensure the banner is height-filled is ', + min_width + ) } - bg_col <- switch (background, - white = 'white', - light = greta:::greta_col('light'), - lighter = greta:::greta_col('lighter'), - purple = greta:::greta_col('main')) - - text_col <- switch (background, - white = greta:::greta_col('dark'), - light = 'white', - lighter = greta:::greta_col('dark'), - purple = 'white') + bg_col <- switch( + background, + white = 'white', + light = greta:::greta_col('light'), + lighter = greta:::greta_col('lighter'), + purple = greta:::greta_col('main') + ) + + text_col <- switch( + background, + white = greta:::greta_col('dark'), + light = 'white', + lighter = greta:::greta_col('dark'), + purple = 'white' + ) # cache the old graphics options old_bg <- par('bg') @@ -160,15 +170,11 @@ banner <- function (background = c('purple', 'white', 'light', 'lighter'), old_family <- par('family') # switch to a purple background, no margins and Muli typeface - par(bg = ifelse(transparent_bg, NA, bg_col), - mar = rep(0, 4), - family = font) + par(bg = ifelse(transparent_bg, NA, bg_col), mar = rep(0, 4), family = font) # set up the device, to have the correct width plot.new() - plot.window(xlim = c(0, width), - ylim = c(0, 1), - asp = 1) + plot.window(xlim = c(0, width), ylim = c(0, 1), asp = 1) # scale the font, so that 'greta' fills the area (excluding self-imposed # margins) either vertically or horizontally @@ -184,81 +190,77 @@ banner <- function (background = c('purple', 'white', 'light', 'lighter'), xpos <- margin # 'g' should be aligned to the left of the box - text(x = xpos, - y = 0.5, - label = 'greta', - col = text_col, - cex = fontsize, - pos = 4, - offset = 0) + text( + x = xpos, + y = 0.5, + label = 'greta', + col = text_col, + cex = fontsize, + pos = 4, + offset = 0 + ) if (add_logo) { - - plot_logo(background = background, - add = TRUE, - x_start = string_width + xpos * 3, - y_range = 0.55 + string_height * 0.5 * c(-1, 1), - ...) - + plot_logo( + background = background, + add = TRUE, + x_start = string_width + xpos * 3, + y_range = 0.55 + string_height * 0.5 * c(-1, 1), + ... + ) } - par(bg = old_bg, - mar = old_mar, - family = old_family) + par(bg = old_bg, mar = old_mar, family = old_family) invisible(NULL) - } # same dimensions as banner, but with no text -blank_banner <- function (width = 8, margin = 0.2) { - +blank_banner <- function(width = 8, margin = 0.2) { # cache the old graphics options old_bg <- par('bg') old_mar <- par('mar') # switch to a purple background with no margins - par(bg = greta:::greta_col(), - mar = rep(0, 4)) + par(bg = greta:::greta_col(), mar = rep(0, 4)) # set up the device, to have the correct width plot.new() - plot.window(xlim = c(0, width), - ylim = c(0, 1), - asp = 1) + plot.window(xlim = c(0, width), ylim = c(0, 1), asp = 1) - par(bg = old_bg, - mar = old_mar) + par(bg = old_bg, mar = old_mar) invisible(NULL) - } # make and save an image of a triangular tesselation GMRF pattern in greta purple -tesselation_image <- function (ncol = 10, nrow = 10, - max_edge = 0.08, - jitter = 0.1, - thickness = 1, - line_col = greta:::greta_col('light'), - ramp_cols = NULL) { - +tesselation_image <- function( + ncol = 10, + nrow = 10, + max_edge = 0.08, + jitter = 0.1, + thickness = 1, + line_col = greta:::greta_col('light'), + ramp_cols = NULL +) { if (is.null(ramp_cols)) { - cols <- c(greta:::greta_col('lighter'), - greta:::greta_col('light')) + cols <- c(greta:::greta_col('lighter'), greta:::greta_col('light')) ramp_cols <- colorRampPalette(cols)(2000)[-(1:1000)] } - require (INLA) - require (raster) - require (greta) - require (fields) + require(INLA) + require(raster) + require(greta) + require(fields) # grid sizes for sampling the GRF and for the final image ncol_sim <- round(ncol / 10) nrow_sim <- round(nrow / 10) ratio <- ncol / nrow - grid <- list(x = seq(0, 1, length.out = ncol_sim), - y = seq(0, 1, length.out = nrow_sim)) + grid <- list( + x = seq(0, 1, length.out = ncol_sim), + y = seq(0, 1, length.out = nrow_sim) + ) obj <- Exp.image.cov(grid = grid, theta = 0.1, setup = TRUE) @@ -268,48 +270,41 @@ tesselation_image <- function (ncol = 10, nrow = 10, extent(image) <- c(0, ratio, 0, 1) extent(r) <- extent(image) - - pts <- expand.grid(seq(0, ratio, length.out = 10), - seq(0, 1, length.out = 10)) - pts <- pts + cbind(rnorm(100, 0, jitter), - rnorm(100, 0, jitter)) + pts <- expand.grid(seq(0, ratio, length.out = 10), seq(0, 1, length.out = 10)) + pts <- pts + cbind(rnorm(100, 0, jitter), rnorm(100, 0, jitter)) # make an inla mesh sp <- as(extent(image), 'SpatialPolygons') - mesh <- inla.mesh.2d(loc = pts, - boundary = inla.sp2segment(sp), - max.edge = max_edge, - offset = 0) + mesh <- inla.mesh.2d( + loc = pts, + boundary = inla.sp2segment(sp), + max.edge = max_edge, + offset = 0 + ) # sample GRF at nodes z <- extract(r, mesh$loc[, 1:2]) # get projection to raster - image_coords<- xyFromCell(image, 1:ncell(image)) + image_coords <- xyFromCell(image, 1:ncell(image)) A <- inla.spde.make.A(mesh, loc = image_coords) # instead of linear interpolation, average the three node values A2 <- A - A2@x[A2@x > 0] <- 1/3 + A2@x[A2@x > 0] <- 1 / 3 image[] <- (A2 %*% z)[, 1] pm <- par("mar") on.exit(par(mar = pm)) par(mar = rep(0, 4)) - image(image, - col = ramp_cols, - asp = 1, - axes = FALSE, - xlab = '', - ylab = '') - plot(mesh, - add = TRUE, - edge.color = line_col, - lwd = thickness, - draw.segments = FALSE) - points(mesh$loc, - pch = 16, - cex = 0.5 * sqrt(thickness), - col = line_col) + image(image, col = ramp_cols, asp = 1, axes = FALSE, xlab = '', ylab = '') + plot( + mesh, + add = TRUE, + edge.color = line_col, + lwd = thickness, + draw.segments = FALSE + ) + points(mesh$loc, pch = 16, cex = 0.5 * sqrt(thickness), col = line_col) } diff --git a/logos/make_hex.R b/logos/make_hex.R index 65f6e648..7094a2d2 100644 --- a/logos/make_hex.R +++ b/logos/make_hex.R @@ -5,11 +5,13 @@ font_add_google("Muli") attach(greta::.internals$utils$colours) # load various functions -source ("logos/logo_functions.R") +source("logos/logo_functions.R") # make a hex-shaped mask -hexd <- data.frame(x = 1 + c(rep(-sqrt(3) / 2, 2), 0, rep(sqrt(3) / 2, 2), 0), - y = 1 + c(0.5, -0.5, -1, -0.5, 0.5, 1)) +hexd <- data.frame( + x = 1 + c(rep(-sqrt(3) / 2, 2), 0, rep(sqrt(3) / 2, 2), 0), + y = 1 + c(0.5, -0.5, -1, -0.5, 0.5, 1) +) x_lim <- range(hexd$x) y_lim <- range(hexd$y) @@ -17,10 +19,7 @@ x_dim <- abs(diff(x_lim)) y_dim <- abs(diff(y_lim)) ex <- 4 pdf("logos/hex_mask.pdf", width = x_dim * ex, height = y_dim * ex) -par(pty = "s", - xpd = NA, - bg = "black", - mar = rep(0, 4), oma = rep(0, 4)) +par(pty = "s", xpd = NA, bg = "black", mar = rep(0, 4), oma = rep(0, 4)) plot.new() plot.window(xlim = x_lim, ylim = y_lim) polygon(hexd, col = "white") @@ -32,59 +31,63 @@ mask <- image_transparent(mask, "black") mask_dim <- as.numeric(image_info(mask)[1, 2:3]) # create a tesselation figure with the same footprint # make a nice GRF tesselation for the header image -set.seed(2018-02-20) +set.seed(2018 - 02 - 20) dim <- mask_dim * 2 -cols <- c(greta:::greta_col("lighter"), - greta:::greta_col("main")) -ramp_cols <- colorRampPalette(cols)(2000)#[-(1:1000)] +cols <- c(greta:::greta_col("lighter"), greta:::greta_col("main")) +ramp_cols <- colorRampPalette(cols)(2000) #[-(1:1000)] -png("logos/hex_bg.png", - width = dim[1], - height = dim[2], - pointsize = 30) -tesselation_image(ncol = dim[1], nrow = dim[2], - max_edge = 0.1, - jitter = 0.05, - thickness = 2, - ramp_cols = ramp_cols, - line_col = greta_col("main")) +png("logos/hex_bg.png", width = dim[1], height = dim[2], pointsize = 30) +tesselation_image( + ncol = dim[1], + nrow = dim[2], + max_edge = 0.1, + jitter = 0.05, + thickness = 2, + ramp_cols = ramp_cols, + line_col = greta_col("main") +) dev.off() # crop and mask the pattern to a hexagon bg <- image_read("logos/hex_bg.png") -geometry <- sprintf("%ix%i+%i+%i", - mask_dim[1], - mask_dim[2], - mask_dim[1] %/% 2, - mask_dim[2] %/% 2) +geometry <- sprintf( + "%ix%i+%i+%i", + mask_dim[1], + mask_dim[2], + mask_dim[1] %/% 2, + mask_dim[2] %/% 2 +) bg <- image_crop(bg, geometry) hex_bg <- image_composite(bg, mask, "CopyOpacity") image_write(hex_bg, path = "logos/hex_bg.pdf") -par(pty = "s", - xpd = NA, - bg = "black", - mar = c(5, 4, 4, 2) + 0.1) +par(pty = "s", xpd = NA, bg = "black", mar = c(5, 4, 4, 2) + 0.1) -greta_hex <- sticker("logos/hex_bg.pdf", - s_x = 1, - s_y = 1, - s_width = 1.05, - package = "greta", - p_y = 1.1, - p_family = "Muli", - p_size = 15, - h_fill = greta_col("main"), - h_color = greta_col("main"), - filename = "logos/greta_hex.png") +greta_hex <- sticker( + "logos/hex_bg.pdf", + s_x = 1, + s_y = 1, + s_width = 1.05, + package = "greta", + p_y = 1.1, + p_family = "Muli", + p_size = 15, + h_fill = greta_col("main"), + h_color = greta_col("main"), + filename = "logos/greta_hex.png" +) library(ggplot2) -ggsave(greta_hex, width = 43.9, height = 50.8, - filename = "logos/greta_hex.png", - bg = "transparent", - units = "mm", - dpi = 600) +ggsave( + greta_hex, + width = 43.9, + height = 50.8, + filename = "logos/greta_hex.png", + bg = "transparent", + units = "mm", + dpi = 600 +) file.remove("logos/hex_bg.pdf") file.remove("logos/hex_bg.png") diff --git a/logos/make_logos.R b/logos/make_logos.R index 6b77b1dc..329e03b7 100644 --- a/logos/make_logos.R +++ b/logos/make_logos.R @@ -1,114 +1,108 @@ # load various functions -source ("logos/logo_functions.R") +source("logos/logo_functions.R") # ~~~~~~~~~~~ # render logos and banners # banner for top of documents, with name and logo -png('README_files/top_banner.png', - height = 288, width = 4032, - pointsize = 25) +png('README_files/top_banner.png', height = 288, width = 4032, pointsize = 25) banner(width = 14) dev.off() # thin blank banner for between document sections -png('README_files/banner.png', - height = 4, width = 940, - pointsize = 25) -blank_banner(14/0.2) +png('README_files/banner.png', height = 4, width = 940, pointsize = 25) +blank_banner(14 / 0.2) dev.off() # thicker blank banner for end of document -png('README_files/bottom_banner.png', - height = 8, width = 940, - pointsize = 25) -blank_banner(14/0.5) +png('README_files/bottom_banner.png', height = 8, width = 940, pointsize = 25) +blank_banner(14 / 0.5) dev.off() -png('logos/icon_on_white.png', - height = 1000, width = 1800, - pointsize = 16) +png('logos/icon_on_white.png', height = 1000, width = 1800, pointsize = 16) plot_logo(pointsize = 24) dev.off() -png('logos/icon_on_purple.png', - height = 1000, width = 1800, - pointsize = 16) +png('logos/icon_on_purple.png', height = 1000, width = 1800, pointsize = 16) plot_logo('purple', pointsize = 24) dev.off() ptsz <- 80 -png('logos/name_on_purple.png', - height = 1000, width = 1800, - pointsize = ptsz) -banner(transparent_bg = TRUE, - width = 2.310505, - add_logo = FALSE) +png('logos/name_on_purple.png', height = 1000, width = 1800, pointsize = ptsz) +banner(transparent_bg = TRUE, width = 2.310505, add_logo = FALSE) dev.off() -png('logos/name_on_white.png', - height = 1000, width = 1800, - pointsize = ptsz) -banner('white', - transparent_bg = TRUE, - width = 2.310505, - add_logo = FALSE) +png('logos/name_on_white.png', height = 1000, width = 1800, pointsize = ptsz) +banner('white', transparent_bg = TRUE, width = 2.310505, add_logo = FALSE) dev.off() -png('logos/name_icon_on_white.png', - height = 1000, width = 3600, - pointsize = ptsz) -banner('white', - transparent_bg = TRUE, - width = 4, - add_logo = TRUE, - edge_width = 2.7) +png( + 'logos/name_icon_on_white.png', + height = 1000, + width = 3600, + pointsize = ptsz +) +banner( + 'white', + transparent_bg = TRUE, + width = 4, + add_logo = TRUE, + edge_width = 2.7 +) dev.off() -png('logos/name_icon_on_purple.png', - height = 1000, width = 3600, - pointsize = ptsz) -banner('purple', - transparent_bg = TRUE, - width = 4, - add_logo = TRUE, - edge_width = 2.7) +png( + 'logos/name_icon_on_purple.png', + height = 1000, + width = 3600, + pointsize = ptsz +) +banner( + 'purple', + transparent_bg = TRUE, + width = 4, + add_logo = TRUE, + edge_width = 2.7 +) dev.off() -png('logos/name_icon_on_light.png', - height = 1000, width = 3600, - pointsize = ptsz) -banner('light', - transparent_bg = TRUE, - width = 4, - add_logo = TRUE, - edge_width = 2.7) +png( + 'logos/name_icon_on_light.png', + height = 1000, + width = 3600, + pointsize = ptsz +) +banner( + 'light', + transparent_bg = TRUE, + width = 4, + add_logo = TRUE, + edge_width = 2.7 +) dev.off() -png('logos/name_icon_on_lighter.png', - height = 1000, width = 3600, - pointsize = ptsz) -banner('lighter', - transparent_bg = TRUE, - width = 4, - add_logo = TRUE, - edge_width = 2.7) +png( + 'logos/name_icon_on_lighter.png', + height = 1000, + width = 3600, + pointsize = ptsz +) +banner( + 'lighter', + transparent_bg = TRUE, + width = 4, + add_logo = TRUE, + edge_width = 2.7 +) dev.off() -png('logos/gravatar.png', - height = 3600, width = 3600, - pointsize = ptsz) -banner('white', - width = 2.310505, - add_logo = FALSE) +png('logos/gravatar.png', height = 3600, width = 3600, pointsize = ptsz) +banner('white', width = 2.310505, add_logo = FALSE) dev.off() # make a nice GRF tesselation for the header image set.seed(1) -png("logos/greta-header.png", - width = 2732, - height = 1194, - pointsize = 30) +png("logos/greta-header.png", width = 2732, height = 1194, pointsize = 30) tesselation_image(ncol = 2732, nrow = 1194) dev.off() diff --git a/man/figures/plot_greta_legend.R b/man/figures/plot_greta_legend.R index 639d37b7..a02a5db6 100644 --- a/man/figures/plot_greta_legend.R +++ b/man/figures/plot_greta_legend.R @@ -1,74 +1,75 @@ # plot legend -library (DiagrammeR) -library (raster) +library(DiagrammeR) +library(raster) ns <- 0.3 # bespoke set of nodes for legend # pad operation with some invisible nodes to get in the right position -nodes1 <- create_node_df(n = 6, - label = c("data", - "variable", - "distribution", - "", - "operation", - ""), - type = "lower", - style = "filled", - fontcolor = greta_col('dark'), - fontname = 'Helvetica', - fontsize = 12, - fillcolor = c('white', - greta_col('super_light'), - greta_col('lighter'), - 'white', - 'lightgray', - 'white'), - color = c(greta_col('lighter'), - greta_col('lighter'), - greta_col('light'), - 'white', - 'lightgray', - 'white'), - penwidth = 2, - shape = c("square", "circle", "diamond", "circle", "circle", "circle"), - width = c(0.5, 0.6, 1, 0.01, 0.2, 0.01), - height = c(0.5, 0.6, 0.8, 0.01, 0.2, 0.01)) +nodes1 <- create_node_df( + n = 6, + label = c("data", "variable", "distribution", "", "operation", ""), + type = "lower", + style = "filled", + fontcolor = greta_col('dark'), + fontname = 'Helvetica', + fontsize = 12, + fillcolor = c( + 'white', + greta_col('super_light'), + greta_col('lighter'), + 'white', + 'lightgray', + 'white' + ), + color = c( + greta_col('lighter'), + greta_col('lighter'), + greta_col('light'), + 'white', + 'lightgray', + 'white' + ), + penwidth = 2, + shape = c("square", "circle", "diamond", "circle", "circle", "circle"), + width = c(0.5, 0.6, 1, 0.01, 0.2, 0.01), + height = c(0.5, 0.6, 0.8, 0.01, 0.2, 0.01) +) gr1 <- create_graph(nodes1) gr1$global_attrs[1, 'value'] <- 'dot' f_nodes <- tempfile(fileext = '.png') -export_graph(gr1, file_name = f_nodes, - width = 1005, - height = 249) +export_graph(gr1, file_name = f_nodes, width = 1005, height = 249) ns <- 0.01 -nodes2 <- create_node_df(n = 4, - label = '', - type = "lower", - alpha = 0, - style = "filled", - fillcolor = rep('#ffffff00', 4), - color = rep('#ffffff00', 4), - shape = rep("circle", 4), - width = rep(ns, 4)) - -edges2 <- create_edge_df(from = c(1, 3), - to = c(2, 4), - label = c('deterministic', 'stochastic'), - color = rep('Gainsboro', 2), - fontname = 'Helvetica', - fontcolor = greta_col('dark'), - fontsize = 12, - penwidth = 3, - style = c('solid', 'dashed')) +nodes2 <- create_node_df( + n = 4, + label = '', + type = "lower", + alpha = 0, + style = "filled", + fillcolor = rep('#ffffff00', 4), + color = rep('#ffffff00', 4), + shape = rep("circle", 4), + width = rep(ns, 4) +) + +edges2 <- create_edge_df( + from = c(1, 3), + to = c(2, 4), + label = c('deterministic', 'stochastic'), + color = rep('Gainsboro', 2), + fontname = 'Helvetica', + fontcolor = greta_col('dark'), + fontsize = 12, + penwidth = 3, + style = c('solid', 'dashed') +) gr2 <- create_graph(nodes2, edges2) gr2$global_attrs[1, 'value'] <- 'dot' f_edges <- tempfile(fileext = '.png') -export_graph(gr2, file_name = f_edges, - width = 631, - height = 249) +export_graph(gr2, file_name = f_edges, width = 631, height = 249) # combine the two panels into one @@ -80,18 +81,18 @@ edges <- brick(f_edges) nodes <- crop(nodes, extent(nodes) - 4) edges <- crop(edges, extent(edges) - 4) -dim_pixels <- c(width = round(1005 + 631 + 1005/16), - height = 249) +dim_pixels <- c(width = round(1005 + 631 + 1005 / 16), height = 249) dim_inches <- dim_pixels / 10 # layout with a gap inbetween -mat <- matrix(rep(c(1, 2, 3), c(16, 1, 10)), - nrow = 1) +mat <- matrix(rep(c(1, 2, 3), c(16, 1, 10)), nrow = 1) -png('plotlegend.png', - width = dim_pixels['width'] * 2, - height = dim_pixels['height'] * 2) +png( + 'plotlegend.png', + width = dim_pixels['width'] * 2, + height = dim_pixels['height'] * 2 +) layout(mat) raster::plotRGB(nodes, maxpixels = Inf) @@ -101,9 +102,11 @@ raster::plotRGB(edges, maxpixels = Inf) dev.off() -pdf('plotlegend.pdf', - width = dim_inches['width'], - height = dim_inches['height']) +pdf( + 'plotlegend.pdf', + width = dim_inches['width'], + height = dim_inches['height'] +) layout(mat) raster::plotRGB(nodes, maxpixels = Inf) diff --git a/tests/spelling.R b/tests/spelling.R index 6713838f..a8cf85b9 100644 --- a/tests/spelling.R +++ b/tests/spelling.R @@ -1,3 +1,6 @@ -if(requireNamespace('spelling', quietly = TRUE)) - spelling::spell_check_test(vignettes = TRUE, error = FALSE, - skip_on_cran = TRUE) +if (requireNamespace('spelling', quietly = TRUE)) + spelling::spell_check_test( + vignettes = TRUE, + error = FALSE, + skip_on_cran = TRUE + ) diff --git a/tests/testthat/helper-expectations.R b/tests/testthat/helper-expectations.R index 1e26f570..2ee9b9ea 100644 --- a/tests/testthat/helper-expectations.R +++ b/tests/testthat/helper-expectations.R @@ -1,21 +1,21 @@ -ga_to_mat <- function(ga){ +ga_to_mat <- function(ga) { mat_nrow <- mat_ncol <- ncol(ga) matrix(ga, nrow = mat_nrow, ncol = mat_nrow, byrow = TRUE) } # ga_to_mat(calc_chol$chol_x) -get_upper_tri2 <- function(mat){ +get_upper_tri2 <- function(mat) { mat[upper.tri(mat)] } -get_lower_tri <- function(mat){ +get_lower_tri <- function(mat) { mat[lower.tri(mat)] } -expect_upper_tri <- function(object){ +expect_upper_tri <- function(object) { act <- quasi_label(rlang::enquo(object), arg = "object") - act$mat <- object[1,,] + act$mat <- object[1, , ] act$upper_tri <- get_upper_tri2(act$mat) act$lower_tri <- get_lower_tri(act$mat) @@ -23,50 +23,59 @@ expect_upper_tri <- function(object){ all_lower_zero <- all(act$lower_tri == 0) all_upper_non_zero <- all(act$upper_tri != 0) is_upper_tri <- all_lower_zero && all_upper_non_zero - if (is_upper_tri){ + if (is_upper_tri) { succeed() return(invisible(act$val)) } - if (!all_lower_zero){ - vals <- glue::glue_collapse(glue::glue("{round(act$lower_tri, 3)}"), sep = " ") - msg <- glue::glue("{act$lab} is not upper triangular. Values below the \\ - main diagonal are not all zero: {vals}") + if (!all_lower_zero) { + vals <- glue::glue_collapse( + glue::glue("{round(act$lower_tri, 3)}"), + sep = " " + ) + msg <- glue::glue( + "{act$lab} is not upper triangular. Values below the \\ + main diagonal are not all zero: {vals}" + ) } - if (!all_upper_non_zero){ - vals <- glue::glue_collapse(glue::glue("{round(act$upper_tri, 3)}"), sep = " ") - msg <- glue::glue_collapse(glue::glue("{act$lab} is not upper triangular. Some values above \\ - the main diagonal contain zero: {vals}")) + if (!all_upper_non_zero) { + vals <- glue::glue_collapse( + glue::glue("{round(act$upper_tri, 3)}"), + sep = " " + ) + msg <- glue::glue_collapse(glue::glue( + "{act$lab} is not upper triangular. Some values above \\ + the main diagonal contain zero: {vals}" + )) } fail(msg) - } -expect_square <- function(object){ +expect_square <- function(object) { # 1. Capture object and label act <- quasi_label(rlang::enquo(object), arg = "object") # 2. Call expect() - act$nrow <- dim(act$val[1,,])[1] - act$ncol <- dim(act$val[1,,])[2] + act$nrow <- dim(act$val[1, , ])[1] + act$ncol <- dim(act$val[1, , ])[2] expect( ok = act$nrow == act$ncol, - failure_message = glue::glue("{act$lab} has dim {act$nrow}x{act$ncol}, and is not square.") + failure_message = glue::glue( + "{act$lab} has dim {act$nrow}x{act$ncol}, and is not square." + ) ) # 3. Invisibly return the value invisible(act$val) - } # expect_square(calc_chol$chol_x) # expect_square(array(data = 1:9, c(1,3,3))) # expect_square(array(data = 1:12, c(1,3,4))) - -expect_symmetric <- function(object){ +expect_symmetric <- function(object) { # 1. Capture object and label act <- quasi_label(rlang::enquo(object), arg = "object") @@ -77,13 +86,12 @@ expect_symmetric <- function(object){ # 2. Call expect() expect( - ok = all.equal(act$upper,act$lower), + ok = all.equal(act$upper, act$lower), failure_message = glue::glue("{act$lab} is not symmetric") ) # 3. Invisibly return the value invisible(act$val) - } # xmat <- calculate(x, nsim = 1)[[1]] |> ga_to_mat() diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R index 4a2a8959..fe589749 100644 --- a/tests/testthat/helpers.R +++ b/tests/testthat/helpers.R @@ -97,7 +97,6 @@ greta_density <- function( dim = NULL, multivariate = FALSE ) { - dim <- dim %||% NROW(x) # add the output dimension to the arguments list @@ -376,7 +375,7 @@ rlkjcorr <- function(n, eta = 1, dimension = 2) { r <- replicate(n, f()) if (dim(r)[3] == 1) { - r <- r[, , 1] + r <- r[,, 1] } else { r <- aperm(r, c(3, 1, 2)) } @@ -402,7 +401,12 @@ lkj_log_normalising <- function(eta, n) { ans } -dlkj_correlation_unnormalised <- function(x, eta, log = FALSE, dimension = NULL) { +dlkj_correlation_unnormalised <- function( + x, + eta, + log = FALSE, + dimension = NULL +) { res <- (eta - 1) * log(det(x)) if (!log) { res <- exp(res) @@ -420,11 +424,13 @@ dlkj_correlation <- function(x, eta, log = FALSE, dimension = NULL) { } # helper RNG functions -rmvnorm <- function(n, mean, Sigma) { # nolint +rmvnorm <- function(n, mean, Sigma) { + # nolint mvtnorm::rmvnorm(n = n, mean = mean, sigma = Sigma) } -rwish <- function(n, df, Sigma) { # nolint +rwish <- function(n, df, Sigma) { + # nolint draws <- stats::rWishart(n = n, df = df, Sigma = Sigma) aperm(draws, c(3, 1, 2)) } @@ -677,23 +683,25 @@ check_geweke <- function( ) geweke_checks - } -geweke_qq <- function(geweke_checks, title){ +geweke_qq <- function(geweke_checks, title) { # visualise correspondence quants <- (1:99) / 100 q1 <- stats::quantile(geweke_checks$target_theta, quants) q2 <- stats::quantile(geweke_checks$greta_theta, quants) plot(q2, q1, main = title) graphics::abline(0, 1) - } -geweke_ks <- function(geweke_checks){ +geweke_ks <- function(geweke_checks) { # do a formal hypothesis test - suppressWarnings(stat <- stats::ks.test(geweke_checks$target_theta, - geweke_checks$greta_theta)) + suppressWarnings( + stat <- stats::ks.test( + geweke_checks$target_theta, + geweke_checks$greta_theta + ) + ) stat } @@ -729,7 +737,6 @@ p_theta_greta <- function( # now loop through, sampling and updating x and returning theta for (i in 2:niter) { - # update the progress bar cli::cli_progress_update() @@ -757,7 +764,6 @@ p_theta_greta <- function( # trace the sample theta[i] <- tail(as.numeric(draws[[1]]), 1) - } # kill the progress_bar @@ -808,8 +814,10 @@ get_enough_draws <- function( one_by_one = one_by_one ) - while (need_more_samples(draws, n_effective) && - still_have_time(start_time, time_limit)) { + while ( + need_more_samples(draws, n_effective) && + still_have_time(start_time, time_limit) + ) { n_samples <- new_samples(draws, n_effective) draws <- extra_samples( draws, @@ -866,7 +874,7 @@ scaled_error <- function(draws, expectation) { check_mvn_samples <- function(sampler, n_effective = 3000) { # get multivariate normal samples mu <- as_data(t(rnorm(2, 0, 5))) - sigma <- stats::rWishart(1, 3, diag(2))[, , 1] + sigma <- stats::rWishart(1, 3, diag(2))[,, 1] x <- multivariate_normal(mu, sigma) m <- model(x, precision = "single") @@ -908,9 +916,9 @@ do_thinning <- function(x, thinning = 1) { } -get_distribution_name <- function(x){ +get_distribution_name <- function(x) { x_node <- get_node(x) - if (inherits(x_node, "operation_node")){ + if (inherits(x_node, "operation_node")) { dist_name <- x_node$parents[[1]]$distribution$distribution_name } else { dist_name <- get_node(x)$distribution$distribution_name @@ -959,8 +967,7 @@ check_samples <- function( ) } -qqplot_checked_samples <- function(checked_samples, title){ - +qqplot_checked_samples <- function(checked_samples, title) { distrib <- checked_samples$distrib sampler_name <- checked_samples$sampler_name title <- paste(distrib, "with", sampler_name) @@ -972,16 +979,17 @@ qqplot_checked_samples <- function(checked_samples, title){ x = mcmc_samples, y = iid_samples, main = title - ) + ) graphics::abline(0, 1) } ## helpers for running Kolmogorov-Smirnov test for MCMC samples vs IID samples -ks_test_mcmc_vs_iid <- function(checked_samples){ +ks_test_mcmc_vs_iid <- function(checked_samples) { # do a formal hypothesis test - suppressWarnings(stat <- ks.test(checked_samples$mcmc_samples, - checked_samples$iid_samples)) + suppressWarnings( + stat <- ks.test(checked_samples$mcmc_samples, checked_samples$iid_samples) + ) stat } @@ -1023,9 +1031,10 @@ tidy_optimisers <- function(opt_df, tolerance = 1e-2) { par_x_diff = purrr::map2( .x = par, .y = x_val, - .f = function(.x, .y){ + .f = function(.x, .y) { abs(.y - .x) - }), + } + ), close_to_truth = purrr::map_lgl( par_x_diff, function(x) all(x < tolerance) diff --git a/tests/testthat/test-diagrammer-installed.R b/tests/testthat/test-diagrammer-installed.R index 7577b333..028a281c 100644 --- a/tests/testthat/test-diagrammer-installed.R +++ b/tests/testthat/test-diagrammer-installed.R @@ -4,9 +4,9 @@ test_that("DiagrammeR installation is checked", { local_mocked_bindings( is_DiagrammeR_installed = function() FALSE ) - m <- model(normal(0,1)) - expect_snapshot( - error = TRUE, - x = plot(m) - ) + m <- model(normal(0, 1)) + expect_snapshot( + error = TRUE, + x = plot(m) + ) }) diff --git a/tests/testthat/test-greta-sitrep.R b/tests/testthat/test-greta-sitrep.R index 3f45ec0d..fd346f0c 100644 --- a/tests/testthat/test-greta-sitrep.R +++ b/tests/testthat/test-greta-sitrep.R @@ -12,31 +12,30 @@ test_that("check_tf_version errors when have_python, _tf, or _tfp is FALSE", { have_tfp = function(...) FALSE ) - expect_snapshot( - error = TRUE, - check_tf_version("error") - ) - - expect_snapshot_warning( - check_tf_version("warn") - ) + expect_snapshot( + error = TRUE, + check_tf_version("error") + ) - expect_snapshot( - check_tf_version("message") - ) + expect_snapshot_warning( + check_tf_version("warn") + ) + expect_snapshot( + check_tf_version("message") + ) }) test_that("check_tf_version fails when tfp not available", { - greta_stash$python_has_been_initialised <- FALSE - local_mocked_bindings( - have_tfp = function(...) FALSE - ) - expect_snapshot( - error = TRUE, - check_tf_version("error") - ) - }) + greta_stash$python_has_been_initialised <- FALSE + local_mocked_bindings( + have_tfp = function(...) FALSE + ) + expect_snapshot( + error = TRUE, + check_tf_version("error") + ) +}) test_that("greta_sitrep warns when have_python, _tf, or _tfp is FALSE", { skip_if_not(check_tf_version()) @@ -111,7 +110,6 @@ test_that("greta_sitrep warns when have_python, _tf, or _tfp is FALSE", { expect_snapshot( greta_sitrep() ) - }) test_that("greta_sitrep warns when different versions of python, tf, tfp", { @@ -142,7 +140,6 @@ test_that("greta_sitrep warns when different versions of python, tf, tfp", { expect_snapshot( greta_sitrep() ) - }) test_that("greta_sitrep warns greta conda env not available", { @@ -180,5 +177,4 @@ test_that("greta_sitrep works with quiet, minimal, and detailed options", { expect_snapshot( greta_sitrep(verbosity = "detailed") ) - }) diff --git a/tests/testthat/test-message_if_using_gpu.R b/tests/testthat/test-message_if_using_gpu.R index 33c1ec06..ea6eb315 100644 --- a/tests/testthat/test-message_if_using_gpu.R +++ b/tests/testthat/test-message_if_using_gpu.R @@ -8,7 +8,7 @@ test_that("message_if_using_gpu gives the correct message for cpu or gpu use", { ) }) -test_that("message_if_using_gpu does not message when option set",{ +test_that("message_if_using_gpu does not message when option set", { skip_if_not(check_tf_version()) withr::local_options( list("greta_gpu_message" = FALSE) @@ -17,10 +17,9 @@ test_that("message_if_using_gpu does not message when option set",{ expect_snapshot( message_if_using_gpu(gpu_only()) ) - }) -test_that("message_if_using_gpu does message when option set",{ +test_that("message_if_using_gpu does message when option set", { skip_if_not(check_tf_version()) withr::local_options( list("greta_gpu_message" = TRUE) @@ -29,10 +28,9 @@ test_that("message_if_using_gpu does message when option set",{ expect_snapshot( message_if_using_gpu(gpu_only()) ) - }) -test_that("is_using_gpu and is_using_cpu work",{ +test_that("is_using_gpu and is_using_cpu work", { skip_if_not(check_tf_version()) expect_true(is_using_gpu(gpu_only())) expect_false(is_using_gpu(cpu_only())) @@ -52,10 +50,9 @@ test_that("calculate provides a message when GPU is set", { expect_snapshot( calc_x <- calculate(x, nsim = 1, compute_options = cpu_only()) ) - }) -test_that("calculate/mcmc does not message when option set",{ +test_that("calculate/mcmc does not message when option set", { skip_if_not(check_tf_version()) withr::local_options( list("greta_gpu_message" = FALSE) @@ -70,16 +67,17 @@ test_that("calculate/mcmc does not message when option set",{ m <- model(x) expect_snapshot( - mcmc_m <- mcmc(model = m, - n_samples = 1, - warmup = 0, - compute_options = gpu_only(), - verbose = FALSE) + mcmc_m <- mcmc( + model = m, + n_samples = 1, + warmup = 0, + compute_options = gpu_only(), + verbose = FALSE + ) ) - }) -test_that("calculate/mcmc does message when option set",{ +test_that("calculate/mcmc does message when option set", { skip_if_not(check_tf_version()) withr::local_options( list("greta_gpu_message" = TRUE) @@ -94,13 +92,14 @@ test_that("calculate/mcmc does message when option set",{ m <- model(x) expect_snapshot( - mcmc_m <- mcmc(model = m, - n_samples = 1, - warmup = 0, - compute_options = gpu_only(), - verbose = FALSE) + mcmc_m <- mcmc( + model = m, + n_samples = 1, + warmup = 0, + compute_options = gpu_only(), + verbose = FALSE + ) ) - }) test_that("mcmc provides a message when GPU is set", { @@ -110,27 +109,30 @@ test_that("mcmc provides a message when GPU is set", { m <- model(x) expect_snapshot( - mcmc_gpu <- mcmc(model = m, - n_samples = 1, - warmup = 0, - compute_options = gpu_only(), - verbose = FALSE) + mcmc_gpu <- mcmc( + model = m, + n_samples = 1, + warmup = 0, + compute_options = gpu_only(), + verbose = FALSE + ) ) expect_snapshot( - mcmc_cpu <- mcmc(model = m, - n_samples = 1, - warmup = 0, - compute_options = cpu_only(), - verbose = FALSE) + mcmc_cpu <- mcmc( + model = m, + n_samples = 1, + warmup = 0, + compute_options = cpu_only(), + verbose = FALSE + ) ) - }) test_that("mcmc prints out CPU and GPU text", { skip_if_not(check_tf_version()) - x <- normal(0,1) + x <- normal(0, 1) m <- model(x) # removed snapshot testing as it was too fickle cpu_output <- get_output( diff --git a/tests/testthat/test-print_calculate.R b/tests/testthat/test-print_calculate.R index 8dfc4c31..9ffa8133 100644 --- a/tests/testthat/test-print_calculate.R +++ b/tests/testthat/test-print_calculate.R @@ -4,9 +4,9 @@ test_that("calculate print method is different for different inputs", { skip_on_cran() skip_on_ci() - x <- normal(0,1) + x <- normal(0, 1) m <- model(x) - new_seed <- 2024-11-07-14-01 + new_seed <- 2024 - 11 - 07 - 14 - 01 x_sim_10 <- calculate(x, nsim = 10, seed = new_seed) expect_snapshot(names(x_sim_10)) expect_snapshot(dim(x_sim_10$x)) diff --git a/tests/testthat/test-tensorflow-rpkg-stability.R b/tests/testthat/test-tensorflow-rpkg-stability.R index 553102ae..a2ae6944 100644 --- a/tests/testthat/test-tensorflow-rpkg-stability.R +++ b/tests/testthat/test-tensorflow-rpkg-stability.R @@ -11,11 +11,11 @@ test_that("tensorflow returns appropriate thing with 'dim'", { expect_identical(dim(xt_int_64), integer(0)) expect_identical(dim(xt_float_32), integer(0)) expect_identical(dim(xt_float_32_dec), integer(0)) - expect_null(dim(shape(1,2,3))) + expect_null(dim(shape(1, 2, 3))) expect_identical(dim(tensorflow::as_tensor(c(1:3))), 3L) }) -test_that("Tensor behaves as we expect",{ +test_that("Tensor behaves as we expect", { skip_if_not(check_tf_version()) x <- tensorflow::as_tensor(42, "int32") expect_snapshot(length(x)) @@ -34,7 +34,7 @@ test_that("shape returns right thing", { expect_snapshot(shape(3, 4)) expect_snapshot(shape(NA, 4)) expect_snapshot(shape(dims = c(NA, 4))) - expect_snapshot(shape(1,1,1)) + expect_snapshot(shape(1, 1, 1)) expect_null(dim(shape())) expect_null(dim(shape(NULL))) expect_null(dim(shape(NA))) @@ -97,15 +97,15 @@ test_that("shape returns appropriate TensorShape object", { expect_snapshot(as.integer(shape(1, 3))) expect_snapshot(as.numeric(shape(1, 3))) expect_snapshot(as.double(shape(1, 3))) - expect_snapshot(shape(1, 3) == shape(1,3)) - expect_snapshot(shape(1, 3) == shape(1,2)) - expect_snapshot(shape(1, 3) != shape(1,3)) - expect_snapshot(shape(1, 3) != shape(1,2)) + expect_snapshot(shape(1, 3) == shape(1, 3)) + expect_snapshot(shape(1, 3) == shape(1, 2)) + expect_snapshot(shape(1, 3) != shape(1, 3)) + expect_snapshot(shape(1, 3) != shape(1, 2)) }) test_that("[, [[, and assignment returns right object", { skip_if_not(check_tf_version()) - x_extract <- shape(1,2,3) + x_extract <- shape(1, 2, 3) expect_snapshot(x_extract[1]) expect_snapshot(x_extract[[1]]) expect_snapshot(x_extract[2:3]) @@ -113,18 +113,17 @@ test_that("[, [[, and assignment returns right object", { expect_snapshot({ x_extract[1] <- 11 x_extract[1] - }) + }) expect_snapshot({ x_extract[1] <- shape(11) x_extract[1] - }) + }) expect_snapshot({ x_extract[1] <- list(11) x_extract[1] - }) + }) }) - # other parts to test: # batch_size <- tf$shape(x)[[0]] # shape_list <- c(list(batch_size), as.integer(to_shape(dims_out))) diff --git a/tests/testthat/test_as_data.R b/tests/testthat/test_as_data.R index 94230ebe..2934d518 100644 --- a/tests/testthat/test_as_data.R +++ b/tests/testthat/test_as_data.R @@ -72,13 +72,19 @@ test_that("as_data coerces correctly", { int_df <- as.data.frame(int_mat) num_df <- as.data.frame(num_mat) - expect_true(is.data.frame(log_df) & - all(vapply(log_df, is.logical, FALSE))) - expect_true(is.data.frame(int_df) & - all(vapply(int_df, is.numeric, FALSE)) & - all(vapply(int_df, is.integer, FALSE))) - expect_true(is.data.frame(num_df) & - all(vapply(num_df, is.numeric, FALSE))) + expect_true( + is.data.frame(log_df) & + all(vapply(log_df, is.logical, FALSE)) + ) + expect_true( + is.data.frame(int_df) & + all(vapply(int_df, is.numeric, FALSE)) & + all(vapply(int_df, is.integer, FALSE)) + ) + expect_true( + is.data.frame(num_df) & + all(vapply(num_df, is.numeric, FALSE)) + ) ga_log_df <- as_data(log_df) ga_int_df <- as_data(int_df) @@ -105,7 +111,6 @@ test_that("as_data coerces correctly", { test_that("as_data errors informatively", { skip_if_not(check_tf_version()) - # wrong class of object expect_snapshot( error = TRUE, diff --git a/tests/testthat/test_calculate.R b/tests/testthat/test_calculate.R index 0efaceb7..4f067a00 100644 --- a/tests/testthat/test_calculate.R +++ b/tests/testthat/test_calculate.R @@ -87,7 +87,8 @@ test_that("stochastic calculate works with correct lists", { y <- as_data(y) a <- normal(0, 1, dim = c(1, k)) distribution(y) <- categorical(ilogit(a * x), n_realisations = n) - sims <- calculate(y, + sims <- calculate( + y, nsim = nsim, values = list( a = c(50, 5, 0.5), @@ -185,9 +186,7 @@ test_that("stochastic calculate works with greta_mcmc_list objects", { ) # this should error without nsim being specified (y is stochastic) - expect_snapshot(error = TRUE, - calc_a <- calculate(a, y, values = draws) - ) + expect_snapshot(error = TRUE, calc_a <- calculate(a, y, values = draws)) # this should be OK sims <- calculate(y, values = draws, nsim = 10) @@ -215,7 +214,6 @@ test_that("stochastic calculate works with greta_mcmc_list objects", { expect_snapshot_warning( new_y <- calculate(y, values = draws, nsim = samples * chains + 1) ) - }) test_that("calculate errors if the mcmc samples unrelated to target", { @@ -240,9 +238,7 @@ test_that("calculate errors if the mcmc samples unrelated to target", { c <- normal(0, 1) - expect_snapshot(error = TRUE, - calc_c <- calculate(c, values = draws) - ) + expect_snapshot(error = TRUE, calc_c <- calculate(c, values = draws)) }) test_that("stochastic calculate works with mcmc samples & new stochastics", { @@ -270,9 +266,7 @@ test_that("stochastic calculate works with mcmc samples & new stochastics", { # this should error without nsim being specified (b is stochastic and not # given by draws) - expect_snapshot(error = TRUE, - calc_b <- calculate(b, values = draws) - ) + expect_snapshot(error = TRUE, calc_b <- calculate(b, values = draws)) sims <- calculate(b, values = draws, nsim = 10) expect_identical(dim(sims$b), c(10L, dim(b))) @@ -287,15 +281,13 @@ test_that("calculate errors nicely if non-greta arrays are passed", { y <- a * x # it should error nicely - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, calc_y <- calculate(y, x, values = list(x = c(2, 1))) ) # and a hint for this common error - expect_snapshot(error = TRUE, - calc_y <- calculate(y, list(x = c(2, 1))) - ) - + expect_snapshot(error = TRUE, calc_y <- calculate(y, list(x = c(2, 1)))) }) test_that("calculate errors nicely if values for stochastics not passed", { @@ -306,7 +298,8 @@ test_that("calculate errors nicely if values for stochastics not passed", { y <- a * x # it should error nicely - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, calc_y <- calculate(y, values = list(x = c(2, 1))) ) @@ -322,7 +315,8 @@ test_that("calculate errors nicely if values have incorrect dimensions", { y <- a * x # it should error nicely - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, calc_y <- calculate(y, values = list(a = c(1, 1))) ) }) @@ -365,13 +359,16 @@ test_that("calculate errors nicely with invalid batch sizes", { draws <- mcmc(m, warmup = 0, n_samples = samples, verbose = FALSE) # variable valid batch sizes - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, calc_y <- calculate(y, values = draws, trace_batch_size = 0) ) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, calc_y <- calculate(y, values = draws, trace_batch_size = NULL) ) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, calc_y <- calculate(y, values = draws, trace_batch_size = NA) ) }) @@ -419,13 +416,11 @@ test_that("calculate produces the right number of samples", { sims <- calculate(y, nsim = 19) expect_identical(dim(sims$y), c(19L, dim(y))) - }) test_that("calculate works if distribution-free variables are fixed", { skip_if_not(check_tf_version()) - # fix variable a <- variable() y <- normal(a, 1) @@ -440,9 +435,7 @@ test_that("calculate errors if distribution-free variables are not fixed", { # fix variable a <- variable() y <- normal(a, 1) - expect_snapshot(error = TRUE, - calc_a <- calculate(a, y, nsim = 1) - ) + expect_snapshot(error = TRUE, calc_a <- calculate(a, y, nsim = 1)) }) test_that("calculate errors if a distribution cannot be sampled from", { @@ -450,9 +443,7 @@ test_that("calculate errors if a distribution cannot be sampled from", { # fix variable y <- hypergeometric(5, 3, 2) - expect_snapshot(error = TRUE, - sims <- calculate(y, nsim = 1) - ) + expect_snapshot(error = TRUE, sims <- calculate(y, nsim = 1)) }) test_that("calculate errors nicely if nsim is invalid", { @@ -460,15 +451,9 @@ test_that("calculate errors nicely if nsim is invalid", { x <- normal(0, 1) - expect_snapshot(error = TRUE, - calc_x <- calculate(x, nsim = 0) - ) + expect_snapshot(error = TRUE, calc_x <- calculate(x, nsim = 0)) - expect_snapshot(error = TRUE, - calc_x <- calculate(x, nsim = -1) - ) + expect_snapshot(error = TRUE, calc_x <- calculate(x, nsim = -1)) - expect_snapshot(error = TRUE, - calc_x <- calculate(x, nsim = "five") - ) + expect_snapshot(error = TRUE, calc_x <- calculate(x, nsim = "five")) }) diff --git a/tests/testthat/test_distributions.R b/tests/testthat/test_distributions.R index 69f5d200..38af0b2b 100644 --- a/tests/testthat/test_distributions.R +++ b/tests/testthat/test_distributions.R @@ -273,10 +273,11 @@ test_that("multivariate normal distribution has correct density", { # parameters to test m <- 5 mn <- t(rnorm(m)) - sig <- rWishart(1, m + 1, diag(m))[, , 1] + sig <- rWishart(1, m + 1, diag(m))[,, 1] # function converting Sigma to sigma - dmvnorm2 <- function(x, mean, Sigma, log = FALSE) { # nolint + dmvnorm2 <- function(x, mean, Sigma, log = FALSE) { + # nolint mvtnorm::dmvnorm(x = x, mean = mean, sigma = Sigma, log = log) } @@ -301,7 +302,6 @@ test_that("Wishart and LKJ distributions have correct density", { # distributions are tested using integration tests with the MCMC sampler, in # test_posteriors_wishart.R and test_posteriors_lkj.R. skip() - }) test_that("multinomial distribution has correct density", { @@ -456,7 +456,7 @@ test_that("scalar-valued distributions can be defined in models", { expect_ok(model(uniform(-13, 2.4))) # multivariate continuous distributions - sig <- rWishart(1, 4, diag(3))[, , 1] + sig <- rWishart(1, 4, diag(3))[,, 1] expect_ok(model(multivariate_normal(t(rnorm(3)), sig))) expect_ok(model(wishart(4, sig))) @@ -536,7 +536,7 @@ test_that("array-valued distributions can be defined in models", { expect_ok(model(f(24.3, 2.4, dim = dim))) # multivariate continuous distributions - sig <- rWishart(1, 4, diag(3))[, , 1] + sig <- rWishart(1, 4, diag(3))[,, 1] expect_ok( model(multivariate_normal(t(rnorm(3)), sig, n_realisations = dim[1])) ) @@ -629,7 +629,7 @@ test_that("distributions can be sampled from by MCMC", { sample_distribution(uniform(-13, 2.4), lower = -13, upper = 2.4) # multivariate continuous - sig <- rWishart(1, 4, diag(3))[, , 1] + sig <- rWishart(1, 4, diag(3))[,, 1] sample_distribution(multivariate_normal(t(rnorm(3)), sig)) sample_distribution(wishart(10L, Sig = diag(2)), warmup = 0) sample_distribution(lkj_correlation(4, dimension = 3)) @@ -674,23 +674,15 @@ test_that("poisson() and binomial() error informatively in glm", { skip_if_not(check_tf_version()) # if passed as an object - expect_snapshot(error = TRUE, - glm(1 ~ 1, family = poisson) - ) + expect_snapshot(error = TRUE, glm(1 ~ 1, family = poisson)) - expect_snapshot(error = TRUE, - glm(1 ~ 1, family = binomial) - ) + expect_snapshot(error = TRUE, glm(1 ~ 1, family = binomial)) # if executed alone - expect_snapshot(error = TRUE, - glm(1 ~ 1, family = poisson()) - ) + expect_snapshot(error = TRUE, glm(1 ~ 1, family = poisson())) # if given a link - expect_snapshot(error = TRUE, - glm(1 ~ 1, family = poisson("sqrt")) - ) + expect_snapshot(error = TRUE, glm(1 ~ 1, family = poisson("sqrt"))) }) test_that("wishart distribution errors informatively", { @@ -726,29 +718,17 @@ test_that("lkj_correlation distribution errors informatively", { "greta_array" )) - expect_snapshot(error = TRUE, - lkj_correlation(-1, dim) - ) + expect_snapshot(error = TRUE, lkj_correlation(-1, dim)) - expect_snapshot(error = TRUE, - lkj_correlation(c(3, 3), dim) - ) + expect_snapshot(error = TRUE, lkj_correlation(c(3, 3), dim)) - expect_snapshot(error = TRUE, - lkj_correlation(uniform(0, 1, dim = 2), dim) - ) + expect_snapshot(error = TRUE, lkj_correlation(uniform(0, 1, dim = 2), dim)) - expect_snapshot(error = TRUE, - lkj_correlation(4, dimension = -1) - ) + expect_snapshot(error = TRUE, lkj_correlation(4, dimension = -1)) - expect_snapshot(error = TRUE, - lkj_correlation(4, dim = c(3, 3)) - ) + expect_snapshot(error = TRUE, lkj_correlation(4, dim = c(3, 3))) - expect_snapshot(error = TRUE, - lkj_correlation(4, dim = NA) - ) + expect_snapshot(error = TRUE, lkj_correlation(4, dim = NA)) }) test_that("multivariate_normal distribution errors informatively", { @@ -776,13 +756,9 @@ test_that("multivariate_normal distribution errors informatively", { )) # bad means - expect_snapshot(error = TRUE, - multivariate_normal(m_c, a) - ) + expect_snapshot(error = TRUE, multivariate_normal(m_c, a)) - expect_snapshot(error = TRUE, - multivariate_normal(m_d, a) - ) + expect_snapshot(error = TRUE, multivariate_normal(m_d, a)) # good sigmas expect_true(inherits( @@ -791,39 +767,32 @@ test_that("multivariate_normal distribution errors informatively", { )) # bad sigmas - expect_snapshot(error = TRUE, - multivariate_normal(m_a, b) - ) + expect_snapshot(error = TRUE, multivariate_normal(m_a, b)) - expect_snapshot(error = TRUE, - multivariate_normal(m_a, c) - ) + expect_snapshot(error = TRUE, multivariate_normal(m_a, c)) # mismatched parameters - expect_snapshot(error = TRUE, - multivariate_normal(m_a, d) - ) + expect_snapshot(error = TRUE, multivariate_normal(m_a, d)) # scalars - expect_snapshot(error = TRUE, - multivariate_normal(0, 1) - ) + expect_snapshot(error = TRUE, multivariate_normal(0, 1)) # bad n_realisations - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, multivariate_normal(m_a, a, n_realisations = -1) ) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, multivariate_normal(m_a, a, n_realisations = c(1, 3)) ) # bad dimension - expect_snapshot(error = TRUE, - multivariate_normal(m_a, a, dimension = -1) - ) + expect_snapshot(error = TRUE, multivariate_normal(m_a, a, dimension = -1)) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, multivariate_normal(m_a, a, dimension = c(1, 3)) ) }) @@ -858,27 +827,17 @@ test_that("multinomial distribution errors informatively", { )) # scalars - expect_snapshot(error = TRUE, - multinomial(c(1), 1) - ) + expect_snapshot(error = TRUE, multinomial(c(1), 1)) # bad n_realisations - expect_snapshot(error = TRUE, - multinomial(10, p_a, n_realisations = -1) - ) + expect_snapshot(error = TRUE, multinomial(10, p_a, n_realisations = -1)) - expect_snapshot(error = TRUE, - multinomial(10, p_a, n_realisations = c(1, 3)) - ) + expect_snapshot(error = TRUE, multinomial(10, p_a, n_realisations = c(1, 3))) # bad dimension - expect_snapshot(error = TRUE, - multinomial(10, p_a, dimension = -1) - ) + expect_snapshot(error = TRUE, multinomial(10, p_a, dimension = -1)) - expect_snapshot(error = TRUE, - multinomial(10, p_a, dimension = c(1, 3)) - ) + expect_snapshot(error = TRUE, multinomial(10, p_a, dimension = c(1, 3))) }) test_that("categorical distribution errors informatively", { @@ -899,27 +858,17 @@ test_that("categorical distribution errors informatively", { )) # scalars - expect_snapshot(error = TRUE, - categorical(1) - ) + expect_snapshot(error = TRUE, categorical(1)) # bad n_realisations - expect_snapshot(error = TRUE, - categorical(p_a, n_realisations = -1) - ) + expect_snapshot(error = TRUE, categorical(p_a, n_realisations = -1)) - expect_snapshot(error = TRUE, - categorical(p_a, n_realisations = c(1, 3)) - ) + expect_snapshot(error = TRUE, categorical(p_a, n_realisations = c(1, 3))) # bad dimension - expect_snapshot(error = TRUE, - categorical(p_a, dimension = -1) - ) + expect_snapshot(error = TRUE, categorical(p_a, dimension = -1)) - expect_snapshot(error = TRUE, - categorical(p_a, dimension = c(1, 3)) - ) + expect_snapshot(error = TRUE, categorical(p_a, dimension = c(1, 3))) }) test_that("dirichlet distribution errors informatively", { @@ -934,34 +883,23 @@ test_that("dirichlet distribution errors informatively", { "greta_array" )) - expect_true(inherits( dirichlet(alpha_b), "greta_array" )) # scalars - expect_snapshot(error = TRUE, - dirichlet(1) - ) + expect_snapshot(error = TRUE, dirichlet(1)) # bad n_realisations - expect_snapshot(error = TRUE, - dirichlet(alpha_a, n_realisations = -1) - ) + expect_snapshot(error = TRUE, dirichlet(alpha_a, n_realisations = -1)) - expect_snapshot(error = TRUE, - dirichlet(alpha_a, n_realisations = c(1, 3)) - ) + expect_snapshot(error = TRUE, dirichlet(alpha_a, n_realisations = c(1, 3))) # bad dimension - expect_snapshot(error = TRUE, - dirichlet(alpha_a, dimension = -1) - ) + expect_snapshot(error = TRUE, dirichlet(alpha_a, dimension = -1)) - expect_snapshot(error = TRUE, - dirichlet(alpha_a, dimension = c(1, 3)) - ) + expect_snapshot(error = TRUE, dirichlet(alpha_a, dimension = c(1, 3))) }) @@ -1007,25 +945,27 @@ test_that("dirichlet-multinomial distribution errors informatively", { )) # scalars - expect_snapshot(error = TRUE, - dirichlet_multinomial(c(1), 1) - ) + expect_snapshot(error = TRUE, dirichlet_multinomial(c(1), 1)) # bad n_realisations - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, dirichlet_multinomial(10, alpha_a, n_realisations = -1) ) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, dirichlet_multinomial(10, alpha_a, n_realisations = c(1, 3)) ) # bad dimension - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, dirichlet_multinomial(10, alpha_a, dimension = -1) ) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, dirichlet_multinomial(10, alpha_a, dimension = c(1, 3)) ) }) diff --git a/tests/testthat/test_distributions_cholesky.R b/tests/testthat/test_distributions_cholesky.R index a5ccac0b..ec5747c5 100644 --- a/tests/testthat/test_distributions_cholesky.R +++ b/tests/testthat/test_distributions_cholesky.R @@ -6,7 +6,7 @@ test_that("Wishart can use a choleskied Sigma", { sig <- lkj_correlation(2, dim = 2) w <- wishart(5, sig) m <- model(w, precision = "double") - tensorflow::set_random_seed(2024-07-30-1520) + tensorflow::set_random_seed(2024 - 07 - 30 - 1520) expect_ok(draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE)) }) @@ -17,7 +17,7 @@ test_that("Cholesky factor of Wishart should be a lower triangular matrix", { x <- wishart(df = 4, Sigma = diag(3)) chol_x <- chol(x) expect_snapshot( - calculate(chol_x, nsim = 1, seed = 2024-10-31-1338)$chol_x[1,,] + calculate(chol_x, nsim = 1, seed = 2024 - 10 - 31 - 1338)$chol_x[1, , ] ) calc_x <- calculate(x, nsim = 1) calc_chol <- calculate(chol_x, nsim = 1) @@ -28,12 +28,12 @@ test_that("Cholesky factor of Wishart should be a lower triangular matrix", { ## Test if we do calculate on x and chol_x x <- wishart(df = 4, Sigma = diag(3)) chol_x <- chol(x) - calc_chol <- calculate(x, chol_x, nsim = 1, seed = 2024-10-31-1342) + calc_chol <- calculate(x, chol_x, nsim = 1, seed = 2024 - 10 - 31 - 1342) expect_snapshot( - calc_chol$x[1,,] + calc_chol$x[1, , ] ) expect_snapshot( - calc_chol$chol_x[1,,] + calc_chol$chol_x[1, , ] ) expect_square(calc_chol$chol_x) expect_upper_tri(calc_chol$chol_x) @@ -46,10 +46,10 @@ test_that("Cholesky factor of LJK_correlation should be a lower triangular matri x <- lkj_correlation(eta = 3, dimension = 3) chol_x <- chol(x) expect_snapshot( - calculate(chol_x, nsim = 1, seed = 2024-07-30-1431)$chol_x[1,,] + calculate(chol_x, nsim = 1, seed = 2024 - 07 - 30 - 1431)$chol_x[1, , ] ) - calc_x <- calculate(x, nsim = 1, seed = 2024-07-30-1431) - calc_chol <- calculate(chol_x, nsim = 1, seed = 2024-07-30-1431) + calc_x <- calculate(x, nsim = 1, seed = 2024 - 07 - 30 - 1431) + calc_chol <- calculate(chol_x, nsim = 1, seed = 2024 - 07 - 30 - 1431) expect_upper_tri(calc_chol$chol_x) expect_square(calc_chol$chol_x) @@ -57,26 +57,26 @@ test_that("Cholesky factor of LJK_correlation should be a lower triangular matri ## Test if we do calculate on x and chol_x x <- lkj_correlation(eta = 3, dimension = 3) chol_x <- chol(x) - calc_chol <- calculate(x, chol_x, nsim = 1, seed = 2024-07-30-1431) + calc_chol <- calculate(x, chol_x, nsim = 1, seed = 2024 - 07 - 30 - 1431) expect_snapshot( - calc_chol$x[1,,] + calc_chol$x[1, , ] ) expect_snapshot( - calc_chol$chol_x[1,,] + calc_chol$chol_x[1, , ] ) expect_square(calc_chol$chol_x) expect_upper_tri(calc_chol$chol_x) }) -test_that("Post-MCMC, Wishart distribution stays symmetric, chol remains lower tri",{ +test_that("Post-MCMC, Wishart distribution stays symmetric, chol remains lower tri", { skip_if_not(check_tf_version()) -# From https://github.com/greta-dev/greta/issues/585 + # From https://github.com/greta-dev/greta/issues/585 x <- wishart(df = 4, Sigma = diag(3)) m <- model(x) - tensorflow::set_random_seed(2024-07-30-1431) + tensorflow::set_random_seed(2024 - 07 - 30 - 1431) draws <- mcmc(m, warmup = 1, n_samples = 1) - calcs <- calculate(x, chol(x), nsim = 1, seed = 2024-07-30-1431) + calcs <- calculate(x, chol(x), nsim = 1, seed = 2024 - 07 - 30 - 1431) # ensure that the symmetric matrix is still symmetric expect_snapshot( calcs @@ -86,18 +86,17 @@ test_that("Post-MCMC, Wishart distribution stays symmetric, chol remains lower t expect_square(calcs$`chol(x)`) expect_symmetric(calcs$x) expect_upper_tri(calcs$`chol(x)`) - }) -test_that("Post-MCMC, LKJ distribution stays symmetric, chol remains lower tri",{ +test_that("Post-MCMC, LKJ distribution stays symmetric, chol remains lower tri", { skip_if_not(check_tf_version()) # From https://github.com/greta-dev/greta/issues/585 x <- lkj_correlation(eta = 3, dimension = 3) m <- model(x) - tensorflow::set_random_seed(2024-07-30-1431) + tensorflow::set_random_seed(2024 - 07 - 30 - 1431) draws <- mcmc(m, warmup = 1, n_samples = 1) - calcs <- calculate(x, chol(x), nsim = 1, seed = 2024-07-30-1431) + calcs <- calculate(x, chol(x), nsim = 1, seed = 2024 - 07 - 30 - 1431) # ensure that the symmetric matrix is still symmetric expect_snapshot( calcs @@ -107,5 +106,4 @@ test_that("Post-MCMC, LKJ distribution stays symmetric, chol remains lower tri", expect_square(calcs$`chol(x)`) expect_symmetric(calcs$x) expect_upper_tri(calcs$`chol(x)`) - }) diff --git a/tests/testthat/test_extract_replace_combine.R b/tests/testthat/test_extract_replace_combine.R index 43540d28..93740e78 100644 --- a/tests/testthat/test_extract_replace_combine.R +++ b/tests/testthat/test_extract_replace_combine.R @@ -31,12 +31,12 @@ test_that("extract works like R", { check_expr(d[2:1, , c(TRUE, FALSE), drop = FALSE], "d") # can extract with missing entries in various places - check_expr(d[, , 2:1], "d") + check_expr(d[,, 2:1], "d") check_expr(d[, 2:1, ], "d") check_expr(d[2:1, , ], "d") # can extract single elements without dropping dimensions - check_expr(d[, , 1, drop = FALSE], "d") + check_expr(d[,, 1, drop = FALSE], "d") check_expr(d[, 1, , drop = FALSE], "d") check_expr(d[1, , , drop = FALSE], "d") @@ -152,7 +152,7 @@ test_that("replace works like R", { # can assign with missing entries in various places x <- randn(10, 2, 2) check_expr({ - x[, , 1:2] <- seq_len(10 * 2 * 2) + x[,, 1:2] <- seq_len(10 * 2 * 2) x }) @@ -183,7 +183,7 @@ test_that("replace works like R", { x <- randn(10, 2, 2) check_expr({ - x[, , 1] <- seq_len(10 * 2 * 1) + x[,, 1] <- seq_len(10 * 2 * 1) x }) @@ -388,14 +388,9 @@ test_that("abind errors informatively", { b <- ones(1, 1, 3) c <- ones(5, 1, 1) - expect_snapshot(error = TRUE, - abind(a, b) - ) - - expect_snapshot(error = TRUE, - abind(a, c, along = 5) - ) + expect_snapshot(error = TRUE, abind(a, b)) + expect_snapshot(error = TRUE, abind(a, c, along = 5)) }) test_that("rbind and cbind can prepend R arrays to greta arrays", { @@ -425,9 +420,7 @@ test_that("assign errors on variable greta arrays", { skip_if_not(check_tf_version()) z <- normal(0, 1, dim = 5) - expect_snapshot(error = TRUE, - z[1] <- 3 - ) + expect_snapshot(error = TRUE, z[1] <- 3) }) test_that("rbind and cbind give informative error messages", { @@ -436,45 +429,31 @@ test_that("rbind and cbind give informative error messages", { a <- as_data(randn(5, 1)) b <- as_data(randn(1, 5)) - expect_snapshot(error = TRUE, - rbind(a, b) - ) + expect_snapshot(error = TRUE, rbind(a, b)) - expect_snapshot(error = TRUE, - cbind(a, b) - ) + expect_snapshot(error = TRUE, cbind(a, b)) }) test_that("replacement gives informative error messages", { skip_if_not(check_tf_version()) x <- ones(2, 2, 2) - expect_snapshot(error = TRUE, - x[1:2, , 1] <- seq_len(3) - ) + expect_snapshot(error = TRUE, x[1:2, , 1] <- seq_len(3)) - expect_snapshot(error = TRUE, - x[1, 1, 3] <- 1 - ) + expect_snapshot(error = TRUE, x[1, 1, 3] <- 1) x <- ones(2) - expect_snapshot(error = TRUE, - x[3] <- 1 - ) + expect_snapshot(error = TRUE, x[3] <- 1) }) test_that("extraction gives informative error messages", { skip_if_not(check_tf_version()) x <- ones(2, 2, 2) - expect_snapshot(error = TRUE, - x[1, 1, 3] - ) + expect_snapshot(error = TRUE, x[1, 1, 3]) x <- ones(2) - expect_snapshot(error = TRUE, - x[3] - ) + expect_snapshot(error = TRUE, x[3]) }) test_that("stochastic and operation greta arrays can be extracted", { @@ -603,21 +582,13 @@ test_that("dim<- errors as expected", { x <- zeros(3, 4) - expect_snapshot(error = TRUE, - dim(x) <- pi[0] - ) + expect_snapshot(error = TRUE, dim(x) <- pi[0]) - expect_snapshot(error = TRUE, - dim(x) <- c(1, NA) - ) + expect_snapshot(error = TRUE, dim(x) <- c(1, NA)) - expect_snapshot(error = TRUE, - dim(x) <- c(1, -1) - ) + expect_snapshot(error = TRUE, dim(x) <- c(1, -1)) - expect_snapshot(error = TRUE, - dim(x) <- 13 - ) + expect_snapshot(error = TRUE, dim(x) <- 13) }) test_that("dim<- works in a model", { diff --git a/tests/testthat/test_functions.R b/tests/testthat/test_functions.R index 3f29a20b..f2b662cd 100644 --- a/tests/testthat/test_functions.R +++ b/tests/testthat/test_functions.R @@ -1,6 +1,6 @@ set.seed(2020 - 02 - 11) -test_that("log.greta_array has a warning when base argument used",{ +test_that("log.greta_array has a warning when base argument used", { skip_if_not(check_tf_version()) x <- normal(0, 1) @@ -20,7 +20,6 @@ test_that("log.greta_array has a warning when base argument used",{ expect_snapshot_warning( log(exp(x), base = 3) ) - }) test_that("simple functions work as expected", { @@ -92,9 +91,7 @@ test_that("cummax and cummin functions error informatively", { x <- as_data(randn(10)) for (fun in cumulative_funs) { - expect_snapshot(error = TRUE, - fun(x) - ) + expect_snapshot(error = TRUE, fun(x)) } }) @@ -105,16 +102,14 @@ test_that("complex number functions error informatively", { x <- as_data(randn(25, 4)) for (fun in complex_funs) { - expect_snapshot(error = TRUE, - fun(x) - ) + expect_snapshot(error = TRUE, fun(x)) } }) test_that("matrix functions work as expected", { skip_if_not(check_tf_version()) - a <- rWishart(1, 6, diag(5))[, , 1] + a <- rWishart(1, 6, diag(5))[,, 1] b <- randn(5, 25) c <- chol(a) d <- c(1, 1) @@ -146,7 +141,7 @@ test_that("matrix functions work as expected", { test_that("kronecker works with greta and base array arguments", { skip_if_not(check_tf_version()) - a <- rWishart(1, 6, diag(5))[, , 1] + a <- rWishart(1, 6, diag(5))[,, 1] b <- chol(a) a_greta <- as_data(a) @@ -222,8 +217,11 @@ test_that("apply works as expected", { skip_if_not(check_tf_version()) # check apply.greta_array works like R's apply for X - check_apply <- function(X, MARGIN, FUN) { # nolint - check_op(apply, a, + check_apply <- function(X, MARGIN, FUN) { + # nolint + check_op( + apply, + a, other_args = list( MARGIN = MARGIN, FUN = FUN @@ -266,23 +264,13 @@ test_that("cumulative functions error as expected", { a <- as_data(randn(1, 5)) b <- as_data(randn(5, 1, 1)) + expect_snapshot(error = TRUE, cumsum(a)) - expect_snapshot(error = TRUE, - cumsum(a) - ) - - expect_snapshot(error = TRUE, - cumsum(b) - ) - - expect_snapshot(error = TRUE, - cumprod(a) - ) + expect_snapshot(error = TRUE, cumsum(b)) - expect_snapshot(error = TRUE, - cumprod(b) - ) + expect_snapshot(error = TRUE, cumprod(a)) + expect_snapshot(error = TRUE, cumprod(b)) }) test_that("sweep works as expected", { @@ -399,7 +387,6 @@ test_that("solve and sweep and kronecker error as expected", { error = TRUE, kronecker(b, c) ) - }) test_that("colSums etc. error as expected", { @@ -407,22 +394,13 @@ test_that("colSums etc. error as expected", { x <- as_data(randn(3, 4, 5)) - expect_snapshot(error = TRUE, - colSums(x, dims = 3) - ) - - expect_snapshot(error = TRUE, - rowSums(x, dims = 3) - ) + expect_snapshot(error = TRUE, colSums(x, dims = 3)) - expect_snapshot(error = TRUE, - colMeans(x, dims = 3) - ) + expect_snapshot(error = TRUE, rowSums(x, dims = 3)) - expect_snapshot(error = TRUE, - rowMeans(x, dims = 3) - ) + expect_snapshot(error = TRUE, colMeans(x, dims = 3)) + expect_snapshot(error = TRUE, rowMeans(x, dims = 3)) }) test_that("forwardsolve and backsolve error as expected", { @@ -432,22 +410,13 @@ test_that("forwardsolve and backsolve error as expected", { b <- as_data(randn(5, 25)) c <- chol(a) - expect_snapshot(error = TRUE, - forwardsolve(a, b, k = 1) - ) - - expect_snapshot(error = TRUE, - backsolve(a, b, k = 1) - ) + expect_snapshot(error = TRUE, forwardsolve(a, b, k = 1)) - expect_snapshot(error = TRUE, - forwardsolve(a, b, transpose = TRUE) - ) + expect_snapshot(error = TRUE, backsolve(a, b, k = 1)) - expect_snapshot(error = TRUE, - backsolve(a, b, transpose = TRUE) - ) + expect_snapshot(error = TRUE, forwardsolve(a, b, transpose = TRUE)) + expect_snapshot(error = TRUE, backsolve(a, b, transpose = TRUE)) }) test_that("tapply errors as expected", { @@ -458,21 +427,17 @@ test_that("tapply errors as expected", { b <- ones(10, 2) # X must be a column vector - expect_snapshot(error = TRUE, - tapply(b, group, "sum") - ) + expect_snapshot(error = TRUE, tapply(b, group, "sum")) # INDEX can't be a greta array - expect_snapshot(error = TRUE, - tapply(a, as_data(group), "sum") - ) + expect_snapshot(error = TRUE, tapply(a, as_data(group), "sum")) }) test_that("eigen works as expected", { skip_if_not(check_tf_version()) k <- 4 - x <- rWishart(1, k + 1, diag(k))[, , 1] + x <- rWishart(1, k + 1, diag(k))[,, 1] x_ga <- as_data(x) r_out <- eigen(x) @@ -501,7 +466,8 @@ test_that("eigen works as expected", { } greta_vectors <- grab(greta_out$vectors) - difference <- vapply(seq_len(k), + difference <- vapply( + seq_len(k), function(i) { column_difference( r_out$vectors[, i], @@ -518,9 +484,7 @@ test_that("ignored options are errored/warned about", { skip_if_not(check_tf_version()) x <- ones(3, 3) - expect_snapshot(error = TRUE, - round(x, 2) - ) + expect_snapshot(error = TRUE, round(x, 2)) expect_snapshot_warning( chol(x, pivot = TRUE) @@ -546,47 +510,29 @@ test_that("incorrect dimensions are errored about", { x <- ones(3, 3, 3) y <- ones(3, 4) - expect_snapshot(error = TRUE, - t(x) - ) + expect_snapshot(error = TRUE, t(x)) - expect_snapshot(error = TRUE, - aperm(x, 2:1) - ) + expect_snapshot(error = TRUE, aperm(x, 2:1)) - expect_snapshot(error = TRUE, - chol(x) - ) + expect_snapshot(error = TRUE, chol(x)) - expect_snapshot(error = TRUE, - chol(y) - ) + expect_snapshot(error = TRUE, chol(y)) - expect_snapshot(error = TRUE, - chol2symm(x) - ) + expect_snapshot(error = TRUE, chol2symm(x)) - expect_snapshot(error = TRUE, - chol2symm(y) - ) + expect_snapshot(error = TRUE, chol2symm(y)) - expect_snapshot(error = TRUE, - eigen(x) - ) + expect_snapshot(error = TRUE, eigen(x)) - expect_snapshot(error = TRUE, - eigen(y) - ) + expect_snapshot(error = TRUE, eigen(y)) - expect_snapshot(error = TRUE, - rdist(x, y) - ) + expect_snapshot(error = TRUE, rdist(x, y)) }) test_that("chol2symm inverts chol", { skip_if_not(check_tf_version()) - x <- rWishart(1, 10, diag(9))[, , 1] + x <- rWishart(1, 10, diag(9))[,, 1] u <- chol(x) # check the R version diff --git a/tests/testthat/test_future.R b/tests/testthat/test_future.R index 0dde1f16..6be85361 100644 --- a/tests/testthat/test_future.R +++ b/tests/testthat/test_future.R @@ -9,8 +9,7 @@ test_that("check_future_plan() works when only one core available", { # one chain expect_snapshot_output( check_future_plan() - ) - + ) }) test_that("check_future_plan() works", { @@ -22,8 +21,7 @@ test_that("check_future_plan() works", { # one chain expect_snapshot_output( check_future_plan() - ) - + ) }) test_that("mcmc errors for invalid parallel plans", { @@ -41,19 +39,14 @@ test_that("mcmc errors for invalid parallel plans", { ) future::plan(future::multicore) - expect_snapshot(error = TRUE, - check_future_plan() - ) + expect_snapshot(error = TRUE, check_future_plan()) # skip on windows - if (.Platform$OS.type != "windows"){ + if (.Platform$OS.type != "windows") { cl <- parallel::makeCluster(2L, type = "FORK") future::plan(future::cluster, workers = cl) - expect_snapshot(error = TRUE, - check_future_plan() - ) + expect_snapshot(error = TRUE, check_future_plan()) } - }) test_that("parallel reporting works", { @@ -68,9 +61,14 @@ test_that("parallel reporting works", { # should report each sampler's progress with a fraction #out <- get_output(. <- mcmc(m, warmup = 50, n_samples = 50, chains = 2)) - expect_match(get_output(. <- mcmc(m, warmup = 50, n_samples = 50, chains = 2)), "2 samplers in parallel") - expect_match(get_output(. <- mcmc(m, warmup = 50, n_samples = 50, chains = 2)), "50/50") - + expect_match( + get_output(. <- mcmc(m, warmup = 50, n_samples = 50, chains = 2)), + "2 samplers in parallel" + ) + expect_match( + get_output(. <- mcmc(m, warmup = 50, n_samples = 50, chains = 2)), + "50/50" + ) }) @@ -89,16 +87,11 @@ test_that("mcmc errors for invalid parallel plans", { withr::defer(future::plan(op)) future::plan(future::multicore) - expect_snapshot(error = TRUE, - mcmc(m, verbose = FALSE) - ) + expect_snapshot(error = TRUE, mcmc(m, verbose = FALSE)) cl <- parallel::makeForkCluster(2L) future::plan(future::cluster, workers = cl) - expect_snapshot(error = TRUE, - mcmc(m, verbose = FALSE) - ) - + expect_snapshot(error = TRUE, mcmc(m, verbose = FALSE)) }) # this is the test that says: 'Loaded Tensorflow version 1.14.0' @@ -113,24 +106,19 @@ test_that("mcmc works in parallel", { future::plan(future::multisession) # one chain - expect_ok(draws <- mcmc(m, - warmup = 10, n_samples = 10, - chains = 1, - verbose = FALSE - )) + expect_ok( + draws <- mcmc(m, warmup = 10, n_samples = 10, chains = 1, verbose = FALSE) + ) expect_true(inherits(draws, "greta_mcmc_list")) expect_true(coda::niter(draws) == 10) rm(draws) # multiple chains - expect_ok(draws <- mcmc(m, - warmup = 10, n_samples = 10, - chains = 2, - verbose = FALSE - )) + expect_ok( + draws <- mcmc(m, warmup = 10, n_samples = 10, chains = 2, verbose = FALSE) + ) expect_true(inherits(draws, "greta_mcmc_list")) expect_true(coda::niter(draws) == 10) - }) diff --git a/tests/testthat/test_greta_array_class.R b/tests/testthat/test_greta_array_class.R index 2cff45d4..77d37708 100644 --- a/tests/testthat/test_greta_array_class.R +++ b/tests/testthat/test_greta_array_class.R @@ -53,7 +53,6 @@ test_that("print and summary work", { expect_snapshot( n ) - }) test_that("as.matrix works", { @@ -75,7 +74,7 @@ test_that("as.matrix works", { expect_true(inherits(o_mat, "matrix")) }) -test_that("print method works for longer greta arrays",{ +test_that("print method works for longer greta arrays", { skip_if_not(check_tf_version()) ga_data_long <- as_data(matrix(1:20, ncol = 1)) @@ -84,39 +83,38 @@ test_that("print method works for longer greta arrays",{ expect_snapshot( ga_data_long - ) + ) expect_snapshot( ga_stochastic_long - ) + ) expect_snapshot( ga_operation_long - ) + ) expect_snapshot( print(ga_data_long, n = 19) - ) + ) expect_snapshot( print(ga_data_long, n = 20) - ) + ) expect_snapshot( print(ga_data_long, n = 21) - ) + ) expect_snapshot( print(ga_stochastic_long, n = 19) - ) + ) expect_snapshot( print(ga_stochastic_long, n = 20) - ) + ) expect_snapshot( print(ga_stochastic_long, n = 21) - ) + ) expect_snapshot( print(ga_operation_long, n = 19) - ) + ) expect_snapshot( print(ga_operation_long, n = 20) - ) + ) expect_snapshot( print(ga_operation_long, n = 21) - ) - + ) }) diff --git a/tests/testthat/test_greta_deps_spec.R b/tests/testthat/test_greta_deps_spec.R index 804b52e4..677354ab 100644 --- a/tests/testthat/test_greta_deps_spec.R +++ b/tests/testthat/test_greta_deps_spec.R @@ -1,4 +1,4 @@ -test_that("greta python range detection works correctly",{ +test_that("greta python range detection works correctly", { skip_if_not(check_tf_version()) # correct ranges expect_snapshot(check_greta_python_range("3.11")) @@ -38,101 +38,131 @@ test_that("greta_deps_spec fails appropriately", { expect_snapshot(greta_deps_spec()) # some correct ranges expect_snapshot( - greta_deps_spec(tf_version = "2.14.0", - tfp_version = "0.22.1", - python_version = "3.9") + greta_deps_spec( + tf_version = "2.14.0", + tfp_version = "0.22.1", + python_version = "3.9" + ) ) expect_snapshot( - greta_deps_spec(tf_version = "2.12.0", - tfp_version = "0.20.0", - python_version = "3.9") + greta_deps_spec( + tf_version = "2.12.0", + tfp_version = "0.20.0", + python_version = "3.9" + ) ) # TF above range expect_snapshot( error = TRUE, - greta_deps_spec(tf_version = "2.16.1", - tfp_version = "0.11.0", - python_version = "3.8") + greta_deps_spec( + tf_version = "2.16.1", + tfp_version = "0.11.0", + python_version = "3.8" + ) ) # TF below range expect_snapshot( error = TRUE, - greta_deps_spec(tf_version = "1.9.0", - tfp_version = "0.11.0", - python_version = "3.8") + greta_deps_spec( + tf_version = "1.9.0", + tfp_version = "0.11.0", + python_version = "3.8" + ) ) # TFP above range expect_snapshot( error = TRUE, - greta_deps_spec(tf_version = "2.15.0", - tfp_version = "0.24.0", - python_version = "3.10") + greta_deps_spec( + tf_version = "2.15.0", + tfp_version = "0.24.0", + python_version = "3.10" + ) ) # TFP below range expect_snapshot( error = TRUE, - greta_deps_spec(tf_version = "2.15.0", - tfp_version = "0.6.0", - python_version = "3.10") + greta_deps_spec( + tf_version = "2.15.0", + tfp_version = "0.6.0", + python_version = "3.10" + ) ) # Python above range expect_snapshot( error = TRUE, - greta_deps_spec(tf_version = "2.9.1", - tfp_version = "0.23.0", - python_version = "3.13") + greta_deps_spec( + tf_version = "2.9.1", + tfp_version = "0.23.0", + python_version = "3.13" + ) ) # Python below range expect_snapshot( error = TRUE, - greta_deps_spec(tf_version = "2.9.1", - tfp_version = "0.23.0", - python_version = "2.6") + greta_deps_spec( + tf_version = "2.9.1", + tfp_version = "0.23.0", + python_version = "2.6" + ) ) # Only Python is not valid # TODO - suggest changing python version in error message expect_snapshot( error = TRUE, - greta_deps_spec(tf_version = "2.15.0", - tfp_version = "0.23.0", - python_version = "3.8") + greta_deps_spec( + tf_version = "2.15.0", + tfp_version = "0.23.0", + python_version = "3.8" + ) ) # Only TF is not valid expect_snapshot( error = TRUE, - greta_deps_spec(tf_version = "2.15.0", - tfp_version = "0.22.0", - python_version = "3.10") + greta_deps_spec( + tf_version = "2.15.0", + tfp_version = "0.22.0", + python_version = "3.10" + ) ) expect_snapshot( error = TRUE, - greta_deps_spec(tf_version = "2.14.0", - tfp_version = "0.21.0", - python_version = "3.8") + greta_deps_spec( + tf_version = "2.14.0", + tfp_version = "0.21.0", + python_version = "3.8" + ) ) expect_snapshot( error = TRUE, - greta_deps_spec(tf_version = "2.13.0", - tfp_version = "0.20.0", - python_version = "3.8") + greta_deps_spec( + tf_version = "2.13.0", + tfp_version = "0.20.0", + python_version = "3.8" + ) ) # Only TFP is not valid expect_snapshot( error = TRUE, - greta_deps_spec(tf_version = "2.15.0", - tfp_version = "0.17.0", - python_version = "3.8") + greta_deps_spec( + tf_version = "2.15.0", + tfp_version = "0.17.0", + python_version = "3.8" + ) ) expect_snapshot( error = TRUE, - greta_deps_spec(tf_version = "2.17.0", - tfp_version = "0.23.0", - python_version = "3.8") + greta_deps_spec( + tf_version = "2.17.0", + tfp_version = "0.23.0", + python_version = "3.8" + ) ) expect_snapshot( error = TRUE, - greta_deps_spec(tf_version = "2.9.0", - tfp_version = "0.17.0", - python_version = "3.8") + greta_deps_spec( + tf_version = "2.9.0", + tfp_version = "0.17.0", + python_version = "3.8" + ) ) }) diff --git a/tests/testthat/test_greta_mcmc_list_class.R b/tests/testthat/test_greta_mcmc_list_class.R index 97bfee9f..4fd573ad 100644 --- a/tests/testthat/test_greta_mcmc_list_class.R +++ b/tests/testthat/test_greta_mcmc_list_class.R @@ -72,7 +72,8 @@ test_that("windowing does not have spooky effects", { n_samples <- as.integer(chains * samples) x <- normal(0, 1) m <- model(x) - draws <- mcmc(m, + draws <- mcmc( + m, warmup = 100, n_samples = samples, chains = chains, @@ -103,7 +104,7 @@ test_that("greta_mcmc_list print method works", { warmup <- 10 z <- normal(0, 1) m <- model(z) - tensorflow::set_random_seed(2024-07-29-1217) + tensorflow::set_random_seed(2024 - 07 - 29 - 1217) draws <- mcmc(m, warmup = warmup, n_samples = samples, verbose = FALSE) expect_snapshot( draws @@ -116,7 +117,7 @@ test_that("greta_mcmc_list print method works with larger sample size", { warmup <- 20 z <- normal(0, 1) m <- model(z) - tensorflow::set_random_seed(2024-07-30-1233) + tensorflow::set_random_seed(2024 - 07 - 30 - 1233) draws <- mcmc(m, warmup = warmup, n_samples = samples, verbose = FALSE) expect_snapshot( draws @@ -138,7 +139,7 @@ test_that("greta_mcmc_list print method works with smaller sample size", { warmup <- 2 z <- normal(0, 1) m <- model(z) - tensorflow::set_random_seed(2024-07-30-34) + tensorflow::set_random_seed(2024 - 07 - 30 - 34) draws <- mcmc(m, warmup = warmup, n_samples = samples, verbose = FALSE) expect_snapshot( draws diff --git a/tests/testthat/test_iid_samples.R b/tests/testthat/test_iid_samples.R index 8c2d547a..f66380f0 100644 --- a/tests/testthat/test_iid_samples.R +++ b/tests/testthat/test_iid_samples.R @@ -3,107 +3,91 @@ set.seed(2020 - 02 - 11) test_that("univariate samples are correct", { skip_if_not(check_tf_version()) - compare_iid_samples(uniform, - runif, - parameters = list(min = -2, max = 3) - ) + compare_iid_samples(uniform, runif, parameters = list(min = -2, max = 3)) - compare_iid_samples(normal, - rnorm, - parameters = list(mean = -2, sd = 3) - ) + compare_iid_samples(normal, rnorm, parameters = list(mean = -2, sd = 3)) - compare_iid_samples(lognormal, - rlnorm, - parameters = list(mean = -2, sd = 3) - ) + compare_iid_samples(lognormal, rlnorm, parameters = list(mean = -2, sd = 3)) - compare_iid_samples(bernoulli, + compare_iid_samples( + bernoulli, extraDistr::rbern, parameters = list(prob = 0.3) ) - compare_iid_samples(binomial, + compare_iid_samples( + binomial, rbinom, parameters = list(size = 5, prob = 0.3), nsim = 1000 ) - compare_iid_samples(beta_binomial, + compare_iid_samples( + beta_binomial, extraDistr::rbbinom, parameters = list(size = 12, alpha = 4, beta = 2) ) - compare_iid_samples(negative_binomial, + compare_iid_samples( + negative_binomial, rnbinom, parameters = list(size = 12, prob = 0.3), nsim = 1000 ) - compare_iid_samples(poisson, - rpois, - parameters = list(lambda = 3.14) - ) + compare_iid_samples(poisson, rpois, parameters = list(lambda = 3.14)) - compare_iid_samples(gamma, - rgamma, - parameters = list(shape = 3, rate = 1.2) - ) + compare_iid_samples(gamma, rgamma, parameters = list(shape = 3, rate = 1.2)) - compare_iid_samples(inverse_gamma, + compare_iid_samples( + inverse_gamma, extraDistr::rinvgamma, parameters = list(alpha = 3, beta = 1.2) ) - compare_iid_samples(weibull, + compare_iid_samples( + weibull, rweibull, parameters = list(shape = 1.2, scale = 3.2) ) - compare_iid_samples(exponential, - rexp, - parameters = list(rate = 0.54) - ) + compare_iid_samples(exponential, rexp, parameters = list(rate = 0.54)) - compare_iid_samples(pareto, + compare_iid_samples( + pareto, extraDistr::rpareto, parameters = list(a = 1, b = 2) ) - compare_iid_samples(student, + compare_iid_samples( + student, extraDistr::rlst, parameters = list(df = 3, mu = -2, sigma = 3) ) - compare_iid_samples(laplace, + compare_iid_samples( + laplace, extraDistr::rlaplace, parameters = list(mu = -2, sigma = 1.2) ) - compare_iid_samples(beta, - rbeta, - parameters = list(shape1 = 3, shape2 = 2) - ) + compare_iid_samples(beta, rbeta, parameters = list(shape1 = 3, shape2 = 2)) - compare_iid_samples(cauchy, + compare_iid_samples( + cauchy, rcauchy, parameters = list(location = -1, scale = 1.2) ) - compare_iid_samples(chi_squared, - rchisq, - parameters = list(df = 4) - ) + compare_iid_samples(chi_squared, rchisq, parameters = list(df = 4)) - compare_iid_samples(logistic, + compare_iid_samples( + logistic, rlogis, parameters = list(location = -2, scale = 1.3) ) - compare_iid_samples(f, - rf, - parameters = list(df1 = 4, df2 = 1) - ) + compare_iid_samples(f, rf, parameters = list(df1 = 4, df2 = 1)) }) test_that("truncated univariate samples are correct", { @@ -112,7 +96,8 @@ test_that("truncated univariate samples are correct", { # an originally unconstrained distribution # positive - compare_iid_samples(normal, + compare_iid_samples( + normal, rtnorm, parameters = list( mean = -2, @@ -122,7 +107,8 @@ test_that("truncated univariate samples are correct", { ) # negative - compare_iid_samples(normal, + compare_iid_samples( + normal, rtnorm, parameters = list( mean = -2, @@ -132,7 +118,8 @@ test_that("truncated univariate samples are correct", { ) # both - compare_iid_samples(normal, + compare_iid_samples( + normal, rtnorm, parameters = list( mean = -2, @@ -143,7 +130,8 @@ test_that("truncated univariate samples are correct", { # originally constrained distribution - compare_iid_samples(lognormal, + compare_iid_samples( + lognormal, rtlnorm, parameters = list( mean = -2, @@ -152,7 +140,8 @@ test_that("truncated univariate samples are correct", { ) ) - compare_iid_samples(weibull, + compare_iid_samples( + weibull, rtweibull, parameters = list( shape = 1.2, @@ -169,37 +158,36 @@ test_that("multivariate samples are correct", { prob <- t(runif(4)) prob <- prob / sum(prob) - compare_iid_samples(multivariate_normal, + compare_iid_samples( + multivariate_normal, rmvnorm, parameters = list(mean = t(rnorm(4)), Sigma = sigma) ) - compare_iid_samples(multinomial, + compare_iid_samples( + multinomial, rmulti, parameters = list(size = 12, prob = prob) ) - compare_iid_samples(categorical, - rcat, - parameters = list(prob = prob) - ) + compare_iid_samples(categorical, rcat, parameters = list(prob = prob)) - compare_iid_samples(dirichlet, + compare_iid_samples( + dirichlet, extraDistr::rdirichlet, parameters = list(alpha = t(runif(4))) ) - compare_iid_samples(dirichlet_multinomial, + compare_iid_samples( + dirichlet_multinomial, extraDistr::rdirmnom, parameters = list(size = 3, alpha = t(runif(4))) ) - compare_iid_samples(wishart, - rwish, - parameters = list(df = 7, Sigma = sigma) - ) + compare_iid_samples(wishart, rwish, parameters = list(df = 7, Sigma = sigma)) - compare_iid_samples(lkj_correlation, + compare_iid_samples( + lkj_correlation, rlkjcorr, parameters = list(eta = 6.5, dimension = 4) ) @@ -213,11 +201,7 @@ test_that("joint samples are correct", { list(mean = 0, sd = 1), list(mean = 0, sd = 2) ) - compare_iid_samples(joint_normals, - rjnorm, - parameters = params - ) - + compare_iid_samples(joint_normals, rjnorm, parameters = params) # joint of truncated normal distributions params <- list( @@ -225,10 +209,7 @@ test_that("joint samples are correct", { list(mean = 0, sd = 2, truncation = c(-Inf, 2)), list(mean = 0, sd = 3, truncation = c(-2, 1)) ) - compare_iid_samples(joint_normals, - rjtnorm, - parameters = params - ) + compare_iid_samples(joint_normals, rjtnorm, parameters = params) }) @@ -246,10 +227,7 @@ test_that("mixture samples are correct", { weights = weights ) - compare_iid_samples(mixture_normals, - rmixnorm, - parameters = params - ) + compare_iid_samples(mixture_normals, rmixnorm, parameters = params) # mixture of truncated normal distributions params <- list( @@ -259,10 +237,7 @@ test_that("mixture samples are correct", { weights = weights ) - compare_iid_samples(mixture_normals, - rmixtnorm, - parameters = params - ) + compare_iid_samples(mixture_normals, rmixtnorm, parameters = params) # mixture of multivariate normal distributions sigma <- diag(4) * 0.1 @@ -273,7 +248,8 @@ test_that("mixture samples are correct", { weights = weights ) - compare_iid_samples(mixture_multivariate_normals, + compare_iid_samples( + mixture_multivariate_normals, rmixmvnorm, parameters = params ) @@ -284,16 +260,20 @@ test_that("distributions without RNG error nicely", { skip_if_not(check_tf_version()) # univariate - expect_snapshot(error = TRUE, - compare_iid_samples(hypergeometric, - rhyper, - parameters = list(m = 11, n = 8, k = 5) + expect_snapshot( + error = TRUE, + compare_iid_samples( + hypergeometric, + rhyper, + parameters = list(m = 11, n = 8, k = 5) ) ) # truncated RNG not implemented - expect_snapshot(error = TRUE, - compare_iid_samples(f, + expect_snapshot( + error = TRUE, + compare_iid_samples( + f, rtf, parameters = list( df1 = 4, diff --git a/tests/testthat/test_inference.R b/tests/testthat/test_inference.R index b70f2f15..4f883a04 100644 --- a/tests/testthat/test_inference.R +++ b/tests/testthat/test_inference.R @@ -10,34 +10,38 @@ test_that("bad mcmc proposals are rejected", { m <- model(z, precision = "single") # # catch badness in the progress bar - out <- get_output( - mcmc(m, n_samples = 10, warmup = 0, pb_update = 10) - ) - expect_match(out, "100% bad") - - expect_snapshot(error = TRUE, - draws <- mcmc(m, - chains = 1, - n_samples = 2, - warmup = 0, - verbose = FALSE, - initial_values = initials(z = 1e120) - ) + out <- get_output( + mcmc(m, n_samples = 10, warmup = 0, pb_update = 10) + ) + expect_match(out, "100% bad") + + expect_snapshot( + error = TRUE, + draws <- mcmc( + m, + chains = 1, + n_samples = 2, + warmup = 0, + verbose = FALSE, + initial_values = initials(z = 1e120) ) + ) # really bad proposals x <- rnorm(100000, 1e120, 1) z <- normal(-1e120, 1e-120) distribution(x) <- normal(z, 1e-120) m <- model(z, precision = "single") - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mcmc(m, chains = 1, n_samples = 1, warmup = 0, verbose = FALSE) ) # proposals that are fine, but rejected anyway z <- normal(0, 1) m <- model(z, precision = "single") - expect_ok(mcmc(m, + expect_ok(mcmc( + m, hmc( epsilon = 100, Lmin = 1, @@ -70,10 +74,10 @@ test_that("mcmc works with cpu and gpu options", { m <- model(z) quietly( expect_ok(mcmc(m, n_samples = 5, warmup = 5, compute_options = cpu_only())) - ) + ) quietly( expect_ok(mcmc(m, n_samples = 5, warmup = 5, compute_options = gpu_only())) - ) + ) }) test_that("mcmc works with multiple chains", { @@ -85,13 +89,21 @@ test_that("mcmc works with multiple chains", { m <- model(z) # multiple chains, automatic initial values - quietly(expect_ok(mcmc(m, warmup = 10, n_samples = 10, chains = 2, - verbose = FALSE))) + quietly(expect_ok(mcmc( + m, + warmup = 10, + n_samples = 10, + chains = 2, + verbose = FALSE + ))) # multiple chains, user-specified initial values inits <- list(initials(z = 1), initials(z = 2)) - quietly(expect_ok(mcmc(m, - warmup = 10, n_samples = 10, chains = 2, + quietly(expect_ok(mcmc( + m, + warmup = 10, + n_samples = 10, + chains = 2, initial_values = inits, verbose = FALSE ))) @@ -101,7 +113,7 @@ test_that("mcmc handles initial values nicely", { skip_if_not(check_tf_version()) # preserve R version - current_r_version <- paste0(R.version$major,".", R.version$minor) + current_r_version <- paste0(R.version$major, ".", R.version$minor) required_r_version <- "3.6.0" old_rng_r <- compareVersion(required_r_version, current_r_version) <= 0 @@ -119,30 +131,43 @@ test_that("mcmc handles initial values nicely", { # too many sets of initial values inits <- replicate(3, initials(z = rnorm(1)), simplify = FALSE) - expect_snapshot(error = TRUE, - draws <- mcmc(m, - warmup = 10, n_samples = 10, verbose = FALSE, - chains = 2, initial_values = inits + expect_snapshot( + error = TRUE, + draws <- mcmc( + m, + warmup = 10, + n_samples = 10, + verbose = FALSE, + chains = 2, + initial_values = inits ) ) # initial values have the wrong length inits <- replicate(2, initials(z = rnorm(2)), simplify = FALSE) - expect_snapshot(error = TRUE, - draws <- mcmc(m, - warmup = 10, n_samples = 10, verbose = FALSE, - chains = 2, initial_values = inits + expect_snapshot( + error = TRUE, + draws <- mcmc( + m, + warmup = 10, + n_samples = 10, + verbose = FALSE, + chains = 2, + initial_values = inits ) ) inits <- initials(z = rnorm(1)) quietly( expect_snapshot( - draws <- mcmc(m, - warmup = 10, n_samples = 10, - chains = 2, initial_values = inits, - verbose = FALSE - ) + draws <- mcmc( + m, + warmup = 10, + n_samples = 10, + chains = 2, + initial_values = inits, + verbose = FALSE + ) ) ) }) @@ -158,7 +183,6 @@ test_that("progress bar gives a range of messages", { # 10/10 should be 100% expect_snapshot(draws <- mock_mcmc(10)) - }) test_that("extra_samples works", { @@ -215,12 +239,16 @@ test_that("stashed_samples works", { # mock up a stash stash <- greta:::greta_stash - samplers_stash <- replicate(2, list( - traced_free_state = list(as.matrix(rnorm(17))), - traced_values = list(as.matrix(rnorm(17))), - thin = 1, - model = m - ), simplify = FALSE) + samplers_stash <- replicate( + 2, + list( + traced_free_state = list(as.matrix(rnorm(17))), + traced_values = list(as.matrix(rnorm(17))), + thin = 1, + model = m + ), + simplify = FALSE + ) assign("samplers", samplers_stash, envir = stash) # should convert to a greta_mcmc_list @@ -258,20 +286,22 @@ test_that("model errors nicely", { # model should give a nice error if passed something other than a greta array a <- 1 b <- normal(0, a) - expect_snapshot(error = TRUE, - model(a, b) - ) + expect_snapshot(error = TRUE, model(a, b)) }) test_that("mcmc supports rwmh sampler with normal proposals", { skip_if_not(check_tf_version()) x <- normal(0, 1) m <- model(x) - expect_ok(draws <- mcmc(m, - sampler = rwmh("normal"), - n_samples = 100, warmup = 100, - verbose = FALSE - )) + expect_ok( + draws <- mcmc( + m, + sampler = rwmh("normal"), + n_samples = 100, + warmup = 100, + verbose = FALSE + ) + ) }) test_that("mcmc supports rwmh sampler with uniform proposals", { @@ -279,11 +309,15 @@ test_that("mcmc supports rwmh sampler with uniform proposals", { set.seed(5) x <- uniform(0, 1) m <- model(x) - expect_ok(draws <- mcmc(m, - sampler = rwmh("uniform"), - n_samples = 100, warmup = 100, - verbose = FALSE - )) + expect_ok( + draws <- mcmc( + m, + sampler = rwmh("uniform"), + n_samples = 100, + warmup = 100, + verbose = FALSE + ) + ) }) test_that("mcmc supports slice sampler with single precision models", { @@ -291,30 +325,29 @@ test_that("mcmc supports slice sampler with single precision models", { set.seed(5) x <- uniform(0, 1) m <- model(x, precision = "single") - expect_ok(draws <- mcmc(m, - sampler = slice(), - n_samples = 100, warmup = 100, - verbose = FALSE - )) + expect_ok( + draws <- mcmc( + m, + sampler = slice(), + n_samples = 100, + warmup = 100, + verbose = FALSE + ) + ) }) test_that("initials works", { skip_if_not(check_tf_version()) # errors on bad objects - expect_snapshot(error = TRUE, - initials(a = FALSE) - ) + expect_snapshot(error = TRUE, initials(a = FALSE)) - expect_snapshot(error = TRUE, - initials(FALSE) - ) + expect_snapshot(error = TRUE, initials(FALSE)) # prints nicely expect_snapshot( initials(a = 3) ) - }) test_that("prep_initials errors informatively", { @@ -329,39 +362,47 @@ test_that("prep_initials errors informatively", { m <- model(z) # bad objects: - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mcmc(m, initial_values = FALSE, verbose = FALSE) ) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mcmc(m, initial_values = list(FALSE), verbose = FALSE) ) # an unrelated greta array g <- normal(0, 1) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mcmc(m, chains = 1, initial_values = initials(g = 1), verbose = FALSE) ) # non-variable greta arrays - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mcmc(m, chains = 1, initial_values = initials(f = 1), verbose = FALSE) ) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mcmc(m, chains = 1, initial_values = initials(z = 1), verbose = FALSE) ) # out of bounds errors - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mcmc(m, chains = 1, initial_values = initials(b = -1), verbose = FALSE) ) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mcmc(m, chains = 1, initial_values = initials(d = -1), verbose = FALSE) ) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mcmc(m, chains = 1, initial_values = initials(e = 2), verbose = FALSE) ) }) @@ -392,9 +433,15 @@ test_that("pb_update > thin to avoid bursts with no saved iterations", { set.seed(5) x <- uniform(0, 1) m <- model(x) - expect_ok(draws <- mcmc(m, - n_samples = 100, warmup = 100, - thin = 3, pb_update = 2, verbose = FALSE - )) + expect_ok( + draws <- mcmc( + m, + n_samples = 100, + warmup = 100, + thin = 3, + pb_update = 2, + verbose = FALSE + ) + ) expect_identical(thin(draws), 3) }) diff --git a/tests/testthat/test_install_greta_deps.R b/tests/testthat/test_install_greta_deps.R index 2a239f14..1b1e59a6 100644 --- a/tests/testthat/test_install_greta_deps.R +++ b/tests/testthat/test_install_greta_deps.R @@ -1,8 +1,6 @@ test_that("install_greta_deps errors appropriately", { skip_if_not(check_tf_version()) - expect_snapshot(error = TRUE, - install_greta_deps(timeout = 0.001) - ) + expect_snapshot(error = TRUE, install_greta_deps(timeout = 0.001)) }) # test_that("reinstall_greta_deps errors appropriately", { diff --git a/tests/testthat/test_joint.R b/tests/testthat/test_joint.R index c2750d98..37fc598f 100644 --- a/tests/testthat/test_joint.R +++ b/tests/testthat/test_joint.R @@ -51,7 +51,8 @@ test_that("fixed continuous joint distributions can be sampled from", { obs <- matrix(rnorm(3, 0, 2), 100, 3) mu <- variable(dim = 3) - distribution(obs) <- joint(normal(mu[1], 1), + distribution(obs) <- joint( + normal(mu[1], 1), normal(mu[2], 2), normal(mu[3], 3), dim = 100 @@ -78,7 +79,8 @@ test_that("fixed discrete joint distributions can be sampled from", { test_that("joint of fixed and continuous distributions errors", { skip_if_not(check_tf_version()) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, joint( bernoulli(0.5), normal(0, 1) @@ -89,19 +91,16 @@ test_that("joint of fixed and continuous distributions errors", { test_that("joint with insufficient distributions errors", { skip_if_not(check_tf_version()) - expect_snapshot(error = TRUE, - joint(normal(0, 2)) - ) + expect_snapshot(error = TRUE, joint(normal(0, 2))) - expect_snapshot(error = TRUE, - joint() - ) + expect_snapshot(error = TRUE, joint()) }) test_that("joint with non-scalar distributions errors", { skip_if_not(check_tf_version()) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, joint( normal(0, 2, dim = 3), normal(0, 1, dim = 3) @@ -113,7 +112,8 @@ test_that("joint of normals has correct density", { skip_if_not(check_tf_version()) joint_greta <- function(means, sds, dim) { - joint(normal(means[1], sds[1]), + joint( + normal(means[1], sds[1]), normal(means[2], sds[2]), normal(means[3], sds[3]), dim = dim @@ -121,10 +121,7 @@ test_that("joint of normals has correct density", { } joint_r <- function(x, means, sds) { - densities <- matrix(NA, - nrow = length(x), - ncol = length(means) - ) + densities <- matrix(NA, nrow = length(x), ncol = length(means)) for (i in seq_along(means)) { densities[, i] <- dnorm(x[, i], means[i], sds[i], log = TRUE) } @@ -137,7 +134,8 @@ test_that("joint of normals has correct density", { sds = c(3, 0.5, 1) ) - compare_distribution(joint_greta, + compare_distribution( + joint_greta, joint_r, parameters = params, x = matrix(rnorm(300, -2, 3), 100, 3) @@ -148,7 +146,8 @@ test_that("joint of truncated normals has correct density", { skip_if_not(check_tf_version()) joint_greta <- function(means, sds, lower, upper, dim) { - joint(normal(means[1], sds[1], truncation = c(lower[1], upper[1])), + joint( + normal(means[1], sds[1], truncation = c(lower[1], upper[1])), normal(means[2], sds[2], truncation = c(lower[2], upper[2])), normal(means[3], sds[3], truncation = c(lower[3], upper[3])), dim = dim @@ -156,12 +155,10 @@ test_that("joint of truncated normals has correct density", { } joint_r <- function(x, means, sds, lower, upper) { - densities <- matrix(NA, - nrow = length(x), - ncol = length(means) - ) + densities <- matrix(NA, nrow = length(x), ncol = length(means)) for (i in seq_along(means)) { - densities[, i] <- truncdist::dtrunc(x[, i], + densities[, i] <- truncdist::dtrunc( + x[, i], "norm", a = lower[i], b = upper[i], @@ -186,18 +183,15 @@ test_that("joint of truncated normals has correct density", { } x <- mapply(fun, params$means, params$sds, params$lower, params$upper) - compare_distribution(joint_greta, - joint_r, - parameters = params, - x = x - ) + compare_distribution(joint_greta, joint_r, parameters = params, x = x) }) test_that("joint of uniforms has correct density", { skip_if_not(check_tf_version()) joint_greta <- function(lower, upper, dim) { - joint(uniform(lower[1], upper[1]), + joint( + uniform(lower[1], upper[1]), uniform(lower[2], upper[2]), uniform(lower[3], upper[3]), dim = dim @@ -205,10 +199,7 @@ test_that("joint of uniforms has correct density", { } joint_r <- function(x, lower, upper) { - densities <- matrix(NA, - nrow = length(x), - ncol = length(lower) - ) + densities <- matrix(NA, nrow = length(x), ncol = length(lower)) for (i in seq_along(lower)) { densities[, i] <- dunif(x[, i], lower[i], upper[i], log = TRUE) } @@ -226,29 +217,18 @@ test_that("joint of uniforms has correct density", { } x <- mapply(fun, params$lower, params$upper) - compare_distribution(joint_greta, - joint_r, - parameters = params, - x = x - ) + compare_distribution(joint_greta, joint_r, parameters = params, x = x) }) test_that("joint of Poissons has correct density", { skip_if_not(check_tf_version()) joint_greta <- function(rates, dim) { - joint(poisson(rates[1]), - poisson(rates[2]), - poisson(rates[3]), - dim = dim - ) + joint(poisson(rates[1]), poisson(rates[2]), poisson(rates[3]), dim = dim) } joint_r <- function(x, rates) { - densities <- matrix(NA, - nrow = length(x), - ncol = length(rates) - ) + densities <- matrix(NA, nrow = length(x), ncol = length(rates)) for (i in seq_along(rates)) { densities[, i] <- dpois(x[, i], rates[i], log = TRUE) } @@ -257,7 +237,8 @@ test_that("joint of Poissons has correct density", { params <- list(rates = c(0.1, 2, 5)) - compare_distribution(joint_greta, + compare_distribution( + joint_greta, joint_r, parameters = params, x = matrix(rpois(300, 3), 100, 3) diff --git a/tests/testthat/test_misc.R b/tests/testthat/test_misc.R index 047966e9..42279f74 100644 --- a/tests/testthat/test_misc.R +++ b/tests/testthat/test_misc.R @@ -5,9 +5,7 @@ test_that("check_tf_version works", { true_version <- tf$`__version__` tf$`__version__` <- "0.9.0" # nolint - expect_snapshot(error = TRUE, - check_tf_version("error") - ) + expect_snapshot(error = TRUE, check_tf_version("error")) expect_snapshot_warning( check_tf_version("warn") ) @@ -65,58 +63,39 @@ test_that("greta_model objects print", { test_that("define and mcmc error informatively", { skip_if_not(check_tf_version()) - x <- as_data(randn(10)) # no model with non-probability density greta arrays - expect_snapshot(error = TRUE, - model(variable()) - ) + expect_snapshot(error = TRUE, model(variable())) - expect_snapshot(error = TRUE, - model(x) - ) + expect_snapshot(error = TRUE, model(x)) - expect_snapshot(error = TRUE, - model() - ) + expect_snapshot(error = TRUE, model()) # can't define a model for an unfixed discrete variable - expect_snapshot(error = TRUE, - model(bernoulli(0.5)) - ) + expect_snapshot(error = TRUE, model(bernoulli(0.5))) # no parameters here, so define or dag should error distribution(x) <- normal(0, 1) - expect_snapshot(error = TRUE, - model(x) - ) + expect_snapshot(error = TRUE, model(x)) # a bad number of cores a <- normal(0, 1) m <- model(a) expect_warning( - mcmc(m, - warmup = 1, - n_samples = 1, - n_cores = 1000000L, - verbose = FALSE - ), + mcmc(m, warmup = 1, n_samples = 1, n_cores = 1000000L, verbose = FALSE), "cores were requested, but only" ) # can't draw samples of a data greta array z <- normal(x, 1) m <- model(x, z) - expect_snapshot(error = TRUE, - draws <- mcmc(m, verbose = FALSE) - ) + expect_snapshot(error = TRUE, draws <- mcmc(m, verbose = FALSE)) }) test_that("check_dims errors informatively", { skip_if_not(check_tf_version()) - a <- ones(3, 3) b <- ones(1) c <- ones(2, 2) @@ -133,9 +112,7 @@ test_that("check_dims errors informatively", { expect_identical(check_dims(b, b), dim(b)) # with two differently shaped arrays it shouldn't - expect_snapshot(error = TRUE, - check_dims(a, c) - ) + expect_snapshot(error = TRUE, check_dims(a, c)) # with two scalars and a target dimension, just return the target dimension expect_identical(check_dims(b, b, target_dim = dim1), dim1) @@ -144,7 +121,6 @@ test_that("check_dims errors informatively", { test_that("disjoint graphs are checked", { skip_if_not(check_tf_version()) - # if the target nodes aren't related, they sould be checked separately a <- uniform(0, 1) @@ -153,23 +129,17 @@ test_that("disjoint graphs are checked", { # c is unrelated and has no density c <- variable() - expect_snapshot(error = TRUE, - m <- model(a, b, c) - ) + expect_snapshot(error = TRUE, m <- model(a, b, c)) # d is unrelated and known d <- as_data(randn(3)) distribution(d) <- normal(0, 1) - expect_snapshot(error = TRUE, - m <- model(a, b, d) - ) - + expect_snapshot(error = TRUE, m <- model(a, b, d)) }) test_that("plotting models doesn't error", { skip_if_not(check_tf_version()) - a <- uniform(0, 1) m <- model(a) @@ -180,7 +150,6 @@ test_that("plotting models doesn't error", { test_that("structures work correctly", { skip_if_not(check_tf_version()) - a <- ones(2, 2) b <- zeros(2) c <- greta_array(3, dim = c(2, 2, 2)) @@ -193,7 +162,6 @@ test_that("structures work correctly", { test_that("cleanly() handles TF errors nicely", { skip_if_not(check_tf_version()) - inversion_stop <- function() { stop("this non-invertible thing is not invertible") } @@ -208,10 +176,7 @@ test_that("cleanly() handles TF errors nicely", { expect_s3_class(cleanly(inversion_stop()), "error") expect_s3_class(cleanly(cholesky_stop()), "error") - expect_snapshot(error = TRUE, - cleanly(other_stop()) - ) - + expect_snapshot(error = TRUE, cleanly(other_stop())) }) test_that("double precision works for all jacobians", { @@ -237,7 +202,8 @@ test_that("double precision works for all jacobians", { }) test_that("module works", { - mod <- module(mean, + mod <- module( + mean, functions = module( sum, exp, diff --git a/tests/testthat/test_mixture.R b/tests/testthat/test_mixture.R index 5dbdbcab..8d8b5b25 100644 --- a/tests/testthat/test_mixture.R +++ b/tests/testthat/test_mixture.R @@ -2,11 +2,7 @@ test_that("continuous mixture variables can be sampled from", { skip_if_not(check_tf_version()) weights <- uniform(0, 1, 3) - x <- mixture(normal(0, 1), - normal(0, 2), - normal(0, 3), - weights = weights - ) + x <- mixture(normal(0, 1), normal(0, 2), normal(0, 3), weights = weights) sample_distribution(x) }) @@ -16,7 +12,8 @@ test_that("fixed continuous mixture distributions can be sampled from", { weights <- uniform(0, 1, 3) obs <- rnorm(100, 0, 2) - distribution(obs) <- mixture(normal(0, 1), + distribution(obs) <- mixture( + normal(0, 1), normal(0, 2), normal(0, 3), weights = weights @@ -30,7 +27,8 @@ test_that("fixed discrete mixture distributions can be sampled from", { weights <- uniform(0, 1, 3) obs <- rbinom(100, 1, 0.5) - distribution(obs) <- mixture(bernoulli(0.1), + distribution(obs) <- mixture( + bernoulli(0.1), bernoulli(0.5), bernoulli(0.9), weights = weights @@ -43,7 +41,8 @@ test_that("mixtures of fixed and continuous distributions errors", { skip_if_not(check_tf_version()) weights <- uniform(0, 1, dim = 2) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mixture( bernoulli(0.5), normal(0, 1), @@ -56,7 +55,8 @@ test_that("mixtures of multivariate and univariate errors", { skip_if_not(check_tf_version()) weights <- uniform(0, 1, dim = 2) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mixture( multivariate_normal(zeros(1, 3), diag(3)), normal(0, 1, dim = c(1, 3)), @@ -71,7 +71,8 @@ test_that("mixtures of supports errors", { weights <- c(0.5, 0.5) # due to truncation - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mixture( normal(0, 1, truncation = c(0, Inf)), normal(0, 1), @@ -80,7 +81,8 @@ test_that("mixtures of supports errors", { ) # due to bounds - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mixture( lognormal(0, 1), normal(0, 1), @@ -93,7 +95,8 @@ test_that("incorrectly-shaped weights errors", { skip_if_not(check_tf_version()) weights <- uniform(0, 1, dim = c(1, 2)) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mixture( normal(0, 1), normal(0, 2), @@ -107,24 +110,23 @@ test_that("mixtures with insufficient distributions errors", { weights <- uniform(0, 1) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, mixture( normal(0, 2), weights = weights ) ) - expect_snapshot(error = TRUE, - mixture(weights = weights) - ) - + expect_snapshot(error = TRUE, mixture(weights = weights)) }) test_that("mixture of normals has correct density", { skip_if_not(check_tf_version()) mix_greta <- function(means, sds, weights, dim) { - mixture(normal(means[1], sds[1], dim), + mixture( + normal(means[1], sds[1], dim), normal(means[2], sds[2], dim), normal(means[3], sds[3], dim), weights = weights @@ -132,10 +134,7 @@ test_that("mixture of normals has correct density", { } mix_r <- function(x, means, sds, weights) { - densities <- matrix(NA, - nrow = length(x), - ncol = length(means) - ) + densities <- matrix(NA, nrow = length(x), ncol = length(means)) for (i in seq_along(means)) { densities[, i] <- dnorm(x, means[i], sds[i]) } @@ -150,7 +149,8 @@ test_that("mixture of normals has correct density", { weights = c(0.3, 0.6, 0.1) ) - compare_distribution(mix_greta, + compare_distribution( + mix_greta, mix_r, parameters = params, x = rnorm(100, -2, 3) @@ -161,7 +161,8 @@ test_that("mixture of truncated normals has correct density", { skip_if_not(check_tf_version()) mix_greta <- function(means, sds, weights, dim) { - mixture(normal(means[1], sds[1], dim, truncation = c(0, Inf)), + mixture( + normal(means[1], sds[1], dim, truncation = c(0, Inf)), normal(means[2], sds[2], dim, truncation = c(0, Inf)), normal(means[3], sds[3], dim, truncation = c(0, Inf)), weights = weights @@ -169,10 +170,7 @@ test_that("mixture of truncated normals has correct density", { } mix_r <- function(x, means, sds, weights) { - densities <- matrix(NA, - nrow = length(x), - ncol = length(means) - ) + densities <- matrix(NA, nrow = length(x), ncol = length(means)) for (i in seq_along(means)) { densities[, i] <- truncdist::dtrunc( @@ -195,7 +193,8 @@ test_that("mixture of truncated normals has correct density", { weights = c(0.3, 0.6, 0.1) ) - compare_distribution(mix_greta, + compare_distribution( + mix_greta, mix_r, parameters = params, x = abs(rnorm(100, -2, 3)) @@ -206,7 +205,8 @@ test_that("mixture of Poissons has correct density", { skip_if_not(check_tf_version()) mix_greta <- function(rates, weights, dim) { - mixture(poisson(rates[1], dim), + mixture( + poisson(rates[1], dim), poisson(rates[2], dim), poisson(rates[3], dim), weights = weights @@ -214,10 +214,7 @@ test_that("mixture of Poissons has correct density", { } mix_r <- function(x, rates, weights) { - densities <- matrix(NA, - nrow = length(x), - ncol = length(rates) - ) + densities <- matrix(NA, nrow = length(x), ncol = length(rates)) for (i in seq_along(rates)) { densities[, i] <- dpois(x, rates[i]) } @@ -231,18 +228,15 @@ test_that("mixture of Poissons has correct density", { weights = c(0.3, 0.6, 0.1) ) - compare_distribution(mix_greta, - mix_r, - parameters = params, - x = rpois(100, 3) - ) + compare_distribution(mix_greta, mix_r, parameters = params, x = rpois(100, 3)) }) test_that("mixture of normals with varying weights has correct density", { skip_if_not(check_tf_version()) mix_greta <- function(means, sds, weights, dim) { - mixture(normal(means[1], sds[1], dim), + mixture( + normal(means[1], sds[1], dim), normal(means[2], sds[2], dim), normal(means[3], sds[3], dim), weights = weights @@ -276,10 +270,5 @@ test_that("mixture of normals with varying weights has correct density", { weights = weights ) - compare_distribution(mix_greta, - mix_r, - parameters = params, - x = x, - dim = dim - ) + compare_distribution(mix_greta, mix_r, parameters = params, x = x, dim = dim) }) diff --git a/tests/testthat/test_operators.R b/tests/testthat/test_operators.R index e32c9a89..db982be9 100644 --- a/tests/testthat/test_operators.R +++ b/tests/testthat/test_operators.R @@ -65,11 +65,17 @@ test_that("random strings of operators work as expected", { b <- randn(25, 4) # generate a 5-deep random function of operations - fun <- gen_opfun(5, + fun <- gen_opfun( + 5, ops = c( - "+", "-", "*", - "/", "&", - "|", "<", ">" + "+", + "-", + "*", + "/", + "&", + "|", + "<", + ">" ) ) @@ -94,18 +100,13 @@ test_that("random strings of operators work as expected", { test_that("%*% errors informatively", { skip_if_not(check_tf_version()) - a <- ones(3, 4) b <- ones(1, 4) c <- ones(2, 2, 2) - expect_snapshot(error = TRUE, - a %*% b - ) + expect_snapshot(error = TRUE, a %*% b) - expect_snapshot(error = TRUE, - a %*% c - ) + expect_snapshot(error = TRUE, a %*% c) }) test_that("%*% works when one is a non-greta array", { diff --git a/tests/testthat/test_opt.R b/tests/testthat/test_opt.R index 22f96ea9..0afce8a9 100644 --- a/tests/testthat/test_opt.R +++ b/tests/testthat/test_opt.R @@ -86,14 +86,14 @@ test_that("opt fails with defunct optimisers", { m <- model(z) # check that the right ones error about defunct - expect_snapshot(error = TRUE,o <- opt(m, optimiser = powell())) - expect_snapshot(error = TRUE,o <- opt(m, optimiser = momentum())) - expect_snapshot(error = TRUE,o <- opt(m, optimiser = cg())) - expect_snapshot(error = TRUE,o <- opt(m, optimiser = newton_cg())) - expect_snapshot(error = TRUE,o <- opt(m, optimiser = l_bfgs_b())) - expect_snapshot(error = TRUE,o <- opt(m, optimiser = tnc())) - expect_snapshot(error = TRUE,o <- opt(m, optimiser = cobyla())) - expect_snapshot(error = TRUE,o <- opt(m, optimiser = slsqp())) + expect_snapshot(error = TRUE, o <- opt(m, optimiser = powell())) + expect_snapshot(error = TRUE, o <- opt(m, optimiser = momentum())) + expect_snapshot(error = TRUE, o <- opt(m, optimiser = cg())) + expect_snapshot(error = TRUE, o <- opt(m, optimiser = newton_cg())) + expect_snapshot(error = TRUE, o <- opt(m, optimiser = l_bfgs_b())) + expect_snapshot(error = TRUE, o <- opt(m, optimiser = tnc())) + expect_snapshot(error = TRUE, o <- opt(m, optimiser = cobyla())) + expect_snapshot(error = TRUE, o <- opt(m, optimiser = slsqp())) }) test_that("opt accepts initial values for TF optimisers", { @@ -190,10 +190,10 @@ test_that("TF opt with `gradient_descent` fails with bad initial values", { distribution(x) <- normal(z, sd) m <- model(z) - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, o <- opt(m, hessian = TRUE, optimiser = gradient_descent()) ) - }) test_that("TF opt with `adam` succeeds with bad initial values", { @@ -233,7 +233,6 @@ test_that("TF opt with `adam` succeeds with bad initial values", { # the model density is IID normal, so we should be able to recover the SD approx_sd <- sqrt(diag(solve(hess))) expect_true(all(abs(approx_sd - sd) < 1e-9)) - }) ## diff --git a/tests/testthat/test_posteriors_chi_squared.R b/tests/testthat/test_posteriors_chi_squared.R index 68228fd6..1332d347 100644 --- a/tests/testthat/test_posteriors_chi_squared.R +++ b/tests/testthat/test_posteriors_chi_squared.R @@ -5,9 +5,11 @@ test_that("samplers are unbiased for chi-squared", { x <- chi_squared(df) iid <- function(n) rchisq(n, df) - chi_squared_checked <- check_samples(x = x, - iid_function = iid, - sampler = hmc()) + chi_squared_checked <- check_samples( + x = x, + iid_function = iid, + sampler = hmc() + ) # do the plotting qqplot_checked_samples(chi_squared_checked) diff --git a/tests/testthat/test_posteriors_geweke.R b/tests/testthat/test_posteriors_geweke.R index ef810518..40a6cac5 100644 --- a/tests/testthat/test_posteriors_geweke.R +++ b/tests/testthat/test_posteriors_geweke.R @@ -76,5 +76,4 @@ test_that("samplers pass geweke tests", { geweke_qq(geweke_hmc_slice, title = "slice sampler Geweke test") testthat::expect_gte(geweke_hmc_slice$p.value, 0.005) - }) diff --git a/tests/testthat/test_posteriors_wishart.R b/tests/testthat/test_posteriors_wishart.R index 9aa555b4..34de2aea 100644 --- a/tests/testthat/test_posteriors_wishart.R +++ b/tests/testthat/test_posteriors_wishart.R @@ -21,7 +21,6 @@ test_that("samplers are unbiased for Wishart", { one_by_one = TRUE ) - # do the plotting qqplot_checked_samples(wishart_checked) diff --git a/tests/testthat/test_representations.R b/tests/testthat/test_representations.R index a9a09c68..3312edb7 100644 --- a/tests/testthat/test_representations.R +++ b/tests/testthat/test_representations.R @@ -27,10 +27,9 @@ test_that("log and exp function representations work", { test_that("chol & chol2inv function representation works", { skip_if_not(check_tf_version()) - # get symmetric matrix m <- 10 - w <- rWishart(1, m + 1, diag(m))[, , 1] + w <- rWishart(1, m + 1, diag(m))[,, 1] u <- chol(w) # convert to greta arrays @@ -54,7 +53,6 @@ test_that("chol & chol2inv function representation works", { test_that("bernoulli prob representations have correct density", { skip_if_not(check_tf_version()) - n <- 100 x <- rbinom(n, 1, 0.5) probs <- runif(n) @@ -86,7 +84,6 @@ test_that("bernoulli prob representations have correct density", { test_that("binomial prob representations have correct density", { skip_if_not(check_tf_version()) - n <- 100 size <- rpois(n, 50) x <- rbinom(n, size, 0.5) @@ -129,7 +126,6 @@ test_that("binomial prob representations have correct density", { test_that("poisson lambda representation has correct density", { skip_if_not(check_tf_version()) - n <- 100 x <- rpois(n, 10) @@ -154,11 +150,10 @@ test_that("poisson lambda representation has correct density", { test_that("mvn Sigma representation has correct density", { skip_if_not(check_tf_version()) - n <- 100 m <- 5 mn <- t(rnorm(m)) - sig <- rWishart(1, m + 1, diag(m))[, , 1] + sig <- rWishart(1, m + 1, diag(m))[,, 1] x <- mvtnorm::rmvnorm(n, mn, sig) # greta arrays with and without representation @@ -166,7 +161,8 @@ test_that("mvn Sigma representation has correct density", { u <- as_data(chol(sig)) chol_sigs <- chol2symm(u) - sigs_dens <- greta_density(greta::multivariate_normal, + sigs_dens <- greta_density( + greta::multivariate_normal, list( mean = mn, Sigma = sigs @@ -175,7 +171,8 @@ test_that("mvn Sigma representation has correct density", { multivariate = TRUE ) - chol_sigs_dens <- greta_density(greta::multivariate_normal, + chol_sigs_dens <- greta_density( + greta::multivariate_normal, list( mean = mn, Sigma = chol_sigs @@ -190,17 +187,17 @@ test_that("mvn Sigma representation has correct density", { test_that("wishart target and Sigma representations have correct density", { skip_if_not(check_tf_version()) - m <- 10 - x <- rWishart(1, m + 1, diag(m))[, , 1] - sig <- rWishart(1, m + 1, diag(m))[, , 1] + x <- rWishart(1, m + 1, diag(m))[,, 1] + sig <- rWishart(1, m + 1, diag(m))[,, 1] # greta arrays for Sigma with and without representation sigs <- as_data(sig) u <- as_data(chol(sig)) chol_sigs <- chol2symm(u) - sigs_dens <- greta_density(greta::wishart, + sigs_dens <- greta_density( + greta::wishart, list( df = m + 1, Sigma = sigs @@ -209,7 +206,8 @@ test_that("wishart target and Sigma representations have correct density", { multivariate = TRUE ) - chol_sigs_dens <- greta_density(greta::wishart, + chol_sigs_dens <- greta_density( + greta::wishart, list( df = m + 1, Sigma = chol_sigs @@ -225,7 +223,8 @@ test_that("wishart target and Sigma representations have correct density", { ux <- as_data(chol(x)) chol_xs <- chol2symm(ux) - xs_dens <- greta_density(greta::wishart, + xs_dens <- greta_density( + greta::wishart, list( df = m + 1, Sigma = sig @@ -234,7 +233,8 @@ test_that("wishart target and Sigma representations have correct density", { multivariate = TRUE ) - chol_xs_dens <- greta_density(greta::wishart, + chol_xs_dens <- greta_density( + greta::wishart, list( df = m + 1, Sigma = sig @@ -249,10 +249,9 @@ test_that("wishart target and Sigma representations have correct density", { test_that("lkj target representation has correct density", { skip_if_not(check_tf_version()) - m <- 10 eta <- 3 - x <- rWishart(1, m + 1, diag(m))[, , 1] + x <- rWishart(1, m + 1, diag(m))[,, 1] x <- cov2cor(x) # greta arrays for x with and without representation diff --git a/tests/testthat/test_seed.R b/tests/testthat/test_seed.R index 673e82cc..7dfb6ffa 100644 --- a/tests/testthat/test_seed.R +++ b/tests/testthat/test_seed.R @@ -1,4 +1,3 @@ - test_that("calculate uses the local RNG seed", { skip_if_not(check_tf_version()) @@ -38,7 +37,7 @@ test_that("when calculate simulates multiple values, they are calculated using t skip_if_not(check_tf_version()) x <- normal(0, 1) - x_2 <- x*1 + x_2 <- x * 1 vals <- calculate(x, x_2, nsim = 10) expect_identical(vals$x, vals$x_2) @@ -84,7 +83,6 @@ test_that("calculate samples are the same when the argument seed is the same", { c_two <- calculate(y, nsim = 1, seed = 12345) expect_identical(as.numeric(c_one), as.numeric(c_two)) - }) test_that("calculate samples are the same when the R seed is the same", { @@ -123,13 +121,12 @@ test_that("mcmc samples are the same when the R seed is the same, also with tf s expect_identical(as.numeric(one_tf), as.numeric(two_tf)) # but these are not (always) equal to each other - mcmc_matches_tf_one <- identical(as.numeric(one),as.numeric(one_tf)) - mcmc_matches_tf_two <- identical(as.numeric(two),as.numeric(two_tf)) + mcmc_matches_tf_one <- identical(as.numeric(one), as.numeric(one_tf)) + mcmc_matches_tf_two <- identical(as.numeric(two), as.numeric(two_tf)) expect_false(mcmc_matches_tf_one) expect_false(mcmc_matches_tf_two) - }) test_that("simulate uses the local RNG seed", { diff --git a/tests/testthat/test_simulate.R b/tests/testthat/test_simulate.R index ee0d9a8b..e87e58d0 100644 --- a/tests/testthat/test_simulate.R +++ b/tests/testthat/test_simulate.R @@ -19,47 +19,34 @@ test_that("simulate produces the right number of samples", { test_that("simulate errors if distribution-free variables are not fixed", { skip_if_not(check_tf_version()) - # fix variable a <- variable() y <- normal(a, 1) m <- model(y) - expect_snapshot(error = TRUE, - sims <- simulate(m) - ) + expect_snapshot(error = TRUE, sims <- simulate(m)) }) test_that("simulate errors if a distribution cannot be sampled from", { skip_if_not(check_tf_version()) - # fix variable y_ <- rhyper(10, 5, 3, 2) y <- as_data(y_) m <- lognormal(0, 1) distribution(y) <- hypergeometric(m, 3, 2) m <- model(y) - expect_snapshot(error = TRUE, - sims <- simulate(m) - ) + expect_snapshot(error = TRUE, sims <- simulate(m)) }) test_that("simulate errors nicely if nsim is invalid", { skip_if_not(check_tf_version()) - x <- normal(0, 1) m <- model(x) - expect_snapshot(error = TRUE, - simulate(m, nsim = 0) - ) + expect_snapshot(error = TRUE, simulate(m, nsim = 0)) - expect_snapshot(error = TRUE, - simulate(m, nsim = -1) - ) + expect_snapshot(error = TRUE, simulate(m, nsim = -1)) - expect_snapshot(error = TRUE, - simulate(m, nsim = "five") - ) + expect_snapshot(error = TRUE, simulate(m, nsim = "five")) }) diff --git a/tests/testthat/test_syntax.R b/tests/testthat/test_syntax.R index dfaab03a..6001477c 100644 --- a/tests/testthat/test_syntax.R +++ b/tests/testthat/test_syntax.R @@ -36,58 +36,42 @@ test_that("`distribution<-` errors informatively", { x <- randn(1) # not a greta array with a distribution on the right - expect_snapshot(error = TRUE, - distribution(y) <- x - ) + expect_snapshot(error = TRUE, distribution(y) <- x) - expect_snapshot(error = TRUE, - distribution(y) <- as_data(x) - ) + expect_snapshot(error = TRUE, distribution(y) <- as_data(x)) # no density on the right - expect_snapshot(error = TRUE, - distribution(y) <- variable() - ) + expect_snapshot(error = TRUE, distribution(y) <- variable()) # non-scalar and wrong dimensions - expect_snapshot(error = TRUE, + expect_snapshot( + error = TRUE, distribution(y) <- normal(0, 1, dim = c(3, 3, 1)) ) # double assignment of distribution to node y_ <- as_data(y) distribution(y_) <- normal(0, 1) - expect_snapshot(error = TRUE, - distribution(y_) <- normal(0, 1) - ) + expect_snapshot(error = TRUE, distribution(y_) <- normal(0, 1)) # assignment with a greta array that already has a fixed value y1 <- as_data(y) y2 <- as_data(y) d <- normal(0, 1) distribution(y1) <- d - expect_snapshot(error = TRUE, - distribution(y2) <- y1 - ) + expect_snapshot(error = TRUE, distribution(y2) <- y1) # assignment to a variable z <- variable() - expect_snapshot(error = TRUE, - distribution(z) <- normal(0, 1) - ) + expect_snapshot(error = TRUE, distribution(z) <- normal(0, 1)) # assignment to an op z2 <- z^2 - expect_snapshot(error = TRUE, - distribution(z2) <- normal(0, 1) - ) + expect_snapshot(error = TRUE, distribution(z2) <- normal(0, 1)) # assignment to another distribution u <- uniform(0, 1) - expect_snapshot(error = TRUE, - distribution(z2) <- normal(0, 1) - ) - + expect_snapshot(error = TRUE, distribution(z2) <- normal(0, 1)) }) test_that("distribution() errors informatively", { diff --git a/tests/testthat/test_transforms.R b/tests/testthat/test_transforms.R index 0cf2a1ff..c623219e 100644 --- a/tests/testthat/test_transforms.R +++ b/tests/testthat/test_transforms.R @@ -31,8 +31,5 @@ test_that("imultilogit errors informatively", { x <- ones(3, 4, 3) - expect_snapshot(error = TRUE, - imultilogit(x) - ) - + expect_snapshot(error = TRUE, imultilogit(x)) }) diff --git a/tests/testthat/test_truncated.R b/tests/testthat/test_truncated.R index 1803e41e..5eca79ef 100644 --- a/tests/testthat/test_truncated.R +++ b/tests/testthat/test_truncated.R @@ -4,7 +4,8 @@ test_that("truncated normal has correct densities", { skip_if_not(check_tf_version()) # non truncated normal - compare_truncated_distribution(normal, + compare_truncated_distribution( + normal, "norm", parameters = list( mean = -1, @@ -14,7 +15,8 @@ test_that("truncated normal has correct densities", { ) # positive truncated - compare_truncated_distribution(normal, + compare_truncated_distribution( + normal, "norm", parameters = list( mean = -1, @@ -24,7 +26,8 @@ test_that("truncated normal has correct densities", { ) # negative truncated - compare_truncated_distribution(normal, + compare_truncated_distribution( + normal, "norm", parameters = list( mean = -1, @@ -34,7 +37,8 @@ test_that("truncated normal has correct densities", { ) # fully truncated - compare_truncated_distribution(normal, + compare_truncated_distribution( + normal, "norm", parameters = list( mean = -1, @@ -48,7 +52,8 @@ test_that("truncated lognormal has correct densities", { skip_if_not(check_tf_version()) # non truncated - compare_truncated_distribution(lognormal, + compare_truncated_distribution( + lognormal, "lnorm", parameters = list( meanlog = -1, @@ -58,7 +63,8 @@ test_that("truncated lognormal has correct densities", { ) # positive truncated - compare_truncated_distribution(lognormal, + compare_truncated_distribution( + lognormal, "lnorm", parameters = list( meanlog = -1, @@ -68,7 +74,8 @@ test_that("truncated lognormal has correct densities", { ) # negative truncated - compare_truncated_distribution(lognormal, + compare_truncated_distribution( + lognormal, "lnorm", parameters = list( meanlog = -1, @@ -78,7 +85,8 @@ test_that("truncated lognormal has correct densities", { ) # fully truncated - compare_truncated_distribution(lognormal, + compare_truncated_distribution( + lognormal, "lnorm", parameters = list( meanlog = -1, @@ -92,7 +100,8 @@ test_that("truncated gamma has correct densities", { skip_if_not(check_tf_version()) # non truncated - compare_truncated_distribution(gamma, + compare_truncated_distribution( + gamma, "gamma", parameters = list( shape = 2, @@ -102,7 +111,8 @@ test_that("truncated gamma has correct densities", { ) # positive truncated - compare_truncated_distribution(gamma, + compare_truncated_distribution( + gamma, "gamma", parameters = list( shape = 2, @@ -112,7 +122,8 @@ test_that("truncated gamma has correct densities", { ) # negative truncated - compare_truncated_distribution(gamma, + compare_truncated_distribution( + gamma, "gamma", parameters = list( shape = 2, @@ -122,7 +133,8 @@ test_that("truncated gamma has correct densities", { ) # fully truncated - compare_truncated_distribution(gamma, + compare_truncated_distribution( + gamma, "gamma", parameters = list( shape = 2, @@ -142,7 +154,8 @@ test_that("truncated inverse gamma has correct densities", { pinvgamma <<- extraDistr::pinvgamma # non truncated - compare_truncated_distribution(inverse_gamma, + compare_truncated_distribution( + inverse_gamma, "invgamma", parameters = list( alpha = 2, @@ -152,7 +165,8 @@ test_that("truncated inverse gamma has correct densities", { ) # positive truncated - compare_truncated_distribution(inverse_gamma, + compare_truncated_distribution( + inverse_gamma, "invgamma", parameters = list( alpha = 2, @@ -162,7 +176,8 @@ test_that("truncated inverse gamma has correct densities", { ) # negative truncated - compare_truncated_distribution(inverse_gamma, + compare_truncated_distribution( + inverse_gamma, "invgamma", parameters = list( alpha = 2, @@ -172,7 +187,8 @@ test_that("truncated inverse gamma has correct densities", { ) # fully truncated - compare_truncated_distribution(inverse_gamma, + compare_truncated_distribution( + inverse_gamma, "invgamma", parameters = list( alpha = 2, @@ -186,7 +202,8 @@ test_that("truncated weibull has correct densities", { skip_if_not(check_tf_version()) # non truncated - compare_truncated_distribution(weibull, + compare_truncated_distribution( + weibull, "weibull", parameters = list( shape = 2, @@ -196,7 +213,8 @@ test_that("truncated weibull has correct densities", { ) # positive truncated - compare_truncated_distribution(weibull, + compare_truncated_distribution( + weibull, "weibull", parameters = list( shape = 2, @@ -206,7 +224,8 @@ test_that("truncated weibull has correct densities", { ) # negative truncated - compare_truncated_distribution(weibull, + compare_truncated_distribution( + weibull, "weibull", parameters = list( shape = 2, @@ -216,7 +235,8 @@ test_that("truncated weibull has correct densities", { ) # fully truncated - compare_truncated_distribution(weibull, + compare_truncated_distribution( + weibull, "weibull", parameters = list( shape = 2, @@ -230,28 +250,32 @@ test_that("truncated exponential has correct densities", { skip_if_not(check_tf_version()) # non truncated - compare_truncated_distribution(exponential, + compare_truncated_distribution( + exponential, "exp", parameters = list(rate = 2), truncation = c(0, Inf) ) # positive truncated - compare_truncated_distribution(exponential, + compare_truncated_distribution( + exponential, "exp", parameters = list(rate = 2), truncation = c(1, Inf) ) # negative truncated - compare_truncated_distribution(exponential, + compare_truncated_distribution( + exponential, "exp", parameters = list(rate = 2), truncation = c(0, 2) ) # fully truncated - compare_truncated_distribution(exponential, + compare_truncated_distribution( + exponential, "exp", parameters = list(rate = 2), truncation = c(1, 2) @@ -271,7 +295,8 @@ test_that("truncated pareto has correct densities", { qpreto <<- function(p, a_, b_) extraDistr::qpareto(p, a_, b_) # non truncated - compare_truncated_distribution(preto, + compare_truncated_distribution( + preto, "preto", parameters = list( a_ = 1.9, @@ -281,7 +306,8 @@ test_that("truncated pareto has correct densities", { ) # positive truncated - compare_truncated_distribution(preto, + compare_truncated_distribution( + preto, "preto", parameters = list( a_ = 1.9, @@ -291,7 +317,8 @@ test_that("truncated pareto has correct densities", { ) # negative truncated - compare_truncated_distribution(preto, + compare_truncated_distribution( + preto, "preto", parameters = list( a_ = 1.9, @@ -301,7 +328,8 @@ test_that("truncated pareto has correct densities", { ) # fully truncated - compare_truncated_distribution(preto, + compare_truncated_distribution( + preto, "preto", parameters = list( a_ = 1.9, @@ -319,7 +347,8 @@ test_that("truncated student has correct densities", { pstudent <<- extraDistr::plst # non truncated - compare_truncated_distribution(student, + compare_truncated_distribution( + student, "student", parameters = list( df = 5, @@ -330,7 +359,8 @@ test_that("truncated student has correct densities", { ) # positive truncated - compare_truncated_distribution(student, + compare_truncated_distribution( + student, "student", parameters = list( df = 5, @@ -341,7 +371,8 @@ test_that("truncated student has correct densities", { ) # negative truncated - compare_truncated_distribution(student, + compare_truncated_distribution( + student, "student", parameters = list( df = 5, @@ -352,7 +383,8 @@ test_that("truncated student has correct densities", { ) # fully truncated - compare_truncated_distribution(student, + compare_truncated_distribution( + student, "student", parameters = list( df = 5, @@ -372,7 +404,8 @@ test_that("truncated laplace has correct densities", { plaplace <<- extraDistr::plaplace # non truncated - compare_truncated_distribution(laplace, + compare_truncated_distribution( + laplace, "laplace", parameters = list( mu = 2, @@ -382,7 +415,8 @@ test_that("truncated laplace has correct densities", { ) # positive truncated - compare_truncated_distribution(laplace, + compare_truncated_distribution( + laplace, "laplace", parameters = list( mu = 2, @@ -392,7 +426,8 @@ test_that("truncated laplace has correct densities", { ) # negative truncated - compare_truncated_distribution(laplace, + compare_truncated_distribution( + laplace, "laplace", parameters = list( mu = 2, @@ -402,7 +437,8 @@ test_that("truncated laplace has correct densities", { ) # fully truncated - compare_truncated_distribution(laplace, + compare_truncated_distribution( + laplace, "laplace", parameters = list( mu = 2, @@ -416,7 +452,8 @@ test_that("truncated beta has correct densities", { skip_if_not(check_tf_version()) # non truncated - compare_truncated_distribution(beta, + compare_truncated_distribution( + beta, "beta", parameters = list( shape1 = 2.1, @@ -426,7 +463,8 @@ test_that("truncated beta has correct densities", { ) # positive truncated - compare_truncated_distribution(beta, + compare_truncated_distribution( + beta, "beta", parameters = list( shape1 = 2.1, @@ -436,7 +474,8 @@ test_that("truncated beta has correct densities", { ) # negative truncated - compare_truncated_distribution(beta, + compare_truncated_distribution( + beta, "beta", parameters = list( shape1 = 2.1, @@ -446,7 +485,8 @@ test_that("truncated beta has correct densities", { ) # fully truncated - compare_truncated_distribution(beta, + compare_truncated_distribution( + beta, "beta", parameters = list( shape1 = 2.1, @@ -460,7 +500,8 @@ test_that("truncated cauchy has correct densities", { skip_if_not(check_tf_version()) # non truncated - compare_truncated_distribution(cauchy, + compare_truncated_distribution( + cauchy, "cauchy", parameters = list( location = -1.3, @@ -470,7 +511,8 @@ test_that("truncated cauchy has correct densities", { ) # positive truncated - compare_truncated_distribution(cauchy, + compare_truncated_distribution( + cauchy, "cauchy", parameters = list( location = -1.3, @@ -480,7 +522,8 @@ test_that("truncated cauchy has correct densities", { ) # negative truncated - compare_truncated_distribution(cauchy, + compare_truncated_distribution( + cauchy, "cauchy", parameters = list( location = -1.3, @@ -490,7 +533,8 @@ test_that("truncated cauchy has correct densities", { ) # fully truncated - compare_truncated_distribution(cauchy, + compare_truncated_distribution( + cauchy, "cauchy", parameters = list( location = -1.3, @@ -504,7 +548,8 @@ test_that("truncated logistic has correct densities", { skip_if_not(check_tf_version()) # non truncated - compare_truncated_distribution(logistic, + compare_truncated_distribution( + logistic, "logis", parameters = list( location = -1.3, @@ -514,7 +559,8 @@ test_that("truncated logistic has correct densities", { ) # positive truncated - compare_truncated_distribution(logistic, + compare_truncated_distribution( + logistic, "logis", parameters = list( location = -1.3, @@ -524,7 +570,8 @@ test_that("truncated logistic has correct densities", { ) # negative truncated - compare_truncated_distribution(logistic, + compare_truncated_distribution( + logistic, "logis", parameters = list( location = -1.3, @@ -534,7 +581,8 @@ test_that("truncated logistic has correct densities", { ) # fully truncated - compare_truncated_distribution(logistic, + compare_truncated_distribution( + logistic, "logis", parameters = list( location = -1.3, @@ -548,7 +596,8 @@ test_that("truncated f has correct densities", { skip_if_not(check_tf_version()) # non truncated - compare_truncated_distribution(f, + compare_truncated_distribution( + f, "f", parameters = list( df1 = 1.3, @@ -558,7 +607,8 @@ test_that("truncated f has correct densities", { ) # positive truncated - compare_truncated_distribution(f, + compare_truncated_distribution( + f, "f", parameters = list( df1 = 1.3, @@ -568,7 +618,8 @@ test_that("truncated f has correct densities", { ) # negative truncated - compare_truncated_distribution(f, + compare_truncated_distribution( + f, "f", parameters = list( df1 = 1.3, @@ -578,7 +629,8 @@ test_that("truncated f has correct densities", { ) # fully truncated - compare_truncated_distribution(f, + compare_truncated_distribution( + f, "f", parameters = list( df1 = 1.3, @@ -592,28 +644,32 @@ test_that("truncated chi squared has correct densities", { skip_if_not(check_tf_version()) # non truncated - compare_truncated_distribution(chi_squared, + compare_truncated_distribution( + chi_squared, "chisq", parameters = list(df = 9.3), truncation = c(0, Inf) ) # positive truncated - compare_truncated_distribution(chi_squared, + compare_truncated_distribution( + chi_squared, "chisq", parameters = list(df = 9.3), truncation = c(0.1, Inf) ) # negative truncated - compare_truncated_distribution(chi_squared, + compare_truncated_distribution( + chi_squared, "chisq", parameters = list(df = 9.3), truncation = c(0, 0.2) ) # fully truncated - compare_truncated_distribution(chi_squared, + compare_truncated_distribution( + chi_squared, "chisq", parameters = list(df = 9.3), truncation = c(0.1, 0.2) @@ -623,11 +679,7 @@ test_that("truncated chi squared has correct densities", { test_that("bad truncations error", { skip_if_not(check_tf_version()) - expect_snapshot(error = TRUE, - lognormal(0, 1, truncation = c(-1, Inf)) - ) + expect_snapshot(error = TRUE, lognormal(0, 1, truncation = c(-1, Inf))) - expect_snapshot(error = TRUE, - beta(1, 1, truncation = c(-1, 2)) - ) + expect_snapshot(error = TRUE, beta(1, 1, truncation = c(-1, 2))) }) diff --git a/tests/testthat/test_variables.R b/tests/testthat/test_variables.R index aff5fd5c..6a10f891 100644 --- a/tests/testthat/test_variables.R +++ b/tests/testthat/test_variables.R @@ -2,31 +2,19 @@ test_that("variable() errors informatively", { skip_if_not(check_tf_version()) # bad types - expect_snapshot(error = TRUE, - variable(upper = NA) - ) + expect_snapshot(error = TRUE, variable(upper = NA)) - expect_snapshot(error = TRUE, - variable(upper = head) - ) + expect_snapshot(error = TRUE, variable(upper = head)) - expect_snapshot(error = TRUE, - variable(lower = NA) - ) + expect_snapshot(error = TRUE, variable(lower = NA)) - expect_snapshot(error = TRUE, - variable(lower = head) - ) + expect_snapshot(error = TRUE, variable(lower = head)) # good types, bad values - expect_snapshot(error = TRUE, - variable(lower = 0:2, upper = 1:2) - ) + expect_snapshot(error = TRUE, variable(lower = 0:2, upper = 1:2)) # lower not below upper - expect_snapshot(error = TRUE, - variable(lower = 1, upper = 1) - ) + expect_snapshot(error = TRUE, variable(lower = 1, upper = 1)) }) test_that("constrained variable constructors error informatively", { @@ -141,9 +129,11 @@ test_that("cholesky_variable() correlation can be sampled correctly", { variances_one <- abs(variance_samples - 1) < 1e-3 correlations_above_minus_one <- sweep(correlation_samples, 2, -1, `>=`) correlations_below_one <- sweep(correlation_samples, 2, 1, `<=`) - expect_true(all(variances_one) & - all(correlations_above_minus_one) & - all(correlations_below_one)) + expect_true( + all(variances_one) & + all(correlations_above_minus_one) & + all(correlations_below_one) + ) }) test_that("simplex_variable() can be sampled correctly", { diff --git a/touchstone/script.R b/touchstone/script.R index 0152a476..4e5655eb 100644 --- a/touchstone/script.R +++ b/touchstone/script.R @@ -15,19 +15,19 @@ install_greta_deps(timeout = 50) touchstone::benchmark_run( # expr_before_benchmark = source("dir/data.R"), #<-- TODO OTPIONAL setup before benchmark expr_before_benchmark = library(greta), - create_normal = normal(0,1), + create_normal = normal(0, 1), n = 2 ) touchstone::benchmark_run( expr_before_benchmark = library(greta), - create_model = model(normal(0,1)), + create_model = model(normal(0, 1)), n = 5 ) touchstone::benchmark_run( expr_before_benchmark = library(greta), - run_mcmc = mcmc(model(normal(0,1))), + run_mcmc = mcmc(model(normal(0, 1))), n = 5 ) @@ -59,6 +59,5 @@ touchstone::benchmark_run( # n = 6 # ) - # create artifacts used downstream in the GitHub Action touchstone::benchmark_analyze()