From d6ebf5a5f6a796b5d48efd9bd800995fbce88a02 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ola=20R=C3=B8nning?= Date: Wed, 4 Dec 2024 04:39:13 +0100 Subject: [PATCH] added note and assert that sbvm conc < 10k (#3412) --- pyro/distributions/sine_bivariate_von_mises.py | 12 ++++++++++++ tests/distributions/test_sine_bivariate_von_mises.py | 9 ++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/pyro/distributions/sine_bivariate_von_mises.py b/pyro/distributions/sine_bivariate_von_mises.py index 40be29ec9a..e3723d6ae6 100644 --- a/pyro/distributions/sine_bivariate_von_mises.py +++ b/pyro/distributions/sine_bivariate_von_mises.py @@ -55,6 +55,8 @@ class SineBivariateVonMises(TorchDistribution): .. note:: In the context of :class:`~pyro.infer.SVI`, this distribution can be used as a likelihood but not for latent variables. + .. note:: Normalization remains accurate up to concentrations of 10,000. + ** References: ** 1. Probabilistic model for two dependent circular variables Singh, H., Hnizdo, V., and Demchuck, E. (2002) 2. Protein Bioinformatics and Mixtures of Bivariate von Mises Distributions for Angular Data, @@ -108,6 +110,16 @@ def __init__( ) = broadcast_all( phi_loc, psi_loc, phi_concentration, psi_concentration, correlation ) + + max_conc = torch.maximum( + torch.max(phi_concentration), torch.max(psi_concentration) + ) + assrt_hstr = ( + "Normalization of SineBiviateVonMises is inaccurate for" + f"current max concentration ({max_conc} > 10,000)." + ) + assert max_conc <= torch.tensor(10_000.0), assrt_hstr + self.phi_loc = phi_loc self.psi_loc = psi_loc self.phi_concentration = phi_concentration diff --git a/tests/distributions/test_sine_bivariate_von_mises.py b/tests/distributions/test_sine_bivariate_von_mises.py index 6212220dcb..93a7e90c30 100644 --- a/tests/distributions/test_sine_bivariate_von_mises.py +++ b/tests/distributions/test_sine_bivariate_von_mises.py @@ -132,8 +132,15 @@ def guide(data): assert_equal(expected[k].squeeze(), actual.squeeze(), 9e-2) -@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10000.0]) +@pytest.mark.parametrize("conc", [1.0, 10.0, 1000.0, 10_000.0, 10_001.0]) def test_sine_bivariate_von_mises_norm(conc): + if conc > 10_000.0: + try: + dist = SineBivariateVonMises(0, 0, conc, conc, 0.0) + pytest.fail() + except AssertionError: + return + dist = SineBivariateVonMises(0, 0, conc, conc, 0.0) num_samples = 500 x = torch.linspace(-torch.pi, torch.pi, num_samples)