|
15 | 15 | import collections
|
16 | 16 | import logging
|
17 | 17 | import time
|
18 |
| -import warnings as _warnings |
19 | 18 |
|
20 | 19 | from collections import Counter
|
21 | 20 | from collections.abc import Callable, Iterator
|
|
40 | 39 | from pymc.model import modelcontext
|
41 | 40 | from pymc.model.core import Point
|
42 | 41 | from pymc.pytensorf import (
|
43 |
| - compile_pymc, |
| 42 | + compile, |
44 | 43 | find_rng_nodes,
|
45 | 44 | reseed_rngs,
|
46 | 45 | )
|
|
76 | 75 | )
|
77 | 76 |
|
78 | 77 | logger = logging.getLogger(__name__)
|
79 |
| -_warnings.filterwarnings( |
80 |
| - "ignore", category=FutureWarning, message="compile_pymc was renamed to compile" |
81 |
| -) |
82 | 78 |
|
83 | 79 | REGULARISATION_TERM = 1e-8
|
84 | 80 | DEFAULT_LINKER = "cvm_nogc"
|
@@ -142,7 +138,7 @@ def get_logp_dlogp_of_ravel_inputs(
|
142 | 138 | [model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)],
|
143 | 139 | model.value_vars,
|
144 | 140 | )
|
145 |
| - logp_dlogp_fn = compile_pymc([inputs], (logP, dlogP), **compile_kwargs) |
| 141 | + logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs) |
146 | 142 | logp_dlogp_fn.trust_input = True
|
147 | 143 |
|
148 | 144 | return logp_dlogp_fn
|
@@ -502,9 +498,10 @@ def bfgs_sample_dense(
|
502 | 498 |
|
503 | 499 | logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
|
504 | 500 |
|
505 |
| - with _warnings.catch_warnings(): |
506 |
| - _warnings.simplefilter("ignore", category=FutureWarning) |
507 |
| - mu = x - pt.batched_dot(H_inv, g) |
| 501 | + # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g |
| 502 | + |
| 503 | + batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)") |
| 504 | + mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None])) |
508 | 505 |
|
509 | 506 | phi = pt.matrix_transpose(
|
510 | 507 | # (L, N, 1)
|
@@ -573,17 +570,16 @@ def bfgs_sample_sparse(
|
573 | 570 | logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
|
574 | 571 | logdet += pt.sum(pt.log(alpha), axis=-1)
|
575 | 572 |
|
| 573 | + # inverse Hessian |
| 574 | + # (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) |
| 575 | + H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta)) |
| 576 | + |
576 | 577 | # NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
|
577 |
| - with _warnings.catch_warnings(): |
578 |
| - _warnings.simplefilter("ignore", category=FutureWarning) |
579 |
| - mu = x - ( |
580 |
| - # (L, N), (L, N) -> (L, N) |
581 |
| - pt.batched_dot(alpha_diag, g) |
582 |
| - # beta @ gamma @ beta.T |
583 |
| - # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) |
584 |
| - # (L, N, N), (L, N) -> (L, N) |
585 |
| - + pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g) |
586 |
| - ) |
| 578 | + |
| 579 | + # mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g |
| 580 | + |
| 581 | + batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)") |
| 582 | + mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None])) |
587 | 583 |
|
588 | 584 | phi = pt.matrix_transpose(
|
589 | 585 | # (L, N, 1)
|
@@ -857,7 +853,7 @@ def make_pathfinder_body(
|
857 | 853 |
|
858 | 854 | # return psi, logP_psi, logQ_psi, elbo_argmax
|
859 | 855 |
|
860 |
| - pathfinder_body_fn = compile_pymc( |
| 856 | + pathfinder_body_fn = compile( |
861 | 857 | [x_full, g_full],
|
862 | 858 | [psi, logP_psi, logQ_psi, elbo_argmax],
|
863 | 859 | **compile_kwargs,
|
|
0 commit comments