diff --git a/src/deepali/core/flow.py b/src/deepali/core/flow.py index 805b532..4693aaf 100644 --- a/src/deepali/core/flow.py +++ b/src/deepali/core/flow.py @@ -325,6 +325,7 @@ def expv( sampling: Union[Sampling, str] = Sampling.LINEAR, padding: Union[PaddingMode, str] = PaddingMode.BORDER, align_corners: bool = ALIGN_CORNERS, + inverse: bool = False, ) -> Tensor: r"""Group exponential maps of flow fields computed using scaling and squaring. @@ -336,6 +337,8 @@ def expv( padding: Flow field extrapolation mode. align_corners: Whether ``flow`` vectors are defined with respect to ``Axes.CUBE`` (False) or ``Axes.CUBE_CORNERS`` (True). + inverse: Whether to negate scaled velocity field. Setting this to ``True`` + is equivalent to negating the ``scale`` (e.g., ``scale=-1``). Returns: Exponential map of input flow field. If ``steps=0``, a reference to ``flow`` is returned. @@ -343,6 +346,8 @@ def expv( """ if scale is None: scale = 1 + if inverse: + scale = -scale if steps is None: steps = 5 if not isinstance(steps, int): @@ -699,6 +704,56 @@ def normalize_flow( return data +def logv( + flow: Tensor, + num_iters: int = 5, + bch_terms: int = 1, + sigma: Optional[float] = 1.0, + spacing: Optional[Union[Scalar, Array]] = None, + exp_steps: Optional[int] = None, + sampling: Union[Sampling, str] = Sampling.LINEAR, + padding: Union[PaddingMode, str] = PaddingMode.BORDER, + align_corners: bool = ALIGN_CORNERS, +) -> Tensor: + r"""Group logarithmic maps of flow fields computed using algorithm by Bossa & Olsom (2008). + + References: + - Bossa & Olmos, 2008. A new algorithm for the computation of the group logarithm of diffeomorphisms. + https://inria.hal.science/inria-00629873 + + Args: + num_iters: Number of iterations. + bch_terms: Number of Lie bracket terms of the Baker-Campbell-Hausdorff (BCH) formula to use + when computing the composite of current velocity field with the correction field. + sigma: Standard deviation of Gaussian kernel used as low-pass filter when computing spatial + derivatives required for evaluation of Lie brackets during application of BCH formula. + spacing: Physical size of image voxels used to compute spatial derivatives. + exp_steps: Number of exponentiation steps to evaluate current inverse displacement field. + sampling: Flow field interpolation mode when computing inverse displacement field. + padding: Flow field extrapolation mode when computing inverse displacement field. + align_corners: Whether ``flow`` vectors are defined with respect to + ``Axes.CUBE`` (False) or ``Axes.CUBE_CORNERS`` (True). + + Returns: + Approximate stationary velocity field which when exponentiated (cf. :func:`expv`) results + in the given input ``flow`` field. + + """ + v = flow + for _ in range(num_iters): + u = expv( + v, + steps=exp_steps, + sampling=sampling, + padding=padding, + align_corners=align_corners, + inverse=True, + ) + u = compose_flows(flow, u) + v = compose_svfs(u, v, bch_terms=bch_terms, sigma=sigma, spacing=spacing) + return v + + def denormalize_flow( data: Tensor, size: Optional[Union[Tensor, torch.Size]] = None, diff --git a/src/deepali/core/functional.py b/src/deepali/core/functional.py index 31c56a7..3c8b6b9 100644 --- a/src/deepali/core/functional.py +++ b/src/deepali/core/functional.py @@ -104,6 +104,7 @@ from .flow import jacobian_dict from .flow import jacobian_matrix from .flow import lie_bracket +from .flow import logv from .flow import normalize_flow from .flow import sample_flow from .flow import warp_grid @@ -220,6 +221,7 @@ "jacobian_dict", "jacobian_matrix", "lie_bracket", + "logv", "max_pool", "min_pool", "normalize_flow", diff --git a/tests/_test_core_flow_logv.py b/tests/_test_core_flow_logv.py new file mode 100644 index 0000000..4b623b7 --- /dev/null +++ b/tests/_test_core_flow_logv.py @@ -0,0 +1,79 @@ +# %% +# Imports +from typing import Optional, Sequence + +import matplotlib.pyplot as plt + +import torch +from torch import Tensor +from torch.random import Generator + +from deepali.core import Grid +import deepali.core.bspline as B +import deepali.core.functional as U + + +# %% +# Auxiliary functions +def random_svf( + size: Sequence[int], + stride: int = 1, + generator: Optional[Generator] = None, +) -> Tensor: + cp_grid_size = B.cubic_bspline_control_point_grid_size(size, stride=stride) + data = torch.randn((1, 3) + cp_grid_size, generator=generator) + data = U.fill_border(data, margin=3, value=0, inplace=True) + return B.evaluate_cubic_bspline(data, size=size, stride=stride) + + +def visualize_flow(ax, flow: Tensor, label: Optional[str] = None) -> None: + grid = Grid(shape=flow.shape[2:], align_corners=True) + x = grid.coords(channels_last=False, dtype=u.dtype, device=u.device) + x = U.move_dim(x.unsqueeze(0).add_(flow), 1, -1) + target_grid = U.grid_image(shape=flow.shape[2:], inverted=True, stride=(5, 5)) + warped_grid = U.warp_image(target_grid, x) + ax.imshow(warped_grid[0, 0, flow.shape[2] // 2], cmap="gray") + if label: + ax.set_title(label, fontsize=24) + + +# %% +# Random velocity fields +size = (128, 128, 128) +generator = torch.Generator().manual_seed(42) +v = random_svf(size, stride=8, generator=generator).mul_(0.1) + + +# %% +# Compute logarithm of exponentiated velocity field +bch_terms = 3 +exp_steps = 5 +log_steps = 5 + +u = U.expv(v, steps=exp_steps) +w = U.logv(u, num_iters=log_steps, bch_terms=bch_terms, exp_steps=exp_steps, sigma=1.0) + +fig, axes = plt.subplots(1, 4, figsize=(40, 10)) + +ax = axes[0] +ax.set_title("v", fontsize=32, pad=20) +visualize_flow(ax, v) + +ax = axes[1] +ax.set_title("u = exp(v)", fontsize=32, pad=20) +visualize_flow(ax, u) + +ax = axes[2] +ax.set_title("log(u)", fontsize=32, pad=20) +visualize_flow(ax, w) + +error = w.sub(v).norm(dim=1, keepdim=True) + +ax = axes[3] +ax.set_title("|log(u) - v|", fontsize=32, pad=20) +_ = ax.imshow(error[0, 0, error.shape[2] // 2], cmap="jet", vmin=0, vmax=0.1) + +print(f"Mean error: {error.mean():.5f}") +print(f"Maximium error: {error.max():.5f}") + +# %% diff --git a/tests/test_core_flow_utils.py b/tests/test_core_flow_utils.py index bf8cdc1..ce1125c 100644 --- a/tests/test_core_flow_utils.py +++ b/tests/test_core_flow_utils.py @@ -428,6 +428,17 @@ def test_flow_lie_bracket() -> None: assert error.max().lt(0.134) +def test_flow_logv() -> None: + size = (128, 128, 128) + generator = torch.Generator().manual_seed(42) + v = random_svf(size, stride=8, generator=generator).mul_(0.1) + u = U.expv(v) + w = U.logv(u) + error = w.sub(v).norm(dim=1, keepdim=True) + assert error.mean().lt(0.001) + assert error.max().lt(0.02) + + def test_flow_compose_svfs() -> None: # 3D flow fields p = U.move_dim(Grid(size=(64, 32, 16)).coords().unsqueeze_(0), -1, 1)