Skip to content

Commit 386e21e

Browse files
committed
Adding more checking functions
* Add `dim` argument to `check_square` so you can either pass x, and then calculate dim, or pass a dimension * check_dim_length() * check_is_distribution_node() * check_values_dim() * check_dot_nodes_scalar() * inform_if_one_set_of_initials() * check_subgraphs() * check_has_representation() * check_is_greta_array()
1 parent 3efb9d4 commit 386e21e

15 files changed

+311
-205
lines changed

R/checkers.R

+199-17
Original file line numberDiff line numberDiff line change
@@ -183,17 +183,20 @@ check_2d_multivariate <- function(x,
183183
}
184184
}
185185

186-
check_square <- function(x,
186+
check_square <- function(x = NULL,
187+
dim = NULL,
187188
call = rlang::caller_env()) {
188-
dim <- dim(x)
189+
190+
# allows for specifying x or named dim = dim
191+
dim <- dim %||% dim(x)
189192
ndim <- length(dim)
190-
is_square <- ndim == 2 && dim[1] == dim[2]
191-
if (!is_square) {
193+
not_square <- dim[1] != dim[2]
194+
if (ndim == 2 && not_square) {
192195
cli::cli_abort(
193196
message = c(
194-
"Not 2D square greta array",
195-
"x" = "expected a 2D square greta array, but object {.var x} had \\
196-
dimension: {paste(dim, collapse = 'x')}"
197+
"Object must be 2D square array",
198+
"x" = "But it had dimension: \\
199+
{.val {paste(dim, collapse = 'x')}}"
197200
),
198201
call = call
199202
)
@@ -1395,20 +1398,27 @@ check_initials_are_numeric <- function(values,
13951398
check_initial_values_match_chains <- function(initial_values,
13961399
n_chains,
13971400
call = rlang::caller_env()){
1398-
n_sets <- length(initial_values)
13991401

1400-
initial_values_do_not_match_chains <- n_sets != n_chains
1401-
if (initial_values_do_not_match_chains) {
1402-
cli::cli_abort(
1403-
message = c(
1404-
"The number of provided initial values does not match chains",
1405-
"{n_sets} set{?s} of initial values were provided, but there \\
1402+
if (!is.initials(initial_values) && is.list(initial_values)) {
1403+
# if the user provided a list of initial values, check elements and length
1404+
are_initials <- vapply(initial_values, is.initials, FUN.VALUE = FALSE)
1405+
1406+
n_sets <- length(initial_values)
1407+
1408+
initial_values_do_not_match_chains <- n_sets != n_chains
1409+
if (initial_values_do_not_match_chains && all(are_initials)) {
1410+
cli::cli_abort(
1411+
message = c(
1412+
"The number of provided initial values does not match chains",
1413+
"{n_sets} set{?s} of initial values were provided, but there \\
14061414
{cli::qty(n_chains)} {?is only/are} {n_chains} \\
14071415
{cli::qty(n_chains)} chain{?s}"
1408-
),
1409-
call = call
1410-
)
1416+
),
1417+
call = call
1418+
)
1419+
}
14111420
}
1421+
14121422
}
14131423

14141424
check_initial_values_correct_dim <- function(target_dims,
@@ -1752,6 +1762,178 @@ check_for_errors <- function(res,
17521762

17531763
}
17541764

1765+
check_dim_length <- function(dim,
1766+
call = rlang::caller_env()){
1767+
1768+
ndim <- length(dim)
1769+
ndim_gt2 <- ndim > 2
1770+
if (ndim_gt2) {
1771+
cli::cli_abort(
1772+
message = c(
1773+
"{.arg dim} can either be a scalar or a vector of length 2",
1774+
"However {.arg dim} has length {.val {ndim}}, and contains:",
1775+
"{.val {dim}}"
1776+
),
1777+
call = call
1778+
)
1779+
}
1780+
}
1781+
1782+
check_is_distribution_node <- function(distribution,
1783+
call = rlang::caller_env()){
1784+
if (!is.distribution_node(distribution)) {
1785+
cli::cli_abort(
1786+
message = c("Invalid distribution"),
1787+
call = call
1788+
)
1789+
}
1790+
1791+
}
1792+
1793+
check_values_dim <- function(value,
1794+
dim,
1795+
call = rlang::caller_env()){
1796+
values_have_wrong_dim <- !is.null(value) && !all.equal(dim(value), dim)
1797+
if (values_have_wrong_dim) {
1798+
cli::cli_abort(
1799+
message = "Values have the wrong dimension so cannot be used",
1800+
call = call
1801+
)
1802+
}
1803+
1804+
}
1805+
1806+
# check they are all scalar
1807+
check_dot_nodes_scalar <- function(dot_nodes,
1808+
call = rlang::caller_env()){
1809+
are_scalar <- vapply(dot_nodes, is_scalar, logical(1))
1810+
if (!all(are_scalar)) {
1811+
cli::cli_abort(
1812+
message = "{.fun joint} only accepts probability distributions over \\
1813+
scalars",
1814+
call = call
1815+
)
1816+
}
1817+
1818+
}
1819+
1820+
inform_if_one_set_of_initials <- function(initial_values,
1821+
n_chains,
1822+
call = rlang::caller_env()){
1823+
1824+
is_blank <- identical(initial_values, initials())
1825+
1826+
one_set_of_initials <- !is_blank & n_chains > 1
1827+
if (one_set_of_initials) {
1828+
cli::cli_inform(
1829+
message = "Only one set of initial values was provided, and was used \\
1830+
for all chains"
1831+
)
1832+
}
1833+
}
1834+
1835+
# the user might pass greta arrays with groups of nodes that are unconnected
1836+
# to one another. Need to check there are densities in each graph
1837+
check_subgraphs <- function(dag,
1838+
call = rlang::caller_env()){
1839+
# get and check the types
1840+
types <- dag$node_types
1841+
1842+
# the user might pass greta arrays with groups of nodes that are unconnected
1843+
# to one another. Need to check there are densities in each graph
1844+
1845+
# so find the subgraph to which each node belongs
1846+
graph_id <- dag$subgraph_membership()
1847+
1848+
graphs <- unique(graph_id)
1849+
n_graphs <- length(graphs)
1850+
1851+
# separate messages to avoid the subgraphs issue for beginners
1852+
1853+
if (n_graphs == 1) {
1854+
density_message <- cli::format_error(
1855+
c(
1856+
"none of the {.cls greta_array}s in the model are associated with a \\
1857+
probability density, so a model cannot be defined"
1858+
)
1859+
)
1860+
variable_message <- cli::format_error(
1861+
c(
1862+
"none of the {.cls greta_array}s in the model are unknown, so a model \\
1863+
cannot be defined"
1864+
)
1865+
)
1866+
} else {
1867+
density_message <- cli::format_error(
1868+
c(
1869+
"the model contains {n_graphs} disjoint graphs",
1870+
"one or more of these sub-graphs does not contain any \\
1871+
{.cls greta_array}s that are associated with a probability density, \\
1872+
so a model cannot be defined"
1873+
)
1874+
)
1875+
variable_message <- cli::format_error(
1876+
c(
1877+
"the model contains {n_graphs} disjoint graphs",
1878+
"one or more of these sub-graphs does not contain any \\
1879+
{.cls greta_array}s that are unknown, so a model cannot be defined"
1880+
)
1881+
)
1882+
}
1883+
1884+
for (graph in graphs) {
1885+
types_sub <- types[graph_id == graph]
1886+
1887+
# check they have a density among them
1888+
no_distribution <- !("distribution" %in% types_sub)
1889+
if (no_distribution) {
1890+
cli::cli_abort(
1891+
message = density_message,
1892+
call = call
1893+
)
1894+
}
1895+
1896+
# check they have a variable node among them
1897+
no_variable_node <- !("variable" %in% types_sub)
1898+
if (no_variable_node) {
1899+
cli::cli_abort(
1900+
message = variable_message,
1901+
call = call
1902+
)
1903+
}
1904+
}
1905+
1906+
}
1907+
1908+
check_has_representation <- function(repr,
1909+
name,
1910+
error,
1911+
call = rlang::caller_env()){
1912+
not_represented <- error && is.null(repr)
1913+
if (not_represented) {
1914+
cli::cli_abort(
1915+
message = "{.cls greta_array} has no representation {.var {name}}",
1916+
call = call
1917+
)
1918+
}
1919+
}
1920+
1921+
check_is_greta_array <- function(x,
1922+
arg = rlang::caller_arg(x),
1923+
call = rlang::caller_env()){
1924+
# only for greta arrays
1925+
if (!is.greta_array(x)) {
1926+
cli::cli_abort(
1927+
message = c(
1928+
"{.arg {arg}} must be {.cls greta_array}",
1929+
"Object was is {.cls {class(x)}}"
1930+
),
1931+
call = call
1932+
)
1933+
}
1934+
}
1935+
1936+
17551937
checks_module <- module(
17561938
check_tf_version,
17571939
check_dims,

R/dag_class.R

+1
Original file line numberDiff line numberDiff line change
@@ -807,6 +807,7 @@ dag_class <- R6Class(
807807

808808
# try to draw a random sample from a distribution node
809809
draw_sample = function(distribution_node) {
810+
# self$check_sampling_implemented(distribution_node)
810811
tfp_distribution <- self$get_tfp_distribution(distribution_node)
811812

812813
sample <- tfp_distribution$sample

R/distribution.R

+1-8
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,7 @@
123123
distribution <- function(greta_array) {
124124

125125
# only for greta arrays
126-
if (!is.greta_array(greta_array)) {
127-
cli::cli_abort(
128-
c(
129-
"{.fun distribution} expects object of type {.cls greta_array}",
130-
"object was not a {.cls greta_array}, but {.cls {class(greta_array)}}"
131-
)
132-
)
133-
}
126+
check_is_greta_array(greta_array)
134127

135128
# if greta_array has a distribution, return this greta array
136129
if (is.distribution_node(get_node(greta_array)$distribution)) {

R/greta_array_class.R

+1-6
Original file line numberDiff line numberDiff line change
@@ -252,12 +252,7 @@ representation <- function(x, name, error = TRUE) {
252252
x_node <- x
253253
}
254254
repr <- x_node$representations[[name]]
255-
not_represented <- error && is.null(repr)
256-
if (not_represented) {
257-
cli::cli_abort(
258-
"{.cls greta_array} has no representation {.var name}"
259-
)
260-
}
255+
check_has_representation(repr, name, error)
261256
repr
262257
}
263258

R/greta_model_class.R

+1-62
Original file line numberDiff line numberDiff line change
@@ -91,70 +91,9 @@ model <- function(...,
9191
compile = compile
9292
)
9393

94-
# get and check the types
95-
types <- dag$node_types
96-
9794
# the user might pass greta arrays with groups of nodes that are unconnected
9895
# to one another. Need to check there are densities in each graph
99-
100-
# so find the subgraph to which each node belongs
101-
graph_id <- dag$subgraph_membership()
102-
103-
graphs <- unique(graph_id)
104-
n_graphs <- length(graphs)
105-
106-
# separate messages to avoid the subgraphs issue for beginners
107-
if (n_graphs == 1) {
108-
density_message <- cli::format_error(
109-
c(
110-
"none of the {.cls greta_array}s in the model are associated with a \\
111-
probability density, so a model cannot be defined"
112-
)
113-
)
114-
variable_message <- cli::format_error(
115-
c(
116-
"none of the {.cls greta_array}s in the model are unknown, so a model \\
117-
cannot be defined"
118-
)
119-
)
120-
} else {
121-
density_message <- cli::format_error(
122-
c(
123-
"the model contains {n_graphs} disjoint graphs",
124-
"one or more of these sub-graphs does not contain any \\
125-
{.cls greta_array}s that are associated with a probability density, \\
126-
so a model cannot be defined"
127-
)
128-
)
129-
variable_message <- cli::format_error(
130-
c(
131-
"the model contains {n_graphs} disjoint graphs",
132-
"one or more of these sub-graphs does not contain any \\
133-
{.cls greta_array}s that are unknown, so a model cannot be defined"
134-
)
135-
)
136-
}
137-
138-
for (graph in graphs) {
139-
types_sub <- types[graph_id == graph]
140-
141-
# check they have a density among them
142-
no_distribution <- !("distribution" %in% types_sub)
143-
if (no_distribution) {
144-
cli::cli_abort(
145-
message = density_message
146-
)
147-
}
148-
149-
# check they have a variable node among them
150-
no_variable_node <- !("variable" %in% types_sub)
151-
if (no_variable_node) {
152-
cli::cli_abort(
153-
variable_message
154-
)
155-
}
156-
}
157-
96+
check_subgraphs(dag)
15897
check_unfixed_discrete_distributions(dag)
15998

16099
# define the TF graph

0 commit comments

Comments
 (0)