Skip to content

Commit

Permalink
Update convolve for full
Browse files Browse the repository at this point in the history
  • Loading branch information
DeanHazineh committed Sep 12, 2024
1 parent 0f60a5f commit 57f8942
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions dflat/render/fft_convolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,21 @@
from dflat.radial_tranforms import resize_with_crop_or_pad


def general_convolve(image, filter, rfft=False):


def general_convolve(image, filter, rfft=False, mode="valid"):
"""Runs the Fourier space convolution between an image and filter, where the filter kernels may have a different size from the image shape.
Args:
`image` (tf.float or tf.complex): Input image to apply the convolution filter kernel to, of shape [..., Ny, Nx]
`filter` (tf.float or tf.complex): Convolutional filter kernel, of shape [..., My, Mx], but the same rank as the image input
`rfft` (bool, optional): Flag to use real rfft instead of the general fft. Defaults to False.
'mode' (str, optional): Choice of valid or full convolution.
Returns:
tf.float or tf.complex: Image with the filter convolved on the inner-most two dimensions
"""
assert mode in ["valid", "full"]
init_image_shape = image.shape
im_ny = init_image_shape[-2]
im_nx = init_image_shape[-1]
Expand All @@ -25,22 +29,10 @@ def general_convolve(image, filter, rfft=False):
filt_ny = init_filter_shape[-2]
filt_nx = init_filter_shape[-1]

# If the image is smaller than the filter size, then increase the image size
if im_ny < filt_ny or im_nx < filt_nx:
image = resize_with_crop_or_pad(
image, np.maximum(im_ny, filt_ny), np.maximum(im_nx, filt_nx), False
)

# Zero pad the image with half the filter dimensionality and ensure the image is odd (non-cyclic boundary)
# Zero pad the image by half filter
padby = (len(init_image_shape) * 2) * [0]
if np.mod(im_nx, 2) == 0 and np.mod(filt_nx, 2) == 0:
padby[0:2] = [filt_nx // 2, filt_nx // 2 + 1]
else:
padby[0:2] = [filt_nx // 2, filt_nx // 2]
if np.mod(im_ny, 2) == 0 and np.mod(filt_ny, 2) == 0:
padby[2:4] = [filt_ny // 2, filt_ny // 2 + 1]
else:
padby[2:4] = [filt_ny // 2, filt_ny // 2]
padby[0:2] = [filt_nx // 2, filt_nx // 2]
padby[2:4] = [filt_ny // 2, filt_ny // 2]
image = F.pad(image, padby, mode="constant", value=0.0)

### Pad the psf to match the new image dimensionality
Expand All @@ -51,8 +43,8 @@ def general_convolve(image, filter, rfft=False):
image = checkpoint(fourier_convolve, image, filter_resh, rfft)
image = torch.real(image)

### Undo odd padding if it was done before FFT
image = resize_with_crop_or_pad(image, *init_image_shape[-2:], False)
if mode == "valid":
image = resize_with_crop_or_pad(image, *init_image_shape[-2:], False)
return image


Expand Down

0 comments on commit 57f8942

Please sign in to comment.