Skip to content

Commit fe3356e

Browse files
committed
Add more checking functions:
* check_for_free_state_error() * check_initial_values_correct_class() * check_missing_infinite_values() * check_timeout() * check_sampling_implemented() * check_truncation_implemented() * add internal `are_intials()` function
1 parent 360d23c commit fe3356e

14 files changed

+149
-75
lines changed

R/checkers.R

+93-6
Original file line numberDiff line numberDiff line change
@@ -1399,14 +1399,16 @@ check_initial_values_match_chains <- function(initial_values,
13991399
n_chains,
14001400
call = rlang::caller_env()){
14011401

1402-
if (!is.initials(initial_values) && is.list(initial_values)) {
1402+
initials <- initial_values
1403+
not_initials_but_list <- !is.initials(initials) && is.list(initials)
1404+
if (not_initials_but_list) {
14031405
# if the user provided a list of initial values, check elements and length
1404-
are_initials <- vapply(initial_values, is.initials, FUN.VALUE = FALSE)
1406+
all_initials <- all(are_initials(initials))
14051407

1406-
n_sets <- length(initial_values)
1408+
n_sets <- length(initials)
14071409

14081410
initial_values_do_not_match_chains <- n_sets != n_chains
1409-
if (initial_values_do_not_match_chains && all(are_initials)) {
1411+
if (initial_values_do_not_match_chains && all_initials) {
14101412
cli::cli_abort(
14111413
message = c(
14121414
"The number of provided initial values does not match chains",
@@ -1437,6 +1439,29 @@ check_initial_values_correct_dim <- function(target_dims,
14371439

14381440
}
14391441

1442+
check_initial_values_correct_class <- function(initial_values,
1443+
call = rlang::caller_env()){
1444+
1445+
initials <- initial_values
1446+
not_initials_but_list <- !is.initials(initials) && is.list(initials)
1447+
not_initials_not_list <- !is.initials(initials) && !is.list(initials)
1448+
# if the user provided a list of initial values, check elements and the
1449+
# length
1450+
all_initials <- all(are_initials(initials))
1451+
not_all_initials <- !all_initials
1452+
1453+
if (not_initials_but_list && not_all_initials || not_initials_not_list) {
1454+
cli::cli_abort(
1455+
message = c(
1456+
"{.arg initial_values} must be an initials object created with \\
1457+
{.fun initials}, or a simple list of initials objects"
1458+
),
1459+
call = call
1460+
)
1461+
}
1462+
1463+
}
1464+
14401465
check_nodes_all_variable <- function(nodes,
14411466
call = rlang::caller_env()){
14421467
types <- lapply(nodes, node_type)
@@ -1921,16 +1946,78 @@ check_has_representation <- function(repr,
19211946
check_is_greta_array <- function(x,
19221947
arg = rlang::caller_arg(x),
19231948
call = rlang::caller_env()){
1924-
# only for greta arrays
19251949
if (!is.greta_array(x)) {
19261950
cli::cli_abort(
19271951
message = c(
19281952
"{.arg {arg}} must be {.cls greta_array}",
1929-
"Object was is {.cls {class(x)}}"
1953+
"{.arg {arg}} is: {.cls {class(x)}}"
1954+
),
1955+
call = call
1956+
)
1957+
}
1958+
}
1959+
1960+
check_missing_infinite_values <- function(x,
1961+
optional,
1962+
call = rlang::caller_env()){
1963+
contains_missing_or_inf <- !optional & any(!is.finite(x))
1964+
if (contains_missing_or_inf) {
1965+
cli::cli_abort(
1966+
message = c(
1967+
"{.cls greta_array} must not contain missing or infinite values"
1968+
),
1969+
call = call
1970+
)
1971+
}
1972+
}
1973+
1974+
check_truncation_implemented <- function(tfp_distribution,
1975+
distribution_node,
1976+
call = rlang::caller_env()){
1977+
1978+
cdf <- tfp_distribution$cdf
1979+
quantile <- tfp_distribution$quantile
1980+
1981+
is_truncated <- is.null(cdf) | is.null(quantile)
1982+
if (is_truncated) {
1983+
cli::cli_abort(
1984+
message = c(
1985+
"Sampling is not yet implemented for truncated \\
1986+
{.val {distribution_node$distribution_name}} distributions"
1987+
),
1988+
call = call
1989+
)
1990+
}
1991+
1992+
}
1993+
1994+
check_sampling_implemented <- function(sample,
1995+
distribution_node,
1996+
call = rlang::caller_env()){
1997+
if (is.null(sample)) {
1998+
cli::cli_abort(
1999+
"Sampling is not yet implemented for \\
2000+
{.val {distribution_node$distribution_name}} distributions"
2001+
)
2002+
}
2003+
}
2004+
2005+
check_timeout <- function(it,
2006+
maxit,
2007+
call = rlang::caller_env()){
2008+
# check we didn't time out
2009+
if (it == maxit) {
2010+
cli::cli_abort(
2011+
message = c(
2012+
"Could not determine the number of independent models in a reasonable \\
2013+
amount of time",
2014+
"Iterations = {.val {it}}",
2015+
"Maximum iterations = {.cal {maxit}}"
19302016
),
19312017
call = call
19322018
)
19332019
}
2020+
19342021
}
19352022

19362023

R/dag_class.R

+3-21
Original file line numberDiff line numberDiff line change
@@ -769,13 +769,7 @@ dag_class <- R6Class(
769769
}
770770
}
771771

772-
# check we didn't time out
773-
if (it == maxit) {
774-
cli::cli_abort(
775-
"could not determine the number of independent models in a \\
776-
reasonable amount of time"
777-
)
778-
}
772+
check_timeout(it, maxit)
779773

780774
# find the cluster IDs
781775
n <- nrow(r)
@@ -812,12 +806,7 @@ dag_class <- R6Class(
812806

813807
sample <- tfp_distribution$sample
814808

815-
if (is.null(sample)) {
816-
cli::cli_abort(
817-
"sampling is not yet implemented for \\
818-
{.val {distribution_node$distribution_name}} distributions"
819-
)
820-
}
809+
check_sampling_implemented(sample, distribution_node)
821810

822811
truncation <- distribution_node$truncation
823812

@@ -833,14 +822,7 @@ dag_class <- R6Class(
833822

834823
cdf <- tfp_distribution$cdf
835824
quantile <- tfp_distribution$quantile
836-
837-
is_truncated <- is.null(cdf) | is.null(quantile)
838-
if (is_truncated) {
839-
cli::cli_abort(
840-
"sampling is not yet implemented for truncated \\
841-
{.val {distribution_node$distribution_name}} distributions"
842-
)
843-
}
825+
check_truncation_implemented(tfp_distribution, distribution_node)
844826

845827
# generate a random uniform sample of the correct shape and transform
846828
# through truncated inverse CDF to get draws on truncated scale

R/distribution.R

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343

4444
node <- get_node(greta_array)
4545

46+
# TODO revisit checking functions here
4647
# only for greta arrays without distributions
4748
## TODO provide more detail on the distribution already assigned
4849
## This might come up when the user accidentally runs assignment

R/extract_replace_combine.R

+2
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ abind.greta_array <- function(...,
405405
along <- max(1, min(n + 1, ceiling(along)))
406406
}
407407

408+
# TODO revisit checking functions here
408409
along_outside_0_n <- !(along %in% 0:n)
409410
if (along_outside_0_n) {
410411
cli::cli_abort(
@@ -530,6 +531,7 @@ length.greta_array <- function(x) {
530531

531532
dims <- dims %||% length(x)
532533

534+
# TODO revisit logic / checking functions here
533535
if (length(dims) == 0L) {
534536
cli::cli_abort(
535537
"length-0 dimension vector is invalid"

R/greta_array_class.R

+2-6
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,8 @@ as.greta_array.array <- function(x, optional = FALSE, original_x = x, ...) {
7878
# finally, reject if there are any missing values, or set up the greta_array
7979
#' @export
8080
as.greta_array.numeric <- function(x, optional = FALSE, original_x = x, ...) {
81-
contains_missing_or_inf <- !optional & any(!is.finite(x))
82-
if (contains_missing_or_inf) {
83-
cli::cli_abort(
84-
"{.cls greta_array} must not contain missing or infinite values"
85-
)
86-
}
81+
check_missing_infinite_values(x, optional)
82+
8783
as.greta_array.node(data_node$new(x),
8884
optional = optional,
8985
original_x = original_x,

R/inference.R

+5-28
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,8 @@ to_free <- function(node, data) {
561561
lower <- node$lower
562562
upper <- node$upper
563563

564+
# TODO
565+
# replace these with more informative errors related to the range of values
564566
unsupported_error <- function() {
565567
cli::cli_abort(
566568
"Some provided initial values are outside the range of values their \\
@@ -661,7 +663,6 @@ parse_initial_values <- function(initials, dag) {
661663
# correct length, with nice error messages
662664
prep_initials <- function(initial_values, n_chains, dag) {
663665

664-
# TODO: Tidy up the logic here for errors and messages
665666
# if the user passed a single set of initial values, repeat them for all
666667
# chains
667668
if (is.initials(initial_values)) {
@@ -673,33 +674,9 @@ prep_initials <- function(initial_values, n_chains, dag) {
673674
)
674675
}
675676

676-
not_initials_but_list <- !is.initials(initial_values) && is.list(initial_values)
677-
if (not_initials_but_list) {
678-
679-
# if the user provided a list of initial values, check elements and the
680-
# length
681-
are_initials <- vapply(initial_values, is.initials, FUN.VALUE = FALSE)
682-
683-
if (all(are_initials)) {
684-
check_initial_values_match_chains(initial_values, n_chains)
685-
}
686-
if (!all(are_initials)) {
687-
initial_values <- NULL
688-
}
689-
}
690-
if (!not_initials_but_list) {
691-
initial_values <- NULL
692-
}
693-
694-
# error on a bad object
695-
if (is.null(initial_values)) {
696-
cli::cli_abort(
697-
c(
698-
"{.arg initial_values} must be an initials object created with \\
699-
{.fun initials}, or a simple list of initials objects"
700-
)
701-
)
702-
}
677+
# TODO: revisit logic here for errors and messages
678+
check_initial_values_match_chains(initial_values, n_chains)
679+
check_initial_values_correct_class(initial_values)
703680

704681
# convert them to free state vectors
705682
initial_values <- lapply(

R/inference_class.R

+11-3
Original file line numberDiff line numberDiff line change
@@ -804,6 +804,15 @@ sampler <- R6Class(
804804
)
805805
) # closing cleanly
806806

807+
# if it's fine, batch_results is the output
808+
# if it's a non-numerical error, it will error
809+
# if it's a numerical error, batch_results will be an error object
810+
self$check_for_free_state_error(result, n_samples)
811+
812+
result
813+
},
814+
815+
check_for_free_state_error = function(result, n_samples){
807816
# if it's fine, batch_results is the output
808817
# if it's a non-numerical error, it will error
809818
# if it's a numerical error, batch_results will be an error object
@@ -827,7 +836,7 @@ sampler <- R6Class(
827836
# won't be valid if we just restart, so we need to error here,
828837
# informing the user how to run one sample at a time
829838
cli::cli_abort(
830-
c(
839+
message = c(
831840
"TensorFlow hit a numerical problem that caused it to error",
832841
"{.pkg greta} can handle these as bad proposals if you rerun \\
833842
{.fun mcmc} with the argument {.code one_by_one = TRUE}.",
@@ -839,9 +848,8 @@ sampler <- R6Class(
839848

840849
}
841850
}
842-
843-
result
844851
},
852+
845853
sampler_parameter_values = function() {
846854

847855
# random number of integration steps

R/utils.R

+8
Original file line numberDiff line numberDiff line change
@@ -1166,3 +1166,11 @@ outside_version_range <- function(provided, range) {
11661166
}
11671167

11681168
pretty_dim <- function(x) paste0(dim(x), collapse = "x")
1169+
1170+
are_initials <- function(x){
1171+
vapply(
1172+
X = x,
1173+
FUN = is.initials,
1174+
FUN.VALUE = logical(1)
1175+
)
1176+
}

tests/testthat/_snaps/calculate.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959

6060
# calculate errors if a distribution cannot be sampled from
6161

62-
sampling is not yet implemented for "hypergeometric" distributions
62+
Sampling is not yet implemented for "hypergeometric" distributions
6363

6464
# calculate errors nicely if nsim is invalid
6565

tests/testthat/_snaps/distributions.md

+12-4
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,21 @@
8181

8282
# wishart distribution errors informatively
8383

84-
`Sigma` must be a square 2D greta array
85-
However, `Sigma` has dimensions "3x3x3"
84+
Code
85+
wishart(3, b)
86+
Condition
87+
Error in `initialize()`:
88+
! `Sigma` must be a square 2D greta array
89+
However, `Sigma` has dimensions "3x3x3"
8690

8791
---
8892

89-
`Sigma` must be a square 2D greta array
90-
However, `Sigma` has dimensions "3x2"
93+
Code
94+
wishart(3, c)
95+
Condition
96+
Error in `initialize()`:
97+
! `Sigma` must be a square 2D greta array
98+
However, `Sigma` has dimensions "3x2"
9199

92100
# lkj_correlation distribution errors informatively
93101

tests/testthat/_snaps/iid_samples.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# distributions without RNG error nicely
22

3-
sampling is not yet implemented for "hypergeometric" distributions
3+
Sampling is not yet implemented for "hypergeometric" distributions
44

55
---
66

7-
sampling is not yet implemented for truncated "f" distributions
7+
Sampling is not yet implemented for truncated "f" distributions
88

tests/testthat/_snaps/simulate.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# simulate errors if a distribution cannot be sampled from
66

7-
sampling is not yet implemented for "hypergeometric" distributions
7+
Sampling is not yet implemented for "hypergeometric" distributions
88

99
# simulate errors nicely if nsim is invalid
1010

tests/testthat/_snaps/syntax.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737

3838
# distribution() errors informatively
3939

40-
`distribution()` expects object of type <greta_array>
41-
object was not a <greta_array>, but <array>
40+
Code
41+
distribution(y)
42+
Condition
43+
Error in `distribution()`:
44+
! `greta_array` must be <greta_array>
45+
`greta_array` is: <array>
4246

tests/testthat/test_syntax.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ test_that("distribution() errors informatively", {
9595

9696
y <- randn(3)
9797

98-
expect_snapshot_error(
98+
expect_snapshot(
99+
error = TRUE,
99100
distribution(y)
100101
)
101102
})

0 commit comments

Comments
 (0)