Skip to content

Commit

Permalink
GPU deskew (#146)
Browse files Browse the repository at this point in the history
* non-working draft

* minimal monai deskew

* remove profiling

* isort

* depend on monai

* clean up

* revert changed test

* remove unused `order` and `cval` options

* use `to.("cuda")` instead of `transforms.ToDevice("cuda")`

* keep data on CPU by default

* style

* bumping iohub as well

* Update raw_data type casting

Co-authored-by: Ziwen Liu <[email protected]>

---------

Co-authored-by: Ivan Ivanov <[email protected]>
Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]>
Co-authored-by: Ziwen Liu <[email protected]>
  • Loading branch information
4 people authored Jul 11, 2024
1 parent 8c68275 commit 162da51
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 31 deletions.
56 changes: 26 additions & 30 deletions mantis/analysis/deskew.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import scipy
import torch

from monai.transforms.spatial.array import Affine


def _average_n_slices(data, average_window_width=1):
Expand Down Expand Up @@ -116,8 +118,7 @@ def deskew_data(
px_to_scan_ratio: float,
keep_overhang: bool,
average_n_slices: int = 1,
order: int = 1,
cval: float = None,
device='cpu',
):
"""Deskews fluorescence data from the mantis microscope
Parameters
Expand All @@ -127,65 +128,60 @@ def deskew_data(
- axis 0 corresponds to the scanning axis
- axis 1 corresponds to the "tilted" axis
- axis 2 corresponds to the axis in the plane of the coverslip
ls_angle_deg : float
angle of light sheet with respect to the optical axis in degrees
px_to_scan_ratio : float
(pixel spacing / scan spacing) in object space
e.g. if camera pixels = 6.5 um and mag = 1.4*40, then the pixel spacing
is 6.5/(1.4*40) = 0.116 um. If the scan spacing is 0.3 um, then
px_to_scan_ratio = 0.116 / 0.3 = 0.386
ls_angle_deg : float
angle of light sheet with respect to the optical axis in degrees
keep_overhang : bool
If true, compute the whole volume within the tilted parallelepiped.
If false, only compute the deskewed volume within a cuboid region.
average_n_slices : int, optional
after deskewing, averages every n slices (default = 1 applies no averaging)
order : int, optional
interpolation order (default 1 is linear interpolation)
cval : float, optional
fill value area outside of the measured volume (default None fills
with the minimum value of the input array)
device : str, optional
torch device to use for computation. Default is 'cpu'.
Returns
-------
deskewed_data : NDArray with ndim == 3
axis 0 is the Z axis, normal to the coverslip
axis 1 is the Y axis, input axis 2 in the plane of the coverslip
axis 2 is the X axis, the scanning axis
"""
if cval is None:
cval = np.min(np.ravel(raw_data))

# Prepare transforms
Z, Y, X = raw_data.shape

ct = np.cos(ls_angle_deg * np.pi / 180)
Z_shift = 0
if not keep_overhang:
Z_shift = int(np.floor(Y * ct * px_to_scan_ratio))

matrix = np.array(
[
[
-px_to_scan_ratio * ct,
0,
px_to_scan_ratio,
Z_shift,
0,
],
[-1, 0, 0, Y - 1],
[0, -1, 0, X - 1],
[-1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, 0, 1],
]
)
output_shape, _ = get_deskewed_data_shape(
raw_data.shape, ls_angle_deg, px_to_scan_ratio, keep_overhang
)

# Apply transforms
deskewed_data = scipy.ndimage.affine_transform(
raw_data,
matrix,
output_shape=output_shape,
order=order,
cval=cval,
)
# convert to tensor on GPU
# convert raw_data to int32 if it is uint16
raw_data_tensor = torch.from_numpy(raw_data.astype(np.float32)).to(device)

# Returns callable
affine_func = Affine(affine=matrix, padding_mode="zeros", image_only=True)

# affine_func accepts CZYX array, so for ZYX input we need [None] and for ZYX output we need [0]
deskewed_data = affine_func(
raw_data_tensor[None], mode="bilinear", spatial_size=output_shape
)[0]

# to numpy array on CPU
deskewed_data = deskewed_data.cpu().numpy()

# Apply averaging
averaged_deskewed_data = _average_n_slices(
Expand Down
9 changes: 9 additions & 0 deletions mantis/cli/deskew.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

import click
import torch

from iohub.ngff import open_ome_zarr

Expand All @@ -11,6 +12,10 @@
from mantis.cli.parsing import config_filepath, input_position_dirpaths, output_dirpath
from mantis.cli.utils import yaml_to_model

# Needed for multiprocessing with GPUs
# https://github.com/pytorch/pytorch/issues/40403#issuecomment-1422625325
torch.multiprocessing.set_start_method('spawn', force=True)


@click.command()
@input_position_dirpaths()
Expand Down Expand Up @@ -86,3 +91,7 @@ def deskew(
num_processes=num_processes,
**deskew_args,
)


if __name__ == "__main__":
deskew()
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
# list package dependencies here
dependencies = [
"copylot @ git+https://github.com/czbiohub-sf/coPylot",
"iohub==0.1.0.dev5",
"iohub==0.1.0",
"matplotlib",
"napari; 'arm64' in platform_machine", # without Qt5 and skimage
"napari[all]; 'arm64' not in platform_machine", # with Qt5 and skimage
Expand All @@ -30,6 +30,7 @@ dependencies = [
"ndtiff>=2.0",
"nidaqmx",
"numpy<2",
"monai",
"pandas~=2.1",
"pycromanager==0.28.1",
"pydantic",
Expand Down

0 comments on commit 162da51

Please sign in to comment.