@@ -183,17 +183,20 @@ check_2d_multivariate <- function(x,
183
183
}
184
184
}
185
185
186
- check_square <- function (x ,
186
+ check_square <- function (x = NULL ,
187
+ dim = NULL ,
187
188
call = rlang :: caller_env()) {
188
- dim <- dim(x )
189
+
190
+ # allows for specifying x or named dim = dim
191
+ dim <- dim %|| % dim(x )
189
192
ndim <- length(dim )
190
- is_square <- ndim == 2 && dim [1 ] = = dim [2 ]
191
- if (! is_square ) {
193
+ not_square <- dim [1 ] ! = dim [2 ]
194
+ if (ndim == 2 && not_square ) {
192
195
cli :: cli_abort(
193
196
message = c(
194
- " Not 2D square greta array" ,
195
- " x" = " expected a 2D square greta array, but object {.var x} had \\
196
- dimension: {paste(dim, collapse = 'x')}"
197
+ " Object must be 2D square array" ,
198
+ " x" = " But it had dimension: \\
199
+ {.val {paste(dim, collapse = 'x')} }"
197
200
),
198
201
call = call
199
202
)
@@ -1395,20 +1398,27 @@ check_initials_are_numeric <- function(values,
1395
1398
check_initial_values_match_chains <- function (initial_values ,
1396
1399
n_chains ,
1397
1400
call = rlang :: caller_env()){
1398
- n_sets <- length(initial_values )
1399
1401
1400
- initial_values_do_not_match_chains <- n_sets != n_chains
1401
- if (initial_values_do_not_match_chains ) {
1402
- cli :: cli_abort(
1403
- message = c(
1404
- " The number of provided initial values does not match chains" ,
1405
- " {n_sets} set{?s} of initial values were provided, but there \\
1402
+ if (! is.initials(initial_values ) && is.list(initial_values )) {
1403
+ # if the user provided a list of initial values, check elements and length
1404
+ are_initials <- vapply(initial_values , is.initials , FUN.VALUE = FALSE )
1405
+
1406
+ n_sets <- length(initial_values )
1407
+
1408
+ initial_values_do_not_match_chains <- n_sets != n_chains
1409
+ if (initial_values_do_not_match_chains && all(are_initials )) {
1410
+ cli :: cli_abort(
1411
+ message = c(
1412
+ " The number of provided initial values does not match chains" ,
1413
+ " {n_sets} set{?s} of initial values were provided, but there \\
1406
1414
{cli::qty(n_chains)} {?is only/are} {n_chains} \\
1407
1415
{cli::qty(n_chains)} chain{?s}"
1408
- ),
1409
- call = call
1410
- )
1416
+ ),
1417
+ call = call
1418
+ )
1419
+ }
1411
1420
}
1421
+
1412
1422
}
1413
1423
1414
1424
check_initial_values_correct_dim <- function (target_dims ,
@@ -1752,6 +1762,178 @@ check_for_errors <- function(res,
1752
1762
1753
1763
}
1754
1764
1765
+ check_dim_length <- function (dim ,
1766
+ call = rlang :: caller_env()){
1767
+
1768
+ ndim <- length(dim )
1769
+ ndim_gt2 <- ndim > 2
1770
+ if (ndim_gt2 ) {
1771
+ cli :: cli_abort(
1772
+ message = c(
1773
+ " {.arg dim} can either be a scalar or a vector of length 2" ,
1774
+ " However {.arg dim} has length {.val {ndim}}, and contains:" ,
1775
+ " {.val {dim}}"
1776
+ ),
1777
+ call = call
1778
+ )
1779
+ }
1780
+ }
1781
+
1782
+ check_is_distribution_node <- function (distribution ,
1783
+ call = rlang :: caller_env()){
1784
+ if (! is.distribution_node(distribution )) {
1785
+ cli :: cli_abort(
1786
+ message = c(" Invalid distribution" ),
1787
+ call = call
1788
+ )
1789
+ }
1790
+
1791
+ }
1792
+
1793
+ check_values_dim <- function (value ,
1794
+ dim ,
1795
+ call = rlang :: caller_env()){
1796
+ values_have_wrong_dim <- ! is.null(value ) && ! all.equal(dim(value ), dim )
1797
+ if (values_have_wrong_dim ) {
1798
+ cli :: cli_abort(
1799
+ message = " Values have the wrong dimension so cannot be used" ,
1800
+ call = call
1801
+ )
1802
+ }
1803
+
1804
+ }
1805
+
1806
+ # check they are all scalar
1807
+ check_dot_nodes_scalar <- function (dot_nodes ,
1808
+ call = rlang :: caller_env()){
1809
+ are_scalar <- vapply(dot_nodes , is_scalar , logical (1 ))
1810
+ if (! all(are_scalar )) {
1811
+ cli :: cli_abort(
1812
+ message = " {.fun joint} only accepts probability distributions over \\
1813
+ scalars" ,
1814
+ call = call
1815
+ )
1816
+ }
1817
+
1818
+ }
1819
+
1820
+ inform_if_one_set_of_initials <- function (initial_values ,
1821
+ n_chains ,
1822
+ call = rlang :: caller_env()){
1823
+
1824
+ is_blank <- identical(initial_values , initials())
1825
+
1826
+ one_set_of_initials <- ! is_blank & n_chains > 1
1827
+ if (one_set_of_initials ) {
1828
+ cli :: cli_inform(
1829
+ message = " Only one set of initial values was provided, and was used \\
1830
+ for all chains"
1831
+ )
1832
+ }
1833
+ }
1834
+
1835
+ # the user might pass greta arrays with groups of nodes that are unconnected
1836
+ # to one another. Need to check there are densities in each graph
1837
+ check_subgraphs <- function (dag ,
1838
+ call = rlang :: caller_env()){
1839
+ # get and check the types
1840
+ types <- dag $ node_types
1841
+
1842
+ # the user might pass greta arrays with groups of nodes that are unconnected
1843
+ # to one another. Need to check there are densities in each graph
1844
+
1845
+ # so find the subgraph to which each node belongs
1846
+ graph_id <- dag $ subgraph_membership()
1847
+
1848
+ graphs <- unique(graph_id )
1849
+ n_graphs <- length(graphs )
1850
+
1851
+ # separate messages to avoid the subgraphs issue for beginners
1852
+
1853
+ if (n_graphs == 1 ) {
1854
+ density_message <- cli :: format_error(
1855
+ c(
1856
+ " none of the {.cls greta_array}s in the model are associated with a \\
1857
+ probability density, so a model cannot be defined"
1858
+ )
1859
+ )
1860
+ variable_message <- cli :: format_error(
1861
+ c(
1862
+ " none of the {.cls greta_array}s in the model are unknown, so a model \\
1863
+ cannot be defined"
1864
+ )
1865
+ )
1866
+ } else {
1867
+ density_message <- cli :: format_error(
1868
+ c(
1869
+ " the model contains {n_graphs} disjoint graphs" ,
1870
+ " one or more of these sub-graphs does not contain any \\
1871
+ {.cls greta_array}s that are associated with a probability density, \\
1872
+ so a model cannot be defined"
1873
+ )
1874
+ )
1875
+ variable_message <- cli :: format_error(
1876
+ c(
1877
+ " the model contains {n_graphs} disjoint graphs" ,
1878
+ " one or more of these sub-graphs does not contain any \\
1879
+ {.cls greta_array}s that are unknown, so a model cannot be defined"
1880
+ )
1881
+ )
1882
+ }
1883
+
1884
+ for (graph in graphs ) {
1885
+ types_sub <- types [graph_id == graph ]
1886
+
1887
+ # check they have a density among them
1888
+ no_distribution <- ! (" distribution" %in% types_sub )
1889
+ if (no_distribution ) {
1890
+ cli :: cli_abort(
1891
+ message = density_message ,
1892
+ call = call
1893
+ )
1894
+ }
1895
+
1896
+ # check they have a variable node among them
1897
+ no_variable_node <- ! (" variable" %in% types_sub )
1898
+ if (no_variable_node ) {
1899
+ cli :: cli_abort(
1900
+ message = variable_message ,
1901
+ call = call
1902
+ )
1903
+ }
1904
+ }
1905
+
1906
+ }
1907
+
1908
+ check_has_representation <- function (repr ,
1909
+ name ,
1910
+ error ,
1911
+ call = rlang :: caller_env()){
1912
+ not_represented <- error && is.null(repr )
1913
+ if (not_represented ) {
1914
+ cli :: cli_abort(
1915
+ message = " {.cls greta_array} has no representation {.var {name}}" ,
1916
+ call = call
1917
+ )
1918
+ }
1919
+ }
1920
+
1921
+ check_is_greta_array <- function (x ,
1922
+ arg = rlang :: caller_arg(x ),
1923
+ call = rlang :: caller_env()){
1924
+ # only for greta arrays
1925
+ if (! is.greta_array(x )) {
1926
+ cli :: cli_abort(
1927
+ message = c(
1928
+ " {.arg {arg}} must be {.cls greta_array}" ,
1929
+ " Object was is {.cls {class(x)}}"
1930
+ ),
1931
+ call = call
1932
+ )
1933
+ }
1934
+ }
1935
+
1936
+
1755
1937
checks_module <- module(
1756
1938
check_tf_version ,
1757
1939
check_dims ,
0 commit comments