Skip to content

Commit

Permalink
added Gauss Fourier option for Base Cube.
Browse files Browse the repository at this point in the history
  • Loading branch information
iancze committed Mar 5, 2024
1 parent 0b9e416 commit add01fe
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 2 deletions.
88 changes: 88 additions & 0 deletions src/mpol/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,93 @@ def forward(self, cube: torch.Tensor) -> torch.Tensor:
return utils.sky_cube_to_packed_cube(conv_sky_cube)


class GaussBaseBeam(nn.Module):
r"""
This layer will convolve the base cube with a Gaussian beam of variable resolution.
The FWHM of the beam (in arcsec) is a trainable parameter of the layer.
Parameters
----------
coords : :class:`mpol.coordinates.GridCoords`
an object instantiated from the GridCoords class, containing information about
the image `cell_size` and `npix`.
nchan : int
the number of channels in the base cube. Default = 1.
FWHM: float, units of arcsec
the FWHH of the Gaussian
"""

def __init__(self, coords: GridCoords, nchan: int) -> None:
super().__init__()

self.coords = coords
self.nchan = nchan

self._FWHM_base = nn.Parameter(torch.tensor([-3.0]))
self.softplus = nn.Softplus()
# -3.0 corresponds to about 0.05 arcsec

# store coordinates to register so they transfer to GPU
self.register_buffer("u", torch.tensor(self.coords.packed_u_centers_2D, dtype=torch.float32))
self.register_buffer("v", torch.tensor(self.coords.packed_v_centers_2D, dtype=torch.float32))

@property
def FWHM(self):
r"""Map from base parameter to actual FWHM."""
return self.softplus(self._FWHM_base) # ensures always positive

def forward(self, packed_cube):
r"""
Convolve a packed_cube image with a 2D Gaussian PSF. Operation is carried out
in the Fourier domain using a Gaussian taper.
Parameters
----------
packed_cube : :class:`torch.Tensor` type
shape ``(nchan, npix, npix)`` image cube in packed format.
Returns
-------
:class:`torch.Tensor`
The convolved cube in packed format.
"""
nchan, npix_m, npix_l = packed_cube.size()
assert (
(npix_m == self.coords.npix) and (npix_l == self.coords.npix)
), "packed_cube {:} does not have the same pixel dimensions as indicated by coords {:}".format(
packed_cube.size(), self.coords.npix
)

# in FFT packed format
# we're round-tripping, so we can ignore prefactors for correctness
# calling this `vis_like`, since it's not actually the vis
vis_like = torch.fft.fftn(packed_cube, dim=(1, 2))

# convert FWHM to sigma and to radians
FWHM2sigma = 1 / (2 * np.sqrt(2 * np.log(2)))
sigma = self.FWHM * FWHM2sigma * constants.arcsec # radians

# calculate the UV taper from the FWHM size.
taper_2D = torch.exp(-2 * np.pi**2 * (sigma**2 * self.u**2 + sigma**2 * self.v**2))

# apply taper to packed image
tapered_vis = vis_like * torch.broadcast_to(taper_2D, packed_cube.size())

# iFFT back, ignoring prefactors for round-trip
convolved_packed_cube = torch.fft.ifftn(tapered_vis, dim=(1, 2))

# assert imaginaries are effectively zero, otherwise something went wrong
thresh = 1e-7
assert (
torch.max(convolved_packed_cube.imag) < thresh
), "Round-tripped image contains max imaginary value {:} > {:} threshold, something may be amiss.".format(
torch.max(convolved_packed_cube.imag), thresh
)

r_cube: torch.Tensor = convolved_packed_cube.real
return r_cube


class GaussConvCube(nn.Module):
r"""
Once instantiated, this convolutional layer is used to convolve the input cube with
Expand Down Expand Up @@ -322,6 +409,7 @@ def forward(self, sky_cube: torch.Tensor) -> torch.Tensor:
convolved_sky = self.m(sky_cube)
return convolved_sky


class ImageCube(nn.Module):
r"""
The parameter set is the pixel values of the image cube itself. The pixels are
Expand Down
40 changes: 38 additions & 2 deletions test/images_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def test_gaussian_kernel_rotate(coords, tmp_path):
plt.close("all")


def test_convolve(sky_cube, coords, tmp_path):
def test_GaussConvCube(sky_cube, coords, tmp_path):
# show only the first channel
chan = 0
nchan = sky_cube.size()[0]
Expand Down Expand Up @@ -272,7 +272,7 @@ def test_convolve(sky_cube, coords, tmp_path):

plt.close("all")

def test_convolve_rotate(sky_cube, coords, tmp_path):
def test_GaussConvCube_rotate(sky_cube, coords, tmp_path):
# show only the first channel
chan = 0
nchan = sky_cube.size()[0]
Expand Down Expand Up @@ -301,6 +301,42 @@ def test_convolve_rotate(sky_cube, coords, tmp_path):

plt.close("all")

def test_GaussBaseBeam(packed_cube, coords, tmp_path):
# show only the first channel
chan = 0
nchan = packed_cube.size()[0]

layer = images.GaussBaseBeam(coords, nchan)

for FWHM_base in np.linspace(-4, 0.5, num=10):
fig, ax = plt.subplots(ncols=2)
# put back to sky
sky_cube = utils.packed_cube_to_sky_cube(packed_cube)
im = ax[0].imshow(
sky_cube[chan], extent=coords.img_ext, origin="lower", cmap="inferno"
)
flux = coords.cell_size**2 * torch.sum(sky_cube[chan])
ax[0].set_title(f"tot flux: {flux:.3f} Jy")
plt.colorbar(im, ax=ax[0])

# set base resolution
layer._FWHM_base = torch.nn.Parameter(torch.tensor(FWHM_base, dtype=torch.float32))

c = layer(packed_cube)
# put back to sky
c_sky = utils.packed_cube_to_sky_cube(c)
flux = coords.cell_size**2 * torch.sum(c_sky[chan])
im = ax[1].imshow(
c_sky[chan].detach().numpy(), extent=coords.img_ext, origin="lower", cmap="inferno"
)
ax[1].set_title(f"tot flux: {flux:.3f} Jy")

plt.colorbar(im, ax=ax[1])
fig.savefig(tmp_path / "convolved_FWHM_{:.2f}.png".format(layer.FWHM), dpi=300)

plt.close("all")


# old rotate for FFT routine
# def test_convolve_rotate(packed_cube, coords, tmp_path):
# # show only the first channel
Expand Down

0 comments on commit add01fe

Please sign in to comment.