Skip to content

Commit

Permalink
Update deprecated usage of base_samples with GPyTorchPosterior.rsample (
Browse files Browse the repository at this point in the history
#339)

Summary:
Pull Request resolved: #339

The base_samples argument of GPyTorchPosterior.rsample has been deprecated for a while and will soon be reaped. This PR updates the usage to `rsample_from_base_samples`, which does permit the `base_samples` argument. It reshapes the `base_samples` in the same way they were being reshaped by `GPyTorchPosterior.rsample`.

A more elegant fix would be to deprecate the `base_samples` argument to `SemiPPosterior.rsample` so that it matches the signature of its superclass rather than add in this complex workaround.

Reviewed By: Balandat

Differential Revision: D55040610

fbshipit-source-id: 9f908ff52083849e4bf93fcd681e7f7af8f1ce72
  • Loading branch information
esantorella authored and facebook-github-bot committed Mar 19, 2024
1 parent d46b729 commit 3cc0313
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions aepsych/models/semi_p.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def semi_p_posterior_transform(posterior):
offset_cov=offset_cov,
)
approx_mvn = MultivariateNormal(mean=approx_mean, covariance_matrix=approx_cov)
return GPyTorchPosterior(mvn=approx_mvn)
return GPyTorchPosterior(distribution=approx_mvn)


class SemiPPosterior(GPyTorchPosterior):
Expand All @@ -101,7 +101,8 @@ def rsample_from_base_samples(
return (
super()
.rsample_from_base_samples(
sample_shape=sample_shape, base_samples=base_samples
sample_shape=sample_shape,
base_samples=base_samples.expand(self._extended_shape(sample_shape)),
)
.squeeze(-1)
)
Expand All @@ -111,11 +112,13 @@ def rsample(
sample_shape: Optional[torch.Size] = None,
base_samples: Optional[torch.Tensor] = None,
):
kcsamps = (
super()
.rsample(sample_shape=sample_shape, base_samples=base_samples)
.squeeze(-1)
)
if base_samples is None:
samps_ = super().rsample(sample_shape=sample_shape)
else:
samps_ = super().rsample_from_base_samples(
sample_shape=sample_shape, base_samples=base_samples
)
kcsamps = samps_.squeeze(-1)
# fsamps is of shape nsamp x 2 x n, or nsamp x b x 2 x n
return kcsamps

Expand Down

0 comments on commit 3cc0313

Please sign in to comment.