@@ -1399,14 +1399,16 @@ check_initial_values_match_chains <- function(initial_values,
1399
1399
n_chains ,
1400
1400
call = rlang :: caller_env()){
1401
1401
1402
- if (! is.initials(initial_values ) && is.list(initial_values )) {
1402
+ initials <- initial_values
1403
+ not_initials_but_list <- ! is.initials(initials ) && is.list(initials )
1404
+ if (not_initials_but_list ) {
1403
1405
# if the user provided a list of initial values, check elements and length
1404
- are_initials <- vapply( initial_values , is. initials, FUN.VALUE = FALSE )
1406
+ all_initials <- all(are_initials( initials ) )
1405
1407
1406
- n_sets <- length(initial_values )
1408
+ n_sets <- length(initials )
1407
1409
1408
1410
initial_values_do_not_match_chains <- n_sets != n_chains
1409
- if (initial_values_do_not_match_chains && all( are_initials ) ) {
1411
+ if (initial_values_do_not_match_chains && all_initials ) {
1410
1412
cli :: cli_abort(
1411
1413
message = c(
1412
1414
" The number of provided initial values does not match chains" ,
@@ -1437,6 +1439,29 @@ check_initial_values_correct_dim <- function(target_dims,
1437
1439
1438
1440
}
1439
1441
1442
+ check_initial_values_correct_class <- function (initial_values ,
1443
+ call = rlang :: caller_env()){
1444
+
1445
+ initials <- initial_values
1446
+ not_initials_but_list <- ! is.initials(initials ) && is.list(initials )
1447
+ not_initials_not_list <- ! is.initials(initials ) && ! is.list(initials )
1448
+ # if the user provided a list of initial values, check elements and the
1449
+ # length
1450
+ all_initials <- all(are_initials(initials ))
1451
+ not_all_initials <- ! all_initials
1452
+
1453
+ if (not_initials_but_list && not_all_initials || not_initials_not_list ) {
1454
+ cli :: cli_abort(
1455
+ message = c(
1456
+ " {.arg initial_values} must be an initials object created with \\
1457
+ {.fun initials}, or a simple list of initials objects"
1458
+ ),
1459
+ call = call
1460
+ )
1461
+ }
1462
+
1463
+ }
1464
+
1440
1465
check_nodes_all_variable <- function (nodes ,
1441
1466
call = rlang :: caller_env()){
1442
1467
types <- lapply(nodes , node_type )
@@ -1921,16 +1946,78 @@ check_has_representation <- function(repr,
1921
1946
check_is_greta_array <- function (x ,
1922
1947
arg = rlang :: caller_arg(x ),
1923
1948
call = rlang :: caller_env()){
1924
- # only for greta arrays
1925
1949
if (! is.greta_array(x )) {
1926
1950
cli :: cli_abort(
1927
1951
message = c(
1928
1952
" {.arg {arg}} must be {.cls greta_array}" ,
1929
- " Object was is {.cls {class(x)}}"
1953
+ " {.arg {arg}} is: {.cls {class(x)}}"
1954
+ ),
1955
+ call = call
1956
+ )
1957
+ }
1958
+ }
1959
+
1960
+ check_missing_infinite_values <- function (x ,
1961
+ optional ,
1962
+ call = rlang :: caller_env()){
1963
+ contains_missing_or_inf <- ! optional & any(! is.finite(x ))
1964
+ if (contains_missing_or_inf ) {
1965
+ cli :: cli_abort(
1966
+ message = c(
1967
+ " {.cls greta_array} must not contain missing or infinite values"
1968
+ ),
1969
+ call = call
1970
+ )
1971
+ }
1972
+ }
1973
+
1974
+ check_truncation_implemented <- function (tfp_distribution ,
1975
+ distribution_node ,
1976
+ call = rlang :: caller_env()){
1977
+
1978
+ cdf <- tfp_distribution $ cdf
1979
+ quantile <- tfp_distribution $ quantile
1980
+
1981
+ is_truncated <- is.null(cdf ) | is.null(quantile )
1982
+ if (is_truncated ) {
1983
+ cli :: cli_abort(
1984
+ message = c(
1985
+ " Sampling is not yet implemented for truncated \\
1986
+ {.val {distribution_node$distribution_name}} distributions"
1987
+ ),
1988
+ call = call
1989
+ )
1990
+ }
1991
+
1992
+ }
1993
+
1994
+ check_sampling_implemented <- function (sample ,
1995
+ distribution_node ,
1996
+ call = rlang :: caller_env()){
1997
+ if (is.null(sample )) {
1998
+ cli :: cli_abort(
1999
+ " Sampling is not yet implemented for \\
2000
+ {.val {distribution_node$distribution_name}} distributions"
2001
+ )
2002
+ }
2003
+ }
2004
+
2005
+ check_timeout <- function (it ,
2006
+ maxit ,
2007
+ call = rlang :: caller_env()){
2008
+ # check we didn't time out
2009
+ if (it == maxit ) {
2010
+ cli :: cli_abort(
2011
+ message = c(
2012
+ " Could not determine the number of independent models in a reasonable \\
2013
+ amount of time" ,
2014
+ " Iterations = {.val {it}}" ,
2015
+ " Maximum iterations = {.cal {maxit}}"
1930
2016
),
1931
2017
call = call
1932
2018
)
1933
2019
}
2020
+
1934
2021
}
1935
2022
1936
2023
0 commit comments