Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent observation masks improperly expanding samples in plates #3317

Merged
merged 3 commits into from
Feb 7, 2024

Conversation

austinv11
Copy link
Contributor

Occasionally (I haven't quite determined why this happens, I suspect it has to do with nesting subsampling plates) I have noticed that partially masked observations will double in shape (A,B,C) -> (A, B, C, A, B, C) when using the Predictive interface.

This patch fixes this and all tests related to masking appear to pass.

@fritzo
Copy link
Member

fritzo commented Feb 2, 2024

Hi @austinv11, thanks for this fix! This should be able to merge after we fix our CI in #3318.

It would be great to have a regression test for this, maybe editing or forking test_obs_mask_multivariate. (Actually I'm surprised and disappointed that test didn't catch this bug 🤔). Would you be up to either add a regression test in this PR or give us some hints about what a test might look like, based on your models that found this error?

@austinv11
Copy link
Contributor Author

austinv11 commented Feb 2, 2024

Hey @fritzo, I tried modifying some tests to see if I could trigger the error, but I am having difficulty replicating it outside of my model. Perhaps I don't understand the dimension broadcasting mechanisms in pyro enough. But here is snippets from my svi model that always creates issues when using the Predictive interface with parellel=False:

def model(...):
	with pyro.poutine.scale(scale=annealing_factor):
            with pyro.plate(
                "cell_ligand_plate", total_cells, dim=-2, subsample=batch_idx
            ):
                with pyro.plate("ligand_plate", n_ligands, dim=-1):
					lavail = pyro.sample(
                        "ligand_availability",
                        dist.ContinuousBernoulli(
                            logits=ligand_available_logits[
                                data.samples.argmax(1)
                            ].unsqueeze(-1)
                        ).to_event(1),
                        obs=(data.ligand_X > 0).unsqueeze(-1).float(),
                        obs_mask=(data.ligand_X > 0),
                    )

def guide(...):
        with pyro.poutine.scale(scale=annealing_factor):
		    with pyro.plate(
                "cell_ligand_plate", total_cells, dim=-2, subsample=batch_idx
            ):
                with pyro.plate("ligand_plate", n_ligands, dim=-1):
                    with pyro.poutine.mask(mask=data.ligand_X <= 0):
                        # Predict ligand availability from the current cell's profile
                        pyro.sample(
                            "ligand_availability_unobserved",
                            dist.ContinuousBernoulli(
                                logits=self._predict_ligand_activation_from_cell(
                                    data.n_genes, data.n_ligands, data.dense_X.log1p()
                                ).unsqueeze(-1)
                            ).to_event(1),
                        )


...


predictive = pyro.infer.Predictive(
    model,
    guide=guide,
    num_samples=replicates,
    parallel=False,
)

# Calls to predictive(...) will now have dimensionality issues, but during training with svi does not

@fritzo fritzo removed the Blocked label Feb 2, 2024
@fritzo fritzo added this to the 1.9 release milestone Feb 7, 2024
@fritzo
Copy link
Member

fritzo commented Feb 7, 2024

Hi @austinv11, thanks for providing the example. Here's a regression test we could use, on the 3317-regression branch

diff --git a/tests/test_primitives.py b/tests/test_primitives.py
index 663e3a67..c208a2cf 100644
--- a/tests/test_primitives.py
+++ b/tests/test_primitives.py
@@ -1,11 +1,14 @@
 # Copyright Contributors to the Pyro project.
 # SPDX-License-Identifier: Apache-2.0

+from typing import Optional
+
 import pytest
 import torch

 import pyro
 import pyro.distributions as dist
+from pyro import poutine

 pytestmark = pytest.mark.stage("unit")

@@ -31,3 +34,30 @@ def test_deterministic_ok():
     x = pyro.deterministic("x", torch.tensor(0.0))
     assert isinstance(x, torch.Tensor)
     assert x.shape == ()
+
+
+@pytest.mark.parametrize(
+    "mask",
+    [
+        None,
+        torch.tensor(True),
+        torch.tensor([True]),
+        torch.tensor([True, False, True]),
+    ],
+)
+def test_obs_mask_shape(mask: Optional[torch.Tensor]):
+    data = torch.randn(3, 2)
+
+    def model():
+        with pyro.plate("data", 3):
+            pyro.sample(
+                "y",
+                dist.MultivariateNormal(torch.zeros(2), scale_tril=torch.eye(2)),
+                obs=data,
+                obs_mask=mask,
+            )
+
+    trace = poutine.trace(model).get_trace()
+    y_dist = trace.nodes["y"]["fn"]
+    assert y_dist.batch_shape == (3,)
+    assert y_dist.event_shape == (2,)

Could you merge in recent changes to dev (so CI passes), and add this test to your branch? I'd like to get your fix into our upcoming 1.9 release. Thanks again!

fritzo added a commit that referenced this pull request Feb 7, 2024
Copy link
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but let's add a regression test.

@austinv11
Copy link
Contributor Author

Hey @fritzo I appreciate your effort in developing the regression test! My fix passes that test locally. Just updated the PR.

@fritzo
Copy link
Member

fritzo commented Feb 7, 2024

@austinv11 BTW how did you do that cross-repo cherry-pick or merge? Did you do that in the github gui or git command line?

@austinv11
Copy link
Contributor Author

@fritzo It was a little bit of a pain, but I was able to do it using the Git interface in PyCharm (added the original repo as a remote, fetched its changes, then PyCharm let me choose that commit and cherrypick)

@fritzo fritzo merged commit 6337ced into pyro-ppl:dev Feb 7, 2024
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants