Skip to content

Commit

Permalink
Fix(align_tpm): use quadratic interpolation in logit space instead of…
Browse files Browse the repository at this point in the history
… linear interpolation in prob space (mimcs spm_maff8)
  • Loading branch information
balbasty committed Oct 16, 2024
1 parent 782bcad commit e149d82
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 26 deletions.
46 changes: 34 additions & 12 deletions nitorch/tools/registration/affine_tpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
𝓛 = 𝔼_q[ln p(𝒙)] = ∑ₙᵢ q(𝑥ₙ = 𝑖) ln ∑ⱼ 𝐻ᵢⱼ (𝜇 ∘ 𝜙)ₙⱼ
"""
import torch
from nitorch.core import linalg, utils, py
from nitorch.core import linalg, utils, py, math
from nitorch import spatial, io
from .utils import jg, jhj, affine_grid_backward
import nitorch.plot as niplt
Expand Down Expand Up @@ -154,6 +154,18 @@ def align_tpm(dat, tpm=None, weights=None, spacing=(8, 4), device=None,
# ------------------------------------------------------------------
dat = discretize(dat, nbins=bins, mask=weights)

# ------------------------------------------------------------------
# PREFILTER TPM
# ------------------------------------------------------------------
logtpm = tpm.clone()
# ensure normalized
logtpm = logtpm.clamp(tiny, 1-tiny).div_(logtpm.sum(0, keepdim=True))
# transform to logits
logtpm = logtpm.add_(tiny).log_()
# spline prefilter
splineopt = dict(interpolation=2, bound='replicate')
logtpm = spatial.spline_coeff_nd(logtpm, dim=3, inplace=True, **splineopt)

# ------------------------------------------------------------------
# OPTIONS
# ------------------------------------------------------------------
Expand All @@ -175,8 +187,8 @@ def do_spacing(sp):
if not sp:
return dat0, affine_dat0, weights0
sp = [max(1, int(pymath.floor(sp / vx1))) for vx1 in vx]
sp = [slice(None, None, sp1) for sp1 in sp]
affine_dat, _ = spatial.affine_sub(affine_dat0, dat0.shape[-dim:], tuple(sp))
sp = tuple([slice(None, None, sp1) for sp1 in sp])
affine_dat, _ = spatial.affine_sub(affine_dat0, dat0.shape[-dim:], sp)
dat = dat0[(Ellipsis, *sp)]
if weights0 is not None:
weights = weights0[(Ellipsis, *sp)]
Expand Down Expand Up @@ -234,7 +246,7 @@ def do_spacing(sp):
if reorient is not None:
affine_dat = reorient.matmul(affine_dat)

mi, aff, prm = fit_affine_tpm(dat, tpm, affine_dat, affine_tpm,
mi, aff, prm = fit_affine_tpm(dat, logtpm, affine_dat, affine_tpm,
weights, **opt, prm=prm)

if reorient is not None:
Expand Down Expand Up @@ -263,7 +275,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
affine_tpm : (4, 4) tensor
weights : (*spatial) tensor
basis : {'translation', 'rotation', 'rigid', 'similitude', 'affine'}
fwhm : float, default=J/32
fwhm : float, default=J/64
max_iter_gn : int, default=100
max_iter_em : int, default=32
max_line_search : int, default=12
Expand All @@ -276,6 +288,8 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
prm : (F) tensor
"""
# !!! NOTE: `tpm` must contain spline-prefiltered log-probabilities

dim = tpm.dim() - 1

# ------------------------------------------------------------------
Expand Down Expand Up @@ -326,7 +340,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
affine_tpm = affine_tpm.to(**utils.backend(tpm))
shape = dat.shape[-dim:]

tpm = tpm.to(dat.device).clamp(tiny, 1-tiny)
tpm = tpm.to(dat.device)
basis = make_basis(basis, dim, **utils.backend(tpm))
F = len(basis)

