Skip to content

Commit 65173f7

Browse files
committed
added untested convolve method.
1 parent 296f0bf commit 65173f7

File tree

2 files changed

+61
-9
lines changed

2 files changed

+61
-9
lines changed

src/mpol/fourier.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def __init__(self, coords: GridCoords, persistent_vis: bool = False):
4444
self.register_buffer("vis", None, persistent=persistent_vis)
4545
self.vis: torch.Tensor
4646

47-
def forward(self, cube: torch.Tensor) -> torch.Tensor:
47+
def forward(self, packed_cube: torch.Tensor) -> torch.Tensor:
4848
"""
4949
Perform the FFT of the image cube on each channel.
5050
@@ -61,12 +61,12 @@ def forward(self, cube: torch.Tensor) -> torch.Tensor:
6161
"""
6262

6363
# make sure the cube is 3D
64-
assert cube.dim() == 3, "cube must be 3D"
64+
assert packed_cube.dim() == 3, "cube must be 3D"
6565

6666
# the self.cell_size prefactor (in arcsec) is to obtain the correct output units
6767
# since it needs to correct for the spacing of the input grid.
6868
# See MPoL documentation and/or TMS Eqn A8.18 for more information.
69-
self.vis = self.coords.cell_size**2 * torch.fft.fftn(cube, dim=(1, 2))
69+
self.vis = self.coords.cell_size**2 * torch.fft.fftn(packed_cube, dim=(1, 2))
7070

7171
return self.vis
7272

src/mpol/images.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def __init__(
6363
# base_cube = -3 yields a nearly-blank cube after softplus, whereas
6464
# base_cube = 0.0 yields a cube with avg value of ~0.7, which is too high
6565
self.base_cube = nn.Parameter(
66-
-3 * torch.ones(
66+
-3
67+
* torch.ones(
6768
(self.nchan, self.coords.npix, self.coords.npix),
6869
requires_grad=True,
6970
dtype=torch.double,
@@ -227,13 +228,13 @@ def __init__(
227228

228229
def forward(self, packed_cube: torch.Tensor) -> torch.Tensor:
229230
r"""
230-
Pass the cube through as an identity operation, storing the value to the
231-
internal buffer. After the cube has been passed through, convenience
231+
Pass the cube through as an identity operation, storing the value to the
232+
internal buffer. After the cube has been passed through, convenience
232233
instance attributes like `sky_cube` and `flux` will reflect the updated cube.
233234
234235
Parameters
235236
----------
236-
packed_cube : :class:`torch.Tensor` of type :class:`torch.double`
237+
packed_cube : :class:`torch.Tensor` of type :class:`torch.double`
237238
3D torch tensor of shape ``(nchan, npix, npix)``) in 'packed' format
238239
239240
Returns
@@ -287,10 +288,10 @@ def to_FITS(
287288
Returns:
288289
None
289290
"""
290-
291+
291292
from astropy import wcs
292293
from astropy.io import fits
293-
294+
294295
w = wcs.WCS(naxis=2)
295296

296297
w.wcs.crpix = np.array([1, 1])
@@ -312,3 +313,54 @@ def to_FITS(
312313
hdul.writeto(fname, overwrite=overwrite)
313314

314315
hdul.close()
316+
317+
def convolve_packed_cube(
318+
packed_cube: torch.Tensor,
319+
coords: GridCoords,
320+
FWHM_maj: float,
321+
FWHM_min: float,
322+
Omega: float,
323+
) -> torch.Tensor:
324+
r"""
325+
Convolve an image cube with a 2D Gaussian PSF. Operation is carried out in the Fourier domain using a Gaussian taper.
326+
327+
Parameters
328+
----------
329+
packed_cube : :class:`torch.Tensor` of :class:`torch.double` type
330+
shape ``(nchan, npix, npix)`` image cube in packed format.
331+
coords: :class:`mpol.coordinates.GridCoords`
332+
object indicating image and Fourier grid specifications.
333+
FWHM_maj: float, units of arcsec
334+
the FWHH of the Gaussian along the major axis
335+
FWHM_min: float, units of arcsec
336+
the FWHM of the Gaussian along the minor axis
337+
Omega: float, degrees
338+
the rotation of the major axis of the PSF, in degrees East of North. 0 degrees rotation has the major axis aligned in the East-West direction.
339+
"""
340+
nchan, npix_m, npix_l = packed_cube.size()
341+
assert (npix_m == coords.npix) and (
342+
npix_l == coords.npix
343+
), "packed_cube {:} does not have the same pixel dimensions as indicated by coords {:}".format(
344+
packed_cube.size(), coords.npix
345+
)
346+
347+
# in FFT packed format
348+
# we're round-tripping, so we can ignore prefactors for correctness
349+
# calling this `vis_like`, since it's not actually the vis
350+
vis_like = torch.fft.fftn(packed_cube, dim=(1, 2))
351+
352+
# convert FWHM to sigma
353+
FWHM2sigma = 1 / (2 * np.sqrt(2 * np.log(2)))
354+
sigma_x = FWHM_maj * FWHM2sigma
355+
sigma_y = FWHM_min * FWHM2sigma
356+
357+
# calculate corresponding uu and vv matrices in packed format
358+
taper_2D = utils.fourier_gaussian_lambda_arcsec(coords.packed_u_centers_2D, coords.packed_v_centers_2D, a=1.0, delta_x=0.0, delta_y=0.0, sigma_x=sigma_x, sigma_y=sigma_y, Omega=Omega)
359+
360+
# calculate taper on packed image
361+
tapered_vis = vis_like * torch.broadcast_to(taper_2D, packed_cube.size())
362+
363+
# iFFT back, ignoring prefactors for round-trip
364+
convolved_packed_cube = torch.fft.fftn(tapered_vis, dim=(1,2))
365+
366+
return convolved_packed_cube

0 commit comments

Comments
 (0)