diff --git a/SimpleITK/utilities/fft.py b/SimpleITK/utilities/fft.py index b586ced..b301022 100644 --- a/SimpleITK/utilities/fft.py +++ b/SimpleITK/utilities/fft.py @@ -20,7 +20,12 @@ def fft_based_translation_initialization( - fixed: sitk.Image, moving: sitk.Image + fixed: sitk.Image, + moving: sitk.Image, + *, + required_fraction_of_overlapping_pixels: float = 0.0, + initial_transform: sitk.Transform = None, + masked_pixel_value: float = None, ) -> sitk.TranslationTransform: """Perform fast Fourier transform based normalized correlation to find the translation which maximizes correlation between the images. @@ -34,30 +39,55 @@ def fft_based_translation_initialization( :param fixed: A SimpleITK image object. :param moving: Another SimpleITK Image object, which will be resampled onto the grid of the fixed image if it is not congruent. - :return: A TranslationTransform mapping physical points from the fixed to the moving image. + :param required_fraction_of_overlapping_pixels: The required fraction of overlapping pixels between the fixed and + moving image. The value should be in the range of [0, 1]. If the value is 1, then the full overlap is required. + :param initial_transform: An initial transformation to be applied to the moving image by resampling before the + FFT registration. The returned transform will be of the initial_transform type with the translation updated. + :param masked_pixel_value: The value of input pixels to be ignored by correlation. If None, then the + FFTNormalizedCoorrelation will be used, otherwise the MaskedFFTNormalizedCorrelation will be used. + :return: A TranslationTransform (or the initial_transform tyype) mapping physical points from the fixed to the + moving image. """ if ( - moving.GetSpacing() != fixed.GetSpacing() + initial_transform is not None + or moving.GetSpacing() != fixed.GetSpacing() or moving.GetDirection() != fixed.GetDirection() or moving.GetOrigin() != fixed.GetOrigin() ): resampler = sitk.ResampleImageFilter() resampler.SetReferenceImage(fixed) + + if initial_transform is not None: + resampler.SetTransform(initial_transform) moving = resampler.Execute(moving) sigma = fixed.GetSpacing()[0] pixel_type = sitk.sitkFloat32 - fft_fixed = sitk.Cast(sitk.SmoothingRecursiveGaussian(fixed, sigma), pixel_type) - fft_moving = sitk.Cast(sitk.SmoothingRecursiveGaussian(moving, sigma), pixel_type) - - out = sitk.FFTNormalizedCorrelation(fft_fixed, fft_moving) - - out = sitk.SmoothingRecursiveGaussian(out) - cc = sitk.ConnectedComponent(sitk.RegionalMaxima(out, fullyConnected=True)) + fixed = sitk.Cast(sitk.SmoothingRecursiveGaussian(fixed, sigma), pixel_type) + moving = sitk.Cast(sitk.SmoothingRecursiveGaussian(moving, sigma), pixel_type) + + if masked_pixel_value is None: + xcorr = sitk.FFTNormalizedCorrelation( + fixed, + moving, + requiredFractionOfOverlappingPixels=required_fraction_of_overlapping_pixels, + ) + else: + xcorr = sitk.MaskedFFTNormalizedCorrelation( + fixed, + moving, + sitk.Cast(fixed != masked_pixel_value, pixel_type), + sitk.Cast(moving != masked_pixel_value, pixel_type), + requiredFractionOfOverlappingPixels=required_fraction_of_overlapping_pixels, + ) + + xcorr = sitk.SmoothingRecursiveGaussian(xcorr, sigma) + + cc = sitk.ConnectedComponent(sitk.RegionalMaxima(xcorr, fullyConnected=True)) stats = sitk.LabelStatisticsImageFilter() - stats.Execute(out, cc) + stats.Execute(xcorr, cc) labels = sorted(stats.GetLabels(), key=lambda l: stats.GetMean(l)) peak_bb = stats.GetBoundingBox(labels[-1]) @@ -67,13 +97,21 @@ def fft_based_translation_initialization( for min_idx, max_idx in zip(peak_bb[0::2], peak_bb[1::2]) ] - peak_pt = out.TransformContinuousIndexToPhysicalPoint(peak_idx) + peak_pt = xcorr.TransformContinuousIndexToPhysicalPoint(peak_idx) peak_value = stats.GetMean(labels[-1]) - center_pt = out.TransformContinuousIndexToPhysicalPoint( - [p / 2.0 for p in out.GetSize()] + center_pt = xcorr.TransformContinuousIndexToPhysicalPoint( + [p / 2.0 for p in xcorr.GetSize()] ) translation = [c - p for c, p in zip(center_pt, peak_pt)] + if initial_transform is not None: + offset = initial_transform.TransformVector(translation, point=[0, 0]) + + tx_out = sitk.Transform(initial_transform).Downcast() + tx_out.SetTranslation( + [a + b for (a, b) in zip(initial_transform.GetTranslation(), offset)] + ) + return tx_out - return sitk.TranslationTransform(out.GetDimension(), translation) + return sitk.TranslationTransform(xcorr.GetDimension(), translation) diff --git a/requirements.txt b/requirements.txt index b492803..421b76e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -SimpleITK>=2.0 +SimpleITK>=2.3.0 numpy \ No newline at end of file diff --git a/test/test_utilities.py b/test/test_utilities.py index bf66373..7ac86f0 100644 --- a/test/test_utilities.py +++ b/test/test_utilities.py @@ -47,6 +47,7 @@ def test_fft_initialization(): def test_fft_initialization2(): + """Testing with different spacing and origin to force resampling.""" fixed_img = sitk.Image([1024, 512], sitk.sitkUInt8) fixed_img[510:520, 255:265] = 10 @@ -60,6 +61,44 @@ def test_fft_initialization2(): assert tx.GetOffset() == (-85.0, -105.0) +def test_fft_initialization3(): + """Testing with required fraction of overlapping pixels.""" + fixed_img = sitk.Image([1024, 512], sitk.sitkUInt8) + fixed_img[0:10, 0:20] = 10 + fixed_img[510:520, 255:265] = 10 + + moving_img = sitk.Image([1024, 512], sitk.sitkUInt8) + moving_img[425:435, 300:320] = 8 + + tx = sitkutils.fft_based_translation_initialization( + fixed_img, + moving_img, + ) + assert tx.GetOffset() == (425, 300.0) + + tx = sitkutils.fft_based_translation_initialization( + fixed_img, moving_img, required_fraction_of_overlapping_pixels=0.5 + ) + assert tx.GetOffset() == (-85.0, 50.0) + + +def test_fft_initialization4(): + """Testing with initial transform.""" + fixed_img = sitk.Image([1024, 512], sitk.sitkUInt8) + fixed_img[510:520, 255:265] = 10 + + moving_img = sitk.Image([1024, 512], sitk.sitkUInt8) + moving_img.SetSpacing((10, 10)) + moving_img[425:435, 300:310] = 8 + + initial_transform = sitk.Similarity2DTransform(10) + + tx = sitkutils.fft_based_translation_initialization( + fixed_img, moving_img, initial_transform=initial_transform + ) + assert tx.GetTranslation() == (-850.0, 450.0) + + def test_overlay_bounding_boxes(): bounding_boxes = [[10, 10, 60, 20], [200, 180, 230, 250]] scalar_image = sitk.Image([256, 256], sitk.sitkUInt8)