Skip to content

Commit

Permalink
initial commit n. 2
Browse files Browse the repository at this point in the history
  • Loading branch information
McHaillet committed Nov 1, 2024
1 parent 5a9aed2 commit 6fb1951
Show file tree
Hide file tree
Showing 24 changed files with 1,466 additions and 1 deletion.
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,13 @@ classifiers = [
"Typing :: Typed",
]
# add your package dependencies here
dependencies = []
dependencies = [
"torch",
"einops",
"numpy",
"torch-grid-utils",
"torch-cubic-spline-grids"
]

# https://peps.python.org/pep-0621/#dependencies-optional-dependencies
# "extras" (e.g. for `pip install .[test]`)
Expand Down
5 changes: 5 additions & 0 deletions src/tttsa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@
__version__ = "uninstalled"
__author__ = "Marten Chaillet"
__email__ = "[email protected]"

import logging
import sys
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

1 change: 1 addition & 0 deletions src/tttsa/affine/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .affine_transform import affine_transform_2d
47 changes: 47 additions & 0 deletions src/tttsa/affine/affine_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import einops
from typing import Sequence
import torch.nn.functional as F
from torch_grid_utils import coordinate_grid

from tttsa.utils import homogenise_coordinates, array_to_grid_sample


def affine_transform_2d(
images: torch.Tensor, # shape: '... h w'
affine_matrices: torch.Tensor, # shape: '... 3 3'
out_shape: Sequence[int] | None = None,
interpolation: str = "bicubic",
):
"""Affine transform 1 or a batch of images."""
if out_shape is None:
out_shape = images.shape[-2:]
device = images.device
grid = homogenise_coordinates(coordinate_grid(out_shape, device=device))
grid = einops.rearrange(grid, "h w coords -> 1 h w coords 1")
M = einops.rearrange(affine_matrices, "... i j -> ... 1 1 i j").to(device)
grid = M @ grid
grid = einops.rearrange(grid, "... h w coords 1 -> ... h w coords")[
..., :2
].contiguous()
grid_sample_coordinates = array_to_grid_sample(grid, images.shape[-2:])
if images.dim() == 2: # needed for grid sample
images = einops.repeat(images, "h w -> n h w", n=M.shape[0])
transformed = einops.rearrange(
F.grid_sample(
einops.rearrange(images, "... h w -> ... 1 h w"),
grid_sample_coordinates,
align_corners=True,
mode=interpolation,
),
"... 1 h w -> ... h w",
).squeeze() # remove starter dimensions in case we got one image
return transformed


# TODO write some functions like
# def affine_transform_3d(
# volumes: torch.Tensor, # shape: 'n d h w'
# affine_matrices: torch.Tensor, # shape: 'n 4 4'
# interpolation: str = 'bilinear', # is actually trilinear in grid_sample
# ):
1 change: 1 addition & 0 deletions src/tttsa/alignment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .find_shift import find_image_shift
56 changes: 56 additions & 0 deletions src/tttsa/alignment/find_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np
import torch

from tttsa.correlation import correlate_2d
from tttsa.utils import dft_center


def find_image_shift(
image_a: torch.Tensor,
image_b: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Find the shift between image a and b.
Applying the shift to b aligns it with
image a. The region around the maximum in the correlation image is by default
upsampled with bicubic interpolation to find a more precise shift.
Parameters
----------
image_a: torch.Tensor
`(h, w)` image.
image_b: torch.Tensor
`(h, w)` image with the same shape as image_a
mask: torch.Tensor | None, default None
`(h, w)` mask used for normalization
Returns
-------
shift, correlation: torch.Tensor, torch.Tensor
`(2, )` shift in y and x; and maximal correlation
"""
center = dft_center(
image_a.shape, rfft=False, fftshifted=True, device=image_a.device
)

# calculate initial shift with integer precision
correlation = correlate_2d(image_a, image_b, normalize=True)
maximum_idx = torch.tensor( # explicitly put tensor on CPU in case input is on GPU
np.unravel_index(correlation.argmax().cpu(), shape=image_a.shape),
device=image_a.device,
)
y, x = maximum_idx
# Parabolic interpolation in the y direction
f_y0 = correlation[y - 1, x]
f_y1 = correlation[y, x]
f_y2 = correlation[y + 1, x]
subpixel_dy = y + 0.5 * (f_y0 - f_y2) / (f_y0 - 2 * f_y1 + f_y2)

# Parabolic interpolation in the x direction
f_x0 = correlation[y, x - 1]
f_x1 = correlation[y, x]
f_x2 = correlation[y, x + 1]
subpixel_dx = x + 0.5 * (f_x0 - f_x2) / (f_x0 - 2 * f_x1 + f_x2)

shift = torch.tensor([subpixel_dy, subpixel_dx]) - center
return shift
34 changes: 34 additions & 0 deletions src/tttsa/alignment/tests/test_find_shift.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import pytest

from libtilt.alignment import find_image_shift


def test_find_image_shift():
a = torch.zeros((4, 4))
a[1, 1] = 1
b = torch.zeros((4, 4))
b[2, 2] = .7
b[2, 3] = .3

with pytest.raises(ValueError, match=r'Upsampling factor .*'):
find_image_shift(a, b, upsampling_factor=0.5)

shift = find_image_shift(a, b, upsampling_factor=5)
assert torch.all(shift == -1), ("Interpolating a shift too close to a border is "
"not possible, so an integer shift should be "
"returned.")

a = torch.zeros((8, 8))
a[3, 3] = 1
b = torch.zeros((8, 8))
b[4, 4] = .7
b[4, 5] = .3
shift = find_image_shift(a, b, upsampling_factor=1)
assert torch.all(shift == -1), ("Finding shift with upsampling_factor of 1 should "
"return an integer shift (i.e. no interpolation.")

shift = find_image_shift(a, b)
assert shift[0] == -1.1, "y shift should be interpolated to specific value."
assert shift[1] == -1.2, "x shift should be interpolated to specific value."

1 change: 1 addition & 0 deletions src/tttsa/back_projection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .filtered_back_projection import filtered_back_projection_3d
154 changes: 154 additions & 0 deletions src/tttsa/back_projection/filtered_back_projection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import torch
import einops
import torch.nn.functional as F
from torch_grid_utils import coordinate_grid

from tttsa.transformations import R_2d, T_2d, T, Ry
from tttsa.utils import dft_center, homogenise_coordinates, array_to_grid_sample


def filtered_back_projection_3d(
tilt_series,
tomogram_dimensions,
tilt_angles,
tilt_axis_angles,
shifts,
weighting: str = "exact",
object_diameter: float | None = None,
):
"""Run weighted back projection incorporating some alignment parameters.
weighting: str, default "hamming"
all filters here start at 1/N (instead of 0 for ramp and hamming) which
improves the low res signal upon forward projection of the reconstruction
Options:
- "ramp": increases linearly from 1/N to 1 from the zero frequency to
nyquist
- "exact": is based on and improves low-res signal on forward projection:
Reference : Optik, Exact filters for general geometry three-dimensional
reconstruction, vol.73,146,1986.
- "hamming": modified hamming as used in AreTomo, further modified here to
also start a 1/N
object_diameter: float | None, default None
object diameter specified in number of pixels, only needed for the exact filter
"""
# initializes sizes
device = tilt_series.device
n_tilts, h, w = tilt_series.shape # for simplicity assume square images
tilt_image_dimensions = (h, w)
transformed_image_dimensions = tomogram_dimensions[-2:]
tomogram_center = dft_center(tomogram_dimensions, rfft=False, fftshifted=True)
tilt_image_center = dft_center(tilt_image_dimensions, rfft=False, fftshifted=True)
transformed_image_center = dft_center(transformed_image_dimensions, rfft=False,
fftshifted=True)
_, filter_size = transformed_image_dimensions

# generate the 2d alignment affine matrix
s0 = T_2d(-transformed_image_center)
r0 = R_2d(tilt_axis_angles, yx=True)
s1 = T_2d(-shifts)
s2 = T_2d(tilt_image_center)
M = einops.rearrange((s2 @ s1 @ r0 @ s0), "... i j -> ... 1 1 i j").to(device)

grid = homogenise_coordinates(coordinate_grid(transformed_image_dimensions, device=device))
grid = einops.rearrange(grid, "h w coords -> h w coords 1")
grid = M @ grid
grid = einops.rearrange(grid, "... d h w coords 1 -> ... d h w coords")[
..., :2
].contiguous()
grid_sample_coordinates = array_to_grid_sample(grid, tilt_image_dimensions)
aligned_ts = torch.squeeze(
F.grid_sample(
einops.rearrange(tilt_series, "n h w -> n 1 h w"),
grid_sample_coordinates,
align_corners=True,
mode="bicubic",
)
)

# generate weighting function and apply to aligned tilt series
if weighting == "exact":
if object_diameter is None:
raise ValueError(
"Calculation of exact weighting requires an object " "diameter."
)
if len(tilt_angles) == 1:
filters = 1
else: # slice_width could be provided as a function argument it can be
# calculated as: (pixel_size * 2 * imdim) / object_diameter
q = einops.rearrange(
torch.arange(
filter_size // 2 + filter_size % 2 + 1, dtype=torch.float32,
device=device
)
/ filter_size,
"q -> 1 1 q",
)
sampling = torch.sin(
torch.deg2rad(
torch.abs(einops.rearrange(tilt_angles, "n -> n 1") - tilt_angles)
)
).to(device)
sampling = einops.rearrange(sampling, "n m -> n m 1")
q_overlap_inv = sampling / (2 / object_diameter)
over_weighting = 1 - torch.clip(q * q_overlap_inv, min=0, max=1)
filters = 1 / einops.reduce(over_weighting, "n m q -> n q", "sum")
filters = einops.rearrange(filters, "n w -> n 1 w")
elif weighting == "ramp":
filters = torch.arange(
filter_size // 2 + filter_size % 2 + 1, dtype=torch.float32, device=device
)
filters /= filters.max()
filters = filters * (1 - 1 / n_tilts) + 1 / n_tilts # start at 1 / N
elif weighting == "hamming": # AreTomo3 code uses a modified hamming window
# 2 * q * (0.55f + 0.45f * cosf(6.2831852f * q)) # with q from 0 to .5 (Ny)
# https://github.com/czimaginginstitute/AreTomo3/blob/
# c39dcdad9525ee21d7308a95622f3d47fe7ab4b9/AreTomo/Recon/GRWeight.cu#L20
q = (
torch.arange(filter_size // 2 + filter_size % 2 + 1, dtype=torch.float32,
device=device)
/ filter_size
)
# regular hamming: q * (.54 + .46 * torch.cos(torch.pi * q))
filters = 2 * q * (0.54 + 0.46 * torch.cos(2 * torch.pi * q))
filters /= filters.max() # 0-1 normalization
filters = filters * (1 - 1 / n_tilts) + 1 / n_tilts # start at 1 / N
else:
raise ValueError("Invalid weighting option provided for FBP.")

weighted = torch.fft.irfftn(
torch.fft.rfftn(aligned_ts, dim=(-2, -1)) * filters, dim=(-2, -1)
)
if len(weighted.shape) == 2: # rfftn gets rid of batch dimension: add it back
weighted = einops.rearrange(weighted, "h w -> 1 h w")

# time for real space back projection
s0 = T(-tomogram_center)
r0 = Ry(tilt_angles, zyx=True)
s1 = T(F.pad(transformed_image_center, pad=(1, 0), value=0))
M = einops.rearrange(s1 @ r0 @ s0, "... i j -> ... 1 1 i j").to(device)

reconstruction = torch.zeros(
tomogram_dimensions, dtype=torch.float32, device=device
)
grid = homogenise_coordinates(coordinate_grid(tomogram_dimensions, device=device))
grid = einops.rearrange(grid, "d h w coords -> d h w coords 1")

for i in range(n_tilts): # could do all grids simultaneously
grid_t = M[i] @ grid
grid_t = einops.rearrange(grid_t, "... d h w coords 1 -> ... d h w coords")[
..., :3
].contiguous()
grid_sample_coordinates = array_to_grid_sample(grid_t, tomogram_dimensions)
reconstruction += torch.squeeze(
F.grid_sample(
einops.rearrange(weighted[i], "h w -> 1 1 1 h w"),
einops.rearrange(
grid_sample_coordinates, "d h w coords -> 1 d h w coords"
),
align_corners=True,
mode="bilinear",
)
)
return reconstruction, aligned_ts
Loading

0 comments on commit 6fb1951

Please sign in to comment.