Skip to content

Commit

Permalink
Add traditional recon as models
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Aug 21, 2024
1 parent cbf5a56 commit 2eaed0b
Show file tree
Hide file tree
Showing 3 changed files with 317 additions and 24 deletions.
39 changes: 37 additions & 2 deletions direct/nn/registration/config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,50 @@
"""Configuration for the UnetRegistrationModel."""

from __future__ import annotations

from dataclasses import dataclass

from direct.config.defaults import ModelConfig
from direct.registration.demons import DemonsFilterType


@dataclass
class RegistrationModelConfig(ModelConfig):
warp_num_integration_steps: int = 1


@dataclass
class OpticalFlowILKRegistration2dModelConfig(RegistrationModelConfig):
radius: int = 7
num_warp: int = 10
gaussian: bool = False
prefilter: bool = True


@dataclass
class UnetRegistrationModelConfig(ModelConfig):
class OpticalFlowTVL1Registration2dModelConfig(RegistrationModelConfig):
attachment: float = 15
tightness: float = 0.3
num_warp: int = 5
num_iter: int = 10
tol: float = 1e-3
prefilter: bool = True


@dataclass
class DemonsRegistration2dModelConfig(RegistrationModelConfig):
demons_filter_type: DemonsFilterType = DemonsFilterType.SYMMETRIC_FORCES
demons_num_iterations: int = 50
demons_smooth_displacement_field: bool = True
demons_standard_deviations: float = 1.0
demons_intensity_difference_threshold: float | None = None
demons_maximum_rms_error: float | None = None


@dataclass
class UnetRegistration2dModelConfig(RegistrationModelConfig):
max_seq_len: int = 12
unet_num_filters: int = 16
unet_num_pool_layers: int = 4
unet_dropout_probability: float = 0.0
unet_normalized: bool = False
warp_num_integration_steps: int = 1
262 changes: 259 additions & 3 deletions direct/nn/registration/registration.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,270 @@
"""Registration models for direct registration."""

from __future__ import annotations

from functools import partial
from typing import Callable

import torch
import torch.nn as nn

from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d
from direct.registration.demons import DemonsFilterType, multiscale_demons_displacement
from direct.registration.optical_flow import OpticalFlowEstimatorType, optical_flow_displacement
from direct.registration.registration import DISCPLACEMENT_FIELD_2D_DIMENSIONS
from direct.registration.warp import warp

__all__ = [
"OpticalFlowILKRegistration2dModel",
"OpticalFlowTVL1Registration2dModel",
"DemonsRegistration2dModel",
"UnetRegistration2dModel",
]


class ClassicalRegistration2dModel(nn.Module):

def __init__(
self,
displacement_transform: Callable,
warp_num_integration_steps: int = 1,
**kwargs,
) -> None:
super().__init__()
self.displacement_transform = displacement_transform
self.warp_num_integration_steps = warp_num_integration_steps

def forward(self, moving_image: torch.Tensor, reference_image: torch.Tensor) -> torch.Tensor:
"""Forward pass of :class:`UnetRegistrationModel`.
Parameters
----------
moving_image : torch.Tensor
Moving image tensor of shape (batch_size, seq_len, height, width).
reference_image : torch.Tensor
Reference image tensor of shape (batch_size, height, width).
Returns
-------
tuple[torch.Tensor, torch.Tensor]
Tuple containing the warped image tensor of shape (batch_size, seq_len, height, width)
and the displacement field tensor of shape (batch_size, seq_len, 2, height, width).
"""
batch_size, seq_len, height, width = moving_image.shape

device = reference_image.device

# Estimate the displacement field
displacement_field = [
self.displacement_transform(reference_image[_].cpu(), moving_image[_].cpu())
for _ in range(moving_image.shape[0])
]
displacement_field = torch.stack(displacement_field, dim=0)
displacement_field = displacement_field.to(device).reshape(
batch_size * seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width
)

moving_image = moving_image.reshape(batch_size * seq_len, 1, height, width)

# Warp the moving image
warped_image = warp(moving_image, displacement_field, num_integration_steps=self.warp_num_integration_steps)

return (
warped_image.reshape(batch_size, seq_len, height, width),
displacement_field.reshape(batch_size, seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width),
)


class OpticalFlowRegistration2dModel(ClassicalRegistration2dModel):

