Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Curved Planar Reformation via vtkMRMLSliceLogic #35

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
338 changes: 37 additions & 301 deletions CurvedPlanarReformat/CurvedPlanarReformat.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,326 +169,62 @@ def __init__(self):
# there is no need to compute displacement for each slice,
# we just compute for every n-th to make computation faster and inverse computation more robust
# (less contradiction because of there is less overlapping between neighbor slices)
self.transformSpacingFactor = 5.0

@staticmethod
def getPointsProjectedToPlane(pointsArray, transformWorldToPlane):
appLogic = slicer.app.applicationLogic()
resamplerName = "ResampleScalarVectorDWIVolume"
found = appLogic.IsVolumeResamplerRegistered(resamplerName)
if not found:
mesg = f"CurvedPlanarReformat: {resamplerName!r} is not registered"
raise LookupError(mesg)
collectionOfSliceLogics = appLogic.GetSliceLogics()
numSliceLogics = collectionOfSliceLogics.GetNumberOfItems()
if numSliceLogics == 0:
mesg = "CurvedPlanarReformat: no SliceLogics found"
raise LookupError(mesg)
self.sliceLogic = collectionOfSliceLogics.GetItemAsObject(0)
self.sliceLogic.CurvedPlanarReformationInit()

def getPointsProjectedToPlane(self, pointsArray, transformWorldToPlane):
"""
Returns points projected to the plane coordinate system (plane normal = plane Z axis).
pointsArray contains each point as a column vector.
"""
import numpy as np
numberOfPoints = pointsArray.shape[1]
# Concatenate a 4th line containing 1s so that we can transform the positions using
# a single matrix multiplication.
pointsArray_World = np.row_stack((pointsArray,np.ones(numberOfPoints)))

# Point positions in the plane coordinate system:
pointsArray_Plane = np.dot(transformWorldToPlane, pointsArray_World)
# Projected point positions in the plane coordinate system:
pointsArray_Plane[2,:] = np.zeros(numberOfPoints)
# Projected point positions in the world coordinate system:
pointsArrayProjected_World = np.dot(np.linalg.inv(transformWorldToPlane), pointsArray_Plane)

# remove the last row (all ones)
pointsArrayProjected_World = pointsArrayProjected_World[0:3,:]

return pointsArrayProjected_World
pointsArrayOut = vtk.vtkPoints()
success = self.sliceLogic.CurvedPlanarReformationGetPointsProjectedToPlane(
pointsArray, transformWorldToPlane, pointsArrayOut
)
if not success:
raise ValueError("getPointsProjectedToPlane failed")
return pointsArrayOut

def computeStraighteningTransform(self, transformToStraightenedNode, curveNode, sliceSizeMm, outputSpacingMm, stretching=False, rotationDeg=0.0, reslicingPlanesModelNode=None):
"""
Compute straightened volume (useful for example for visualization of curved vessels)
stretching: if True then stretching transform will be computed, otherwise straightening
"""

# Create a temporary resampled curve
resamplingCurveSpacing = outputSpacingMm * self.transformSpacingFactor
originalCurvePoints = curveNode.GetCurvePointsWorld()
sampledPoints = vtk.vtkPoints()
if not slicer.vtkMRMLMarkupsCurveNode.ResamplePoints(originalCurvePoints, sampledPoints, resamplingCurveSpacing, False):
raise ValueError("Resampling curve failed")
resampledCurveNode = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLMarkupsCurveNode", "CurvedPlanarReformat_resampled_curve_temp")
resampledCurveNode.SetNumberOfPointsPerInterpolatingSegment(1)
resampledCurveNode.SetCurveTypeToLinear()
resampledCurveNode.SetControlPointPositionsWorld(sampledPoints)

curveNodePlane = vtk.vtkPlane()
slicer.modules.markups.logic().GetBestFitPlane(resampledCurveNode, curveNodePlane)

# Z axis (from first curve point to last, this will be the straightened curve long axis)
curveStartPoint = np.zeros(3)
curveEndPoint = np.zeros(3)
resampledCurveNode.GetNthControlPointPositionWorld(0, curveStartPoint)
resampledCurveNode.GetNthControlPointPositionWorld(resampledCurveNode.GetNumberOfControlPoints()-1, curveEndPoint)
transformGridAxisZ = (curveEndPoint-curveStartPoint)/np.linalg.norm(curveEndPoint-curveStartPoint)

