diff --git a/mantis/analysis/AnalysisSettings.py b/mantis/analysis/AnalysisSettings.py
index d7fc3f7a..98b75f2f 100644
--- a/mantis/analysis/AnalysisSettings.py
+++ b/mantis/analysis/AnalysisSettings.py
@@ -1,3 +1,5 @@
+import warnings
+
from typing import Literal, Optional, Union
import numpy as np
@@ -10,6 +12,11 @@ class MyBaseModel(BaseModel, extra=Extra.forbid):
pass
+class ProcessingSettings(MyBaseModel):
+ fliplr: Optional[bool] = False
+ flipud: Optional[bool] = False
+
+
class DeskewSettings(MyBaseModel):
pixel_size_um: PositiveFloat
ls_angle_deg: PositiveFloat
@@ -87,3 +94,27 @@ def check_affine_transform_list(cls, v):
raise ValueError("Each element in affine_transform_list must be a 4x4 ndarray")
return v
+
+
+class StitchSettings(MyBaseModel):
+ channels: Optional[list[str]] = None
+ preprocessing: Optional[ProcessingSettings] = None
+ postprocessing: Optional[ProcessingSettings] = None
+ column_translation: Optional[list[float, float]] = None
+ row_translation: Optional[list[float, float]] = None
+ total_translation: Optional[dict[str, list[float, float]]] = None
+
+ def __init__(self, **data):
+ if data.get("total_translation") is None:
+ if any(
+ (data.get("column_translation") is None, data.get("row_translation") is None)
+ ):
+ raise ValueError(
+ "If total_translation is not provided, both column_translation and row_translation must be provided"
+ )
+ else:
+ warnings.warn(
+ "column_translation and row_translation are deprecated. Use total_translation instead.",
+ DeprecationWarning,
+ )
+ super().__init__(**data)
diff --git a/mantis/analysis/settings/example_stitch_settings.yml b/mantis/analysis/settings/example_stitch_settings.yml
new file mode 100644
index 00000000..0aa6ef33
--- /dev/null
+++ b/mantis/analysis/settings/example_stitch_settings.yml
@@ -0,0 +1,28 @@
+channels: [Phase3D] # may be null in which case all channels will be stitched
+preprocessing:
+ fliplr: true
+ flipud: false
+postprocessing:
+ fliplr: false
+ flipud: true
+total_translation: # translation distance in (x, y) in pixels for each image
+ 0/2/000000:
+ - 0.0
+ - 39.7
+ 0/2/000001:
+ - 883.75
+ - 39.7
+ 0/2/001000:
+ - 1.8
+ - 920.5
+ 0/2/001001:
+ - 885.55
+ - 920.7
+# Instead of computing a total (x, y) shift for each image using estimate-stitch
+# you can also supply column_translation and row_translation as (x, y) lists
+# that will be applied to all images. column_translation and row_translation
+# should be approximately (950, 0) and (0, 950) for 1000x1000 images with 5% overlap.
+# This method of stitching images is being deprecated as the stage often does not
+# make reproducible movements.
+column_translation:
+row_translation:
diff --git a/mantis/analysis/stitch.py b/mantis/analysis/stitch.py
new file mode 100644
index 00000000..421912a7
--- /dev/null
+++ b/mantis/analysis/stitch.py
@@ -0,0 +1,539 @@
+from pathlib import Path
+from typing import Literal
+
+import click
+import dask.array as da
+import numpy as np
+import pandas as pd
+import scipy.ndimage as ndi
+
+from iohub import open_ome_zarr
+from skimage.registration import phase_cross_correlation
+
+from mantis.analysis.AnalysisSettings import ProcessingSettings
+
+
+def estimate_shift(
+ im0: np.ndarray, im1: np.ndarray, percent_overlap: float, direction: Literal["row", "col"]
+):
+ """
+ Estimate the shift between two images based on a given percentage overlap and direction.
+
+ Parameters
+ ----------
+ im0 : np.ndarray
+ The first image.
+ im1 : np.ndarray
+ The second image.
+ percent_overlap : float
+ The percentage of overlap between the two images. Must be between 0 and 1.
+ direction : Literal["row", "col"]
+ The direction of the shift. Can be either "row" or "col". See estimate_zarr_fov_shifts
+
+ Returns
+ -------
+ np.ndarray
+ The estimated shift between the two images.
+
+ Raises
+ ------
+ AssertionError
+ If percent_overlap is not between 0 and 1.
+ If direction is not "row" or "col".
+ If the shape of im0 and im1 are not the same.
+ """
+ assert 0 <= percent_overlap <= 1, "percent_overlap must be between 0 and 1"
+ assert direction in ["row", "col"], "direction must be either 'row' or 'col'"
+ assert im0.shape == im1.shape, "Images must have the same shape"
+
+ sizeY, sizeX = im0.shape[-2:]
+
+ # TODO: there may be a one pixel error in the estimated shift
+ if direction == "row":
+ y_roi = int(sizeY * np.minimum(percent_overlap + 0.05, 1))
+ shift, _, _ = phase_cross_correlation(
+ im0[-y_roi:, :], im1[:y_roi, :], upsample_factor=10
+ )
+ shift[0] += sizeY - y_roi
+ elif direction == "col":
+ x_roi = int(sizeX * np.minimum(percent_overlap + 0.05, 1))
+ shift, _, _ = phase_cross_correlation(
+ im0[:, -x_roi:], im1[:, :x_roi], upsample_factor=10
+ )
+ shift[1] += sizeX - x_roi
+
+ # TODO: we shouldn't need to flip the order
+ return shift[::-1]
+
+
+def get_grid_rows_cols(dataset_path: str):
+ grid_rows = set()
+ grid_cols = set()
+
+ with open_ome_zarr(dataset_path) as dataset:
+
+ _, well = next(dataset.wells())
+ for position_name, _ in well.positions():
+ fov_name = Path(position_name).parts[-1]
+ grid_rows.add(fov_name[3:]) # 1-Pos
_ syntax
+ grid_cols.add(fov_name[:3])
+
+ return sorted(grid_rows), sorted(grid_cols)
+
+
+def get_stitch_output_shape(n_rows, n_cols, sizeY, sizeX, col_translation, row_translation):
+ """
+ Compute the output shape of the stitched image and the global translation when only col and row translation are given
+ """
+ global_translation = (
+ np.ceil(np.abs(np.minimum(row_translation[0] * (n_rows - 1), 0))).astype(int),
+ np.ceil(np.abs(np.minimum(col_translation[1] * (n_cols - 1), 0))).astype(int),
+ )
+ xy_output_shape = (
+ np.ceil(
+ sizeY
+ + col_translation[1] * (n_cols - 1)
+ + row_translation[1] * (n_rows - 1)
+ + global_translation[1]
+ ).astype(int),
+ np.ceil(
+ sizeX
+ + col_translation[0] * (n_cols - 1)
+ + row_translation[0] * (n_rows - 1)
+ + global_translation[0]
+ ).astype(int),
+ )
+ return xy_output_shape, global_translation
+
+
+def get_image_shift(col_idx, row_idx, col_translation, row_translation, global_translation):
+ """
+ Compute total translation when only col and row translation are given
+ """
+ total_translation = (
+ col_translation[1] * col_idx + row_translation[1] * row_idx + global_translation[1],
+ col_translation[0] * col_idx + row_translation[0] * row_idx + global_translation[0],
+ )
+
+ return total_translation
+
+
+def shift_image(
+ czyx_data: np.ndarray,
+ yx_output_shape: tuple[float, float],
+ yx_shift: tuple[float, float],
+ verbose: bool = False,
+) -> np.ndarray:
+ assert czyx_data.ndim == 4, "Input data must be a CZYX array"
+ C, Z, Y, X = czyx_data.shape
+
+ if verbose:
+ print(f"Shifting image by {yx_shift}")
+ # Create array of output_shape and put input data at (0, 0)
+ output = np.zeros((C, Z) + yx_output_shape, dtype=np.float32)
+ output[..., :Y, :X] = czyx_data
+
+ return ndi.shift(output, (0, 0) + tuple(yx_shift), order=0)
+
+
+def _stitch_images(
+ data_array: np.ndarray,
+ total_translation: dict[str : tuple[float, float]] = None,
+ percent_overlap: float = None,
+ col_translation: float | tuple[float, float] = None,
+ row_translation: float | tuple[float, float] = None,
+) -> np.ndarray:
+ """
+ Stitch an array of 2D images together to create a larger composite image.
+ This function is not actively maintained.
+
+ Parameters
+ ----------
+ data_array : np.ndarray
+ The data array to with shape (ROWS, COLS, Y, X) that will be stitched. Call this function multiple
+ times to stitch multiple channels, slices, or time points.
+ total_translation : dict[str: tuple[float, float]], optional
+ Shift to be applied to each fov, given as {fov: (y_shift, x_shift)}. Defaults to None.
+ percent_overlap : float, optional
+ The percentage of overlap between adjacent images. Must be between 0 and 1. Defaults to None.
+ col_translation : float | tuple[float, float], optional
+ The translation distance in pixels in the column direction. Can be a single value or a tuple
+ of (x_translation, y_translation) when moving across columns. Defaults to None.
+ row_translation : float | tuple[float, float], optional
+ See col_translation. Defaults to None.
+
+ Returns
+ -------
+ np.ndarray
+ The stitched composite 2D image
+
+ Raises
+ ------
+ AssertionError
+ If percent_overlap is not between 0 and 1.
+
+ """
+
+ n_rows, n_cols, sizeY, sizeX = data_array.shape
+
+ if total_translation is None:
+ if percent_overlap is not None:
+ assert 0 <= percent_overlap <= 1, "percent_overlap must be between 0 and 1"
+ col_translation = sizeX * (1 - percent_overlap)
+ row_translation = sizeY * (1 - percent_overlap)
+ if not isinstance(col_translation, tuple):
+ col_translation = (col_translation, 0)
+ if not isinstance(row_translation, tuple):
+ row_translation = (0, row_translation)
+ xy_output_shape, global_translation = get_stitch_output_shape(
+ n_rows, n_cols, sizeY, sizeX, col_translation, row_translation
+ )
+ else:
+ df = pd.DataFrame.from_dict(
+ total_translation, orient="index", columns=["shift-y", "shift-x"]
+ )
+ xy_output_shape = (
+ np.ceil(df["shift-y"].max() + sizeY).astype(int),
+ np.ceil(df["shift-x"].max() + sizeX).astype(int),
+ )
+ stitched_array = np.zeros(xy_output_shape, dtype=np.float32)
+
+ for row_idx in range(n_rows):
+ for col_idx in range(n_cols):
+ image = data_array[row_idx, col_idx]
+
+ if total_translation is None:
+ shift = get_image_shift(
+ col_idx, row_idx, col_translation, row_translation, global_translation
+ )
+ else:
+ shift = total_translation[f"{col_idx:03d}{row_idx:03d}"]
+
+ warped_image = shift_image(image, xy_output_shape, shift)
+ overlap = np.logical_and(stitched_array, warped_image)
+ stitched_array[:, :] += warped_image
+ stitched_array[overlap] /= 2 # average blending in the overlapping region
+
+ return stitched_array
+
+
+def process_dataset(
+ data_array: np.ndarray | da.Array,
+ settings: ProcessingSettings,
+ verbose: bool = True,
+) -> np.ndarray:
+ flip = np.flip
+ if isinstance(data_array, da.Array):
+ flip = da.flip
+
+ if settings:
+ if settings.flipud:
+ if verbose:
+ click.echo("Flipping data array up-down")
+ data_array = flip(data_array, axis=-2)
+
+ if settings.fliplr:
+ if verbose:
+ click.echo("Flipping data array left-right")
+ data_array = flip(data_array, axis=-1)
+
+ return data_array
+
+
+def preprocess_and_shift(
+ image,
+ settings: ProcessingSettings,
+ output_shape: tuple[int, int],
+ shift_x: float,
+ shift_y: float,
+ verbose=True,
+):
+ return shift_image(
+ process_dataset(image, settings, verbose), output_shape, (shift_y, shift_x), verbose
+ )
+
+
+def blend(array: da.Array, method: Literal["average"] = "average"):
+ """
+ Blend array of pre-shifted images stacked across axis=0
+
+ Parameters
+ ----------
+ array : da.Array
+ Input dask array
+ method : str, optional
+ Blending method. Defaults to "average".
+
+ Raises
+ ------
+ NotImplementedError
+ Raise error is blending method is not implemented.
+
+ Returns
+ -------
+ da.Array
+ Stitched array
+ """
+ if method == "average":
+ # Sum up all images
+ array_sum = array.sum(axis=0)
+ # Count how many images contribute to each pixel in the stitched image
+ array_bool_sum = (array != 0).sum(axis=0)
+ # Replace 0s with 1s to avoid division by zero
+ array_bool_sum[array_bool_sum == 0] = 1
+ # Divide the sum of images by the number of images contributing to each pixel
+ stitched_array = array_sum / array_bool_sum
+ else:
+ raise NotImplementedError(f"Blending method {method} is not implemented")
+
+ return stitched_array
+
+
+def stitch_shifted_store(
+ input_data_path: str,
+ output_data_path: str,
+ settings: ProcessingSettings,
+ blending="average",
+ verbose=True,
+):
+ """
+ Stitch a zarr store of pre-shifted images.
+
+ Parameters
+ ----------
+ input_data_path : str
+ Path to the input zarr store.
+ output_data_path : str
+ Path to the output zarr store.
+ settings : ProcessingSettings
+ Postprocessing settings.
+ blending : str, optional
+ Blending method. Defaults to "average".
+ verbose : bool, optional
+ Whether to print verbose output. Defaults to True.
+ """
+ click.echo(f'Stitching zarr store: {input_data_path}')
+ with open_ome_zarr(input_data_path, mode="r") as input_dataset:
+ for well_name, well in input_dataset.wells():
+ if verbose:
+ click.echo(f'Processing well {well_name}')
+
+ # Stack images along axis=0
+ dask_array = da.stack(
+ [da.from_zarr(pos.data) for _, pos in well.positions()], axis=0
+ )
+
+ # Blend images
+ stitched_array = blend(dask_array, method=blending)
+
+ # Postprocessing
+ stitched_array = process_dataset(stitched_array, settings, verbose)
+
+ # Save stitched array
+ click.echo('Computing and writing data')
+ with open_ome_zarr(
+ Path(output_data_path, well_name, '0'), mode="a"
+ ) as output_image:
+ da.to_zarr(stitched_array, output_image['0'])
+ click.echo(f'Finishing writing data for well {well_name}')
+
+
+def estimate_zarr_fov_shifts(
+ fov0_zarr_path: str,
+ fov1_zarr_path: str,
+ tcz_index: tuple[int, int, int],
+ percent_overlap: float,
+ fliplr: bool,
+ flipud: bool,
+ direction: Literal["row", "col"],
+ output_dirname: str = None,
+):
+ """
+ Estimate shift between two zarr FOVs using phase cross-correlation.Apply flips (fliplr, flipud) as preprocessing step.
+ Phase cross-correlation is computed only across an ROI defined by (percent_overlap + 0.05) for the given direction.
+
+ Parameters
+ ----------
+ fov0_zarr_path : str
+ Path to the first zarr FOV.
+ fov1_zarr_path : str
+ Path to the second zarr FOV.
+ tcz_index : tuple[int, int, int]
+ Index of the time, channel, and z-slice to use for the shift estimation.
+ percent_overlap : float
+ The percentage of overlap between the two FOVs. Can be approximate.
+ fliplr : bool
+ Flag indicating whether to flip the FOVs horizontally before estimating shift.
+ flipud : bool
+ Flag indicating whether to flip the FOVs vertically before estimating shift.
+ direction : Literal["row", "col"]
+ The direction in which to compute the shift.
+ "row" computes vertical overlap with fov1 below fov0.
+ "col" computes horizontal overlap with fov1 to the right of fov0.
+ output_dirname : str, optional
+ The directory to save the output csv file.
+ If None, the function returns a DataFrame with the estimated shift.
+
+ Returns
+ -------
+ pd.DataFrame
+ A DataFrame containing the estimated shift between the two FOVs.
+ """
+ fov0_zarr_path = Path(fov0_zarr_path)
+ fov1_zarr_path = Path(fov1_zarr_path)
+ well_name = Path(*fov0_zarr_path.parts[-3:-1])
+ fov0 = fov0_zarr_path.name
+ fov1 = fov1_zarr_path.name
+ click.echo(f'Estimating shift between FOVs {fov0} and {fov1} in well {well_name}...')
+
+ T, C, Z = tcz_index
+ im0 = open_ome_zarr(fov0_zarr_path).data[T, C, Z]
+ im1 = open_ome_zarr(fov1_zarr_path).data[T, C, Z]
+
+ if fliplr:
+ im0 = np.fliplr(im0)
+ im1 = np.fliplr(im1)
+ if flipud:
+ im0 = np.flipud(im0)
+ im1 = np.flipud(im1)
+
+ shift = estimate_shift(im0, im1, percent_overlap, direction)
+
+ df = pd.DataFrame(
+ {
+ "well": str(well_name),
+ "fov0": fov0,
+ "fov1": fov1,
+ "shift-x": shift[0],
+ "shift-y": shift[1],
+ "direction": direction,
+ },
+ index=[0],
+ )
+ click.echo(f'Estimated shift:\n {df.to_string(index=False)}')
+
+ if output_dirname:
+ df.to_csv(
+ Path(output_dirname, f"{'_'.join(well_name.parts + (fov0, fov1))}_shift.csv"),
+ index=False,
+ )
+ else:
+ return df
+
+
+def consolidate_zarr_fov_shifts(
+ input_dirname: str,
+ output_filepath: str,
+):
+ """
+ Consolidate all csv files in input_dirname into a single csv file.
+
+ Parameters
+ ----------
+ input_dirname : str
+ Directory containing "*_shift.csv" files
+ output_filepath : str
+ Path to output .csv file
+ """
+ # read all csv files in input_dirname and combine into a single dataframe
+ csv_files = Path(input_dirname).rglob("*_shift.csv")
+ df = pd.concat(
+ [pd.read_csv(csv_file, dtype={'fov0': str, 'fov1': str}) for csv_file in csv_files],
+ ignore_index=True,
+ )
+ df.to_csv(output_filepath, index=False)
+
+
+def cleanup_shifts(csv_filepath: str, pixel_size_um: float):
+ """
+ Clean up outlier FOV shifts within a larger grid in case the phase cross-correlation
+ between individual FOVs returned spurious results.
+
+ Since FOVs are acquired in snake fashion, FOVs in a given row should share the same vertical (i.e. row) shift.
+ Hence, the vertical shift for FOVs in a given row is replaced by the median value of all FOVs in that row.
+
+ FOVs across the grid should have similar horizontal (i.e. column) shifts.
+ Values outside of the median +/- MAX_STAGE_ERROR_UM are replaced by the median.
+
+ Parameters
+ ----------
+ csv_filepath : str
+ Path to .csv file containing FOV shifts
+ """
+ MAX_STAGE_ERROR_UM = 5
+ max_stage_error_pix = MAX_STAGE_ERROR_UM / pixel_size_um
+
+ df = pd.read_csv(csv_filepath, dtype={'fov0': str, 'fov1': str})
+ df['shift-x-raw'] = df['shift-x']
+ df['shift-y-raw'] = df['shift-y']
+
+ # replace row shifts with median value calculated across all columns
+ _df = df[df['direction'] == 'row']
+ # group by well and last three characters of fov0
+ groupby = _df.groupby(['well', _df['fov0'].str[-3:]])
+ _df.loc[:, 'shift-x'] = groupby['shift-x-raw'].transform('median')
+ _df.loc[:, 'shift-y'] = groupby['shift-y-raw'].transform('median')
+ df.loc[df['direction'] == 'row', ['shift-x', 'shift-y']] = _df[['shift-x', 'shift-y']]
+
+ # replace col shifts outside of the median +/- MAX_STAGE_ERROR_UM with the median value
+ _df = df[df['direction'] == 'col']
+ x_median, y_median = _df['shift-x-raw'].median(), _df['shift-y-raw'].median()
+ x_low, x_hi = x_median - max_stage_error_pix, x_median + max_stage_error_pix
+ y_low, y_hi = y_median - max_stage_error_pix, y_median + max_stage_error_pix
+ x_outliers = (_df['shift-x-raw'] <= x_low) | (_df['shift-x-raw'] >= x_hi)
+ y_outliers = (_df['shift-y-raw'] <= y_low) | (_df['shift-y-raw'] >= y_hi)
+ outliers = x_outliers | y_outliers
+ num_outliers = sum(outliers)
+
+ _df.loc[outliers, ['shift-x', 'shift-y']] = (x_median, y_median)
+ df.loc[df['direction'] == 'col', ['shift-x', 'shift-y']] = _df[['shift-x', 'shift-y']]
+ if num_outliers > 0:
+ click.echo(f'Replaced {num_outliers} column shift outliers')
+
+ df.to_csv(csv_filepath, index=False)
+
+
+def compute_total_translation(csv_filepath: str) -> pd.DataFrame:
+ """
+ Compute the total translation for each FOV based on the estimated row and col translation shifts.
+
+ Parameters
+ ----------
+ csv_filepath : str
+ Path to .csv file containing FOV shifts
+
+ Returns
+ -------
+ pd.DataFrame
+ Dataframe with total translation shift per FOV
+ """
+ df = pd.read_csv(csv_filepath, dtype={'fov0': str, 'fov1': str})
+
+ # create 'row' and 'col' number columns and sort the dataframe by 'fov1'
+ df['row'] = df['fov1'].str[-3:].astype(int)
+ df['col'] = df['fov1'].str[:3].astype(int)
+ df.set_index('fov1', inplace=True)
+ df.sort_index(inplace=True)
+
+ total_shift = []
+ for well in df['well'].unique():
+ # calculate cumulative shifts for each row and column
+ _df = df[(df['direction'] == 'col') & (df['well'] == well)]
+ col_shifts = _df.groupby('row')[['shift-x', 'shift-y']].cumsum()
+ _df = df[(df['direction'] == 'row') & (df['well'] == well)]
+ row_shifts = _df.groupby('col')[['shift-x', 'shift-y']].cumsum()
+ # total shift is the sum of row and column shifts
+ _total_shift = col_shifts.add(row_shifts, fill_value=0)
+
+ # add row 000000
+ _total_shift = pd.concat(
+ [pd.DataFrame({'shift-x': 0, 'shift-y': 0}, index=['000000']), _total_shift]
+ )
+
+ # add global offset to remove negative values
+ _total_shift['shift-x'] += -np.minimum(_total_shift['shift-x'].min(), 0)
+ _total_shift['shift-y'] += -np.minimum(_total_shift['shift-y'].min(), 0)
+ _total_shift.set_index(well + '/' + _total_shift.index, inplace=True)
+ total_shift.append(_total_shift)
+
+ return pd.concat(total_shift)
diff --git a/mantis/cli/estimate_stitch.py b/mantis/cli/estimate_stitch.py
new file mode 100644
index 00000000..fdd515f7
--- /dev/null
+++ b/mantis/cli/estimate_stitch.py
@@ -0,0 +1,185 @@
+import datetime
+import time
+
+from pathlib import Path
+
+import click
+import pandas as pd
+
+from iohub import open_ome_zarr
+from slurmkit import SlurmParams, slurm_function, submit_function
+
+from mantis.analysis.AnalysisSettings import ProcessingSettings, StitchSettings
+from mantis.analysis.stitch import (
+ cleanup_shifts,
+ compute_total_translation,
+ consolidate_zarr_fov_shifts,
+ estimate_zarr_fov_shifts,
+ get_grid_rows_cols,
+)
+from mantis.cli.parsing import input_position_dirpaths, output_filepath
+from mantis.cli.utils import model_to_yaml
+
+
+def write_config_file(
+ shifts: pd.DataFrame, output_filepath: str, channel: str, fliplr: bool, flipud: bool
+):
+ total_translation_dict = shifts.apply(
+ lambda row: [float(row['shift-y'].round(2)), float(row['shift-x'].round(2))], axis=1
+ ).to_dict()
+
+ settings = StitchSettings(
+ channels=[channel],
+ preprocessing=ProcessingSettings(fliplr=fliplr, flipud=flipud),
+ postprocessing=ProcessingSettings(),
+ total_translation=total_translation_dict,
+ )
+ model_to_yaml(settings, output_filepath)
+
+
+@click.command()
+@input_position_dirpaths()
+@output_filepath()
+@click.option(
+ "--channel",
+ required=True,
+ type=str,
+ help="Channel to use for estimating stitch parameters",
+)
+@click.option(
+ "--percent-overlap", "-p", required=True, type=float, help="Percent overlap between images"
+)
+@click.option("--fliplr", is_flag=True, help="Flip images left-right before stitching")
+@click.option("--flipud", is_flag=True, help="Flip images up-down before stitching")
+@click.option("--slurm", "-s", is_flag=True, help="Run stitching on SLURM")
+def estimate_stitch(
+ input_position_dirpaths: list[Path],
+ output_filepath: str,
+ channel: str,
+ percent_overlap: float,
+ fliplr: bool,
+ flipud: bool,
+ slurm: bool,
+):
+ """
+ Estimate stitching parameters for positions in wells of a zarr store.
+ Position names must follow the naming format XXXYYY, e.g. 000000, 000001, 001000, etc.
+ as created by the Micro-manager Tile Creator: https://micro-manager.org/Micro-Manager_User's_Guide#positioning
+ Assumes all wells have the save FOV grid layout.
+
+ >>> mantis estimate-stitch -i ./input.zarr/*/*/* -o ./stitch_params.yml --channel DAPI --percent-overlap 0.05 --slurm
+ """
+ assert 0 <= percent_overlap <= 1, "Percent overlap must be between 0 and 1"
+
+ input_zarr_path = Path(*input_position_dirpaths[0].parts[:-3])
+ output_filepath = Path(output_filepath)
+ csv_filepath = (
+ output_filepath.parent
+ / f"stitch_shifts_{input_zarr_path.name.replace('.zarr', '.csv')}"
+ )
+
+ with open_ome_zarr(input_position_dirpaths[0]) as dataset:
+ assert (
+ channel in dataset.channel_names
+ ), f"Channel {channel} not found in input zarr store"
+ tcz_idx = (0, dataset.channel_names.index(channel), dataset.data.shape[-3] // 2)
+ pixel_size_um = dataset.scale[-1]
+ if pixel_size_um == 1.0:
+ response = input(
+ 'The pixel size is equal to the default value of 1.0 um. ',
+ 'Inaccurate pixel size will affect stitching outlier removal. ',
+ 'Continue? [y/N]: ',
+ )
+ if response.lower() != 'y':
+ return
+
+ # here we assume that all wells have the same fov grid
+ click.echo('Indexing input zarr store')
+ wells = list(set([Path(*p.parts[-3:-1]) for p in input_position_dirpaths]))
+ grid_rows, grid_cols = get_grid_rows_cols(input_zarr_path)
+ row_fov0 = [col + row for row in grid_rows[:-1] for col in grid_cols]
+ row_fov1 = [col + row for row in grid_rows[1:] for col in grid_cols]
+ col_fov0 = [col + row for col in grid_cols[:-1] for row in grid_rows]
+ col_fov1 = [col + row for col in grid_cols[1:] for row in grid_rows]
+ estimate_shift_params = {
+ "tcz_index": tcz_idx,
+ "percent_overlap": percent_overlap,
+ "fliplr": fliplr,
+ "flipud": flipud,
+ }
+
+ # define slurm parameters
+ if slurm:
+ slurm_out_path = output_filepath.parent / "slurm_output" / "shift-%j.out"
+ csv_dirpath = (
+ output_filepath.parent / 'raw_shifts' / input_zarr_path.name.replace('.zarr', '')
+ )
+ csv_dirpath.mkdir(parents=True, exist_ok=False)
+ params = SlurmParams(
+ partition="preempted",
+ cpus_per_task=1,
+ mem_per_cpu='8G',
+ time=datetime.timedelta(minutes=10),
+ output=slurm_out_path,
+ )
+ slurm_func = {
+ direction: slurm_function(estimate_zarr_fov_shifts)(
+ direction=direction,
+ output_dirname=csv_dirpath,
+ **estimate_shift_params,
+ )
+ for direction in ("row", "col")
+ }
+
+ click.echo('Estimating FOV shifts...')
+ shifts, jobs = [], []
+ for well_name in wells:
+ for direction, fovs in zip(
+ ("row", "col"), (zip(row_fov0, row_fov1), zip(col_fov0, col_fov1))
+ ):
+ for fov0, fov1 in fovs:
+ fov0_zarr_path = Path(input_zarr_path, well_name, fov0)
+ fov1_zarr_path = Path(input_zarr_path, well_name, fov1)
+ if slurm:
+ job_id = submit_function(
+ slurm_func[direction],
+ slurm_params=params,
+ fov0_zarr_path=fov0_zarr_path,
+ fov1_zarr_path=fov1_zarr_path,
+ )
+ jobs.append(job_id)
+ else:
+ shift_params = estimate_zarr_fov_shifts(
+ fov0_zarr_path=fov0_zarr_path,
+ fov1_zarr_path=fov1_zarr_path,
+ direction=direction,
+ **estimate_shift_params,
+ )
+ shifts.append(shift_params)
+
+ click.echo('Consolidating FOV shifts...')
+ if slurm:
+ submit_function(
+ slurm_function(consolidate_zarr_fov_shifts)(
+ input_dirname=csv_dirpath,
+ output_filepath=csv_filepath,
+ ),
+ slurm_params=params,
+ dependencies=jobs,
+ )
+
+ # wait for csv_filepath to be created, capped at 5 min
+ t_start = time.time()
+ while not csv_filepath.exists() and time.time() - t_start < 300:
+ time.sleep(1)
+ else:
+ df = pd.concat(shifts, ignore_index=True)
+ df.to_csv(csv_filepath, index=False)
+
+ cleanup_shifts(csv_filepath, pixel_size_um)
+ shifts = compute_total_translation(csv_filepath)
+ write_config_file(shifts, output_filepath, channel, fliplr, flipud)
+
+
+if __name__ == "__main__":
+ estimate_stitch()
diff --git a/mantis/cli/main.py b/mantis/cli/main.py
index f8cbceb1..5d08167d 100644
--- a/mantis/cli/main.py
+++ b/mantis/cli/main.py
@@ -6,9 +6,11 @@
from mantis.cli.estimate_bleaching import estimate_bleaching
from mantis.cli.estimate_deskew import estimate_deskew
from mantis.cli.estimate_stabilization import estimate_stabilization
+from mantis.cli.estimate_stitch import estimate_stitch
from mantis.cli.optimize_affine import optimize_affine
from mantis.cli.run_acquisition import run_acquisition
from mantis.cli.stabilize import stabilize
+from mantis.cli.stitch import stitch
from mantis.cli.update_scale_metadata import update_scale_metadata
CONTEXT = {"help_option_names": ["-h", "--help"]}
@@ -32,6 +34,8 @@ def cli():
cli.add_command(estimate_affine)
cli.add_command(optimize_affine)
cli.add_command(apply_affine)
+cli.add_command(estimate_stitch)
+cli.add_command(stitch)
cli.add_command(update_scale_metadata)
cli.add_command(estimate_stabilization)
cli.add_command(stabilize)
diff --git a/mantis/cli/parsing.py b/mantis/cli/parsing.py
index 848a0eb3..246e18b2 100644
--- a/mantis/cli/parsing.py
+++ b/mantis/cli/parsing.py
@@ -9,7 +9,9 @@
from mantis.cli.option_eat_all import OptionEatAll
-def _validate_and_process_paths(ctx: click.Context, opt: click.Option, value: str) -> None:
+def _validate_and_process_paths(
+ ctx: click.Context, opt: click.Option, value: str
+) -> list[Path]:
# Sort and validate the input paths
input_paths = [Path(path) for path in natsorted(value)]
for path in input_paths:
diff --git a/mantis/cli/stitch.py b/mantis/cli/stitch.py
new file mode 100644
index 00000000..4d26ff36
--- /dev/null
+++ b/mantis/cli/stitch.py
@@ -0,0 +1,199 @@
+import datetime
+import shutil
+import warnings
+
+from pathlib import Path
+
+import click
+import numpy as np
+import pandas as pd
+
+from iohub import open_ome_zarr
+from iohub.ngff_meta import TransformationMeta
+from slurmkit import HAS_SLURM, SlurmParams, slurm_function, submit_function
+
+from mantis.analysis.AnalysisSettings import StitchSettings
+from mantis.analysis.stitch import (
+ get_grid_rows_cols,
+ get_image_shift,
+ get_stitch_output_shape,
+ preprocess_and_shift,
+ stitch_shifted_store,
+)
+from mantis.cli.parsing import config_filepath, input_position_dirpaths, output_dirpath
+from mantis.cli.utils import create_empty_hcs_zarr, process_single_position_v2, yaml_to_model
+
+
+@click.command()
+@input_position_dirpaths()
+@output_dirpath()
+@config_filepath()
+@click.option(
+ "--temp-path",
+ type=click.Path(exists=True, file_okay=False, dir_okay=True),
+ default='./',
+ help="Path to temporary directory, ideally with fast read/write speeds, e.g. /hpc/scratch/group.comp.micro/",
+)
+def stitch(
+ input_position_dirpaths: list[Path],
+ output_dirpath: str,
+ config_filepath: str,
+ temp_path: str,
+) -> None:
+ """
+ Stitch positions in wells of a zarr store using a configuration file generated by estimate-stitch.
+
+ >>> mantis stitch -i ./input.zarr/*/*/* -c ./stitch_params.yml -o ./output.zarr --temp-path /hpc/scratch/group.comp.micro/
+ """
+ if not HAS_SLURM:
+ warnings.warn(
+ "This function is intended to be used with SLURM. "
+ "Running on local machine instead."
+ )
+
+ slurm_out_path = Path(output_dirpath).parent / "slurm_output" / "stitch-%j.out"
+ shifted_store_path = Path(temp_path, f"TEMP_{input_position_dirpaths[0].parts[-4]}")
+ settings = yaml_to_model(config_filepath, StitchSettings)
+
+ with open_ome_zarr(str(input_position_dirpaths[0]), mode="r") as input_dataset:
+ input_dataset_channels = input_dataset.channel_names
+ T, C, Z, Y, X = input_dataset.data.shape
+ scale = tuple(input_dataset.scale)
+ chunks = input_dataset.data.chunks
+
+ if settings.channels is None:
+ settings.channels = input_dataset_channels
+
+ assert all(
+ channel in input_dataset_channels for channel in settings.channels
+ ), "Invalid channel(s) provided."
+
+ wells = list(set([Path(*p.parts[-3:-1]) for p in input_position_dirpaths]))
+ grid_rows, grid_cols = get_grid_rows_cols(Path(*input_position_dirpaths[0].parts[:-3]))
+ n_rows = len(grid_rows)
+ n_cols = len(grid_cols)
+
+ if settings.total_translation is None:
+ output_shape, global_translation = get_stitch_output_shape(
+ n_rows, n_cols, Y, X, settings.column_translation, settings.row_translation
+ )
+ else:
+ df = pd.DataFrame.from_dict(
+ settings.total_translation, orient="index", columns=["shift-y", "shift-x"]
+ )
+ output_shape = (
+ np.ceil(df["shift-y"].max() + Y).astype(int),
+ np.ceil(df["shift-x"].max() + X).astype(int),
+ )
+
+ # create temp zarr store
+ click.echo(f'Creating temporary zarr store at {shifted_store_path}')
+ stitched_shape = (T, len(settings.channels), Z) + output_shape
+ stitched_chunks = chunks[:3] + (4096, 4096)
+ create_empty_hcs_zarr(
+ store_path=shifted_store_path,
+ position_keys=[p.parts[-3:] for p in input_position_dirpaths],
+ shape=stitched_shape,
+ chunks=stitched_chunks,
+ channel_names=settings.channels,
+ dtype=np.float32,
+ )
+
+ # prepare slurm parameters
+ params = SlurmParams(
+ partition='preempted',
+ cpus_per_task=6,
+ mem_per_cpu='24G',
+ time=datetime.timedelta(minutes=30),
+ output=slurm_out_path,
+ )
+
+ # Shift each FOV to its final position in the stitched image
+ slurm_func = slurm_function(process_single_position_v2)(
+ preprocess_and_shift,
+ input_channel_idx=[input_dataset_channels.index(ch) for ch in settings.channels],
+ output_channel_idx=list(range(len(settings.channels))),
+ num_processes=6,
+ settings=settings.preprocessing,
+ output_shape=output_shape,
+ verbose=True,
+ )
+
+ click.echo('Submitting SLURM jobs')
+ shift_jobs = []
+ for in_path in input_position_dirpaths:
+ well = Path(*in_path.parts[-3:-1])
+ col, row = (in_path.name[:3], in_path.name[3:])
+
+ if settings.total_translation is None:
+ shift = get_image_shift(
+ int(col),
+ int(row),
+ settings.column_translation,
+ settings.row_translation,
+ global_translation,
+ )
+ else:
+ # COL+ROW order here is important
+ shift = settings.total_translation[str(well / (col + row))]
+
+ shift_jobs.append(
+ submit_function(
+ slurm_func,
+ slurm_params=params,
+ shift_x=shift[-1],
+ shift_y=shift[-2],
+ input_data_path=in_path,
+ output_path=shifted_store_path,
+ )
+ )
+
+ # create output zarr store
+ with open_ome_zarr(
+ output_dirpath, layout='hcs', mode="w-", channel_names=settings.channels
+ ) as output_dataset:
+ for well in wells:
+ pos = output_dataset.create_position(*Path(well, '0').parts)
+ pos.create_zeros(
+ name='0',
+ shape=stitched_shape,
+ dtype=np.float32,
+ chunks=stitched_chunks,
+ transform=[TransformationMeta(type="scale", scale=scale)],
+ )
+
+ # Stitch pre-shifted images
+ stitch_job = submit_function(
+ slurm_function(stitch_shifted_store)(
+ shifted_store_path,
+ output_dirpath,
+ settings.postprocessing,
+ blending='average',
+ verbose=True,
+ ),
+ slurm_params=SlurmParams(
+ partition='cpu',
+ cpus_per_task=32,
+ mem_per_cpu='8G',
+ time=datetime.timedelta(hours=12),
+ output=slurm_out_path,
+ ),
+ dependencies=shift_jobs,
+ )
+
+ # Delete temporary store
+ submit_function(
+ slurm_function(shutil.rmtree)(shifted_store_path),
+ slurm_params=SlurmParams(
+ partition='cpu',
+ cpus_per_task=1,
+ mem_per_cpu='12G',
+ time=datetime.timedelta(hours=1),
+ output=slurm_out_path,
+ ),
+ dependencies=stitch_job,
+ )
+
+
+if __name__ == '__main__':
+ stitch()
diff --git a/mantis/cli/utils.py b/mantis/cli/utils.py
index e0efff7d..04b567bf 100644
--- a/mantis/cli/utils.py
+++ b/mantis/cli/utils.py
@@ -140,6 +140,7 @@ def create_empty_hcs_zarr(
output_plate = open_ome_zarr(
str(store_path), layout="hcs", mode="a", channel_names=channel_names
)
+ transform = [TransformationMeta(type="scale", scale=scale)]
# Create positions
for position_key in position_keys:
@@ -152,7 +153,7 @@ def create_empty_hcs_zarr(
shape=shape,
chunks=chunks,
dtype=dtype,
- transform=[TransformationMeta(type="scale", scale=scale)],
+ transform=transform,
)
else:
position = output_plate[position_key_string]
@@ -225,7 +226,7 @@ def apply_transform_to_zyx_and_save_v2(
kwargs["t_idx"] = t_idx
# Process CZYX vs ZYX
- if input_channel_indices is not None:
+ if input_channel_indices is not None and len(input_channel_indices) > 0:
click.echo(f"Processing t={t_idx}")
czyx_data = position.data.oindex[t_idx, input_channel_indices]
@@ -240,15 +241,15 @@ def apply_transform_to_zyx_and_save_v2(
else:
click.echo(f"Processing c={c_idx}, t={t_idx}")
- zyx_data = position.data.oindex[t_idx, c_idx]
+ czyx_data = position.data.oindex[t_idx, c_idx : c_idx + 1]
# Checking if nans or zeros and skip processing
- if not _check_nan_n_zeros(zyx_data):
+ if not _check_nan_n_zeros(czyx_data):
# Apply transformation
- transformed_zyx = func(zyx_data, **kwargs)
+ transformed_czyx = func(czyx_data, **kwargs)
# Write to file
with open_ome_zarr(output_path, mode="r+") as output_dataset:
- output_dataset[0][t_idx_out, c_idx] = transformed_zyx
+ output_dataset[0][t_idx_out, c_idx : c_idx + 1] = transformed_czyx
click.echo(f"Finished Writing.. c={c_idx}, t={t_idx}")
else:
@@ -387,8 +388,8 @@ def process_single_position_v2(
func,
input_dataset,
output_path / Path(*input_data_path.parts[-3:]),
- input_channel_indices=None,
- output_channel_indices=None,
+ input_channel_idx,
+ output_channel_idx,
**func_args,
)
else: