Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use air formatter #776

Merged
merged 1 commit into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 22 additions & 32 deletions R/calculate.R
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,8 @@ calculate <- function(
trace_batch_size = 100,
compute_options = cpu_only()
) {

# set device to be CPU/GPU for the entire run
with(tf$device(compute_options), {

# message users about random seeds and GPU usage if they are using GPU
message_if_using_gpu(compute_options)

Expand Down Expand Up @@ -200,7 +198,6 @@ calculate <- function(
# checks and RNG seed setting if we're sampling
# REFACTOR: check_rng_seed(nim, seed, compute_option)
if (!is.null(nsim)) {

# check nsim is valid
nsim <- check_positive_integer(nsim, "nsim")

Expand All @@ -210,27 +207,25 @@ calculate <- function(
x = ".Random.seed",
envir = .GlobalEnv,
inherits = FALSE
)
)
if (no_global_random_seed) {
runif(1)
}


r_seed <- get(".Random.seed", envir = .GlobalEnv)
on.exit(assign(".Random.seed", r_seed, envir = .GlobalEnv))
tensorflow::set_random_seed(
seed = seed,
disable_gpu = is_using_cpu(compute_options)
)
)
}

if (is.null(seed)){
if (is.null(seed)) {
tensorflow::set_random_seed(
seed = get_seed(),
disable_gpu = is_using_cpu(compute_options)
)
}

}

# set precision
Expand Down Expand Up @@ -259,7 +254,6 @@ calculate <- function(
}

if (!is.greta_mcmc_list(result)) {

# if it's not mcmc samples, make sure the results are in the right order
# (tensorflow order seems to be platform specific?!?)
order <- match(names(result), names(target))
Expand All @@ -278,12 +272,13 @@ calculate <- function(

#' @importFrom coda thin
#' @importFrom stats start end
calculate_greta_mcmc_list <- function(target,
values,
nsim,
tf_float,
trace_batch_size) {

calculate_greta_mcmc_list <- function(
target,
values,
nsim,
tf_float,
trace_batch_size
) {
# assign the free state
stochastic <- !is.null(nsim)

Expand Down Expand Up @@ -321,13 +316,11 @@ calculate_greta_mcmc_list <- function(target,
# if they didn't specify nsim, check we can deterministically compute the
# targets from the draws
if (!stochastic) {

# see if the new dag introduces any new variables
check_dag_introduces_new_variables(dag, mcmc_dag)

# see if any of the targets are stochastic and not sampled in the mcmc
check_targets_stochastic_and_not_sampled(target, mcmc_dag_variables)

}

dag$target_nodes <- lapply(target, get_node)
Expand Down Expand Up @@ -356,19 +349,20 @@ calculate_greta_mcmc_list <- function(target,
# add the batch size to the data list
# assign
# pass these values in as the free state
trace <- dag$trace_values(draws,
trace <- dag$trace_values(
draws,
trace_batch_size = trace_batch_size,
flatten = FALSE
)

# hopefully values is already a list of the correct dimensions...
} else {

# for deterministic posterior prediction, just trace the target for each
# chain

values <- lapply(draws,
#double check the trace value part - can pronbanly just do that here
values <- lapply(
draws,
#double check the trace value part - can pronbanly just do that here
dag$trace_values,
trace_batch_size = trace_batch_size
)
Expand All @@ -394,7 +388,6 @@ calculate_list <- function(target, values, nsim, tf_float, env) {
values_exist <- !identical(values, list())

if (values_exist) {

# check the list of values makes sense, and return these and the
# corresponding greta arrays (looked up by name in environment env)
values_list <- check_values_list(values, env)
Expand All @@ -408,19 +401,16 @@ calculate_list <- function(target, values, nsim, tf_float, env) {

stochastic <- !is.null(nsim)
if (stochastic) {

check_if_unsampleable_and_unfixed(fixed_greta_arrays, dag)

} else {

# check there are no unspecified variables on which the target depends
lapply(target, check_dependencies_satisfied, fixed_greta_arrays, dag, env)
}

# TF1/2 check todo
# need to wrap this in tf_function I think?
if (Sys.getenv("GRETA_DEBUG") == "true") {
browser()
browser()
}
values <- calculate_target_tensor_list(
dag = dag,
Expand All @@ -438,12 +428,12 @@ calculate_list <- function(target, values, nsim, tf_float, env) {
# simultaneously
# assign("calculate_target_tensor_list", target_tensor_list, envir = tfe)

# # add values or data not specified by the user
# data_list <- dag$get_tf_data_list()
# missing <- !names(data_list) %in% names(values)
#
# # send list to tf environment and roll into a dict
#
# # add values or data not specified by the user
# data_list <- dag$get_tf_data_list()
# missing <- !names(data_list) %in% names(values)
#
# # send list to tf environment and roll into a dict
#
}


Expand Down
20 changes: 9 additions & 11 deletions R/callbacks.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ render_progress <- function(reads) {
reads[is.na(reads)] <- ""
some_results <- any(nchar(reads) > 0)
if (some_results) {

# optionally add blanks to put lines at the edges
if (length(reads) > 1) {
reads <- c("", reads, "")
Expand All @@ -50,24 +49,23 @@ percentages <- function() {
}

progress_bars <- function() {
reads <- lapply(greta_stash$progress_bar_log_files,
reads <- lapply(
greta_stash$progress_bar_log_files,
read_progress_log_file,
skip = 1
)
render_progress(reads)
}

# determine the type of progress information
set_progress_bar_type <- function(n_chain){

if (bar_width(n_chain) < 42) {
progress_callback <- percentages
} else {
progress_callback <- progress_bars
}

greta_stash$callbacks$parallel_progress <- progress_callback
set_progress_bar_type <- function(n_chain) {
if (bar_width(n_chain) < 42) {
progress_callback <- percentages
} else {
progress_callback <- progress_bars
}

greta_stash$callbacks$parallel_progress <- progress_callback
}

# register some
Expand Down
Loading
Loading