if stretching:
# Y axis = best fit plane normal
transformGridAxisY = np.copy(curveNodePlane.GetNormal())

# X axis normalize
transformGridAxisX = np.cross(transformGridAxisZ, transformGridAxisY)
transformGridAxisX = transformGridAxisX/np.linalg.norm(transformGridAxisX)

# Make sure that Z axis is orthogonal to X and Y
orthogonalizedTransformGridAxisZ = np.cross(transformGridAxisX, transformGridAxisY)
orthogonalizedTransformGridAxisZ = orthogonalizedTransformGridAxisZ/np.linalg.norm(orthogonalizedTransformGridAxisZ)
if np.dot(transformGridAxisZ, orthogonalizedTransformGridAxisZ) > 0:
transformGridAxisZ = orthogonalizedTransformGridAxisZ
else:
transformGridAxisZ = -orthogonalizedTransformGridAxisZ
transformGridAxisX = -transformGridAxisX

else:

# X axis = average X axis of curve, to minimize torsion (and so have a simple displacement field, which can be robustly inverted)
sumCurveAxisX_RAS = np.zeros(3)
numberOfPoints = resampledCurveNode.GetNumberOfControlPoints()
for gridK in range(numberOfPoints):
curvePointToWorld = vtk.vtkMatrix4x4()
resampledCurveNode.GetCurvePointToWorldTransformAtPointIndex(resampledCurveNode.GetCurvePointIndexFromControlPointIndex(gridK), curvePointToWorld)
curvePointToWorldArray = slicer.util.arrayFromVTKMatrix(curvePointToWorld)
curveAxisX_RAS = curvePointToWorldArray[0:3, 0]
sumCurveAxisX_RAS += curveAxisX_RAS
meanCurveAxisX_RAS = sumCurveAxisX_RAS/np.linalg.norm(sumCurveAxisX_RAS)
transformGridAxisX = meanCurveAxisX_RAS

# Y axis normalize
transformGridAxisY = np.cross(transformGridAxisZ, transformGridAxisX)
transformGridAxisY = transformGridAxisY/np.linalg.norm(transformGridAxisY)

# Make sure that X axis is orthogonal to Y and Z
transformGridAxisX = np.cross(transformGridAxisY, transformGridAxisZ)
transformGridAxisX = transformGridAxisX/np.linalg.norm(transformGridAxisX)

# Rotate by rotationDeg around the Z axis
gridDirectionMatrixArray = np.eye(4)
gridDirectionMatrixArray[0:3, 0] = transformGridAxisX
gridDirectionMatrixArray[0:3, 1] = transformGridAxisY
gridDirectionMatrixArray[0:3, 2] = transformGridAxisZ
gridDirectionMatrix = slicer.util.vtkMatrixFromArray(gridDirectionMatrixArray)
#
gridDirectionTransform = vtk.vtkTransform()
gridDirectionTransform.Concatenate(gridDirectionMatrix)
gridDirectionTransform.RotateZ(rotationDeg)
#
gridDirectionMatrixArray = slicer.util.arrayFromVTKMatrix(gridDirectionTransform.GetMatrix())
transformGridAxisX = gridDirectionMatrixArray[0:3, 0]
transformGridAxisY = gridDirectionMatrixArray[0:3, 1]
transformGridAxisZ = gridDirectionMatrixArray[0:3, 2]

if stretching:
# Project curve points to grid YZ plane
transformFromGridYZPlane = np.eye(4)
transformFromGridYZPlane[0:3, 0] = transformGridAxisY
transformFromGridYZPlane[0:3, 1] = transformGridAxisZ
transformFromGridYZPlane[0:3, 2] = transformGridAxisX
transformFromGridYZPlane[0:3, 3] = curveNodePlane.GetOrigin()
transformToGridYZPlane = np.linalg.inv(transformFromGridYZPlane)

originalCurvePointsArray = slicer.util.arrayFromMarkupsCurvePoints(curveNode)
curvePointsProjected_RAS = CurvedPlanarReformatLogic.getPointsProjectedToPlane(originalCurvePointsArray.T, transformToGridYZPlane).T
slicer.util.updateMarkupsControlPointsFromArray(resampledCurveNode, curvePointsProjected_RAS)

