From 3cc0313bd7724fdaeddc4fd93b4c47d47d0cff5c Mon Sep 17 00:00:00 2001 From: Elizabeth Santorella Date: Tue, 19 Mar 2024 14:48:57 -0700 Subject: [PATCH] Update deprecated usage of base_samples with GPyTorchPosterior.rsample (#339) Summary: Pull Request resolved: https://github.com/facebookresearch/aepsych/pull/339 The base_samples argument of GPyTorchPosterior.rsample has been deprecated for a while and will soon be reaped. This PR updates the usage to `rsample_from_base_samples`, which does permit the `base_samples` argument. It reshapes the `base_samples` in the same way they were being reshaped by `GPyTorchPosterior.rsample`. A more elegant fix would be to deprecate the `base_samples` argument to `SemiPPosterior.rsample` so that it matches the signature of its superclass rather than add in this complex workaround. Reviewed By: Balandat Differential Revision: D55040610 fbshipit-source-id: 9f908ff52083849e4bf93fcd681e7f7af8f1ce72 --- aepsych/models/semi_p.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/aepsych/models/semi_p.py b/aepsych/models/semi_p.py index fcfb19106..0d678054f 100644 --- a/aepsych/models/semi_p.py +++ b/aepsych/models/semi_p.py @@ -74,7 +74,7 @@ def semi_p_posterior_transform(posterior): offset_cov=offset_cov, ) approx_mvn = MultivariateNormal(mean=approx_mean, covariance_matrix=approx_cov) - return GPyTorchPosterior(mvn=approx_mvn) + return GPyTorchPosterior(distribution=approx_mvn) class SemiPPosterior(GPyTorchPosterior): @@ -101,7 +101,8 @@ def rsample_from_base_samples( return ( super() .rsample_from_base_samples( - sample_shape=sample_shape, base_samples=base_samples + sample_shape=sample_shape, + base_samples=base_samples.expand(self._extended_shape(sample_shape)), ) .squeeze(-1) ) @@ -111,11 +112,13 @@ def rsample( sample_shape: Optional[torch.Size] = None, base_samples: Optional[torch.Tensor] = None, ): - kcsamps = ( - super() - .rsample(sample_shape=sample_shape, base_samples=base_samples) - .squeeze(-1) - ) + if base_samples is None: + samps_ = super().rsample(sample_shape=sample_shape) + else: + samps_ = super().rsample_from_base_samples( + sample_shape=sample_shape, base_samples=base_samples + ) + kcsamps = samps_.squeeze(-1) # fsamps is of shape nsamp x 2 x n, or nsamp x b x 2 x n return kcsamps