Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

initial_values seem to fail with LKJ priors #440

Open
hrlai opened this issue Sep 25, 2021 · 4 comments
Open

initial_values seem to fail with LKJ priors #440

hrlai opened this issue Sep 25, 2021 · 4 comments
Milestone

Comments

@hrlai
Copy link

hrlai commented Sep 25, 2021

Hi! Whenever the lkj_correlation() prior is part of a model, the initial_values argument seems to always cause mcmc() to fail.

Reproducible code from the greta example models page:

# model matrix
modmat <- model.matrix(~ Sepal.Width, iris) 
# index of species
jj <- as.numeric(iris$Species)

M <- ncol(modmat) # number of varying coefficients
N <- max(jj) # number of species

# prior on the standard deviation of the varying coefficient
tau <- exponential(0.5, dim = M)

# prior on the correlation between the varying coefficient
Omega <- lkj_correlation(3, M)

# optimization of the varying coefficient sampling through
# cholesky factorization and whitening
Omega_U <- chol(Omega)
Sigma_U <- sweep(Omega_U, 2, tau, "*")
z <- normal(0, 1, dim = c(N, M)) 
ab <- z %*% Sigma_U # equivalent to: ab ~ multi_normal(0, Sigma_U)

# the linear predictor
mu <- rowSums(ab[jj,] * modmat)

# the residual variance
sigma_e <- cauchy(0, 3, truncation = c(0, Inf))

#model
y <- iris$Sepal.Length
distribution(y) <- normal(mu, sigma_e)
m <- model(ab, sigma_e)

draws <- mcmc(m, chains = 4, initial_values = initials(sigma_e = 1))

On my computer this throws an error:

Error in py_call_impl(callable, dots$args, dots$keywords) : 
  ValueError: Cannot feed value of shape (1, 13) for Tensor 'Placeholder:0', which has shape '(?, 10)'

Detailed traceback:
  File "/home/hrlai/.local/share/r-miniconda/envs/r-reticulate/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 950, in run
    run_metadata_ptr)
  File "/home/hrlai/.local/share/r-miniconda/envs/r-reticulate/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1149, in _run
    str(subfeed_t.get_shape())))

On various datasets, the Cannot feed value of shape (1, 13) for Tensor 'Placeholder:0', which has shape '(?, 10)' message always have the first number 13 larger than the second number 10. And the difference between them seems to be N --- this led me to suspect the LKJ prior in the first place.

I tried to remove LKJ from the model and initial values work again. (After restarting R session) This is reproducible via:

# model matrix
modmat <- model.matrix(~ Sepal.Width, iris) 
# index of species
jj <- as.numeric(iris$Species)

M <- ncol(modmat) # number of varying coefficients
N <- max(jj) # number of species

# prior on the standard deviation of the varying coefficient
tau <- exponential(0.5, dim = M)

Sigma_U <- zeros(dim = c(M, M))
diag(Sigma_U) <- tau
z <- normal(0, 1, dim = c(N, M)) 
ab <- z %*% Sigma_U # equivalent to: ab ~ multi_normal(0, Sigma_U)

# the linear predictor
mu <- rowSums(ab[jj,] * modmat)

# the residual variance
sigma_e <- cauchy(0, 3, truncation = c(0, Inf))

#model
y <- iris$Sepal.Length
distribution(y) <- normal(mu, sigma_e)
m <- model(ab, sigma_e)

draws <- mcmc(m, chains = 4, initial_values = initials(sigma_e = 1))

I'd really like to keep the LKJ prior as well as being able to specify initial values to help chain convergence. Looking forward to hear your idea!

@njtierney
Copy link
Collaborator

Thanks for posting! I confirm that I can get the same error:

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

# model matrix
modmat <- model.matrix(~ Sepal.Width, iris) 
# index of species
jj <- as.numeric(iris$Species)

M <- ncol(modmat) # number of varying coefficients
N <- max(jj) # number of species

# prior on the standard deviation of the varying coefficient
tau <- exponential(0.5, dim = M)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✓ Initialising python and checking dependencies ... done!
#> 

# prior on the correlation between the varying coefficient
Omega <- lkj_correlation(3, M)

# optimization of the varying coefficient sampling through
# cholesky factorization and whitening
Omega_U <- chol(Omega)
Sigma_U <- sweep(Omega_U, 2, tau, "*")
z <- normal(0, 1, dim = c(N, M)) 
ab <- z %*% Sigma_U # equivalent to: ab ~ multi_normal(0, Sigma_U)

# the linear predictor
mu <- rowSums(ab[jj,] * modmat)

