diff --git a/src/mpol/fourier.py b/src/mpol/fourier.py index c45c272b..d57c412c 100644 --- a/src/mpol/fourier.py +++ b/src/mpol/fourier.py @@ -44,7 +44,7 @@ def __init__(self, coords: GridCoords, persistent_vis: bool = False): self.register_buffer("vis", None, persistent=persistent_vis) self.vis: torch.Tensor - def forward(self, cube: torch.Tensor) -> torch.Tensor: + def forward(self, packed_cube: torch.Tensor) -> torch.Tensor: """ Perform the FFT of the image cube on each channel. @@ -61,12 +61,12 @@ def forward(self, cube: torch.Tensor) -> torch.Tensor: """ # make sure the cube is 3D - assert cube.dim() == 3, "cube must be 3D" + assert packed_cube.dim() == 3, "cube must be 3D" # the self.cell_size prefactor (in arcsec) is to obtain the correct output units # since it needs to correct for the spacing of the input grid. # See MPoL documentation and/or TMS Eqn A8.18 for more information. - self.vis = self.coords.cell_size**2 * torch.fft.fftn(cube, dim=(1, 2)) + self.vis = self.coords.cell_size**2 * torch.fft.fftn(packed_cube, dim=(1, 2)) return self.vis diff --git a/src/mpol/images.py b/src/mpol/images.py index a5939989..169b4126 100644 --- a/src/mpol/images.py +++ b/src/mpol/images.py @@ -63,7 +63,8 @@ def __init__( # base_cube = -3 yields a nearly-blank cube after softplus, whereas # base_cube = 0.0 yields a cube with avg value of ~0.7, which is too high self.base_cube = nn.Parameter( - -3 * torch.ones( + -3 + * torch.ones( (self.nchan, self.coords.npix, self.coords.npix), requires_grad=True, dtype=torch.double, @@ -227,13 +228,13 @@ def __init__( def forward(self, packed_cube: torch.Tensor) -> torch.Tensor: r""" - Pass the cube through as an identity operation, storing the value to the - internal buffer. After the cube has been passed through, convenience + Pass the cube through as an identity operation, storing the value to the + internal buffer. After the cube has been passed through, convenience instance attributes like `sky_cube` and `flux` will reflect the updated cube. Parameters ---------- - packed_cube : :class:`torch.Tensor` of type :class:`torch.double` + packed_cube : :class:`torch.Tensor` of type :class:`torch.double` 3D torch tensor of shape ``(nchan, npix, npix)``) in 'packed' format Returns @@ -287,10 +288,10 @@ def to_FITS( Returns: None """ - + from astropy import wcs from astropy.io import fits - + w = wcs.WCS(naxis=2) w.wcs.crpix = np.array([1, 1]) @@ -312,3 +313,54 @@ def to_FITS( hdul.writeto(fname, overwrite=overwrite) hdul.close() + + def convolve_packed_cube( + packed_cube: torch.Tensor, + coords: GridCoords, + FWHM_maj: float, + FWHM_min: float, + Omega: float, + ) -> torch.Tensor: + r""" + Convolve an image cube with a 2D Gaussian PSF. Operation is carried out in the Fourier domain using a Gaussian taper. + + Parameters + ---------- + packed_cube : :class:`torch.Tensor` of :class:`torch.double` type + shape ``(nchan, npix, npix)`` image cube in packed format. + coords: :class:`mpol.coordinates.GridCoords` + object indicating image and Fourier grid specifications. + FWHM_maj: float, units of arcsec + the FWHH of the Gaussian along the major axis + FWHM_min: float, units of arcsec + the FWHM of the Gaussian along the minor axis + Omega: float, degrees + the rotation of the major axis of the PSF, in degrees East of North. 0 degrees rotation has the major axis aligned in the East-West direction. + """ + nchan, npix_m, npix_l = packed_cube.size() + assert (npix_m == coords.npix) and ( + npix_l == coords.npix + ), "packed_cube {:} does not have the same pixel dimensions as indicated by coords {:}".format( + packed_cube.size(), coords.npix + ) + + # in FFT packed format + # we're round-tripping, so we can ignore prefactors for correctness + # calling this `vis_like`, since it's not actually the vis + vis_like = torch.fft.fftn(packed_cube, dim=(1, 2)) + + # convert FWHM to sigma + FWHM2sigma = 1 / (2 * np.sqrt(2 * np.log(2))) + sigma_x = FWHM_maj * FWHM2sigma + sigma_y = FWHM_min * FWHM2sigma + + # calculate corresponding uu and vv matrices in packed format + taper_2D = utils.fourier_gaussian_lambda_arcsec(coords.packed_u_centers_2D, coords.packed_v_centers_2D, a=1.0, delta_x=0.0, delta_y=0.0, sigma_x=sigma_x, sigma_y=sigma_y, Omega=Omega) + + # calculate taper on packed image + tapered_vis = vis_like * torch.broadcast_to(taper_2D, packed_cube.size()) + + # iFFT back, ignoring prefactors for round-trip + convolved_packed_cube = torch.fft.fftn(tapered_vis, dim=(1,2)) + + return convolved_packed_cube