From d040e5754dc773d2c9f6bc365bbdedf2fc34a220 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sat, 12 Apr 2025 15:00:07 +0200 Subject: [PATCH 1/3] Filter warning for batched_dot until we change it --- tests/test_pathfinder.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_pathfinder.py b/tests/test_pathfinder.py index af9213ff..bdb8fe81 100644 --- a/tests/test_pathfinder.py +++ b/tests/test_pathfinder.py @@ -18,7 +18,10 @@ import pymc as pm import pytest -pytestmark = pytest.mark.filterwarnings("ignore:compile_pymc was renamed to compile:FutureWarning") +pytestmark = pytest.mark.filterwarnings( + "ignore:compile_pymc was renamed to compile:FutureWarning", + "ignore:batched_dot is deprecated:FutureWarning", +) import pymc_extras as pmx From 18a67778b8de60724052c21ca80ca3eb482ba6b2 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sat, 12 Apr 2025 16:15:48 +0200 Subject: [PATCH 2/3] Move warning filter to code from test --- pymc_extras/inference/pathfinder/pathfinder.py | 2 ++ tests/test_pathfinder.py | 1 - 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 531efc56..12a2c666 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -485,6 +485,7 @@ def bfgs_sample_dense( shapes: L=batch_size, N=num_params, J=history_size, M=num_samples """ + _warnings.simplefilter("ignore", category=FutureWarning) N = x.shape[-1] IdN = pt.eye(N)[None, ...] @@ -560,6 +561,7 @@ def bfgs_sample_sparse( shapes: L=batch_size, N=num_params, J=history_size, M=num_samples """ + _warnings.simplefilter("ignore", category=FutureWarning) # qr_input: (L, N, 2J) qr_input = inv_sqrt_alpha_diag @ beta (Q, R), _ = pytensor.scan(fn=pt.nlinalg.qr, sequences=[qr_input], allow_gc=False) diff --git a/tests/test_pathfinder.py b/tests/test_pathfinder.py index bdb8fe81..ea9d7748 100644 --- a/tests/test_pathfinder.py +++ b/tests/test_pathfinder.py @@ -20,7 +20,6 @@ pytestmark = pytest.mark.filterwarnings( "ignore:compile_pymc was renamed to compile:FutureWarning", - "ignore:batched_dot is deprecated:FutureWarning", ) import pymc_extras as pmx From 1f2f2c1e106d9983930f0f8560631ac53426e450 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sat, 12 Apr 2025 16:23:21 +0200 Subject: [PATCH 3/3] Moving warning into contextmanager around batched_dot --- .../inference/pathfinder/pathfinder.py | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/pymc_extras/inference/pathfinder/pathfinder.py b/pymc_extras/inference/pathfinder/pathfinder.py index 12a2c666..66d081bd 100644 --- a/pymc_extras/inference/pathfinder/pathfinder.py +++ b/pymc_extras/inference/pathfinder/pathfinder.py @@ -485,7 +485,6 @@ def bfgs_sample_dense( shapes: L=batch_size, N=num_params, J=history_size, M=num_samples """ - _warnings.simplefilter("ignore", category=FutureWarning) N = x.shape[-1] IdN = pt.eye(N)[None, ...] @@ -503,7 +502,9 @@ def bfgs_sample_dense( logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1) - mu = x - pt.batched_dot(H_inv, g) + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore", category=FutureWarning) + mu = x - pt.batched_dot(H_inv, g) phi = pt.matrix_transpose( # (L, N, 1) @@ -561,7 +562,6 @@ def bfgs_sample_sparse( shapes: L=batch_size, N=num_params, J=history_size, M=num_samples """ - _warnings.simplefilter("ignore", category=FutureWarning) # qr_input: (L, N, 2J) qr_input = inv_sqrt_alpha_diag @ beta (Q, R), _ = pytensor.scan(fn=pt.nlinalg.qr, sequences=[qr_input], allow_gc=False) @@ -574,14 +574,16 @@ def bfgs_sample_sparse( logdet += pt.sum(pt.log(alpha), axis=-1) # NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version. - mu = x - ( - # (L, N), (L, N) -> (L, N) - pt.batched_dot(alpha_diag, g) - # beta @ gamma @ beta.T - # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) - # (L, N, N), (L, N) -> (L, N) - + pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g) - ) + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore", category=FutureWarning) + mu = x - ( + # (L, N), (L, N) -> (L, N) + pt.batched_dot(alpha_diag, g) + # beta @ gamma @ beta.T + # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) + # (L, N, N), (L, N) -> (L, N) + + pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g) + ) phi = pt.matrix_transpose( # (L, N, 1)