# the residual variance
sigma_e <- cauchy(0, 3, truncation = c(0, Inf))

#model
y <- iris$Sepal.Length
distribution(y) <- normal(mu, sigma_e)
m <- model(ab, sigma_e)

draws <- mcmc(m, chains = 4, initial_values = initials(sigma_e = 1))
#> only one set of initial values was provided, and was used for all chains
#> Error in py_call_impl(callable, dots$args, dots$keywords): ValueError: Cannot feed value of shape (1, 13) for Tensor 'Placeholder:0', which has shape '(?, 10)'
#> 
#> Detailed traceback:
#>   File "/Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 950, in run
#>     run_metadata_ptr)
#>   File "/Users/njtierney/Library/r-miniconda/envs/greta-env/lib/python3.7/site-packages/tensorflow/python/client/session.py", line 1149, in _run
#>     str(subfeed_t.get_shape())))

Created on 2021-09-28 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.1.0 (2021-05-18)
#>  os       macOS Big Sur 10.16         
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_AU.UTF-8                 
#>  ctype    en_AU.UTF-8                 
#>  tz       Australia/Perth             
#>  date     2021-09-28                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date       lib source        
#>  backports     1.2.1      2020-12-09 [1] CRAN (R 4.1.0)
#>  base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.1.0)
#>  callr         3.7.0      2021-04-20 [1] CRAN (R 4.1.0)
#>  cli           3.0.1      2021-07-17 [1] CRAN (R 4.1.0)
#>  coda          0.19-4     2020-09-30 [1] CRAN (R 4.1.0)
#>  codetools     0.2-18     2020-11-04 [1] CRAN (R 4.1.0)
#>  crayon        1.4.1      2021-02-08 [1] CRAN (R 4.1.0)
#>  digest        0.6.27     2020-10-24 [1] CRAN (R 4.1.0)
#>  ellipsis      0.3.2      2021-04-29 [1] CRAN (R 4.1.0)
#>  evaluate      0.14       2019-05-28 [1] CRAN (R 4.1.0)
#>  fansi         0.5.0      2021-05-25 [1] CRAN (R 4.1.0)
#>  fs            1.5.0      2020-07-31 [1] CRAN (R 4.1.0)
#>  future        1.22.1     2021-08-25 [1] CRAN (R 4.1.0)
#>  globals       0.14.0     2020-11-22 [1] CRAN (R 4.1.0)
#>  glue          1.4.2      2020-08-27 [1] CRAN (R 4.1.0)
#>  greta       * 0.3.1.9012 2021-09-23 [1] local         
#>  here          1.0.1      2020-12-13 [1] CRAN (R 4.1.0)
#>  highr         0.9        2021-04-16 [1] CRAN (R 4.1.0)
#>  hms           1.1.0      2021-05-17 [1] CRAN (R 4.1.0)
#>  htmltools     0.5.1.1    2021-01-22 [1] CRAN (R 4.1.0)
#>  jsonlite      1.7.2      2020-12-09 [1] CRAN (R 4.1.0)
#>  knitr         1.33       2021-04-24 [1] CRAN (R 4.1.0)
#>  lattice       0.20-44    2021-05-02 [1] CRAN (R 4.1.0)
#>  lifecycle     1.0.0      2021-02-15 [1] CRAN (R 4.1.0)
#>  listenv       0.8.0      2019-12-05 [1] CRAN (R 4.1.0)
#>  magrittr      2.0.1      2020-11-17 [1] CRAN (R 4.1.0)
#>  Matrix        1.3-4      2021-06-01 [1] CRAN (R 4.1.0)
#>  parallelly    1.28.1     2021-09-09 [1] CRAN (R 4.1.0)
#>  pillar        1.6.2      2021-07-29 [1] CRAN (R 4.1.0)
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.1.0)
#>  png           0.1-7      2013-12-03 [1] CRAN (R 4.1.0)
#>  prettyunits   1.1.1      2020-01-24 [1] CRAN (R 4.1.0)
#>  processx      3.5.2      2021-04-30 [1] CRAN (R 4.1.0)
#>  progress      1.2.2      2019-05-16 [1] CRAN (R 4.1.0)
#>  ps            1.6.0      2021-02-28 [1] CRAN (R 4.1.0)
#>  purrr         0.3.4      2020-04-17 [1] CRAN (R 4.1.0)
#>  R6            2.5.1      2021-08-19 [1] CRAN (R 4.1.0)
#>  Rcpp          1.0.7      2021-07-07 [1] CRAN (R 4.1.0)
#>  reprex        2.0.1      2021-08-05 [1] CRAN (R 4.1.0)
#>  reticulate    1.22       2021-09-17 [1] CRAN (R 4.1.0)
#>  rlang         0.4.11     2021-04-30 [1] CRAN (R 4.1.0)
#>  rmarkdown     2.9        2021-06-15 [1] CRAN (R 4.1.0)
#>  rprojroot     2.0.2      2020-11-15 [1] CRAN (R 4.1.0)
#>  rstudioapi    0.13       2020-11-12 [1] CRAN (R 4.1.0)
#>  sessioninfo   1.1.1      2018-11-05 [1] CRAN (R 4.1.0)
#>  stringi       1.7.4      2021-08-25 [1] CRAN (R 4.1.0)
#>  stringr       1.4.0      2019-02-10 [1] CRAN (R 4.1.0)
#>  styler        1.4.1      2021-03-30 [1] CRAN (R 4.1.0)
#>  tensorflow    2.6.0      2021-08-19 [1] CRAN (R 4.1.0)
#>  tfruns        1.5.0      2021-02-26 [1] CRAN (R 4.1.0)
#>  tibble        3.1.4      2021-08-25 [1] CRAN (R 4.1.0)
#>  utf8          1.2.2      2021-07-24 [1] CRAN (R 4.1.0)
#>  vctrs         0.3.8      2021-04-29 [1] CRAN (R 4.1.0)
#>  whisker       0.4        2019-08-28 [1] CRAN (R 4.1.0)
#>  withr         2.4.2      2021-04-18 [1] CRAN (R 4.1.0)
#>  xfun          0.24       2021-06-15 [1] CRAN (R 4.1.0)
#>  yaml          2.2.1      2020-02-01 [1] CRAN (R 4.1.0)
#> 
#> [1] /Library/Frameworks/R.framework/Versions/4.1/Resources/library

