Skip to content

Commit

Permalink
Merge branch 'main' into concat_cli
Browse files Browse the repository at this point in the history
  • Loading branch information
edyoshikun committed Jul 15, 2024
2 parents 3133e56 + 162da51 commit 5828aa0
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 36 deletions.
56 changes: 26 additions & 30 deletions mantis/analysis/deskew.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import scipy
import torch

from monai.transforms.spatial.array import Affine


def _average_n_slices(data, average_window_width=1):
Expand Down Expand Up @@ -116,8 +118,7 @@ def deskew_data(
px_to_scan_ratio: float,
keep_overhang: bool,
average_n_slices: int = 1,
order: int = 1,
cval: float = None,
device='cpu',
):
"""Deskews fluorescence data from the mantis microscope
Parameters
Expand All @@ -127,65 +128,60 @@ def deskew_data(
- axis 0 corresponds to the scanning axis
- axis 1 corresponds to the "tilted" axis
- axis 2 corresponds to the axis in the plane of the coverslip
ls_angle_deg : float
angle of light sheet with respect to the optical axis in degrees
px_to_scan_ratio : float
(pixel spacing / scan spacing) in object space
e.g. if camera pixels = 6.5 um and mag = 1.4*40, then the pixel spacing
is 6.5/(1.4*40) = 0.116 um. If the scan spacing is 0.3 um, then
px_to_scan_ratio = 0.116 / 0.3 = 0.386
ls_angle_deg : float
angle of light sheet with respect to the optical axis in degrees
keep_overhang : bool
If true, compute the whole volume within the tilted parallelepiped.
If false, only compute the deskewed volume within a cuboid region.
average_n_slices : int, optional
after deskewing, averages every n slices (default = 1 applies no averaging)
order : int, optional
interpolation order (default 1 is linear interpolation)
cval : float, optional
fill value area outside of the measured volume (default None fills
with the minimum value of the input array)
device : str, optional
torch device to use for computation. Default is 'cpu'.
Returns
-------
deskewed_data : NDArray with ndim == 3
axis 0 is the Z axis, normal to the coverslip
axis 1 is the Y axis, input axis 2 in the plane of the coverslip
axis 2 is the X axis, the scanning axis
"""
if cval is None:
cval = np.min(np.ravel(raw_data))

# Prepare transforms
Z, Y, X = raw_data.shape

ct = np.cos(ls_angle_deg * np.pi / 180)
Z_shift = 0
if not keep_overhang:
Z_shift = int(np.floor(Y * ct * px_to_scan_ratio))

matrix = np.array(
[
[
-px_to_scan_ratio * ct,
0,
px_to_scan_ratio,
Z_shift,
0,
],
[-1, 0, 0, Y - 1],
[0, -1, 0, X - 1],
[-1, 0, 0, 0],
[0, -1, 0, 0],
[0, 0, 0, 1],
]
)
output_shape, _ = get_deskewed_data_shape(
raw_data.shape, ls_angle_deg, px_to_scan_ratio, keep_overhang
)

# Apply transforms
deskewed_data = scipy.ndimage.affine_transform(
raw_data,
matrix,
output_shape=output_shape,
order=order,
cval=cval,
)
# convert to tensor on GPU
# convert raw_data to int32 if it is uint16
raw_data_tensor = torch.from_numpy(raw_data.astype(np.float32)).to(device)

# Returns callable
affine_func = Affine(affine=matrix, padding_mode="zeros", image_only=True)

# affine_func accepts CZYX array, so for ZYX input we need [None] and for ZYX output we need [0]
deskewed_data = affine_func(
raw_data_tensor[None], mode="bilinear", spatial_size=output_shape
)[0]

# to numpy array on CPU
deskewed_data = deskewed_data.cpu().numpy()

# Apply averaging
averaged_deskewed_data = _average_n_slices(
Expand Down
13 changes: 10 additions & 3 deletions mantis/cli/deskew.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import multiprocessing as mp

from pathlib import Path
from typing import List

import click
import torch

from iohub.ngff import open_ome_zarr

Expand All @@ -13,6 +12,10 @@
from mantis.cli.parsing import config_filepath, input_position_dirpaths, output_dirpath
from mantis.cli.utils import yaml_to_model

# Needed for multiprocessing with GPUs
# https://github.com/pytorch/pytorch/issues/40403#issuecomment-1422625325
torch.multiprocessing.set_start_method('spawn', force=True)


@click.command()
@input_position_dirpaths()
Expand All @@ -21,7 +24,7 @@
@click.option(
"--num-processes",
"-j",
default=mp.cpu_count(),
default=1,
help="Number of cores",
required=False,
type=int,
Expand Down Expand Up @@ -88,3 +91,7 @@ def deskew(
num_processes=num_processes,
**deskew_args,
)


if __name__ == "__main__":
deskew()
6 changes: 5 additions & 1 deletion mantis/cli/estimate_stabilization.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,11 @@ def estimate_xy_stabilization(
if (output_folder_path / "positions_focus.csv").exists():
df = pd.read_csv(output_folder_path / "positions_focus.csv")
pos_idx = str(Path(*input_data_path.parts[-3:]))
z_idx = list(df[df["position"] == pos_idx]["focus_idx"].replace(0, np.nan).ffill())
focus_idx = df[df["position"] == pos_idx]["focus_idx"]
# forward fill 0 values, when replace remaining NaN with the mean
focus_idx = focus_idx.replace(0, np.nan).ffill()
focus_idx = focus_idx.fillna(focus_idx.mean())
z_idx = focus_idx.astype(int).to_list()
else:
z_idx = [
focus_from_transverse_band(
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,16 @@ classifiers = [
# list package dependencies here
dependencies = [
"copylot @ git+https://github.com/czbiohub-sf/coPylot",
"iohub==0.1.0.dev5",
"iohub==0.1.0",
"matplotlib",
"napari; 'arm64' in platform_machine", # without Qt5 and skimage
"napari[all]; 'arm64' not in platform_machine", # with Qt5 and skimage
"PyQt6; 'arm64' in platform_machine",
"natsort",
"ndtiff>=2.0",
"nidaqmx",
"numpy",
"numpy<2",
"monai",
"pandas~=2.1",
"pycromanager==0.28.1",
"pydantic",
Expand Down

0 comments on commit 5828aa0

Please sign in to comment.