Expand All @@ -337,7 +351,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
em_opt = dict(fwhm=fwhm, max_iter=max_iter_em, weights=weights,
verbose=verbose-2)
drv_opt = dict(weights=weights)
pull_opt = dict(bound='replicate', extrapolate=True)
pull_opt = dict(bound='replicate', extrapolate=True, interpolation=2)

# ------------------------------------------------------------------
# OPTIMIZE
Expand Down Expand Up @@ -365,6 +379,7 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,

# --- warp TPM ---------------------------------------------
mov = spatial.grid_pull(tpm, phi, **pull_opt)
mov = math.softmax(mov, dim=1)

# --- mutual info ------------------------------------------
mi, Nm, prior = em_prior(mov, dat, prior0, **em_opt)
Expand Down Expand Up @@ -399,8 +414,8 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
end = '\n' if verbose >= 2 else '\r'
print(f'({basis_name[:6]}){space} | {n_iter:02d} | {mi.mean():12.6g}', end=end)

if mi.mean() - mi0.mean() < 1e-4:
# print('converged', mi.mean() - mi0.mean())
if mi.mean() - mi0.mean() < 0: #1e-4:
print('converged', mi.mean() - mi0.mean())
break

# --------------------------------------------------------------
Expand All @@ -412,16 +427,22 @@ def fit_affine_tpm(dat, tpm, affine=None, affine_tpm=None, weights=None,
g = g.sum(0)
h = h.sum(0)

# --- chain rule -----------------------------------------------
# --- spatial derivatives --------------------------------------
mov = mov.unsqueeze(-1)
gmov = spatial.grid_grad(tpm, phi, **pull_opt)
gmov = mov * (gmov - (mov * gmov).sum(1, keepdim=True))
mov = mov.squeeze(-1)

# --- chain rule -----------------------------------------------
gaff = lmdiv(affine_tpm, mm(gaff, affine))
g, h = chain_rule(g, h, gmov, gaff, maj=False)
del gmov

# --- Gauss-Newton ---------------------------------------------
h.diagonal(0, -1, -2).add_(h.diagonal(0, -1, -2).abs().max() * 1e-5)
delta = lmdiv(h, g.unsqueeze(-1)).squeeze(-1)
foo = 0

plot_registration(dat, mov, f'{basis_name} | {n_iter}')

if verbose == 1:
print('')
Expand Down Expand Up @@ -898,7 +919,8 @@ def discretize(dat, nbins=256, mask=None):

def get_spm_prior(**backend):
"""Download the SPM prior"""
url = 'https://github.com/spm/spm12/raw/master/tpm/TPM.nii'
# url = 'https://github.com/spm/spm12/raw/master/tpm/TPM.nii'
url = 'https://github.com/spm/spm12/raw/refs/heads/main/tpm/TPM.nii'
fname = os.path.join(cache_dir, 'SPM12_TPM.nii')
if not os.path.exists(fname):
os.makedirs(cache_dir, exist_ok=True)
Expand Down
74 changes: 60 additions & 14 deletions nitorch/tools/registration/pairwise_preproc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

def preproc_image(input, mask=None, label=False, missing=0,
world=None, affine=None, rescale=.95,
pad=None, bound='zero', fwhm=None,
pad=None, bound='zero', fwhm=None, channels=None,
dim=None, device=None, **kwargs):
"""Load an image and preprocess it as required
Expand Down Expand Up @@ -43,6 +43,8 @@ def preproc_image(input, mask=None, label=False, missing=0,
fwhm : [sequence of] float
Smooth the volume with a Gaussian kernel of that FWHM.
If last element is "mm", values are in mm and converted to voxels.
channels : [sequence of] int or range or slice
Channels to load
dim : int, optional
Number of spatial dimensions
device : torch.device
Expand All @@ -58,14 +60,30 @@ def preproc_image(input, mask=None, label=False, missing=0,
Orientation matrix
"""
if not torch.is_tensor(input):
dat, mask0, affine0 = load_image(input, dim=dim, device=device,
label=label, missing=missing)
else:
dat = input
mask0 = torch.isfinite(dat)
dat = dat.masked_fill(~mask0, 0)
affine0 = spatial.affine_default(dat.shape[1:])
dat, mask0, affine0 = load_image(input, dim=dim, device=device,
label=label, missing=missing,
channels=channels)

# if not torch.is_tensor(input):
# dat, mask0, affine0 = load_image(input, dim=dim, device=device,
# label=label, missing=missing,
# channels=channels)
# else:
# dat = input
# if channels is not None:
# channels = make_list(channels)
# channels = [
# list(c) if isinstance(c, range) else
# list(range(len(dat)))[c] if isinstance(c, slice) else
# c for c in channels
# ]
# if not all([isinstance(c, int) for c in channels]):
# raise ValueError('Channel list should be a list of integers')
# dat = dat[channels]
# mask0 = torch.isfinite(dat)
# dat = dat.masked_fill(~mask0, 0)
# affine0 = spatial.affine_default(dat.shape[1:])

dim = dat.dim() - 1

# load user-defined mask
Expand Down Expand Up @@ -199,7 +217,7 @@ def prepare_pyramid_levels(images, levels, dim=None, **opt):
return pyrutils.pyramid_levels(vxs, shapes, levels, **opt)


def map_image(fnames, dim=None):
def map_image(fnames, dim=None, channels=None):
"""Map an ND image from disk
Parameters
Expand Down Expand Up @@ -229,7 +247,6 @@ def map_image(fnames, dim=None):
affine = img.affine
if dim is None:
dim = img.affine.shape[-1] - 1
# img = img.fdata(rand=True, device=device)
if img.dim > dim:
img = img.movedim(-1, 0)
else:
Expand All @@ -241,10 +258,24 @@ def map_image(fnames, dim=None):
imgs.append(img)
del img
imgs = io.cat(imgs, dim=0)

# select a subset of channels
if channels is not None:
channels = make_list(channels)
channels = [
list(c) if isinstance(c, range) else
list(range(len(imgs)))[c] if isinstance(c, slice) else
c for c in channels
]
if not all([isinstance(c, int) for c in channels]):
raise ValueError('Channel list should be a list of integers')
imgs = io.stack([imgs[c] for c in channels])

return imgs, affine


def load_image(input, dim=None, device=None, label=False, missing=0):
def load_image(input, dim=None, device=None, label=False, missing=0,
channels=None):
"""
Load a N-D image from disk
Expand Down Expand Up @@ -272,15 +303,30 @@ def load_image(input, dim=None, device=None, label=False, missing=0):
Orientation matrix
"""
if not torch.is_tensor(input):
dat, affine = map_image(input, dim)
dat, affine = map_image(input, dim, channels=channels)
else:
dat, affine = input, spatial.affine_default(input.shape[1:])

if channels is not None:
channels = make_list(channels)
channels = [
list(c) if isinstance(c, range) else
list(range(len(dat)))[c] if isinstance(c, slice) else
c for c in channels
]
if not all([isinstance(c, int) for c in channels]):
raise ValueError('Channel list should be a list of integers')
dat = dat[channels]

if label:
dtype = dat.dtype
if isinstance(dtype, (list, tuple)):
dtype = dtype[0]
dtype = dtypes.as_torch(dtype, upcast=True)
dat0 = dat.data(device=device, dtype=dtype)[0] # assume single channel
if torch.is_tensor(dat):
dat0 = dat[0]
else:
dat0 = dat.data(device=device, dtype=dtype)[0] # assume single channel
if label is True:
label = dat0.unique(sorted=True)
label = label[label != 0].tolist()
Expand Down

0 comments on commit e149d82

Please sign in to comment.