Skip to content

Commit

Permalink
Merge pull request #39 from blowekamp/update_fft_initialize
Browse files Browse the repository at this point in the history
Add optional parameters fft_based_translation_initialization
  • Loading branch information
zivy authored Feb 8, 2024
2 parents 3e44d5f + dced646 commit 431d7c2
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 16 deletions.
68 changes: 53 additions & 15 deletions SimpleITK/utilities/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,12 @@


def fft_based_translation_initialization(
fixed: sitk.Image, moving: sitk.Image
fixed: sitk.Image,
moving: sitk.Image,
*,
required_fraction_of_overlapping_pixels: float = 0.0,
initial_transform: sitk.Transform = None,
masked_pixel_value: float = None,
) -> sitk.TranslationTransform:
"""Perform fast Fourier transform based normalized correlation to find the translation which maximizes correlation
between the images.
Expand All @@ -34,30 +39,55 @@ def fft_based_translation_initialization(
:param fixed: A SimpleITK image object.
:param moving: Another SimpleITK Image object, which will be resampled onto the grid of the fixed image if it is not
congruent.
:return: A TranslationTransform mapping physical points from the fixed to the moving image.
:param required_fraction_of_overlapping_pixels: The required fraction of overlapping pixels between the fixed and
moving image. The value should be in the range of [0, 1]. If the value is 1, then the full overlap is required.
:param initial_transform: An initial transformation to be applied to the moving image by resampling before the
FFT registration. The returned transform will be of the initial_transform type with the translation updated.
:param masked_pixel_value: The value of input pixels to be ignored by correlation. If None, then the
FFTNormalizedCoorrelation will be used, otherwise the MaskedFFTNormalizedCorrelation will be used.
:return: A TranslationTransform (or the initial_transform tyype) mapping physical points from the fixed to the
moving image.
"""

if (
moving.GetSpacing() != fixed.GetSpacing()
initial_transform is not None
or moving.GetSpacing() != fixed.GetSpacing()
or moving.GetDirection() != fixed.GetDirection()
or moving.GetOrigin() != fixed.GetOrigin()
):
resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(fixed)

if initial_transform is not None:
resampler.SetTransform(initial_transform)
moving = resampler.Execute(moving)

sigma = fixed.GetSpacing()[0]
pixel_type = sitk.sitkFloat32

fft_fixed = sitk.Cast(sitk.SmoothingRecursiveGaussian(fixed, sigma), pixel_type)
fft_moving = sitk.Cast(sitk.SmoothingRecursiveGaussian(moving, sigma), pixel_type)

out = sitk.FFTNormalizedCorrelation(fft_fixed, fft_moving)

out = sitk.SmoothingRecursiveGaussian(out)
cc = sitk.ConnectedComponent(sitk.RegionalMaxima(out, fullyConnected=True))
fixed = sitk.Cast(sitk.SmoothingRecursiveGaussian(fixed, sigma), pixel_type)
moving = sitk.Cast(sitk.SmoothingRecursiveGaussian(moving, sigma), pixel_type)

if masked_pixel_value is None:
xcorr = sitk.FFTNormalizedCorrelation(
fixed,
moving,
requiredFractionOfOverlappingPixels=required_fraction_of_overlapping_pixels,
)
else:
xcorr = sitk.MaskedFFTNormalizedCorrelation(
fixed,
moving,
sitk.Cast(fixed != masked_pixel_value, pixel_type),
sitk.Cast(moving != masked_pixel_value, pixel_type),
requiredFractionOfOverlappingPixels=required_fraction_of_overlapping_pixels,
)

xcorr = sitk.SmoothingRecursiveGaussian(xcorr, sigma)

cc = sitk.ConnectedComponent(sitk.RegionalMaxima(xcorr, fullyConnected=True))
stats = sitk.LabelStatisticsImageFilter()
stats.Execute(out, cc)
stats.Execute(xcorr, cc)
labels = sorted(stats.GetLabels(), key=lambda l: stats.GetMean(l))

peak_bb = stats.GetBoundingBox(labels[-1])
Expand All @@ -67,13 +97,21 @@ def fft_based_translation_initialization(
for min_idx, max_idx in zip(peak_bb[0::2], peak_bb[1::2])
]

peak_pt = out.TransformContinuousIndexToPhysicalPoint(peak_idx)
peak_pt = xcorr.TransformContinuousIndexToPhysicalPoint(peak_idx)
peak_value = stats.GetMean(labels[-1])

center_pt = out.TransformContinuousIndexToPhysicalPoint(
[p / 2.0 for p in out.GetSize()]
center_pt = xcorr.TransformContinuousIndexToPhysicalPoint(
[p / 2.0 for p in xcorr.GetSize()]
)

translation = [c - p for c, p in zip(center_pt, peak_pt)]
if initial_transform is not None:
offset = initial_transform.TransformVector(translation, point=[0, 0])

tx_out = sitk.Transform(initial_transform).Downcast()
tx_out.SetTranslation(
[a + b for (a, b) in zip(initial_transform.GetTranslation(), offset)]
)
return tx_out

return sitk.TranslationTransform(out.GetDimension(), translation)
return sitk.TranslationTransform(xcorr.GetDimension(), translation)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
SimpleITK>=2.0
SimpleITK>=2.3.0
numpy
39 changes: 39 additions & 0 deletions test/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_fft_initialization():


def test_fft_initialization2():
"""Testing with different spacing and origin to force resampling."""
fixed_img = sitk.Image([1024, 512], sitk.sitkUInt8)

fixed_img[510:520, 255:265] = 10
Expand All @@ -60,6 +61,44 @@ def test_fft_initialization2():
assert tx.GetOffset() == (-85.0, -105.0)


def test_fft_initialization3():
"""Testing with required fraction of overlapping pixels."""
fixed_img = sitk.Image([1024, 512], sitk.sitkUInt8)
fixed_img[0:10, 0:20] = 10
fixed_img[510:520, 255:265] = 10

moving_img = sitk.Image([1024, 512], sitk.sitkUInt8)
moving_img[425:435, 300:320] = 8

tx = sitkutils.fft_based_translation_initialization(
fixed_img,
moving_img,
)
assert tx.GetOffset() == (425, 300.0)

tx = sitkutils.fft_based_translation_initialization(
fixed_img, moving_img, required_fraction_of_overlapping_pixels=0.5
)
assert tx.GetOffset() == (-85.0, 50.0)


def test_fft_initialization4():
"""Testing with initial transform."""
fixed_img = sitk.Image([1024, 512], sitk.sitkUInt8)
fixed_img[510:520, 255:265] = 10

moving_img = sitk.Image([1024, 512], sitk.sitkUInt8)
moving_img.SetSpacing((10, 10))
moving_img[425:435, 300:310] = 8

initial_transform = sitk.Similarity2DTransform(10)

tx = sitkutils.fft_based_translation_initialization(
fixed_img, moving_img, initial_transform=initial_transform
)
assert tx.GetTranslation() == (-850.0, 450.0)


def test_overlay_bounding_boxes():
bounding_boxes = [[10, 10, 60, 20], [200, 180, 230, 250]]
scalar_image = sitk.Image([256, 256], sitk.sitkUInt8)
Expand Down

0 comments on commit 431d7c2

Please sign in to comment.