Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU deskew #146

Merged
merged 15 commits into from
Jul 11, 2024
43 changes: 20 additions & 23 deletions mantis/analysis/deskew.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import numpy as np
import scipy
import torch

from monai import transforms
from monai.transforms.spatial.array import Affine


def _average_n_slices(data, average_window_width=1):
Expand Down Expand Up @@ -139,53 +142,47 @@ def deskew_data(
If false, only compute the deskewed volume within a cuboid region.
average_n_slices : int, optional
after deskewing, averages every n slices (default = 1 applies no averaging)
order : int, optional
interpolation order (default 1 is linear interpolation)
cval : float, optional
fill value area outside of the measured volume (default None fills
with the minimum value of the input array)
Returns
-------
deskewed_data : NDArray with ndim == 3
axis 0 is the Z axis, normal to the coverslip
axis 1 is the Y axis, input axis 2 in the plane of the coverslip
axis 2 is the X axis, the scanning axis
"""
if cval is None:
cval = np.min(np.ravel(raw_data))

# Prepare transforms
Z, Y, X = raw_data.shape

ct = np.cos(ls_angle_deg * np.pi / 180)
Z_shift = 0
if not keep_overhang:
Z_shift = int(np.floor(Y * ct * px_to_scan_ratio))
ieivanov marked this conversation as resolved.
Show resolved Hide resolved

matrix = np.array(
[
[
-px_to_scan_ratio * ct,
0,
px_to_scan_ratio,
Z_shift,
0,
],
[-1, 0, 0, Y - 1],
[0, -1, 0, X - 1],
[-1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, 0, 1],
edyoshikun marked this conversation as resolved.
Show resolved Hide resolved
]
)
output_shape, _ = get_deskewed_data_shape(
raw_data.shape, ls_angle_deg, px_to_scan_ratio, keep_overhang
)

# Apply transforms
deskewed_data = scipy.ndimage.affine_transform(
raw_data,
matrix,
output_shape=output_shape,
order=order,
cval=cval,
)
# to tensor on GPU
if torch.cuda.is_available():
raw_data = transforms.ToDevice("cuda")(torch.tensor(raw_data))
ieivanov marked this conversation as resolved.
Show resolved Hide resolved

# Returns callable
affine_func = Affine(affine=matrix, padding_mode="zeros", image_only=True)

# affine_func accepts CZYX array, so for ZYX input we need [None] and for ZYX output we need [0]
deskewed_data = affine_func(raw_data[None], mode="bilinear", spatial_size=output_shape)[0]

# to numpy array on CPU
deskewed_data = deskewed_data.cpu().numpy()
ieivanov marked this conversation as resolved.
Show resolved Hide resolved

# Apply averaging
averaged_deskewed_data = _average_n_slices(
Expand Down
5 changes: 5 additions & 0 deletions mantis/cli/deskew.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import List

import click
import torch

from iohub.ngff import open_ome_zarr

Expand All @@ -13,6 +14,10 @@
from mantis.cli.parsing import config_filepath, input_position_dirpaths, output_dirpath
from mantis.cli.utils import yaml_to_model

# Needed for multiprocessing with GPUs
# https://github.com/pytorch/pytorch/issues/40403#issuecomment-1422625325
torch.multiprocessing.set_start_method('spawn', force=True)


@click.command()
@input_position_dirpaths()
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"ndtiff>=2.0",
"nidaqmx",
"numpy",
"monai",
"pandas~=2.1",
"pycromanager==0.28.1",
"pydantic",
Expand Down
Loading