From 98c11098bca33dea63b221433faa6ef8ca2c6f54 Mon Sep 17 00:00:00 2001 From: Andreas Schuh Date: Fri, 18 Oct 2024 10:20:22 +0000 Subject: [PATCH] fix: Rotation applied in surface_image_stencil() based on image orientation --- src/deepali/utils/vtk/image.py | 40 ++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/src/deepali/utils/vtk/image.py b/src/deepali/utils/vtk/image.py index e536c50..893121f 100644 --- a/src/deepali/utils/vtk/image.py +++ b/src/deepali/utils/vtk/image.py @@ -8,9 +8,9 @@ vtkImageData, vtkImageStencilData, vtkImageStencilToImage, - vtkMatrixToLinearTransform, vtkPolyData, vtkPolyDataToImageStencil, + vtkTransform, vtkTransformPolyDataFilter, ) @@ -42,29 +42,31 @@ def surface_mesh_grid(*mesh: vtkPolyData, resolution: Optional[float] = None) -> def surface_image_stencil(mesh: vtkPolyData, grid: Grid) -> vtkImageStencilData: - r"""Convert vtkPolyData surface mesh to image stencil.""" - max_index = [n - 1 for n in grid.size().tolist()] - - rot = np.eye(4, dtype=np.float) - rot[:3, :3] = np.array(grid.direction).reshape(3, 3) - rot = numpy_to_vtk_matrix4x4(rot) - - transform = vtkMatrixToLinearTransform() - transform.SetInput(rot) - + r"""Convert vtkPolyData surface mesh to image stencil.""" + # Create the transform + transform = vtkTransform() + transform.Translate(grid.center().tolist()) + transform.Concatenate(numpy_to_vtk_matrix4x4(grid.direction().numpy().T)) # type: ignore + transform.Translate(grid.center().neg().tolist()) + + # Apply the transform to the polydata transformer = vtkTransformPolyDataFilter() transformer.SetInputData(mesh) transformer.SetTransform(transform) - converter = vtkPolyDataToImageStencil() - converter.SetInputConnection(transformer.GetOutputPort()) - converter.SetOutputOrigin(grid.origin().tolist()) - converter.SetOutputSpacing(grid.spacing().tolist()) - converter.SetOutputWholeExtent([0, max_index[0], 0, max_index[1], 0, max_index[2]]) - converter.Update() - + # Convert the transformed polydata to an image stencil + stencil_grid = Grid(size=grid.size(), spacing=grid.spacing(), center=grid.center()) + stencil_extent = [0, grid.size(0) - 1, 0, grid.size(1) - 1, 0, grid.size(2) - 1] + polydata_to_stencil = vtkPolyDataToImageStencil() + polydata_to_stencil.SetInputConnection(transformer.GetOutputPort()) + polydata_to_stencil.SetOutputOrigin(stencil_grid.origin().tolist()) + polydata_to_stencil.SetOutputSpacing(stencil_grid.spacing().tolist()) + polydata_to_stencil.SetOutputWholeExtent(stencil_extent) + polydata_to_stencil.Update() + + # Get the output stencil stencil = vtkImageStencilData() - stencil.DeepCopy(converter.GetOutput()) + stencil.DeepCopy(polydata_to_stencil.GetOutput()) return stencil