def __init__(
self,
estimator_type: OpticalFlowEstimatorType,
warp_num_integration_steps: int = 1,
**kwargs,
) -> None:
super().__init__(
displacement_transform=partial(
optical_flow_displacement,
estimator_type=estimator_type,
**kwargs,
),
warp_num_integration_steps=warp_num_integration_steps,
)

def forward(self, moving_image: torch.Tensor, reference_image: torch.Tensor) -> torch.Tensor:
"""Forward pass of :class:`UnetRegistrationModel`.
Parameters
----------
moving_image : torch.Tensor
Moving image tensor of shape (batch_size, seq_len, height, width).
reference_image : torch.Tensor
Reference image tensor of shape (batch_size, height, width).
Returns
-------
tuple[torch.Tensor, torch.Tensor]
Tuple containing the warped image tensor of shape (batch_size, seq_len, height, width)
and the displacement field tensor of shape (batch_size, seq_len, 2, height, width).
"""
batch_size, seq_len, height, width = moving_image.shape

device = reference_image.device

# Estimate the displacement field
displacement_field = [
self.displacement_transform(reference_image[_].cpu(), moving_image[_].cpu())
for _ in range(moving_image.shape[0])
]
displacement_field = torch.stack(displacement_field, dim=0)
displacement_field = displacement_field.to(device).reshape(
batch_size * seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width
)

moving_image = moving_image.reshape(batch_size * seq_len, 1, height, width)

# Warp the moving image
warped_image = warp(moving_image, displacement_field, num_integration_steps=self.warp_num_integration_steps)

return (
warped_image.reshape(batch_size, seq_len, height, width),
displacement_field.reshape(batch_size, seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width),
)


class OpticalFlowILKRegistration2dModel(OpticalFlowRegistration2dModel):

def __init__(
self,
radius: int = 7,
num_warp: int = 10,
gaussian: bool = False,
prefilter: bool = True,
warp_num_integration_steps: int = 1,
) -> None:
super().__init__(
estimator_type=OpticalFlowEstimatorType.ILK,
warp_num_integration_steps=warp_num_integration_steps,
radius=radius,
num_warp=num_warp,
gaussian=gaussian,
prefilter=prefilter,
)


class OpticalFlowTVL1Registration2dModel(OpticalFlowRegistration2dModel):

def __init__(
self,
attachment: float = 15,
tightness: float = 0.3,
num_warp: int = 5,
num_iter: int = 10,
tol: float = 1e-3,
prefilter: bool = True,
warp_num_integration_steps: int = 1,
) -> None:
super().__init__(
estimator_type=OpticalFlowEstimatorType.TV_L1,
warp_num_integration_steps=warp_num_integration_steps,
attachment=attachment,
tightness=tightness,
num_warp=num_warp,
num_iter=num_iter,
tol=tol,
prefilter=prefilter,
)


class DemonsRegistration2dModel(ClassicalRegistration2dModel):

def __init__(
self,
demons_filter_type: DemonsFilterType = DemonsFilterType.SYMMETRIC_FORCES,
demons_num_iterations: int = 50,
demons_smooth_displacement_field: bool = True,
demons_standard_deviations: float = 1.0,
demons_intensity_difference_threshold: float | None = None,
demons_maximum_rms_error: float | None = None,
warp_num_integration_steps: int = 1,
) -> None:
"""Inits :class:`DemonsRegistration2dModel`.
Parameters
----------
demons_filter_type : DemonsFilterType, optional
Type of the Demons filter (DemonsFilterType.DEMONS, DemonsFilterType.FAST_SYMMETRIC_FORCES,
DemonsFilterType.SYMMETRIC_FORCES, DemonsFilterType.DIFFEOMORPHIC). Default: DemonsFilterType.SYMMETRIC_FORCES.
demons_num_iterations : int
Number of iterations for the Demons filter. Default: 100.
demons_smooth_displacement_field : bool
Whether to smooth the displacement field. Default: True.
demons_standard_deviations : float
Standard deviations for Gaussian smoothing. Default: 1.5.
demons_intensity_difference_threshold : float, optional
Intensity difference threshold. Default: None.
demons_maximum_rms_error : float, optional
Maximum RMS error. Default: None.
warp_num_integration_steps : int
Number of integration steps to perform when warping the moving image. Default: 1.
"""

super().__init__(
displacement_transform=partial(
multiscale_demons_displacement,
filter_type=demons_filter_type,
num_iterations=demons_num_iterations,
smooth_displacement_field=demons_smooth_displacement_field,
standard_deviations=demons_standard_deviations,
intensity_difference_threshold=demons_intensity_difference_threshold,
maximum_rms_error=demons_maximum_rms_error,
),
warp_num_integration_steps=warp_num_integration_steps,
)

class UnetRegistrationModel(nn.Module):
def forward(self, moving_image: torch.Tensor, reference_image: torch.Tensor) -> torch.Tensor:
"""Forward pass of :class:`UnetRegistrationModel`.
Parameters
----------
moving_image : torch.Tensor
Moving image tensor of shape (batch_size, seq_len, height, width).
reference_image : torch.Tensor
Reference image tensor of shape (batch_size, height, width).
Returns
-------
tuple[torch.Tensor, torch.Tensor]
Tuple containing the warped image tensor of shape (batch_size, seq_len, height, width)
and the displacement field tensor of shape (batch_size, seq_len, 2, height, width).
"""
batch_size, seq_len, height, width = moving_image.shape

device = reference_image.device

# Estimate the displacement field
displacement_field = [
self.displacement_transform(reference_image[_].cpu(), moving_image[_].cpu())
for _ in range(moving_image.shape[0])
]
displacement_field = torch.stack(displacement_field, dim=0)
displacement_field = displacement_field.to(device).reshape(
batch_size * seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width
)

moving_image = moving_image.reshape(batch_size * seq_len, 1, height, width)

# Warp the moving image
warped_image = warp(moving_image, displacement_field, num_integration_steps=self.warp_num_integration_steps)

return (
warped_image.reshape(batch_size, seq_len, height, width),
displacement_field.reshape(batch_size, seq_len, DISCPLACEMENT_FIELD_2D_DIMENSIONS, height, width),
)


class UnetRegistration2dModel(nn.Module):

def __init__(
self,
Expand All @@ -19,7 +275,7 @@ def __init__(
unet_normalized: bool = False,
warp_num_integration_steps: int = 1,
) -> None:
"""Inits :class:`UnetRegistrationModel`.
"""Inits :class:`UnetRegistration2dModel`.
Parameters
----------
Expand Down Expand Up @@ -50,7 +306,7 @@ def __init__(
self.warp_num_integration_steps = warp_num_integration_steps

def forward(self, moving_image: torch.Tensor, reference_image: torch.Tensor) -> torch.Tensor:
"""Forward pass of :class:`UnetRegistrationModel`.
"""Forward pass of :class:`UnetRegistration2dModel`.
Parameters
----------
Expand Down
40 changes: 21 additions & 19 deletions direct/nn/vsharp/vsharp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,25 +117,27 @@ def _do_iteration(
# Perform registration and compute loss on registered image and displacement field
registered_image, displacement_field = self.do_registration(data, output_images[-1])

shape = data["reference_image"].shape
loss_dict = self.compute_loss_on_data(
loss_dict,
loss_fns,
data,
output_image=registered_image,
target_image=(
data["reference_image"]
if shape == registered_image.shape
else data["reference_image"].tile((1, registered_image.shape[1], *([1] * len(shape[1:]))))
),
)
loss_dict = self.compute_loss_on_data(
loss_dict,
loss_fns,
data,
output_displacement_field=displacement_field,
target_displacement_field=data["displacement_field"],
)
# If DL-based model calculate loss
if len(list(self.models["registration_model"].parameters())) > 0:
shape = data["reference_image"].shape
loss_dict = self.compute_loss_on_data(
loss_dict,
loss_fns,
data,
output_image=registered_image,
target_image=(
data["reference_image"]
if shape == registered_image.shape
else data["reference_image"].tile((1, registered_image.shape[1], *([1] * len(shape[1:]))))
),
)
loss_dict = self.compute_loss_on_data(
loss_dict,
loss_fns,
data,
output_displacement_field=displacement_field,
target_displacement_field=data["displacement_field"],
)

auxiliary_loss_weights = torch.logspace(-1, 0, steps=len(output_images)).to(output_images[0])
for i, output_image in enumerate(output_images):
Expand Down

0 comments on commit 2eaed0b

Please sign in to comment.