Skip to content

Commit abf45b7

Browse files
authored
Fix failed test related to PR #443 by replacing pt.batched_dot with pt.vectorize(pt.dot,...) (#453)
1 parent 1d8f7f5 commit abf45b7

File tree

2 files changed

+17
-24
lines changed

2 files changed

+17
-24
lines changed

Diff for: pymc_extras/inference/pathfinder/pathfinder.py

+16-20
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import collections
1616
import logging
1717
import time
18-
import warnings as _warnings
1918

2019
from collections import Counter
2120
from collections.abc import Callable, Iterator
@@ -40,7 +39,7 @@
4039
from pymc.model import modelcontext
4140
from pymc.model.core import Point
4241
from pymc.pytensorf import (
43-
compile_pymc,
42+
compile,
4443
find_rng_nodes,
4544
reseed_rngs,
4645
)
@@ -76,9 +75,6 @@
7675
)
7776

7877
logger = logging.getLogger(__name__)
79-
_warnings.filterwarnings(
80-
"ignore", category=FutureWarning, message="compile_pymc was renamed to compile"
81-
)
8278

8379
REGULARISATION_TERM = 1e-8
8480
DEFAULT_LINKER = "cvm_nogc"
@@ -142,7 +138,7 @@ def get_logp_dlogp_of_ravel_inputs(
142138
[model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)],
143139
model.value_vars,
144140
)
145-
logp_dlogp_fn = compile_pymc([inputs], (logP, dlogP), **compile_kwargs)
141+
logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs)
146142
logp_dlogp_fn.trust_input = True
147143

148144
return logp_dlogp_fn
@@ -502,9 +498,10 @@ def bfgs_sample_dense(
502498

503499
logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
504500

505-
with _warnings.catch_warnings():
506-
_warnings.simplefilter("ignore", category=FutureWarning)
507-
mu = x - pt.batched_dot(H_inv, g)
501+
# mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g
502+
503+
batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)")
504+
mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None]))
508505

509506
phi = pt.matrix_transpose(
510507
# (L, N, 1)
@@ -573,17 +570,16 @@ def bfgs_sample_sparse(
573570
logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
574571
logdet += pt.sum(pt.log(alpha), axis=-1)
575572

573+
# inverse Hessian
574+
# (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
575+
H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta))
576+
576577
# NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
577-
with _warnings.catch_warnings():
578-
_warnings.simplefilter("ignore", category=FutureWarning)
579-
mu = x - (
580-
# (L, N), (L, N) -> (L, N)
581-
pt.batched_dot(alpha_diag, g)
582-
# beta @ gamma @ beta.T
583-
# (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
584-
# (L, N, N), (L, N) -> (L, N)
585-
+ pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g)
586-
)
578+
579+
# mu = x - pt.einsum("ijk,ik->ij", H_inv, g) # causes error: Multiple destroyers of g
580+
581+
batched_dot = pt.vectorize(pt.dot, signature="(ijk),(ilk)->(ij)")
582+
mu = x - batched_dot(H_inv, pt.matrix_transpose(g[..., None]))
587583

588584
phi = pt.matrix_transpose(
589585
# (L, N, 1)
@@ -857,7 +853,7 @@ def make_pathfinder_body(
857853

858854
# return psi, logP_psi, logQ_psi, elbo_argmax
859855

860-
pathfinder_body_fn = compile_pymc(
856+
pathfinder_body_fn = compile(
861857
[x_full, g_full],
862858
[psi, logP_psi, logQ_psi, elbo_argmax],
863859
**compile_kwargs,

Diff for: tests/test_pathfinder.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,6 @@
1818
import pymc as pm
1919
import pytest
2020

21-
pytestmark = pytest.mark.filterwarnings(
22-
"ignore:compile_pymc was renamed to compile:FutureWarning",
23-
)
24-
2521
import pymc_extras as pmx
2622

2723

@@ -55,6 +51,7 @@ def reference_idata():
5551

5652

5753
@pytest.mark.parametrize("inference_backend", ["pymc", "blackjax"])
54+
@pytest.mark.filterwarnings("ignore:JAXopt is no longer maintained.:DeprecationWarning")
5855
def test_pathfinder(inference_backend, reference_idata):
5956
if inference_backend == "blackjax" and sys.platform == "win32":
6057
pytest.skip("JAX not supported on windows")

0 commit comments

Comments
 (0)