Skip to content

Commit 80192ed

Browse files
authored
Merge pull request #734 from njtierney/tweaking-wishart-fix-from-729
Tweaking wishart fix from 729
2 parents cb14e95 + 118ee26 commit 80192ed

21 files changed

+462
-340
lines changed

DESCRIPTION

+1-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ Imports:
4545
R6,
4646
reticulate (>= 1.19.0),
4747
rlang,
48-
tensorflow (>= 2.8.0),
48+
tensorflow (== 2.16.0),
4949
tools,
5050
utils,
5151
whisker,

NEWS.md

+35-21
Original file line numberDiff line numberDiff line change
@@ -28,51 +28,65 @@ The following optimisers are removed, as they are no longer supported by Tensorf
2828

2929
## Installation revamp
3030

31-
This release provides a few improvements to installation in greta. It should now provide more information about installation progress, and be more robust. The intention is, it should _just work_, and if it doesn't fail gracefully with some useful advice on problem solving.
31+
This release provides a few improvements to installation in greta. It should now provide more information about installation progress, and be more robust. The intention is, it should _just work_, and if it doesn't, it should fail gracefully with some useful advice on problem solving.
3232

33-
* Added option to restart R + run `library(greta)` after installation (#523)
34-
* Added installation deps object, `greta_deps_sepc()` to help simplify specifying package versions (#664)
35-
* removed `method` and `conda` arguments from `install_greta_deps()` as they
33+
* Added option to restart R + run `library(greta)` after installation (#523).
34+
* Added installation deps object, `greta_deps_sepc()` to help simplify specifying package versions (#664).
35+
* Removed `method` and `conda` arguments from `install_greta_deps()` as they
3636
were not used.
37-
* removed `manual` argument in `install_greta_deps()`.
38-
* added default 5 minute timer to installation processes
39-
* Added `greta_deps_receipt()` to list the current main python packages installed. (#668)
40-
* Added checking suite to ensure you are using valid versions of TF, TFP, and Python(#666)
37+
* Removed `manual` argument in `install_greta_deps()`.
38+
* Added default 5 minute timer to installation processes.
39+
* Added `greta_deps_receipt()` to list the current main python packages installed (#668).
40+
* Added checking suite to ensure you are using valid versions of TF, TFP, and Python(#666).
4141
* Added data `greta_deps_tf_tfp` (#666), which contains valid versions combinations of TF, TFP, and Python.
42-
* remove `greta_nodes_install/conda_*()` options as #493 makes them defunct.
43-
* Added option to write to a single logfile with `greta_set_install_logfile()`, and `write_greta_install_log()`, and `open_greta_install_log()` (#493)
44-
* Added `destroy_greta_deps()` function to remove miniconda and python conda environment
45-
* Improved `write_greta_install_log()` and `open_greta_install_log()` to use `tools::R_user_dir()` to always write to a file location. `open_greta_install_log()` will open one found from an environment variable or go to the default location. (#703)
42+
* Remove `greta_nodes_install/conda_*()` options as #493 makes them defunct.
43+
* Added option to write to a single logfile with `greta_set_install_logfile()`, and `write_greta_install_log()`, and `open_greta_install_log()` (#493).
44+
* Added `destroy_greta_deps()` function to remove miniconda and python conda environment.
45+
* Improved `write_greta_install_log()` and `open_greta_install_log()` to use `tools::R_user_dir()` to always write to a file location. `open_greta_install_log()` will open one found from an environment variable or go to the default location (#703).
4646

4747
## New Print methods
4848

49-
* New print method for `greta_mcmc_list`. This means MCMC output will be shorter and more informative. (#644)
50-
* greta arrays now have a print method that stops them from printing too many rows into the console. Similar to MCMC print method, you can control the print output with the `n` argument: `print(object, n = <elements to print>)`. (#644)
49+
* New print method for `greta_mcmc_list`. This means MCMC output will be shorter and more informative (#644).
50+
* greta arrays now have a print method that stops them from printing too many rows into the console. Similar to MCMC print method, you can control the print output with the `n` argument: `print(object, n = <elements to print>)` (#644).
5151

5252
## Minor
5353

54-
* `greta_sitrep()` now checks for installations of Python, TF, and TFP
54+
* `greta_sitrep()` now checks for installations of Python, TF, and TFP.
5555
* Slice sampler no longer needs precision = "single" to work.
5656
* greta now depends on R 4.1.0, which was released May 2021, over 3 years ago.
57-
* export `is.greta_array()` and `is.greta_mcmc_list()`
58-
* `restart` argument for `install_greta_deps()` and `reinstall_greta_deps()` to automatically restart R (#523)
57+
* export `is.greta_array()` and `is.greta_mcmc_list()`.
58+
* `restart` argument for `install_greta_deps()` and `reinstall_greta_deps()` to automatically restart R (#523).
5959

6060
## Internals
6161

6262
* Internally we are replacing most of the error handling code as separate
6363
`check_*` functions.
6464
* Implemented `cli::cli_abort/warn/inform()` in place of `cli::format_error/warning/message()` + `stop/warning/message(msg, call. = FALSE)` pattern.
6565
* Uses legacy optimizer internally (Use `tf$keras$optimizers$legacy$METHOD` over `tf$keras$optimizers$METHOD`). No user impact expected.
66-
* Update photo of Grete Hermann (#598)
67-
* Use `%||%` internally to replace the pattern: `if (is.null(x)) x <- thing` with `x <- x %||% thing`. (#630)
66+
* Update photo of Grete Hermann (#598).
67+
* Use `%||%` internally to replace the pattern: `if (is.null(x)) x <- thing` with `x <- x %||% thing` (#630).
6868
* Add more explaining variables - replace `if (thing & thing & what == this)` with `if (explanation_of_thing)`.
69-
* Refactored repeated uses of `vapply` into functions (#377, #658)
69+
* Refactored repeated uses of `vapply` into functions (#377, #658).
7070
* Add internal data files `.deps_tf` and `.deps_tfp` to track dependencies of TF and TFP. Related to #666.
7171

72+
- Posterior density checks (#720):
73+
- Don't run Geweke on CI as it takes 30 minutes to run.
74+
- Add thinning to Geweke tests.
75+
- Fix broken geweke tests from TF1-->TF2 change.
76+
- Increase the number of effective samples for check_samples for lkj distribution
77+
- Add more checks to posterior to run on CI/on each test of greta
78+
7279
## Bug fixes
7380

7481
* Fix bug where matrix multiply had dimension error before coercing to greta array. (#464)
75-
*
82+
- Fixes for Wishart and LKJ Correlation distributions (#729 #733 #734):
83+
- Add bijection density to choleskied distributions.
84+
- Note about some issues with LKJ and our normalisation constant for the density.
85+
- Removed our custom `forward_log_det_jacobian()` function from `tf_correlation_cholesky_bijector()` (used in `lkj_correlation()`). Previously, it did not work with unknown dimensions, but it now works with them.
86+
- Ensure wishart uses sigma_chol in scale_tril
87+
- Wishart uses `tf$matmul(chol_draws, chol_draws, adjoint_b = TRUE)` instead of `tf_chol2symm(chol_draws)`.
88+
- Test log prob function returns valid numeric numbers.
89+
- Addresses issue with log prob returning NaNs--replace `FillTriangular` with `FillScaleTriL` and apply Chaining to first transpose input.
7690

7791
# greta 0.4.5
7892

R/checkers.R

+13
Original file line numberDiff line numberDiff line change
@@ -1943,6 +1943,19 @@ check_has_representation <- function(repr,
19431943
}
19441944
}
19451945

1946+
check_has_anti_representation <- function(repr,
1947+
name,
1948+
error,
1949+
call = rlang::caller_env()){
1950+
not_anti_represented <- error && is.null(repr)
1951+
if (not_anti_represented) {
1952+
cli::cli_abort(
1953+
message = "{.cls greta_array} has no anti representation {.var {name}}",
1954+
call = call
1955+
)
1956+
}
1957+
}
1958+
19461959
check_is_greta_array <- function(x,
19471960
arg = rlang::caller_arg(x),
19481961
call = rlang::caller_env()){

R/greta_array_class.R

+17
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,23 @@ has_representation <- function(x, name) {
257257
!is.null(repr)
258258
}
259259

260+
anti_representation <- function(x, name, error = TRUE) {
261+
if (is.greta_array(x)) {
262+
x_node <- get_node(x)
263+
} else {
264+
x_node <- x
265+
}
266+
repr <- x_node$anti_representations[[name]]
267+
check_has_anti_representation(repr, name, error)
268+
repr
269+
}
270+
271+
272+
has_anti_representation <- function(x, name){
273+
repr <- anti_representation(x, name, error = FALSE)
274+
!is.null(repr)
275+
}
276+
260277
# helper function to make a copy of the greta array & tensor
261278
copy_representation <- function(x, name) {
262279
repr <- representation(x, name)

R/greta_stash.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -28,5 +28,6 @@ greta_stash$tf_num_error <- greta_note_msg
2828
#' greta_notes_tf_error()
2929
#' }
3030
greta_notes_tf_num_error <- function() {
31-
cat(greta_stash$tf_num_error)
31+
# wrap in paste0 to remove list properties
32+
cat(paste0(greta_stash$tf_num_error))
3233
}

R/inference_class.R

+3
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,9 @@ sampler <- R6Class(
744744
)
745745

746746
# get trace of free state and drop the null dimension
747+
if (is.null(batch_results$all_states)){
748+
browser()
749+
}
747750
free_state_draws <- as.array(batch_results$all_states)
748751

749752
# if there is one sample at a time, and it's rejected, conversion from

R/node_class.R

+43
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,12 @@ node <- R6Class(
55
unique_name = "",
66
parents = list(),
77
children = list(),
8+
# named greta arrays giving different representations of the greta array
9+
# represented by this node that have already been calculated, to be used for
10+
# computational speedups or numerical stability. E.g. a logarithm or a
11+
# cholesky factor
812
representations = list(),
13+
anti_representations = list(),
914
.value = array(NA),
1015
dim = NA,
1116
distribution = NULL,
@@ -82,6 +87,19 @@ node <- R6Class(
8287
parents <- c(parents, list(self$distribution))
8388
}
8489

90+
if (mode == "sampling" & has_representation(self, "cholesky")){
91+
# remove cholesky representation node from parents
92+
parent_names <- extract_unique_names(parents)
93+
antirep_name <- get_node(self$representations$cholesky)$unique_name
94+
parent_names_keep <- setdiff(parent_names, antirep_name)
95+
parents <- parents[match(parent_names_keep, parent_names)]
96+
}
97+
98+
if (mode == "sampling" & has_anti_representation(self, "chol2symm")){
99+
chol2symm_node <- get_node(self$anti_representations$chol2symm)
100+
parents <- c(parents, list(chol2symm_node))
101+
}
102+
85103
parents
86104
},
87105
add_child = function(node) {
@@ -273,6 +291,31 @@ node <- R6Class(
273291
}
274292

275293
label
294+
},
295+
make_antirepresentations = function(representations){
296+
mapply(
297+
FUN = self$make_one_anti_representation,
298+
representations,
299+
names(representations)
300+
)
301+
},
302+
make_one_anti_representation = function(ga, name){
303+
node <- get_node(ga)
304+
anti_name <- self$find_anti_name(name)
305+
node$anti_representations[[anti_name]] <- as.greta_array(self)
306+
node
307+
},
308+
find_anti_name = function(name){
309+
switch(name,
310+
cholesky = "chol2symm",
311+
chol2symm = "chol",
312+
exp = "log",
313+
log = "exp",
314+
probit = "iprobit",
315+
iprobit = "probit",
316+
logit = "ilogit",
317+
ilogit = "logit"
318+
)
276319
}
277320
)
278321
)

R/node_types.R

+34-24
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,6 @@ operation_node <- R6Class(
8484
operation_args = NA,
8585
arguments = list(),
8686
tf_function_env = NA,
87-
88-
# named greta arrays giving different representations of the greta array
89-
# represented by this node that have already been calculated, to be used for
90-
# computational speedups or numerical stability. E.g. a logarithm or a
91-
# cholesky factor
92-
representations = list(),
9387
initialize = function(operation,
9488
...,
9589
dim = NULL,
@@ -127,6 +121,7 @@ operation_node <- R6Class(
127121
self$operation <- tf_operation
128122
self$operation_args <- operation_args
129123
self$representations <- representations
124+
self$make_antirepresentations(representations)
130125
self$tf_function_env <- tf_function_env
131126

132127
# assign empty value of the right dimension, or the values passed via the
@@ -158,23 +153,23 @@ operation_node <- R6Class(
158153
# browser()
159154
tensor <- dag$draw_sample(self$distribution)
160155

161-
if (has_representation(self, "cholesky")) {
156+
# if (has_representation(self, "cholesky")) {
162157
# browser()
163-
cholesky_tensor <- tf_chol(tensor)
164-
# cholesky_tf_name <- dag$tf_name(self$representation$cholesky)
165-
cholesky_node <- get_node(representation(self, "cholesky"))
166-
cholesky_tf_name <- dag$tf_name(cholesky_node)
167-
assign(cholesky_tf_name, cholesky_tensor, envir = tfe)
158+
# cholesky_tensor <- tf_chol(tensor)
159+
# # cholesky_tf_name <- dag$tf_name(self$representation$cholesky)
160+
# cholesky_node <- get_node(representation(self, "cholesky"))
161+
# cholesky_tf_name <- dag$tf_name(cholesky_node)
162+
# assign(cholesky_tf_name, cholesky_tensor, envir = tfe)
168163
## TF1/2
169164
## This assignment I think is supposed to be passed down to later on
170165
## in the script, as `cholesky_tf_name` gets overwritten
171166
# cholesky_tf_name <- dag$tf_name(self)
172167
# tf_name <- cholesky_tf_name
173168
# tensor <- cholesky_tensor
174-
cholesky_tensor <- tf_chol(tensor)
175-
cholesky_tf_name <- dag$tf_name(self$representation$cholesky)
176-
assign(cholesky_tf_name, cholesky_tensor, envir = dag$tf_environment)
177-
}
169+
# cholesky_tensor <- tf_chol(tensor)
170+
# cholesky_tf_name <- dag$tf_name(self$representation$cholesky)
171+
# assign(cholesky_tf_name, cholesky_tensor, envir = dag$tf_environment)
172+
# }
178173
}
179174

180175
if (mode == "forward") {
@@ -292,14 +287,29 @@ variable_node <- R6Class(
292287
distrib_node <- self$distribution
293288

294289
if (is.null(distrib_node)) {
295-
296-
# if the variable has no distribution create a placeholder instead
297-
# (the value must be passed in via values when using simulate)
298-
shape <- to_shape(c(1, self$dim))
299-
# TF1/2 check
300-
# need to change the placeholder approach here.
301-
# NT: can we change this to be a tensor of the right shape with 1s?
302-
tensor <- tensorflow::as_tensor(1L, shape = shape, dtype = tf_float())
290+
# does it have an anti-representation where it is the cholesky?
291+
# the antirepresentation of cholesky is chol2symm
292+
# if it does, we will take the antirepresentation and get it to `tf` itself
293+
# then we need to get the tf_name
294+
chol2symm_ga <- self$anti_representations$chol2symm
295+
chol2symm_existing <- !is.null(chol2symm_ga)
296+
if (chol2symm_existing) {
297+
chol2symm_node <- get_node(chol2symm_ga)
298+
chol2symm_name <- dag$tf_name(chol2symm_node)
299+
chol2symm_tensor <- get(chol2symm_name, envir = dag$tf_environment)
300+
tensor <- tf_chol(chol2symm_tensor)
301+
}
302+
303+
# chol2symm_ga$define_tf(dag)
304+
# } else {
305+
#
306+
# # if the variable has no distribution create a placeholder instead
307+
# # (the value must be passed in via values when using simulate)
308+
# shape <- to_shape(c(1, self$dim))
309+
# # TF1/2 check
310+
# # need to change the placeholder approach here.
311+
# # NT: can we change this to be a tensor of the right shape with 1s?
312+
# tensor <- tensorflow::as_tensor(1L, shape = shape, dtype = tf_float())
303313
} else {
304314
tensor <- dag$draw_sample(self$distribution)
305315
}

0 commit comments

Comments
 (0)