Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use adaptive hmc #779

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft

Conversation

njtierney
Copy link
Collaborator

Resolves #765

Merge branch 'add-snaper-hmc' into adaptive-hmc-v2-i765

# Conflicts:
#	R/inference_class.R
#	tests/testthat/test_posteriors_geweke.R
…rror in trace_list_batches[[1]] : subscript out of bounds` - need to investigate further
@njtierney
Copy link
Collaborator Author

library(greta)
#> 
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#> 
#>     binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#> 
#>     %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#>     eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#>     tapply
library(tictoc)
x <- normal(0, c(0.1, 1, 10, 100))
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#> 
m <- model(x)
tic()
draws_hmc <- mcmc(
  model = m,
  sampler = hmc()
  )
#> running 4 chains simultaneously on up to 8 CPU cores
#> 
#>     warmup                                           0/1000 | eta:  ?s              warmup ==                                       50/1000 | eta: 10s              warmup ====                                    100/1000 | eta:  5s              warmup ======                                  150/1000 | eta:  4s              warmup ========                                200/1000 | eta:  3s              warmup ==========                              250/1000 | eta:  2s              warmup ===========                             300/1000 | eta:  2s              warmup =============                           350/1000 | eta:  2s              warmup ===============                         400/1000 | eta:  2s              warmup =================                       450/1000 | eta:  1s              warmup ===================                     500/1000 | eta:  1s              warmup =====================                   550/1000 | eta:  1s              warmup =======================                 600/1000 | eta:  1s              warmup =========================               650/1000 | eta:  1s              warmup ===========================             700/1000 | eta:  1s              warmup ============================            750/1000 | eta:  1s              warmup ==============================          800/1000 | eta:  0s              warmup ================================        850/1000 | eta:  0s              warmup ==================================      900/1000 | eta:  0s              warmup ====================================    950/1000 | eta:  0s              warmup ====================================== 1000/1000 | eta:  0s          
#>   sampling                                           0/1000 | eta:  ?s            sampling ==                                       50/1000 | eta:  1s            sampling ====                                    100/1000 | eta:  1s            sampling ======                                  150/1000 | eta:  1s            sampling ========                                200/1000 | eta:  1s            sampling ==========                              250/1000 | eta:  0s            sampling ===========                             300/1000 | eta:  0s            sampling =============                           350/1000 | eta:  0s            sampling ===============                         400/1000 | eta:  0s            sampling =================                       450/1000 | eta:  0s            sampling ===================                     500/1000 | eta:  0s            sampling =====================                   550/1000 | eta:  0s            sampling =======================                 600/1000 | eta:  0s            sampling =========================               650/1000 | eta:  0s            sampling ===========================             700/1000 | eta:  0s            sampling ============================            750/1000 | eta:  0s            sampling ==============================          800/1000 | eta:  0s            sampling ================================        850/1000 | eta:  0s            sampling ==================================      900/1000 | eta:  0s            sampling ====================================    950/1000 | eta:  0s            sampling ====================================== 1000/1000 | eta:  0s
toc()
#> 2.877 sec elapsed

par(mfrow = c(2, 4))
plot(draws_hmc, auto.layout = FALSE)

tic()
draws_hmc_adapt <- mcmc(
  model = m,
  sampler = adaptive_hmc()
  )
#> running 4 chains simultaneously on up to 8 CPU cores
#>   sampling                                           0/1000 | eta:  ?s            sampling ==                                       50/1000 | eta: 41s            sampling ====                                    100/1000 | eta: 20s            sampling ======                                  150/1000 | eta: 13s            sampling ========                                200/1000 | eta:  9s            sampling ==========                              250/1000 | eta:  7s            sampling ===========                             300/1000 | eta:  5s            sampling =============                           350/1000 | eta:  4s            sampling ===============                         400/1000 | eta:  4s            sampling =================                       450/1000 | eta:  3s            sampling ===================                     500/1000 | eta:  2s            sampling =====================                   550/1000 | eta:  2s            sampling =======================                 600/1000 | eta:  2s            sampling =========================               650/1000 | eta:  1s            sampling ===========================             700/1000 | eta:  1s            sampling ============================            750/1000 | eta:  1s            sampling ==============================          800/1000 | eta:  1s            sampling ================================        850/1000 | eta:  0s            sampling ==================================      900/1000 | eta:  0s            sampling ====================================    950/1000 | eta:  0s            sampling ====================================== 1000/1000 | eta:  0s
toc()
#> 5.167 sec elapsed

plot(draws_hmc_adapt, auto.layout = FALSE)

Created on 2025-03-12 with reprex v2.1.1

Session info

sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value
#>  version  R version 4.4.2 (2024-10-31)
#>  os       macOS Sequoia 15.1
#>  system   aarch64, darwin20
#>  ui       X11
#>  language (EN)
#>  collate  en_US.UTF-8
#>  ctype    en_US.UTF-8
#>  tz       Australia/Hobart
#>  date     2025-03-12
#>  pandoc   3.2.1 @ /opt/homebrew/bin/ (via rmarkdown)
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package     * version    date (UTC) lib source
#>  abind         1.4-8      2024-09-12 [1] CRAN (R 4.4.1)
#>  backports     1.5.0      2024-05-23 [1] CRAN (R 4.4.0)
#>  base64enc     0.1-3      2015-07-28 [1] CRAN (R 4.4.0)
#>  callr         3.7.6      2024-03-25 [1] CRAN (R 4.4.0)
#>  cli           3.6.4      2025-02-13 [1] CRAN (R 4.4.1)
#>  coda          0.19-4.1   2024-01-31 [1] CRAN (R 4.4.0)
#>  codetools     0.2-20     2024-03-31 [2] CRAN (R 4.4.2)
#>  crayon        1.5.3      2024-06-20 [1] CRAN (R 4.4.0)
#>  curl          6.2.0      2025-01-23 [1] CRAN (R 4.4.1)
#>  digest        0.6.37     2024-08-19 [1] CRAN (R 4.4.1)
#>  evaluate      1.0.1      2024-10-10 [1] CRAN (R 4.4.1)
#>  fastmap       1.2.0      2024-05-15 [1] CRAN (R 4.4.0)
#>  fs            1.6.5      2024-10-30 [1] CRAN (R 4.4.1)
#>  future        1.34.0     2024-07-29 [1] CRAN (R 4.4.0)
#>  globals       0.16.3     2024-03-08 [1] CRAN (R 4.4.0)
#>  glue          1.8.0      2024-09-30 [1] CRAN (R 4.4.1)
#>  greta       * 0.5.0.9000 2025-03-12 [1] local
#>  hms           1.1.3      2023-03-21 [1] CRAN (R 4.4.0)
#>  htmltools     0.5.8.1    2024-04-04 [1] CRAN (R 4.4.0)
#>  jsonlite      1.8.9      2024-09-20 [1] CRAN (R 4.4.1)
#>  knitr         1.49       2024-11-08 [1] CRAN (R 4.4.1)
#>  lattice       0.22-6     2024-03-20 [2] CRAN (R 4.4.2)
#>  lifecycle     1.0.4      2023-11-07 [1] CRAN (R 4.4.0)
#>  listenv       0.9.1      2024-01-29 [1] CRAN (R 4.4.0)
#>  magrittr      2.0.3      2022-03-30 [1] CRAN (R 4.4.0)
#>  Matrix        1.7-1      2024-10-18 [2] CRAN (R 4.4.2)
#>  parallelly    1.41.0     2024-12-18 [1] CRAN (R 4.4.1)
#>  pkgconfig     2.0.3      2019-09-22 [1] CRAN (R 4.4.0)
#>  png           0.1-8      2022-11-29 [1] CRAN (R 4.4.0)
#>  prettyunits   1.2.0      2023-09-24 [1] CRAN (R 4.4.0)
#>  processx      3.8.5      2025-01-08 [1] CRAN (R 4.4.1)
#>  progress      1.2.3      2023-12-06 [1] CRAN (R 4.4.0)
#>  ps            1.8.1      2024-10-28 [1] CRAN (R 4.4.1)
#>  R6            2.6.1      2025-02-15 [1] CRAN (R 4.4.1)
#>  Rcpp          1.0.14     2025-01-12 [1] CRAN (R 4.4.1)
#>  reprex        2.1.1      2024-07-06 [1] CRAN (R 4.4.0)
#>  reticulate    1.40.0     2024-11-15 [1] CRAN (R 4.4.1)
#>  rlang         1.1.5      2025-01-17 [1] CRAN (R 4.4.1)
#>  rmarkdown     2.29       2024-11-04 [1] CRAN (R 4.4.1)
#>  rstudioapi    0.17.1     2024-10-22 [1] CRAN (R 4.4.1)
#>  sessioninfo   1.2.2      2021-12-06 [1] CRAN (R 4.4.0)
#>  tensorflow    2.16.0     2024-04-15 [1] CRAN (R 4.4.0)
#>  tfautograph   0.3.2      2021-09-17 [1] CRAN (R 4.4.0)
#>  tfruns        1.5.3      2024-04-19 [1] CRAN (R 4.4.0)
#>  tictoc      * 1.2.1      2024-03-18 [1] CRAN (R 4.4.0)
#>  vctrs         0.6.5      2023-12-01 [1] CRAN (R 4.4.0)
#>  whisker       0.4.1      2022-12-05 [1] CRAN (R 4.4.0)
#>  withr         3.0.2      2024-10-28 [1] CRAN (R 4.4.1)
#>  xfun          0.50.5     2025-01-15 [1] Github (yihui/xfun@116d689)
#>  xml2          1.3.6      2023-12-04 [1] CRAN (R 4.4.0)
#>  yaml          2.3.10     2024-07-26 [1] CRAN (R 4.4.0)
#> 
#>  [1] /Users/nick/Library/R/arm64/4.4/library
#>  [2] /Library/Frameworks/R.framework/Versions/4.4-arm64/Resources/library
#> 
#> ─ Python configuration ───────────────────────────────────────────────────────
#>  python:         /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/bin/python
#>  libpython:      /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/libpython3.10.dylib
#>  pythonhome:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2:/Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2
#>  version:        3.10.14 | packaged by conda-forge | (main, Mar 20 2024, 12:51:49) [Clang 16.0.6 ]
#>  numpy:          /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/numpy
#>  numpy_version:  1.26.4
#>  tensorflow:     /Users/nick/Library/r-miniconda-arm64/envs/greta-env-tf2/lib/python3.10/site-packages/tensorflow
#>  
#>  NOTE: Python version was forced by use_python() function
#> 
#> ──────────────────────────────────────────────────────────────────────────────

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add snaper HMC sampler
1 participant