# After projection, resampling is needed to get uniform distances
originalCurvePoints = resampledCurveNode.GetCurvePointsWorld()
sampledPoints = vtk.vtkPoints()
if not slicer.vtkMRMLMarkupsCurveNode.ResamplePoints(originalCurvePoints, sampledPoints, resamplingCurveSpacing, False):
raise ValueError("Resampling curve failed")
resampledCurveNode.SetControlPointPositionsWorld(sampledPoints)

# Origin (makes the grid centered at the curve)
curveLength = resampledCurveNode.GetCurveLengthWorld()
transformGridOrigin = np.array(curveNodePlane.GetOrigin())
transformGridOrigin -= transformGridAxisX * sliceSizeMm[0]/2.0
transformGridOrigin -= transformGridAxisY * sliceSizeMm[1]/2.0
transformGridOrigin -= transformGridAxisZ * curveLength/2.0

# Create grid transform
# Each corner of each slice is mapped from the original volume's reformatted slice
# to the straightened volume slice.
# The grid transform contains one vector at the corner of each slice.
# The transform is in the same space and orientation as the straightened volume.

numberOfSlices = resampledCurveNode.GetNumberOfControlPoints()
gridDimensions = [2, 2, numberOfSlices]
gridSpacing = [sliceSizeMm[0], sliceSizeMm[1], resamplingCurveSpacing]
gridDirectionMatrixArray = np.eye(4)
gridDirectionMatrixArray[0:3, 0] = transformGridAxisX
gridDirectionMatrixArray[0:3, 1] = transformGridAxisY
gridDirectionMatrixArray[0:3, 2] = transformGridAxisZ
gridDirectionMatrix = slicer.util.vtkMatrixFromArray(gridDirectionMatrixArray)

gridImage = vtk.vtkImageData()
gridImage.SetOrigin(transformGridOrigin)
gridImage.SetDimensions(gridDimensions)
gridImage.SetSpacing(gridSpacing)
gridImage.AllocateScalars(vtk.VTK_DOUBLE, 3)
transform = slicer.vtkOrientedGridTransform()
transform.SetDisplacementGridData(gridImage)
transform.SetGridDirectionMatrix(gridDirectionMatrix)
transformToStraightenedNode.SetAndObserveTransformFromParent(transform)

if reslicingPlanesModelNode:
appender = vtk.vtkAppendPolyData()

# Currently there is no API to set PreferredInitialNormalVector in the curve coordinate system, therefore
# a new coordinate system generator must be set up:
curveCoordinateSystemGeneratorWorld = slicer.vtkParallelTransportFrame()
curveCoordinateSystemGeneratorWorld.SetInputData(resampledCurveNode.GetCurveWorld())
curveCoordinateSystemGeneratorWorld.SetPreferredInitialNormalVector(transformGridAxisX)
curveCoordinateSystemGeneratorWorld.Update()
curvePoly = curveCoordinateSystemGeneratorWorld.GetOutput()
pointData = curvePoly.GetPointData()
normals = pointData.GetAbstractArray(curveCoordinateSystemGeneratorWorld.GetNormalsArrayName())
binormals = pointData.GetAbstractArray(curveCoordinateSystemGeneratorWorld.GetBinormalsArrayName())
tangents = pointData.GetAbstractArray(curveCoordinateSystemGeneratorWorld.GetTangentsArrayName())

# Compute displacements
transformDisplacements_RAS = slicer.util.arrayFromGridTransform(transformToStraightenedNode)
for gridK in range(gridDimensions[2]):

# The curve's built-in coordinate system generator could be used like this (if it had PreferredInitialNormalVector exposed):
#
# curvePointToWorld = vtk.vtkMatrix4x4()
# resampledCurveNode.GetCurvePointToWorldTransformAtPointIndex(resampledCurveNode.GetCurvePointIndexFromControlPointIndex(gridK), curvePointToWorld)
# curvePointToWorldArray = slicer.util.arrayFromVTKMatrix(curvePointToWorld)
# curveAxisX_RAS = curvePointToWorldArray[0:3, 0]
# curveAxisY_RAS = curvePointToWorldArray[0:3, 1]
# curvePoint_RAS = curvePointToWorldArray[0:3, 3]
#
# But now we get the values from our own coordinate system generator:
curvePointIndex = resampledCurveNode.GetCurvePointIndexFromControlPointIndex(gridK)
curveAxisX_RAS = np.array(normals.GetTuple3(curvePointIndex))
curveAxisY_RAS = np.array(binormals.GetTuple3(curvePointIndex))
curvePoint_RAS = np.array(curvePoly.GetPoint(curvePointIndex))

