Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/marginalise #344

Draft
wants to merge 41 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
312347f
flesh out marginalisation interface
goldingn Jan 7, 2019
1e5e907
add tests for marginalisation error messages
goldingn Jan 7, 2019
3b56a6c
fix typo in marginalisation docs
goldingn Jan 7, 2019
0c766e6
flesh out internals of marginalisation
goldingn Jan 7, 2019
7e802cb
fill in last of guess at internals
goldingn Jan 8, 2019
ad98d72
be strict about float type in as_tf_function
goldingn Jan 8, 2019
9219a42
bugfixes to get marginalise apparently working
goldingn Jan 8, 2019
24fabd1
add marginalisation tests and realted fixes; fix marginalise() example
goldingn Jan 8, 2019
7b9d845
split marginalisers into separate helpfile
goldingn Jan 9, 2019
8689481
flesh out laplace_approximation marginaliser
goldingn Jan 9, 2019
125a989
get laplace approximation (almost) working
goldingn Jan 9, 2019
990e497
finish making laplace approximation run
goldingn Jan 10, 2019
1a8901c
add golden section search and get laplace approximation working
goldingn Jan 11, 2019
f086507
remove stepsize from laplace_approximation options (we optimise it now)
goldingn Jan 11, 2019
0744608
bugfixes and tests to make sure laplace approximation works
goldingn Jan 11, 2019
f6c89e8
Merge branch 'master' into feature/marginalise
goldingn Aug 24, 2019
8e75baf
Merge branch 'master' into feature/marginalise
goldingn Feb 14, 2020
d65ad25
fix lints
goldingn Feb 14, 2020
f8c88dc
fix lints
goldingn Feb 14, 2020
faa9cad
bring marginalisation branch up to date with greta
goldingn Feb 18, 2020
f0c68a2
Merge branch 'master' into feature/marginalise
goldingn Feb 18, 2020
ff34171
more stable computation of discrete marginalisation weights
goldingn Feb 18, 2020
2c6ff39
add posterior conjugate normal test; run short posterior checks every…
goldingn Feb 18, 2020
ddf9e9e
add failing test for laplace approximation
goldingn Feb 18, 2020
18049ea
prepare for non-diagonal hessians
goldingn Feb 18, 2020
e6ae1cb
define distributions as their tfp distribution objects
goldingn Feb 19, 2020
e0f7e23
tidy up tiling to batch size
goldingn Feb 19, 2020
3138eff
stop double-definition of dags
goldingn Feb 19, 2020
8346c8e
enable batch size to be found without access to dag
goldingn Feb 19, 2020
6a02b8d
let marginalisation return the marginalation parameters
goldingn Feb 19, 2020
fd532a5
fix lints
goldingn Feb 19, 2020
57047d8
fix golden section search
goldingn Feb 24, 2020
60ba1b2
stash function calls in gss()
goldingn Feb 24, 2020
5e3dd96
make laplace agree with GRaF
goldingn Feb 26, 2020
2bae9ab
get laplace variances working
goldingn Feb 26, 2020
ca8920b
make laplace work for MVN test case
goldingn Feb 26, 2020
5191f45
bugfix and add notes for univariate Laplace
goldingn Feb 26, 2020
f8b4755
add failing test for univariate normal laplace
goldingn Feb 28, 2020
fc8fe92
switch marginalisers to R6 classes
goldingn Feb 28, 2020
0780a31
pass the distribution to the marginaliser and check inputs early
goldingn Feb 28, 2020
fe40342
get univariate laplace running (with incorrect)
goldingn Feb 28, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Collate:
'internals.R'
'calculate.R'
'callbacks.R'
'marginalise.R'
'marginalisers.R'
'simulate.R'
'chol2symm.R'
Imports:
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ S3method(plot,greta_model)
S3method(print,greta_array)
S3method(print,greta_model)
S3method(print,initials)
S3method(print,marginaliser)
S3method(print,optimiser)
S3method(print,sampler)
S3method(print,summary.greta_array)
Expand Down Expand Up @@ -164,6 +165,7 @@ export(cov2cor)
export(diag)
export(dirichlet)
export(dirichlet_multinomial)
export(discrete_marginalisation)
export(distribution)
export(eigen)
export(exponential)
Expand All @@ -188,12 +190,14 @@ export(iprobit)
export(joint)
export(l_bfgs_b)
export(laplace)
export(laplace_approximation)
export(lkj_correlation)
export(log10.greta_array)
export(log1pe)
export(log2.greta_array)
export(logistic)
export(lognormal)
export(marginalise)
export(mcmc)
export(mixture)
export(model)
Expand Down
85 changes: 39 additions & 46 deletions R/dag_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,16 @@ dag_class <- R6Class(
# float type
on_graph = function(expr) {

# temporarily pass float type info to options, so it can be accessed by
# nodes on definition, without cluncky explicit passing
# temporarily pass float type and batch size info to options, so it can be
# accessed by nodes on definition, without clunky explicit passing
old_float_type <- options()$greta_tf_float
on.exit(options(greta_tf_float = old_float_type))
options(greta_tf_float = self$tf_float)
old_batch_size <- options()$greta_batch_size

on.exit(options(greta_tf_float = old_float_type,
greta_batch_size = old_batch_size))

options(greta_tf_float = self$tf_float,
greta_batch_size = self$tf_environment$batch_size)

with(self$tf_graph$as_default(), expr)
},
Expand Down Expand Up @@ -375,7 +380,7 @@ dag_class <- R6Class(
},

# define tensor for overall log density and gradients
define_joint_density = function() {
define_joint_density = function(adjusted = TRUE) {

tfe <- self$tf_environment

Expand All @@ -392,6 +397,12 @@ dag_class <- R6Class(
target_nodes,
SIMPLIFY = FALSE)

# assign the un-reduced densities, for use in marginalisation
names(densities) <- NULL
assign("component_densities",
densities,
envir = self$tf_environment)

# reduce_sum each of them (skipping the batch dimension)
self$on_graph(summed_densities <- lapply(densities, tf_sum, drop = TRUE))

Expand All @@ -404,43 +415,37 @@ dag_class <- R6Class(
joint_density,
envir = self$tf_environment)

# define adjusted joint density
if (adjusted) {

# get names of Jacobian adjustment tensors for all variable nodes
adj_names <- paste0(self$get_tf_names(types = "variable"), "_adj")
# get names of adjustment tensors for all variable nodes
adj_names <- paste0(self$get_tf_names(types = "variable"), "_adj")

# get TF density tensors for all distribution
adj <- lapply(adj_names, get, envir = self$tf_environment)
# get TF density tensors for all distribution
adj <- lapply(adj_names, get, envir = self$tf_environment)

# remove their names and sum them together (accounting for tfp bijectors
# sometimes returning a scalar tensor)
names(adj) <- NULL
adj <- match_batches(adj)
self$on_graph(total_adj <- tf$add_n(adj))
# remove their names and sum them together (accounting for tfp bijectors
# sometimes returning a scalar tensor)
adj <- match_batches(adj)

# assign overall density to environment
assign("joint_density_adj",
joint_density + total_adj,
envir = self$tf_environment)
# remove their names and sum them together
names(adj) <- NULL
self$on_graph(total_adj <- tf$add_n(adj))

# assign overall density to environment
assign("joint_density_adj",
joint_density + total_adj,
envir = self$tf_environment)

}

},

# evaluate the (truncation-corrected) density of a tfp distribution on its
# target tensor
evaluate_density = function(distribution_node, target_node) {

tfe <- self$tf_environment

parameter_nodes <- distribution_node$parameters

# get the tensorflow objects for these
distrib_constructor <- self$get_tf_object(distribution_node)
tfp_distribution <- self$get_tf_object(distribution_node)
tf_target <- self$get_tf_object(target_node)
tf_parameter_list <- lapply(parameter_nodes, self$get_tf_object)

# execute the distribution constructor functions to return a tfp
# distribution object
tfp_distribution <- distrib_constructor(tf_parameter_list, dag = self)

self$tf_evaluate_density(tfp_distribution,
tf_target,
Expand Down Expand Up @@ -520,6 +525,9 @@ dag_class <- R6Class(
for (name in data_names)
tfe[[name]] <- tfe_old[[name]]

# copy the batch size over
tfe$batch_size <- tfe_old$batch_size

# put the free state in the environment, and build out the tf graph
tfe$free_state <- free_state
self$define_tf_body()
Expand Down Expand Up @@ -775,25 +783,10 @@ dag_class <- R6Class(

},

# get the tfp distribution object for a distribution node
get_tfp_distribution = function(distrib_node) {

# build the tfp distribution object for the distribution, and use it
# to get the tensor for the sample
distrib_constructor <- self$get_tf_object(distrib_node)
parameter_nodes <- distrib_node$parameters
tf_parameter_list <- lapply(parameter_nodes, self$get_tf_object)

# execute the distribution constructor functions to return a tfp
# distribution object
tfp_distribution <- distrib_constructor(tf_parameter_list, dag = self)

},

# try to draw a random sample from a distribution node
draw_sample = function(distribution_node) {

tfp_distribution <- self$get_tfp_distribution(distribution_node)
tfp_distribution <- self$get_tf_object(distribution_node)

sample <- tfp_distribution$sample

Expand Down
1 change: 0 additions & 1 deletion R/inference.R
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,6 @@ run_samplers <- function(samplers,
thin <- as.integer(thin)

dag <- samplers[[1]]$model$dag
chains <- samplers[[1]]$n_chains
n_cores <- check_n_cores(n_cores, length(samplers), plan_is)
float_type <- dag$tf_float

Expand Down
8 changes: 4 additions & 4 deletions R/inference_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ inference <- R6Class(
free_parameters <- model$dag$example_parameters(free = TRUE)
free_parameters <- unlist_tf(free_parameters)
self$n_free <- length(free_parameters)
self$set_initial_values(initial_values)
self$n_traced <- length(model$dag$trace_values(self$free_state))
self$seed <- seed

},
Expand Down Expand Up @@ -277,8 +275,6 @@ sampler <- R6Class(
parameters = parameters,
seed = seed)

self$n_chains <- nrow(self$free_state)

# duplicate diag_sd if needed
n_diag <- length(self$parameters$diag_sd)
n_parameters <- self$n_free
Expand All @@ -289,6 +285,8 @@ sampler <- R6Class(

# define the draws tensor on the tf graph
self$define_tf_draws()
self$set_initial_values(initial_values)
self$n_chains <- nrow(self$free_state)

},

Expand Down Expand Up @@ -1025,6 +1023,8 @@ optimiser <- R6Class(

self$create_optimiser_objective()
self$create_tf_minimiser()
self$set_initial_values(initial_values)
self$n_traced <- length(model$dag$trace_values(self$free_state))

},

Expand Down
10 changes: 5 additions & 5 deletions R/joint.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,15 @@ joint_distribution <- R6Class(

tf_distrib = function(parameters, dag) {

# get information from the *nodes* for component distributions, not the tf
# objects passed in here
# get tfp distributions
tfp_distributions <- parameters
names(tfp_distributions) <- NULL

# get tfp distributions, truncations, & bounds of component distributions
# get information on truncations, & bounds of component distributions from
# the *nodes* for component distributions
distribution_nodes <- self$parameters
truncations <- lapply(distribution_nodes, member, "truncation")
bounds <- lapply(distribution_nodes, member, "bounds")
tfp_distributions <- lapply(distribution_nodes, dag$get_tfp_distribution)
names(tfp_distributions) <- NULL

log_prob <- function(x) {

Expand Down
Loading