Skip to content

Commit 3efb9d4

Browse files
committed
add check_for_errors()
change logic in cholesky variable to clarify error catching intent
1 parent 016464b commit 3efb9d4

File tree

3 files changed

+48
-37
lines changed

3 files changed

+48
-37
lines changed

R/checkers.R

+29-2
Original file line numberDiff line numberDiff line change
@@ -1693,9 +1693,9 @@ check_final_dim <- function(dim,
16931693
if (last_dim_gt_1) {
16941694
cli::cli_abort(
16951695
message = c(
1696-
"The final dimension of a {thing} must have more than \\
1696+
"The final dimension of a {thing} must have more than \\
16971697
one element",
1698-
"The final dimension has: {.val {n_last_dim} element{?s}}"
1698+
"The final dimension has: {.val {n_last_dim} element{?s}}"
16991699
),
17001700
call = call
17011701
)
@@ -1725,6 +1725,33 @@ check_not_greta_array <- function(x,
17251725
}
17261726
}
17271727

1728+
# if it errored
1729+
check_for_errors <- function(res,
1730+
call = rlang::caller_env()){
1731+
1732+
if (inherits(res, "error")) {
1733+
1734+
# check for known numerical errors
1735+
numerical_errors <- vapply(greta_stash$numerical_messages,
1736+
grepl,
1737+
res$message,
1738+
FUN.VALUE = 0
1739+
) == 1
1740+
1741+
# if it was just a numerical error, quietly return a bad value
1742+
if (!any(numerical_errors)) {
1743+
cli::cli_abort(
1744+
message = c(
1745+
"{.pkg greta} hit a tensorflow error:",
1746+
"{res}"
1747+
),
1748+
call = call
1749+
)
1750+
}
1751+
}
1752+
1753+
}
1754+
17281755
checks_module <- module(
17291756
check_tf_version,
17301757
check_dims,

R/utils.R

+1-20
Original file line numberDiff line numberDiff line change
@@ -510,26 +510,7 @@ quietly <- function(expr) {
510510
cleanly <- function(expr) {
511511
res <- tryCatch(expr, error = function(e) e)
512512

513-
# if it errored
514-
if (inherits(res, "error")) {
515-
516-
# check for known numerical errors
517-
numerical_errors <- vapply(greta_stash$numerical_messages,
518-
grepl,
519-
res$message,
520-
FUN.VALUE = 0
521-
) == 1
522-
523-
# if it was just a numerical error, quietly return a bad value
524-
if (!any(numerical_errors)) {
525-
cli::cli_abort(
526-
c(
527-
"{.pkg greta} hit a tensorflow error:",
528-
"{res}"
529-
)
530-
)
531-
}
532-
}
513+
check_for_errors(res)
533514

534515
res
535516
}

R/variable.R

+18-15
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,21 @@ cholesky_variable <- function(dim, correlation = FALSE) {
7272
n_dim <- length(dim)
7373
if (n_dim == 1) {
7474
dim <- c(dim, dim)
75-
} else if (n_dim == 2) {
76-
not_square <- dim[1] != dim[2]
77-
if (not_square) {
78-
msg <- cli::cli_abort(
79-
c(
80-
"cholesky variables must be square",
81-
"However its dimension is: {.val {paste(dim, collapse = 'x')}}"
82-
)
75+
}
76+
77+
not_square <- dim[1] != dim[2]
78+
79+
if (n_dim == 2 && not_square){
80+
cli::cli_abort(
81+
c(
82+
"cholesky variables must be square",
83+
"However its dimension is: {.val {paste(dim, collapse = 'x')}}"
8384
)
84-
}
85-
} else {
86-
msg <- cli::cli_abort(
85+
)
86+
}
87+
88+
if (length(dim) > 2) {
89+
cli::cli_abort(
8790
c(
8891
"{.arg dim} can either be a scalar or a vector of length 2",
8992
"However {.arg dim} has length {.val {length(dim)}}, and contains: \\
@@ -96,8 +99,8 @@ cholesky_variable <- function(dim, correlation = FALSE) {
9699

97100
# dimension of the free state version
98101
free_dim <- ifelse(correlation,
99-
k * (k - 1) / 2,
100-
k + k * (k - 1) / 2
102+
k * (k - 1) / 2,
103+
k + k * (k - 1) / 2
101104
)
102105

103106
# create variable node
@@ -109,8 +112,8 @@ cholesky_variable <- function(dim, correlation = FALSE) {
109112

110113
# set the constraint, to enable transformation
111114
node$constraint <- ifelse(correlation,
112-
"correlation_matrix",
113-
"covariance_matrix"
115+
"correlation_matrix",
116+
"covariance_matrix"
114117
)
115118

116119
# set the printed value to be nicer

0 commit comments

Comments
 (0)