I'll have a think about this, we're a bit pushed at the moment for time, just wanted to give you a heads up we might not get to this as soon as we would like to help you :)

@njtierney njtierney added this to the 0.3.2 milestone Nov 1, 2021
@njtierney njtierney modified the milestones: 0.4.0, 0.4.1 Nov 26, 2021
@hrlai
Copy link
Author

hrlai commented Dec 14, 2021

Just browsing #314 and saw some discussion on dimensions and placeholders, just noting it down in case they are related.

@njtierney njtierney added Up Next and removed Up Next labels Mar 28, 2022
@njtierney njtierney modified the milestones: 0.5.0, 0.6.0 Feb 8, 2023
@hrlai
Copy link
Author

hrlai commented Sep 26, 2023

I was recently trying to do prior predictive check and discovered that calculate also doesn't work on a greta array with chol operation...

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply

M <- 3

Omega <- lkj_correlation(3, M)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
Omega_U <- chol(Omega)

calculate(Omega, nsim = 1)    # works
#> $Omega
#> , , 1
#> 
#>      [,1]       [,2]      [,3]
#> [1,]    1 -0.1751263 0.1297896
#> 
#> , , 2
#> 
#>            [,1] [,2]      [,3]
#> [1,] -0.1751263    1 0.1951698
#> 
#> , , 3
#> 
#>           [,1]      [,2] [,3]
#> [1,] 0.1297896 0.1951698    1
calculate(Omega_U, nsim = 1)  # fails
#> You must feed a value for placeholder tensor 'Placeholder_1' with dtype double and shape [1,3,3]
#>   [[node Placeholder_1 (defined at /ops/array_ops.py:2143) ]]
#> 
#> Original stack trace for 'Placeholder_1':
#>   File "/ops/array_ops.py", line 2143, in placeholder
#>     return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
#>   File "/ops/gen_array_ops.py", line 6262, in placeholder
#>     "Placeholder", dtype=dtype, shape=shape, name=name)
#>   File "/framework/op_def_library.py", line 788, in _apply_op_helper
#>     op_def=op_def)
#>   File "/util/deprecation.py", line 507, in new_func
#>     return func(*args, **kwargs)
#>   File "/framework/ops.py", line 3616, in create_op
#>     op_def=op_def)
#>   File "/framework/ops.py", line 2005, in __init__
#>     self._traceback = tf_stack.extract_stack()

Created on 2023-09-26 with reprex v2.0.2

In case they are related, I'm linking #585 here.

@hrlai
Copy link
Author

hrlai commented Nov 30, 2024

Hi @njtierney , just want to mark the second issue (calculate issue with chol output) being resolved together with #747 , but the initial issue still persist (initial values for LKJ). No hurry though, just noting down.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants