-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
24 changed files
with
1,466 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,3 +8,8 @@ | |
__version__ = "uninstalled" | ||
__author__ = "Marten Chaillet" | ||
__email__ = "[email protected]" | ||
|
||
import logging | ||
import sys | ||
logging.basicConfig(stream=sys.stdout, level=logging.INFO) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .affine_transform import affine_transform_2d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch | ||
import einops | ||
from typing import Sequence | ||
import torch.nn.functional as F | ||
from torch_grid_utils import coordinate_grid | ||
|
||
from tttsa.utils import homogenise_coordinates, array_to_grid_sample | ||
|
||
|
||
def affine_transform_2d( | ||
images: torch.Tensor, # shape: '... h w' | ||
affine_matrices: torch.Tensor, # shape: '... 3 3' | ||
out_shape: Sequence[int] | None = None, | ||
interpolation: str = "bicubic", | ||
): | ||
"""Affine transform 1 or a batch of images.""" | ||
if out_shape is None: | ||
out_shape = images.shape[-2:] | ||
device = images.device | ||
grid = homogenise_coordinates(coordinate_grid(out_shape, device=device)) | ||
grid = einops.rearrange(grid, "h w coords -> 1 h w coords 1") | ||
M = einops.rearrange(affine_matrices, "... i j -> ... 1 1 i j").to(device) | ||
grid = M @ grid | ||
grid = einops.rearrange(grid, "... h w coords 1 -> ... h w coords")[ | ||
..., :2 | ||
].contiguous() | ||
grid_sample_coordinates = array_to_grid_sample(grid, images.shape[-2:]) | ||
if images.dim() == 2: # needed for grid sample | ||
images = einops.repeat(images, "h w -> n h w", n=M.shape[0]) | ||
transformed = einops.rearrange( | ||
F.grid_sample( | ||
einops.rearrange(images, "... h w -> ... 1 h w"), | ||
grid_sample_coordinates, | ||
align_corners=True, | ||
mode=interpolation, | ||
), | ||
"... 1 h w -> ... h w", | ||
).squeeze() # remove starter dimensions in case we got one image | ||
return transformed | ||
|
||
|
||
# TODO write some functions like | ||
# def affine_transform_3d( | ||
# volumes: torch.Tensor, # shape: 'n d h w' | ||
# affine_matrices: torch.Tensor, # shape: 'n 4 4' | ||
# interpolation: str = 'bilinear', # is actually trilinear in grid_sample | ||
# ): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .find_shift import find_image_shift |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import numpy as np | ||
import torch | ||
|
||
from tttsa.correlation import correlate_2d | ||
from tttsa.utils import dft_center | ||
|
||
|
||
def find_image_shift( | ||
image_a: torch.Tensor, | ||
image_b: torch.Tensor, | ||
) -> tuple[torch.Tensor, torch.Tensor]: | ||
"""Find the shift between image a and b. | ||
Applying the shift to b aligns it with | ||
image a. The region around the maximum in the correlation image is by default | ||
upsampled with bicubic interpolation to find a more precise shift. | ||
Parameters | ||
---------- | ||
image_a: torch.Tensor | ||
`(h, w)` image. | ||
image_b: torch.Tensor | ||
`(h, w)` image with the same shape as image_a | ||
mask: torch.Tensor | None, default None | ||
`(h, w)` mask used for normalization | ||
Returns | ||
------- | ||
shift, correlation: torch.Tensor, torch.Tensor | ||
`(2, )` shift in y and x; and maximal correlation | ||
""" | ||
center = dft_center( | ||
image_a.shape, rfft=False, fftshifted=True, device=image_a.device | ||
) | ||
|
||
# calculate initial shift with integer precision | ||
correlation = correlate_2d(image_a, image_b, normalize=True) | ||
maximum_idx = torch.tensor( # explicitly put tensor on CPU in case input is on GPU | ||
np.unravel_index(correlation.argmax().cpu(), shape=image_a.shape), | ||
device=image_a.device, | ||
) | ||
y, x = maximum_idx | ||
# Parabolic interpolation in the y direction | ||
f_y0 = correlation[y - 1, x] | ||
f_y1 = correlation[y, x] | ||
f_y2 = correlation[y + 1, x] | ||
subpixel_dy = y + 0.5 * (f_y0 - f_y2) / (f_y0 - 2 * f_y1 + f_y2) | ||
|
||
# Parabolic interpolation in the x direction | ||
f_x0 = correlation[y, x - 1] | ||
f_x1 = correlation[y, x] | ||
f_x2 = correlation[y, x + 1] | ||
subpixel_dx = x + 0.5 * (f_x0 - f_x2) / (f_x0 - 2 * f_x1 + f_x2) | ||
|
||
shift = torch.tensor([subpixel_dy, subpixel_dx]) - center | ||
return shift |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
import pytest | ||
|
||
from libtilt.alignment import find_image_shift | ||
|
||
|
||
def test_find_image_shift(): | ||
a = torch.zeros((4, 4)) | ||
a[1, 1] = 1 | ||
b = torch.zeros((4, 4)) | ||
b[2, 2] = .7 | ||
b[2, 3] = .3 | ||
|
||
with pytest.raises(ValueError, match=r'Upsampling factor .*'): | ||
find_image_shift(a, b, upsampling_factor=0.5) | ||
|
||
shift = find_image_shift(a, b, upsampling_factor=5) | ||
assert torch.all(shift == -1), ("Interpolating a shift too close to a border is " | ||
"not possible, so an integer shift should be " | ||
"returned.") | ||
|
||
a = torch.zeros((8, 8)) | ||
a[3, 3] = 1 | ||
b = torch.zeros((8, 8)) | ||
b[4, 4] = .7 | ||
b[4, 5] = .3 | ||
shift = find_image_shift(a, b, upsampling_factor=1) | ||
assert torch.all(shift == -1), ("Finding shift with upsampling_factor of 1 should " | ||
"return an integer shift (i.e. no interpolation.") | ||
|
||
shift = find_image_shift(a, b) | ||
assert shift[0] == -1.1, "y shift should be interpolated to specific value." | ||
assert shift[1] == -1.2, "x shift should be interpolated to specific value." | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .filtered_back_projection import filtered_back_projection_3d |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,154 @@ | ||
import torch | ||
import einops | ||
import torch.nn.functional as F | ||
from torch_grid_utils import coordinate_grid | ||
|
||
from tttsa.transformations import R_2d, T_2d, T, Ry | ||
from tttsa.utils import dft_center, homogenise_coordinates, array_to_grid_sample | ||
|
||
|
||
def filtered_back_projection_3d( | ||
tilt_series, | ||
tomogram_dimensions, | ||
tilt_angles, | ||
tilt_axis_angles, | ||
shifts, | ||
weighting: str = "exact", | ||
object_diameter: float | None = None, | ||
): | ||
"""Run weighted back projection incorporating some alignment parameters. | ||
weighting: str, default "hamming" | ||
all filters here start at 1/N (instead of 0 for ramp and hamming) which | ||
improves the low res signal upon forward projection of the reconstruction | ||
Options: | ||
- "ramp": increases linearly from 1/N to 1 from the zero frequency to | ||
nyquist | ||
- "exact": is based on and improves low-res signal on forward projection: | ||
Reference : Optik, Exact filters for general geometry three-dimensional | ||
reconstruction, vol.73,146,1986. | ||
- "hamming": modified hamming as used in AreTomo, further modified here to | ||
also start a 1/N | ||
object_diameter: float | None, default None | ||
object diameter specified in number of pixels, only needed for the exact filter | ||
""" | ||
# initializes sizes | ||
device = tilt_series.device | ||
n_tilts, h, w = tilt_series.shape # for simplicity assume square images | ||
tilt_image_dimensions = (h, w) | ||
transformed_image_dimensions = tomogram_dimensions[-2:] | ||
tomogram_center = dft_center(tomogram_dimensions, rfft=False, fftshifted=True) | ||
tilt_image_center = dft_center(tilt_image_dimensions, rfft=False, fftshifted=True) | ||
transformed_image_center = dft_center(transformed_image_dimensions, rfft=False, | ||
fftshifted=True) | ||
_, filter_size = transformed_image_dimensions | ||
|
||
# generate the 2d alignment affine matrix | ||
s0 = T_2d(-transformed_image_center) | ||
r0 = R_2d(tilt_axis_angles, yx=True) | ||
s1 = T_2d(-shifts) | ||
s2 = T_2d(tilt_image_center) | ||
M = einops.rearrange((s2 @ s1 @ r0 @ s0), "... i j -> ... 1 1 i j").to(device) | ||
|
||
grid = homogenise_coordinates(coordinate_grid(transformed_image_dimensions, device=device)) | ||
grid = einops.rearrange(grid, "h w coords -> h w coords 1") | ||
grid = M @ grid | ||
grid = einops.rearrange(grid, "... d h w coords 1 -> ... d h w coords")[ | ||
..., :2 | ||
].contiguous() | ||
grid_sample_coordinates = array_to_grid_sample(grid, tilt_image_dimensions) | ||
aligned_ts = torch.squeeze( | ||
F.grid_sample( | ||
einops.rearrange(tilt_series, "n h w -> n 1 h w"), | ||
grid_sample_coordinates, | ||
align_corners=True, | ||
mode="bicubic", | ||
) | ||
) | ||
|
||
# generate weighting function and apply to aligned tilt series | ||
if weighting == "exact": | ||
if object_diameter is None: | ||
raise ValueError( | ||
"Calculation of exact weighting requires an object " "diameter." | ||
) | ||
if len(tilt_angles) == 1: | ||
filters = 1 | ||
else: # slice_width could be provided as a function argument it can be | ||
# calculated as: (pixel_size * 2 * imdim) / object_diameter | ||
q = einops.rearrange( | ||
torch.arange( | ||
filter_size // 2 + filter_size % 2 + 1, dtype=torch.float32, | ||
device=device | ||
) | ||
/ filter_size, | ||
"q -> 1 1 q", | ||
) | ||
sampling = torch.sin( | ||
torch.deg2rad( | ||
torch.abs(einops.rearrange(tilt_angles, "n -> n 1") - tilt_angles) | ||
) | ||
).to(device) | ||
sampling = einops.rearrange(sampling, "n m -> n m 1") | ||
q_overlap_inv = sampling / (2 / object_diameter) | ||
over_weighting = 1 - torch.clip(q * q_overlap_inv, min=0, max=1) | ||
filters = 1 / einops.reduce(over_weighting, "n m q -> n q", "sum") | ||
filters = einops.rearrange(filters, "n w -> n 1 w") | ||
elif weighting == "ramp": | ||
filters = torch.arange( | ||
filter_size // 2 + filter_size % 2 + 1, dtype=torch.float32, device=device | ||
) | ||
filters /= filters.max() | ||
filters = filters * (1 - 1 / n_tilts) + 1 / n_tilts # start at 1 / N | ||
elif weighting == "hamming": # AreTomo3 code uses a modified hamming window | ||
# 2 * q * (0.55f + 0.45f * cosf(6.2831852f * q)) # with q from 0 to .5 (Ny) | ||
# https://github.com/czimaginginstitute/AreTomo3/blob/ | ||
# c39dcdad9525ee21d7308a95622f3d47fe7ab4b9/AreTomo/Recon/GRWeight.cu#L20 | ||
q = ( | ||
torch.arange(filter_size // 2 + filter_size % 2 + 1, dtype=torch.float32, | ||
device=device) | ||
/ filter_size | ||
) | ||
# regular hamming: q * (.54 + .46 * torch.cos(torch.pi * q)) | ||
filters = 2 * q * (0.54 + 0.46 * torch.cos(2 * torch.pi * q)) | ||
filters /= filters.max() # 0-1 normalization | ||
filters = filters * (1 - 1 / n_tilts) + 1 / n_tilts # start at 1 / N | ||
else: | ||
raise ValueError("Invalid weighting option provided for FBP.") | ||
|
||
weighted = torch.fft.irfftn( | ||
torch.fft.rfftn(aligned_ts, dim=(-2, -1)) * filters, dim=(-2, -1) | ||
) | ||
if len(weighted.shape) == 2: # rfftn gets rid of batch dimension: add it back | ||
weighted = einops.rearrange(weighted, "h w -> 1 h w") | ||
|
||
# time for real space back projection | ||
s0 = T(-tomogram_center) | ||
r0 = Ry(tilt_angles, zyx=True) | ||
s1 = T(F.pad(transformed_image_center, pad=(1, 0), value=0)) | ||
M = einops.rearrange(s1 @ r0 @ s0, "... i j -> ... 1 1 i j").to(device) | ||
|
||
reconstruction = torch.zeros( | ||
tomogram_dimensions, dtype=torch.float32, device=device | ||
) | ||
grid = homogenise_coordinates(coordinate_grid(tomogram_dimensions, device=device)) | ||
grid = einops.rearrange(grid, "d h w coords -> d h w coords 1") | ||
|
||
for i in range(n_tilts): # could do all grids simultaneously | ||
grid_t = M[i] @ grid | ||
grid_t = einops.rearrange(grid_t, "... d h w coords 1 -> ... d h w coords")[ | ||
..., :3 | ||
].contiguous() | ||
grid_sample_coordinates = array_to_grid_sample(grid_t, tomogram_dimensions) | ||
reconstruction += torch.squeeze( | ||
F.grid_sample( | ||
einops.rearrange(weighted[i], "h w -> 1 1 1 h w"), | ||
einops.rearrange( | ||
grid_sample_coordinates, "d h w coords -> 1 d h w coords" | ||
), | ||
align_corners=True, | ||
mode="bilinear", | ||
) | ||
) | ||
return reconstruction, aligned_ts |
Oops, something went wrong.