Skip to content

Commit

Permalink
added untested convolve method.
Browse files Browse the repository at this point in the history
  • Loading branch information
iancze committed Jan 7, 2024
1 parent 296f0bf commit 65173f7
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 9 deletions.
6 changes: 3 additions & 3 deletions src/mpol/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(self, coords: GridCoords, persistent_vis: bool = False):
self.register_buffer("vis", None, persistent=persistent_vis)
self.vis: torch.Tensor

def forward(self, cube: torch.Tensor) -> torch.Tensor:
def forward(self, packed_cube: torch.Tensor) -> torch.Tensor:
"""
Perform the FFT of the image cube on each channel.
Expand All @@ -61,12 +61,12 @@ def forward(self, cube: torch.Tensor) -> torch.Tensor:
"""

# make sure the cube is 3D
assert cube.dim() == 3, "cube must be 3D"
assert packed_cube.dim() == 3, "cube must be 3D"

# the self.cell_size prefactor (in arcsec) is to obtain the correct output units
# since it needs to correct for the spacing of the input grid.
# See MPoL documentation and/or TMS Eqn A8.18 for more information.
self.vis = self.coords.cell_size**2 * torch.fft.fftn(cube, dim=(1, 2))
self.vis = self.coords.cell_size**2 * torch.fft.fftn(packed_cube, dim=(1, 2))

return self.vis

Expand Down
64 changes: 58 additions & 6 deletions src/mpol/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def __init__(
# base_cube = -3 yields a nearly-blank cube after softplus, whereas
# base_cube = 0.0 yields a cube with avg value of ~0.7, which is too high
self.base_cube = nn.Parameter(
-3 * torch.ones(
-3
* torch.ones(
(self.nchan, self.coords.npix, self.coords.npix),
requires_grad=True,
dtype=torch.double,
Expand Down Expand Up @@ -227,13 +228,13 @@ def __init__(

def forward(self, packed_cube: torch.Tensor) -> torch.Tensor:
r"""
Pass the cube through as an identity operation, storing the value to the
internal buffer. After the cube has been passed through, convenience
Pass the cube through as an identity operation, storing the value to the
internal buffer. After the cube has been passed through, convenience
instance attributes like `sky_cube` and `flux` will reflect the updated cube.
Parameters
----------
packed_cube : :class:`torch.Tensor` of type :class:`torch.double`
packed_cube : :class:`torch.Tensor` of type :class:`torch.double`
3D torch tensor of shape ``(nchan, npix, npix)``) in 'packed' format
Returns
Expand Down Expand Up @@ -287,10 +288,10 @@ def to_FITS(
Returns:
None
"""

from astropy import wcs
from astropy.io import fits

w = wcs.WCS(naxis=2)

w.wcs.crpix = np.array([1, 1])
Expand All @@ -312,3 +313,54 @@ def to_FITS(
hdul.writeto(fname, overwrite=overwrite)

hdul.close()

def convolve_packed_cube(
packed_cube: torch.Tensor,
coords: GridCoords,
FWHM_maj: float,
FWHM_min: float,
Omega: float,
) -> torch.Tensor:
r"""
Convolve an image cube with a 2D Gaussian PSF. Operation is carried out in the Fourier domain using a Gaussian taper.
Parameters
----------
packed_cube : :class:`torch.Tensor` of :class:`torch.double` type
shape ``(nchan, npix, npix)`` image cube in packed format.
coords: :class:`mpol.coordinates.GridCoords`
object indicating image and Fourier grid specifications.
FWHM_maj: float, units of arcsec
the FWHH of the Gaussian along the major axis
FWHM_min: float, units of arcsec
the FWHM of the Gaussian along the minor axis
Omega: float, degrees
the rotation of the major axis of the PSF, in degrees East of North. 0 degrees rotation has the major axis aligned in the East-West direction.
"""
nchan, npix_m, npix_l = packed_cube.size()
assert (npix_m == coords.npix) and (
npix_l == coords.npix
), "packed_cube {:} does not have the same pixel dimensions as indicated by coords {:}".format(
packed_cube.size(), coords.npix
)

# in FFT packed format
# we're round-tripping, so we can ignore prefactors for correctness
# calling this `vis_like`, since it's not actually the vis
vis_like = torch.fft.fftn(packed_cube, dim=(1, 2))

# convert FWHM to sigma
FWHM2sigma = 1 / (2 * np.sqrt(2 * np.log(2)))
sigma_x = FWHM_maj * FWHM2sigma
sigma_y = FWHM_min * FWHM2sigma

# calculate corresponding uu and vv matrices in packed format
taper_2D = utils.fourier_gaussian_lambda_arcsec(coords.packed_u_centers_2D, coords.packed_v_centers_2D, a=1.0, delta_x=0.0, delta_y=0.0, sigma_x=sigma_x, sigma_y=sigma_y, Omega=Omega)

# calculate taper on packed image
tapered_vis = vis_like * torch.broadcast_to(taper_2D, packed_cube.size())

# iFFT back, ignoring prefactors for round-trip
convolved_packed_cube = torch.fft.fftn(tapered_vis, dim=(1,2))

return convolved_packed_cube

0 comments on commit 65173f7

Please sign in to comment.