for gridJ in range(gridDimensions[1]):
for gridI in range(gridDimensions[0]):
straightenedVolume_RAS = (transformGridOrigin
+ gridI*gridSpacing[0]*transformGridAxisX
+ gridJ*gridSpacing[1]*transformGridAxisY
+ gridK*gridSpacing[2]*transformGridAxisZ)
inputVolume_RAS = (curvePoint_RAS
+ (gridI-0.5)*sliceSizeMm[0]*curveAxisX_RAS
+ (gridJ-0.5)*sliceSizeMm[1]*curveAxisY_RAS)
if reslicingPlanesModelNode:
if gridI == 0 and gridJ == 0:
plane = vtk.vtkPlaneSource()
plane.SetOrigin(inputVolume_RAS)
elif gridI == 1 and gridJ == 0:
plane.SetPoint1(inputVolume_RAS)
elif gridI == 0 and gridJ == 1:
plane.SetPoint2(inputVolume_RAS)
transformDisplacements_RAS[gridK][gridJ][gridI] = inputVolume_RAS - straightenedVolume_RAS

if reslicingPlanesModelNode:
plane.Update()
appender.AddInputData(plane.GetOutput())

slicer.util.arrayFromGridTransformModified(transformToStraightenedNode)

# delete temporary curve
slicer.mrmlScene.RemoveNode(resampledCurveNode)

if reslicingPlanesModelNode:
appender.Update()
if not reslicingPlanesModelNode.GetPolyData():
reslicingPlanesModelNode.CreateDefaultDisplayNodes()
reslicingPlanesModelNode.GetDisplayNode().SetVisibility2D(True)
reslicingPlanesModelNode.SetAndObservePolyData(appender.GetOutput())

return self.sliceLogic.CurvedPlanarReformationComputeStraighteningTransform(
transformToStraightenedNode,
curveNode,
sliceSizeMm,
outputSpacingMm,
stretching,
rotationDeg,
reslicingPlanesModelNode,
)

def straightenVolume(self, outputStraightenedVolume, volumeNode, outputStraightenedVolumeSpacing, straighteningTransformNode):
"""
Compute straightened volume (useful for example for visualization of curved vessels)
"""
gridTransform = straighteningTransformNode.GetTransformFromParentAs("vtkOrientedGridTransform")
if not gridTransform:
raise ValueError("Straightening transform is expected to contain a vtkOrientedGridTransform form parent")

# Get transformation grid geometry
gridIjkToRasDirectionMatrix = gridTransform.GetGridDirectionMatrix()
gridTransformImage = gridTransform.GetDisplacementGrid()
gridOrigin = gridTransformImage.GetOrigin()
gridSpacing = gridTransformImage.GetSpacing()
gridDimensions = gridTransformImage.GetDimensions()
gridExtentMm = [gridSpacing[0]*(gridDimensions[0]-1), gridSpacing[1]*(gridDimensions[1]-1), gridSpacing[2]*(gridDimensions[2]-1)]

# Compute IJK to RAS matrix of output volume
# Get grid axis directions
straightenedVolumeIJKToRASArray = slicer.util.arrayFromVTKMatrix(gridIjkToRasDirectionMatrix)
# Apply scaling
straightenedVolumeIJKToRASArray = np.dot(straightenedVolumeIJKToRASArray,
np.diag([outputStraightenedVolumeSpacing[0], outputStraightenedVolumeSpacing[1], outputStraightenedVolumeSpacing[2], 1]))
# Set origin
straightenedVolumeIJKToRASArray[0:3,3] = gridOrigin

outputStraightenedImageData = vtk.vtkImageData()
outputStraightenedImageData.SetExtent(
0, int(gridExtentMm[0]/outputStraightenedVolumeSpacing[0])-1,
0, int(gridExtentMm[1]/outputStraightenedVolumeSpacing[1])-1,
0, int(gridExtentMm[2]/outputStraightenedVolumeSpacing[2])-1)
outputStraightenedImageData.AllocateScalars(volumeNode.GetImageData().GetScalarType(), volumeNode.GetImageData().GetNumberOfScalarComponents())
outputStraightenedVolume.SetAndObserveImageData(outputStraightenedImageData)
outputStraightenedVolume.SetIJKToRASMatrix(slicer.util.vtkMatrixFromArray(straightenedVolumeIJKToRASArray))

