diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py index f4f2c99a40..306904bf42 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/__init__.py @@ -70,7 +70,7 @@ def analyzeResults(out_dir, DT_dir, prediction_dir, mean_prefix): def get_image_registration_transform(fixed_image_file, moving_image_file, transform_type='rigid'): itk_transform = image_utils.get_image_registration_transform(fixed_image_file, moving_image_file, - transform_type='rigid') + transform_type=transform_type) return itk_transform diff --git a/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py b/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py index 5fe6397794..86e12fc03f 100644 --- a/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py +++ b/Python/DeepSSMUtilsPackage/DeepSSMUtils/image_utils.py @@ -3,7 +3,7 @@ import numpy as np def get_image_registration_transform(fixed_image_file, moving_image_file, transform_type='rigid'): - # Import Default Parameter Map + # Prepare parameter map parameter_object = itk.ParameterObject.New() parameter_map = parameter_object.GetDefaultParameterMap('rigid') if transform_type == 'similarity': @@ -12,11 +12,15 @@ def get_image_registration_transform(fixed_image_file, moving_image_file, transf parameter_map['Transform'] = ['TranslationTransform'] parameter_map['MaximumNumberOfIterations'] = ['1024'] parameter_object.AddParameterMap(parameter_map) - # Call registration function + + # Load images fixed_image = itk.imread(fixed_image_file, itk.F) moving_image = itk.imread(moving_image_file, itk.F) + + # Call registration method result_image, result_transform_parameters = itk.elastix_registration_method( fixed_image, moving_image, parameter_object=parameter_object) + # Get transform matrix parameter_map = result_transform_parameters.GetParameterMap(0) transform_params = np.array(parameter_map['TransformParameters'], dtype=float) @@ -25,11 +29,14 @@ def get_image_registration_transform(fixed_image_file, moving_image_file, transf elif transform_type == 'similarity': itk_transform = SimpleITK.Similarity3DTransform() elif transform_type == 'translation': - itk_transform = SimpleITK.TranslationTransform() + itk_transform = SimpleITK.TranslationTransform(3) else: - print("Error: " + transform_type + " transform unimplemented.") + raise NotImplementedError("Error: " + transform_type + " transform unimplemented.") itk_transform.SetParameters(transform_params) itk_transform_matrix = np.eye(4) - itk_transform_matrix[:3,:3] = np.array(itk_transform.GetMatrix()).reshape(3,3) - itk_transform_matrix[-1,:3] = np.array(itk_transform.GetTranslation()) + if transform_type == 'translation': + itk_transform_matrix[-1, :3] = np.array(itk_transform.GetOffset()) + else: + itk_transform_matrix[:3,:3] = np.array(itk_transform.GetMatrix()).reshape(3,3) + itk_transform_matrix[-1,:3] = np.array(itk_transform.GetTranslation()) return itk_transform_matrix \ No newline at end of file