# Resample input volume to straightened volume
parameters = {}
parameters["inputVolume"] = volumeNode.GetID()
parameters["outputVolume"] = outputStraightenedVolume.GetID()
parameters["referenceVolume"] = outputStraightenedVolume.GetID()
parameters["transformationFile"] = straighteningTransformNode.GetID()
# Use nearest neighbor interpolation for label volumes (to avoid incorrect labels at boundaries)
# and higher-order (bspline) interpolation for scalar volumes.
parameters["interpolationType"] = "nn" if volumeNode.IsA('vtkMRMLLabelMapVolumeNode') else "bs"
resamplerModule = slicer.modules.resamplescalarvectordwivolume
parameterNode = slicer.cli.runSync(resamplerModule, None, parameters)

outputStraightenedVolume.CreateDefaultDisplayNodes()
outputStraightenedVolume.GetDisplayNode().CopyContent(volumeNode.GetDisplayNode())
slicer.mrmlScene.RemoveNode(parameterNode)
return self.sliceLogic.CurvedPlanarReformationStraightenVolume(
outputStraightenedVolume, volumeNode, outputStraightenedVolumeSpacing, straighteningTransformNode
)

def projectVolume(self, outputProjectedVolume, inputStraightenedVolume, projectionAxisIndex = 0):
"""Create panoramic volume by mean intensity projection along an axis of the straightened volume
"""

projectedImageData = vtk.vtkImageData()
outputProjectedVolume.SetAndObserveImageData(projectedImageData)
straightenedImageData = inputStraightenedVolume.GetImageData()

outputImageDimensions = list(straightenedImageData.GetDimensions())
outputImageDimensions[projectionAxisIndex] = 1
projectedImageData.SetDimensions(outputImageDimensions)

projectedImageData.AllocateScalars(straightenedImageData.GetScalarType(), straightenedImageData.GetNumberOfScalarComponents())
outputProjectedVolumeArray = slicer.util.arrayFromVolume(outputProjectedVolume)
inputStraightenedVolumeArray = slicer.util.arrayFromVolume(inputStraightenedVolume)

if projectionAxisIndex == 0:
outputProjectedVolumeArray[:, :, 0] = inputStraightenedVolumeArray.mean(2-projectionAxisIndex)
elif projectionAxisIndex == 1:
outputProjectedVolumeArray[:, 0, :] = inputStraightenedVolumeArray.mean(2-projectionAxisIndex)
else:
outputProjectedVolumeArray[0, :, :] = inputStraightenedVolumeArray.mean(2-projectionAxisIndex)

slicer.util.arrayFromVolumeModified(outputProjectedVolume)

# Shift projection image into the center of the input image
ijkToRas = vtk.vtkMatrix4x4()
inputStraightenedVolume.GetIJKToRASMatrix(ijkToRas)
curvePointToWorldArray = slicer.util.arrayFromVTKMatrix(ijkToRas)
origin = curvePointToWorldArray[0:3, 3]
offsetToCenterDirectionVector = curvePointToWorldArray[0:3, projectionAxisIndex]
offsetToCenterDirectionLength = inputStraightenedVolume.GetImageData().GetDimensions()[projectionAxisIndex] * inputStraightenedVolume.GetSpacing()[projectionAxisIndex]
newOrigin = origin + offsetToCenterDirectionVector * offsetToCenterDirectionLength
ijkToRas.SetElement(0, 3, newOrigin[0])
ijkToRas.SetElement(1, 3, newOrigin[1])
ijkToRas.SetElement(2, 3, newOrigin[2])
outputProjectedVolume.SetIJKToRASMatrix(ijkToRas)
outputProjectedVolume.CreateDefaultDisplayNodes()

return True
return self.sliceLogic.CurvedPlanarReformationProjectVolume(
outputProjectedVolume, inputStraightenedVolume, projectionAxisIndex
)

class CurvedPlanarReformatTest(ScriptedLoadableModuleTest):
"""
Expand Down