diff --git a/mantis/analysis/AnalysisSettings.py b/mantis/analysis/AnalysisSettings.py index d7fc3f7a..98b75f2f 100644 --- a/mantis/analysis/AnalysisSettings.py +++ b/mantis/analysis/AnalysisSettings.py @@ -1,3 +1,5 @@ +import warnings + from typing import Literal, Optional, Union import numpy as np @@ -10,6 +12,11 @@ class MyBaseModel(BaseModel, extra=Extra.forbid): pass +class ProcessingSettings(MyBaseModel): + fliplr: Optional[bool] = False + flipud: Optional[bool] = False + + class DeskewSettings(MyBaseModel): pixel_size_um: PositiveFloat ls_angle_deg: PositiveFloat @@ -87,3 +94,27 @@ def check_affine_transform_list(cls, v): raise ValueError("Each element in affine_transform_list must be a 4x4 ndarray") return v + + +class StitchSettings(MyBaseModel): + channels: Optional[list[str]] = None + preprocessing: Optional[ProcessingSettings] = None + postprocessing: Optional[ProcessingSettings] = None + column_translation: Optional[list[float, float]] = None + row_translation: Optional[list[float, float]] = None + total_translation: Optional[dict[str, list[float, float]]] = None + + def __init__(self, **data): + if data.get("total_translation") is None: + if any( + (data.get("column_translation") is None, data.get("row_translation") is None) + ): + raise ValueError( + "If total_translation is not provided, both column_translation and row_translation must be provided" + ) + else: + warnings.warn( + "column_translation and row_translation are deprecated. Use total_translation instead.", + DeprecationWarning, + ) + super().__init__(**data) diff --git a/mantis/analysis/settings/example_stitch_settings.yml b/mantis/analysis/settings/example_stitch_settings.yml new file mode 100644 index 00000000..0aa6ef33 --- /dev/null +++ b/mantis/analysis/settings/example_stitch_settings.yml @@ -0,0 +1,28 @@ +channels: [Phase3D] # may be null in which case all channels will be stitched +preprocessing: + fliplr: true + flipud: false +postprocessing: + fliplr: false + flipud: true +total_translation: # translation distance in (x, y) in pixels for each image + 0/2/000000: + - 0.0 + - 39.7 + 0/2/000001: + - 883.75 + - 39.7 + 0/2/001000: + - 1.8 + - 920.5 + 0/2/001001: + - 885.55 + - 920.7 +# Instead of computing a total (x, y) shift for each image using estimate-stitch +# you can also supply column_translation and row_translation as (x, y) lists +# that will be applied to all images. column_translation and row_translation +# should be approximately (950, 0) and (0, 950) for 1000x1000 images with 5% overlap. +# This method of stitching images is being deprecated as the stage often does not +# make reproducible movements. +column_translation: +row_translation: diff --git a/mantis/analysis/stitch.py b/mantis/analysis/stitch.py new file mode 100644 index 00000000..421912a7 --- /dev/null +++ b/mantis/analysis/stitch.py @@ -0,0 +1,539 @@ +from pathlib import Path +from typing import Literal + +import click +import dask.array as da +import numpy as np +import pandas as pd +import scipy.ndimage as ndi + +from iohub import open_ome_zarr +from skimage.registration import phase_cross_correlation + +from mantis.analysis.AnalysisSettings import ProcessingSettings + + +def estimate_shift( + im0: np.ndarray, im1: np.ndarray, percent_overlap: float, direction: Literal["row", "col"] +): + """ + Estimate the shift between two images based on a given percentage overlap and direction. + + Parameters + ---------- + im0 : np.ndarray + The first image. + im1 : np.ndarray + The second image. + percent_overlap : float + The percentage of overlap between the two images. Must be between 0 and 1. + direction : Literal["row", "col"] + The direction of the shift. Can be either "row" or "col". See estimate_zarr_fov_shifts + + Returns + ------- + np.ndarray + The estimated shift between the two images. + + Raises + ------ + AssertionError + If percent_overlap is not between 0 and 1. + If direction is not "row" or "col". + If the shape of im0 and im1 are not the same. + """ + assert 0 <= percent_overlap <= 1, "percent_overlap must be between 0 and 1" + assert direction in ["row", "col"], "direction must be either 'row' or 'col'" + assert im0.shape == im1.shape, "Images must have the same shape" + + sizeY, sizeX = im0.shape[-2:] + + # TODO: there may be a one pixel error in the estimated shift + if direction == "row": + y_roi = int(sizeY * np.minimum(percent_overlap + 0.05, 1)) + shift, _, _ = phase_cross_correlation( + im0[-y_roi:, :], im1[:y_roi, :], upsample_factor=10 + ) + shift[0] += sizeY - y_roi + elif direction == "col": + x_roi = int(sizeX * np.minimum(percent_overlap + 0.05, 1)) + shift, _, _ = phase_cross_correlation( + im0[:, -x_roi:], im1[:, :x_roi], upsample_factor=10 + ) + shift[1] += sizeX - x_roi + + # TODO: we shouldn't need to flip the order + return shift[::-1] + + +def get_grid_rows_cols(dataset_path: str): + grid_rows = set() + grid_cols = set() + + with open_ome_zarr(dataset_path) as dataset: + + _, well = next(dataset.wells()) + for position_name, _ in well.positions(): + fov_name = Path(position_name).parts[-1] + grid_rows.add(fov_name[3:]) # 1-Pos_ syntax + grid_cols.add(fov_name[:3]) + + return sorted(grid_rows), sorted(grid_cols) + + +def get_stitch_output_shape(n_rows, n_cols, sizeY, sizeX, col_translation, row_translation): + """ + Compute the output shape of the stitched image and the global translation when only col and row translation are given + """ + global_translation = ( + np.ceil(np.abs(np.minimum(row_translation[0] * (n_rows - 1), 0))).astype(int), + np.ceil(np.abs(np.minimum(col_translation[1] * (n_cols - 1), 0))).astype(int), + ) + xy_output_shape = ( + np.ceil( + sizeY + + col_translation[1] * (n_cols - 1) + + row_translation[1] * (n_rows - 1) + + global_translation[1] + ).astype(int), + np.ceil( + sizeX + + col_translation[0] * (n_cols - 1) + + row_translation[0] * (n_rows - 1) + + global_translation[0] + ).astype(int), + ) + return xy_output_shape, global_translation + + +def get_image_shift(col_idx, row_idx, col_translation, row_translation, global_translation): + """ + Compute total translation when only col and row translation are given + """ + total_translation = ( + col_translation[1] * col_idx + row_translation[1] * row_idx + global_translation[1], + col_translation[0] * col_idx + row_translation[0] * row_idx + global_translation[0], + ) + + return total_translation + + +def shift_image( + czyx_data: np.ndarray, + yx_output_shape: tuple[float, float], + yx_shift: tuple[float, float], + verbose: bool = False, +) -> np.ndarray: + assert czyx_data.ndim == 4, "Input data must be a CZYX array" + C, Z, Y, X = czyx_data.shape + + if verbose: + print(f"Shifting image by {yx_shift}") + # Create array of output_shape and put input data at (0, 0) + output = np.zeros((C, Z) + yx_output_shape, dtype=np.float32) + output[..., :Y, :X] = czyx_data + + return ndi.shift(output, (0, 0) + tuple(yx_shift), order=0) + + +def _stitch_images( + data_array: np.ndarray, + total_translation: dict[str : tuple[float, float]] = None, + percent_overlap: float = None, + col_translation: float | tuple[float, float] = None, + row_translation: float | tuple[float, float] = None, +) -> np.ndarray: + """ + Stitch an array of 2D images together to create a larger composite image. + This function is not actively maintained. + + Parameters + ---------- + data_array : np.ndarray + The data array to with shape (ROWS, COLS, Y, X) that will be stitched. Call this function multiple + times to stitch multiple channels, slices, or time points. + total_translation : dict[str: tuple[float, float]], optional + Shift to be applied to each fov, given as {fov: (y_shift, x_shift)}. Defaults to None. + percent_overlap : float, optional + The percentage of overlap between adjacent images. Must be between 0 and 1. Defaults to None. + col_translation : float | tuple[float, float], optional + The translation distance in pixels in the column direction. Can be a single value or a tuple + of (x_translation, y_translation) when moving across columns. Defaults to None. + row_translation : float | tuple[float, float], optional + See col_translation. Defaults to None. + + Returns + ------- + np.ndarray + The stitched composite 2D image + + Raises + ------ + AssertionError + If percent_overlap is not between 0 and 1. + + """ + + n_rows, n_cols, sizeY, sizeX = data_array.shape + + if total_translation is None: + if percent_overlap is not None: + assert 0 <= percent_overlap <= 1, "percent_overlap must be between 0 and 1" + col_translation = sizeX * (1 - percent_overlap) + row_translation = sizeY * (1 - percent_overlap) + if not isinstance(col_translation, tuple): + col_translation = (col_translation, 0) + if not isinstance(row_translation, tuple): + row_translation = (0, row_translation) + xy_output_shape, global_translation = get_stitch_output_shape( + n_rows, n_cols, sizeY, sizeX, col_translation, row_translation + ) + else: + df = pd.DataFrame.from_dict( + total_translation, orient="index", columns=["shift-y", "shift-x"] + ) + xy_output_shape = ( + np.ceil(df["shift-y"].max() + sizeY).astype(int), + np.ceil(df["shift-x"].max() + sizeX).astype(int), + ) + stitched_array = np.zeros(xy_output_shape, dtype=np.float32) + + for row_idx in range(n_rows): + for col_idx in range(n_cols): + image = data_array[row_idx, col_idx] + + if total_translation is None: + shift = get_image_shift( + col_idx, row_idx, col_translation, row_translation, global_translation + ) + else: + shift = total_translation[f"{col_idx:03d}{row_idx:03d}"] + + warped_image = shift_image(image, xy_output_shape, shift) + overlap = np.logical_and(stitched_array, warped_image) + stitched_array[:, :] += warped_image + stitched_array[overlap] /= 2 # average blending in the overlapping region + + return stitched_array + + +def process_dataset( + data_array: np.ndarray | da.Array, + settings: ProcessingSettings, + verbose: bool = True, +) -> np.ndarray: + flip = np.flip + if isinstance(data_array, da.Array): + flip = da.flip + + if settings: + if settings.flipud: + if verbose: + click.echo("Flipping data array up-down") + data_array = flip(data_array, axis=-2) + + if settings.fliplr: + if verbose: + click.echo("Flipping data array left-right") + data_array = flip(data_array, axis=-1) + + return data_array + + +def preprocess_and_shift( + image, + settings: ProcessingSettings, + output_shape: tuple[int, int], + shift_x: float, + shift_y: float, + verbose=True, +): + return shift_image( + process_dataset(image, settings, verbose), output_shape, (shift_y, shift_x), verbose + ) + + +def blend(array: da.Array, method: Literal["average"] = "average"): + """ + Blend array of pre-shifted images stacked across axis=0 + + Parameters + ---------- + array : da.Array + Input dask array + method : str, optional + Blending method. Defaults to "average". + + Raises + ------ + NotImplementedError + Raise error is blending method is not implemented. + + Returns + ------- + da.Array + Stitched array + """ + if method == "average": + # Sum up all images + array_sum = array.sum(axis=0) + # Count how many images contribute to each pixel in the stitched image + array_bool_sum = (array != 0).sum(axis=0) + # Replace 0s with 1s to avoid division by zero + array_bool_sum[array_bool_sum == 0] = 1 + # Divide the sum of images by the number of images contributing to each pixel + stitched_array = array_sum / array_bool_sum + else: + raise NotImplementedError(f"Blending method {method} is not implemented") + + return stitched_array + + +def stitch_shifted_store( + input_data_path: str, + output_data_path: str, + settings: ProcessingSettings, + blending="average", + verbose=True, +): + """ + Stitch a zarr store of pre-shifted images. + + Parameters + ---------- + input_data_path : str + Path to the input zarr store. + output_data_path : str + Path to the output zarr store. + settings : ProcessingSettings + Postprocessing settings. + blending : str, optional + Blending method. Defaults to "average". + verbose : bool, optional + Whether to print verbose output. Defaults to True. + """ + click.echo(f'Stitching zarr store: {input_data_path}') + with open_ome_zarr(input_data_path, mode="r") as input_dataset: + for well_name, well in input_dataset.wells(): + if verbose: + click.echo(f'Processing well {well_name}') + + # Stack images along axis=0 + dask_array = da.stack( + [da.from_zarr(pos.data) for _, pos in well.positions()], axis=0 + ) + + # Blend images + stitched_array = blend(dask_array, method=blending) + + # Postprocessing + stitched_array = process_dataset(stitched_array, settings, verbose) + + # Save stitched array + click.echo('Computing and writing data') + with open_ome_zarr( + Path(output_data_path, well_name, '0'), mode="a" + ) as output_image: + da.to_zarr(stitched_array, output_image['0']) + click.echo(f'Finishing writing data for well {well_name}') + + +def estimate_zarr_fov_shifts( + fov0_zarr_path: str, + fov1_zarr_path: str, + tcz_index: tuple[int, int, int], + percent_overlap: float, + fliplr: bool, + flipud: bool, + direction: Literal["row", "col"], + output_dirname: str = None, +): + """ + Estimate shift between two zarr FOVs using phase cross-correlation.Apply flips (fliplr, flipud) as preprocessing step. + Phase cross-correlation is computed only across an ROI defined by (percent_overlap + 0.05) for the given direction. + + Parameters + ---------- + fov0_zarr_path : str + Path to the first zarr FOV. + fov1_zarr_path : str + Path to the second zarr FOV. + tcz_index : tuple[int, int, int] + Index of the time, channel, and z-slice to use for the shift estimation. + percent_overlap : float + The percentage of overlap between the two FOVs. Can be approximate. + fliplr : bool + Flag indicating whether to flip the FOVs horizontally before estimating shift. + flipud : bool + Flag indicating whether to flip the FOVs vertically before estimating shift. + direction : Literal["row", "col"] + The direction in which to compute the shift. + "row" computes vertical overlap with fov1 below fov0. + "col" computes horizontal overlap with fov1 to the right of fov0. + output_dirname : str, optional + The directory to save the output csv file. + If None, the function returns a DataFrame with the estimated shift. + + Returns + ------- + pd.DataFrame + A DataFrame containing the estimated shift between the two FOVs. + """ + fov0_zarr_path = Path(fov0_zarr_path) + fov1_zarr_path = Path(fov1_zarr_path) + well_name = Path(*fov0_zarr_path.parts[-3:-1]) + fov0 = fov0_zarr_path.name + fov1 = fov1_zarr_path.name + click.echo(f'Estimating shift between FOVs {fov0} and {fov1} in well {well_name}...') + + T, C, Z = tcz_index + im0 = open_ome_zarr(fov0_zarr_path).data[T, C, Z] + im1 = open_ome_zarr(fov1_zarr_path).data[T, C, Z] + + if fliplr: + im0 = np.fliplr(im0) + im1 = np.fliplr(im1) + if flipud: + im0 = np.flipud(im0) + im1 = np.flipud(im1) + + shift = estimate_shift(im0, im1, percent_overlap, direction) + + df = pd.DataFrame( + { + "well": str(well_name), + "fov0": fov0, + "fov1": fov1, + "shift-x": shift[0], + "shift-y": shift[1], + "direction": direction, + }, + index=[0], + ) + click.echo(f'Estimated shift:\n {df.to_string(index=False)}') + + if output_dirname: + df.to_csv( + Path(output_dirname, f"{'_'.join(well_name.parts + (fov0, fov1))}_shift.csv"), + index=False, + ) + else: + return df + + +def consolidate_zarr_fov_shifts( + input_dirname: str, + output_filepath: str, +): + """ + Consolidate all csv files in input_dirname into a single csv file. + + Parameters + ---------- + input_dirname : str + Directory containing "*_shift.csv" files + output_filepath : str + Path to output .csv file + """ + # read all csv files in input_dirname and combine into a single dataframe + csv_files = Path(input_dirname).rglob("*_shift.csv") + df = pd.concat( + [pd.read_csv(csv_file, dtype={'fov0': str, 'fov1': str}) for csv_file in csv_files], + ignore_index=True, + ) + df.to_csv(output_filepath, index=False) + + +def cleanup_shifts(csv_filepath: str, pixel_size_um: float): + """ + Clean up outlier FOV shifts within a larger grid in case the phase cross-correlation + between individual FOVs returned spurious results. + + Since FOVs are acquired in snake fashion, FOVs in a given row should share the same vertical (i.e. row) shift. + Hence, the vertical shift for FOVs in a given row is replaced by the median value of all FOVs in that row. + + FOVs across the grid should have similar horizontal (i.e. column) shifts. + Values outside of the median +/- MAX_STAGE_ERROR_UM are replaced by the median. + + Parameters + ---------- + csv_filepath : str + Path to .csv file containing FOV shifts + """ + MAX_STAGE_ERROR_UM = 5 + max_stage_error_pix = MAX_STAGE_ERROR_UM / pixel_size_um + + df = pd.read_csv(csv_filepath, dtype={'fov0': str, 'fov1': str}) + df['shift-x-raw'] = df['shift-x'] + df['shift-y-raw'] = df['shift-y'] + + # replace row shifts with median value calculated across all columns + _df = df[df['direction'] == 'row'] + # group by well and last three characters of fov0 + groupby = _df.groupby(['well', _df['fov0'].str[-3:]]) + _df.loc[:, 'shift-x'] = groupby['shift-x-raw'].transform('median') + _df.loc[:, 'shift-y'] = groupby['shift-y-raw'].transform('median') + df.loc[df['direction'] == 'row', ['shift-x', 'shift-y']] = _df[['shift-x', 'shift-y']] + + # replace col shifts outside of the median +/- MAX_STAGE_ERROR_UM with the median value + _df = df[df['direction'] == 'col'] + x_median, y_median = _df['shift-x-raw'].median(), _df['shift-y-raw'].median() + x_low, x_hi = x_median - max_stage_error_pix, x_median + max_stage_error_pix + y_low, y_hi = y_median - max_stage_error_pix, y_median + max_stage_error_pix + x_outliers = (_df['shift-x-raw'] <= x_low) | (_df['shift-x-raw'] >= x_hi) + y_outliers = (_df['shift-y-raw'] <= y_low) | (_df['shift-y-raw'] >= y_hi) + outliers = x_outliers | y_outliers + num_outliers = sum(outliers) + + _df.loc[outliers, ['shift-x', 'shift-y']] = (x_median, y_median) + df.loc[df['direction'] == 'col', ['shift-x', 'shift-y']] = _df[['shift-x', 'shift-y']] + if num_outliers > 0: + click.echo(f'Replaced {num_outliers} column shift outliers') + + df.to_csv(csv_filepath, index=False) + + +def compute_total_translation(csv_filepath: str) -> pd.DataFrame: + """ + Compute the total translation for each FOV based on the estimated row and col translation shifts. + + Parameters + ---------- + csv_filepath : str + Path to .csv file containing FOV shifts + + Returns + ------- + pd.DataFrame + Dataframe with total translation shift per FOV + """ + df = pd.read_csv(csv_filepath, dtype={'fov0': str, 'fov1': str}) + + # create 'row' and 'col' number columns and sort the dataframe by 'fov1' + df['row'] = df['fov1'].str[-3:].astype(int) + df['col'] = df['fov1'].str[:3].astype(int) + df.set_index('fov1', inplace=True) + df.sort_index(inplace=True) + + total_shift = [] + for well in df['well'].unique(): + # calculate cumulative shifts for each row and column + _df = df[(df['direction'] == 'col') & (df['well'] == well)] + col_shifts = _df.groupby('row')[['shift-x', 'shift-y']].cumsum() + _df = df[(df['direction'] == 'row') & (df['well'] == well)] + row_shifts = _df.groupby('col')[['shift-x', 'shift-y']].cumsum() + # total shift is the sum of row and column shifts + _total_shift = col_shifts.add(row_shifts, fill_value=0) + + # add row 000000 + _total_shift = pd.concat( + [pd.DataFrame({'shift-x': 0, 'shift-y': 0}, index=['000000']), _total_shift] + ) + + # add global offset to remove negative values + _total_shift['shift-x'] += -np.minimum(_total_shift['shift-x'].min(), 0) + _total_shift['shift-y'] += -np.minimum(_total_shift['shift-y'].min(), 0) + _total_shift.set_index(well + '/' + _total_shift.index, inplace=True) + total_shift.append(_total_shift) + + return pd.concat(total_shift) diff --git a/mantis/cli/estimate_stitch.py b/mantis/cli/estimate_stitch.py new file mode 100644 index 00000000..fdd515f7 --- /dev/null +++ b/mantis/cli/estimate_stitch.py @@ -0,0 +1,185 @@ +import datetime +import time + +from pathlib import Path + +import click +import pandas as pd + +from iohub import open_ome_zarr +from slurmkit import SlurmParams, slurm_function, submit_function + +from mantis.analysis.AnalysisSettings import ProcessingSettings, StitchSettings +from mantis.analysis.stitch import ( + cleanup_shifts, + compute_total_translation, + consolidate_zarr_fov_shifts, + estimate_zarr_fov_shifts, + get_grid_rows_cols, +) +from mantis.cli.parsing import input_position_dirpaths, output_filepath +from mantis.cli.utils import model_to_yaml + + +def write_config_file( + shifts: pd.DataFrame, output_filepath: str, channel: str, fliplr: bool, flipud: bool +): + total_translation_dict = shifts.apply( + lambda row: [float(row['shift-y'].round(2)), float(row['shift-x'].round(2))], axis=1 + ).to_dict() + + settings = StitchSettings( + channels=[channel], + preprocessing=ProcessingSettings(fliplr=fliplr, flipud=flipud), + postprocessing=ProcessingSettings(), + total_translation=total_translation_dict, + ) + model_to_yaml(settings, output_filepath) + + +@click.command() +@input_position_dirpaths() +@output_filepath() +@click.option( + "--channel", + required=True, + type=str, + help="Channel to use for estimating stitch parameters", +) +@click.option( + "--percent-overlap", "-p", required=True, type=float, help="Percent overlap between images" +) +@click.option("--fliplr", is_flag=True, help="Flip images left-right before stitching") +@click.option("--flipud", is_flag=True, help="Flip images up-down before stitching") +@click.option("--slurm", "-s", is_flag=True, help="Run stitching on SLURM") +def estimate_stitch( + input_position_dirpaths: list[Path], + output_filepath: str, + channel: str, + percent_overlap: float, + fliplr: bool, + flipud: bool, + slurm: bool, +): + """ + Estimate stitching parameters for positions in wells of a zarr store. + Position names must follow the naming format XXXYYY, e.g. 000000, 000001, 001000, etc. + as created by the Micro-manager Tile Creator: https://micro-manager.org/Micro-Manager_User's_Guide#positioning + Assumes all wells have the save FOV grid layout. + + >>> mantis estimate-stitch -i ./input.zarr/*/*/* -o ./stitch_params.yml --channel DAPI --percent-overlap 0.05 --slurm + """ + assert 0 <= percent_overlap <= 1, "Percent overlap must be between 0 and 1" + + input_zarr_path = Path(*input_position_dirpaths[0].parts[:-3]) + output_filepath = Path(output_filepath) + csv_filepath = ( + output_filepath.parent + / f"stitch_shifts_{input_zarr_path.name.replace('.zarr', '.csv')}" + ) + + with open_ome_zarr(input_position_dirpaths[0]) as dataset: + assert ( + channel in dataset.channel_names + ), f"Channel {channel} not found in input zarr store" + tcz_idx = (0, dataset.channel_names.index(channel), dataset.data.shape[-3] // 2) + pixel_size_um = dataset.scale[-1] + if pixel_size_um == 1.0: + response = input( + 'The pixel size is equal to the default value of 1.0 um. ', + 'Inaccurate pixel size will affect stitching outlier removal. ', + 'Continue? [y/N]: ', + ) + if response.lower() != 'y': + return + + # here we assume that all wells have the same fov grid + click.echo('Indexing input zarr store') + wells = list(set([Path(*p.parts[-3:-1]) for p in input_position_dirpaths])) + grid_rows, grid_cols = get_grid_rows_cols(input_zarr_path) + row_fov0 = [col + row for row in grid_rows[:-1] for col in grid_cols] + row_fov1 = [col + row for row in grid_rows[1:] for col in grid_cols] + col_fov0 = [col + row for col in grid_cols[:-1] for row in grid_rows] + col_fov1 = [col + row for col in grid_cols[1:] for row in grid_rows] + estimate_shift_params = { + "tcz_index": tcz_idx, + "percent_overlap": percent_overlap, + "fliplr": fliplr, + "flipud": flipud, + } + + # define slurm parameters + if slurm: + slurm_out_path = output_filepath.parent / "slurm_output" / "shift-%j.out" + csv_dirpath = ( + output_filepath.parent / 'raw_shifts' / input_zarr_path.name.replace('.zarr', '') + ) + csv_dirpath.mkdir(parents=True, exist_ok=False) + params = SlurmParams( + partition="preempted", + cpus_per_task=1, + mem_per_cpu='8G', + time=datetime.timedelta(minutes=10), + output=slurm_out_path, + ) + slurm_func = { + direction: slurm_function(estimate_zarr_fov_shifts)( + direction=direction, + output_dirname=csv_dirpath, + **estimate_shift_params, + ) + for direction in ("row", "col") + } + + click.echo('Estimating FOV shifts...') + shifts, jobs = [], [] + for well_name in wells: + for direction, fovs in zip( + ("row", "col"), (zip(row_fov0, row_fov1), zip(col_fov0, col_fov1)) + ): + for fov0, fov1 in fovs: + fov0_zarr_path = Path(input_zarr_path, well_name, fov0) + fov1_zarr_path = Path(input_zarr_path, well_name, fov1) + if slurm: + job_id = submit_function( + slurm_func[direction], + slurm_params=params, + fov0_zarr_path=fov0_zarr_path, + fov1_zarr_path=fov1_zarr_path, + ) + jobs.append(job_id) + else: + shift_params = estimate_zarr_fov_shifts( + fov0_zarr_path=fov0_zarr_path, + fov1_zarr_path=fov1_zarr_path, + direction=direction, + **estimate_shift_params, + ) + shifts.append(shift_params) + + click.echo('Consolidating FOV shifts...') + if slurm: + submit_function( + slurm_function(consolidate_zarr_fov_shifts)( + input_dirname=csv_dirpath, + output_filepath=csv_filepath, + ), + slurm_params=params, + dependencies=jobs, + ) + + # wait for csv_filepath to be created, capped at 5 min + t_start = time.time() + while not csv_filepath.exists() and time.time() - t_start < 300: + time.sleep(1) + else: + df = pd.concat(shifts, ignore_index=True) + df.to_csv(csv_filepath, index=False) + + cleanup_shifts(csv_filepath, pixel_size_um) + shifts = compute_total_translation(csv_filepath) + write_config_file(shifts, output_filepath, channel, fliplr, flipud) + + +if __name__ == "__main__": + estimate_stitch() diff --git a/mantis/cli/main.py b/mantis/cli/main.py index f8cbceb1..5d08167d 100644 --- a/mantis/cli/main.py +++ b/mantis/cli/main.py @@ -6,9 +6,11 @@ from mantis.cli.estimate_bleaching import estimate_bleaching from mantis.cli.estimate_deskew import estimate_deskew from mantis.cli.estimate_stabilization import estimate_stabilization +from mantis.cli.estimate_stitch import estimate_stitch from mantis.cli.optimize_affine import optimize_affine from mantis.cli.run_acquisition import run_acquisition from mantis.cli.stabilize import stabilize +from mantis.cli.stitch import stitch from mantis.cli.update_scale_metadata import update_scale_metadata CONTEXT = {"help_option_names": ["-h", "--help"]} @@ -32,6 +34,8 @@ def cli(): cli.add_command(estimate_affine) cli.add_command(optimize_affine) cli.add_command(apply_affine) +cli.add_command(estimate_stitch) +cli.add_command(stitch) cli.add_command(update_scale_metadata) cli.add_command(estimate_stabilization) cli.add_command(stabilize) diff --git a/mantis/cli/parsing.py b/mantis/cli/parsing.py index 848a0eb3..246e18b2 100644 --- a/mantis/cli/parsing.py +++ b/mantis/cli/parsing.py @@ -9,7 +9,9 @@ from mantis.cli.option_eat_all import OptionEatAll -def _validate_and_process_paths(ctx: click.Context, opt: click.Option, value: str) -> None: +def _validate_and_process_paths( + ctx: click.Context, opt: click.Option, value: str +) -> list[Path]: # Sort and validate the input paths input_paths = [Path(path) for path in natsorted(value)] for path in input_paths: diff --git a/mantis/cli/stitch.py b/mantis/cli/stitch.py new file mode 100644 index 00000000..4d26ff36 --- /dev/null +++ b/mantis/cli/stitch.py @@ -0,0 +1,199 @@ +import datetime +import shutil +import warnings + +from pathlib import Path + +import click +import numpy as np +import pandas as pd + +from iohub import open_ome_zarr +from iohub.ngff_meta import TransformationMeta +from slurmkit import HAS_SLURM, SlurmParams, slurm_function, submit_function + +from mantis.analysis.AnalysisSettings import StitchSettings +from mantis.analysis.stitch import ( + get_grid_rows_cols, + get_image_shift, + get_stitch_output_shape, + preprocess_and_shift, + stitch_shifted_store, +) +from mantis.cli.parsing import config_filepath, input_position_dirpaths, output_dirpath +from mantis.cli.utils import create_empty_hcs_zarr, process_single_position_v2, yaml_to_model + + +@click.command() +@input_position_dirpaths() +@output_dirpath() +@config_filepath() +@click.option( + "--temp-path", + type=click.Path(exists=True, file_okay=False, dir_okay=True), + default='./', + help="Path to temporary directory, ideally with fast read/write speeds, e.g. /hpc/scratch/group.comp.micro/", +) +def stitch( + input_position_dirpaths: list[Path], + output_dirpath: str, + config_filepath: str, + temp_path: str, +) -> None: + """ + Stitch positions in wells of a zarr store using a configuration file generated by estimate-stitch. + + >>> mantis stitch -i ./input.zarr/*/*/* -c ./stitch_params.yml -o ./output.zarr --temp-path /hpc/scratch/group.comp.micro/ + """ + if not HAS_SLURM: + warnings.warn( + "This function is intended to be used with SLURM. " + "Running on local machine instead." + ) + + slurm_out_path = Path(output_dirpath).parent / "slurm_output" / "stitch-%j.out" + shifted_store_path = Path(temp_path, f"TEMP_{input_position_dirpaths[0].parts[-4]}") + settings = yaml_to_model(config_filepath, StitchSettings) + + with open_ome_zarr(str(input_position_dirpaths[0]), mode="r") as input_dataset: + input_dataset_channels = input_dataset.channel_names + T, C, Z, Y, X = input_dataset.data.shape + scale = tuple(input_dataset.scale) + chunks = input_dataset.data.chunks + + if settings.channels is None: + settings.channels = input_dataset_channels + + assert all( + channel in input_dataset_channels for channel in settings.channels + ), "Invalid channel(s) provided." + + wells = list(set([Path(*p.parts[-3:-1]) for p in input_position_dirpaths])) + grid_rows, grid_cols = get_grid_rows_cols(Path(*input_position_dirpaths[0].parts[:-3])) + n_rows = len(grid_rows) + n_cols = len(grid_cols) + + if settings.total_translation is None: + output_shape, global_translation = get_stitch_output_shape( + n_rows, n_cols, Y, X, settings.column_translation, settings.row_translation + ) + else: + df = pd.DataFrame.from_dict( + settings.total_translation, orient="index", columns=["shift-y", "shift-x"] + ) + output_shape = ( + np.ceil(df["shift-y"].max() + Y).astype(int), + np.ceil(df["shift-x"].max() + X).astype(int), + ) + + # create temp zarr store + click.echo(f'Creating temporary zarr store at {shifted_store_path}') + stitched_shape = (T, len(settings.channels), Z) + output_shape + stitched_chunks = chunks[:3] + (4096, 4096) + create_empty_hcs_zarr( + store_path=shifted_store_path, + position_keys=[p.parts[-3:] for p in input_position_dirpaths], + shape=stitched_shape, + chunks=stitched_chunks, + channel_names=settings.channels, + dtype=np.float32, + ) + + # prepare slurm parameters + params = SlurmParams( + partition='preempted', + cpus_per_task=6, + mem_per_cpu='24G', + time=datetime.timedelta(minutes=30), + output=slurm_out_path, + ) + + # Shift each FOV to its final position in the stitched image + slurm_func = slurm_function(process_single_position_v2)( + preprocess_and_shift, + input_channel_idx=[input_dataset_channels.index(ch) for ch in settings.channels], + output_channel_idx=list(range(len(settings.channels))), + num_processes=6, + settings=settings.preprocessing, + output_shape=output_shape, + verbose=True, + ) + + click.echo('Submitting SLURM jobs') + shift_jobs = [] + for in_path in input_position_dirpaths: + well = Path(*in_path.parts[-3:-1]) + col, row = (in_path.name[:3], in_path.name[3:]) + + if settings.total_translation is None: + shift = get_image_shift( + int(col), + int(row), + settings.column_translation, + settings.row_translation, + global_translation, + ) + else: + # COL+ROW order here is important + shift = settings.total_translation[str(well / (col + row))] + + shift_jobs.append( + submit_function( + slurm_func, + slurm_params=params, + shift_x=shift[-1], + shift_y=shift[-2], + input_data_path=in_path, + output_path=shifted_store_path, + ) + ) + + # create output zarr store + with open_ome_zarr( + output_dirpath, layout='hcs', mode="w-", channel_names=settings.channels + ) as output_dataset: + for well in wells: + pos = output_dataset.create_position(*Path(well, '0').parts) + pos.create_zeros( + name='0', + shape=stitched_shape, + dtype=np.float32, + chunks=stitched_chunks, + transform=[TransformationMeta(type="scale", scale=scale)], + ) + + # Stitch pre-shifted images + stitch_job = submit_function( + slurm_function(stitch_shifted_store)( + shifted_store_path, + output_dirpath, + settings.postprocessing, + blending='average', + verbose=True, + ), + slurm_params=SlurmParams( + partition='cpu', + cpus_per_task=32, + mem_per_cpu='8G', + time=datetime.timedelta(hours=12), + output=slurm_out_path, + ), + dependencies=shift_jobs, + ) + + # Delete temporary store + submit_function( + slurm_function(shutil.rmtree)(shifted_store_path), + slurm_params=SlurmParams( + partition='cpu', + cpus_per_task=1, + mem_per_cpu='12G', + time=datetime.timedelta(hours=1), + output=slurm_out_path, + ), + dependencies=stitch_job, + ) + + +if __name__ == '__main__': + stitch() diff --git a/mantis/cli/utils.py b/mantis/cli/utils.py index e0efff7d..04b567bf 100644 --- a/mantis/cli/utils.py +++ b/mantis/cli/utils.py @@ -140,6 +140,7 @@ def create_empty_hcs_zarr( output_plate = open_ome_zarr( str(store_path), layout="hcs", mode="a", channel_names=channel_names ) + transform = [TransformationMeta(type="scale", scale=scale)] # Create positions for position_key in position_keys: @@ -152,7 +153,7 @@ def create_empty_hcs_zarr( shape=shape, chunks=chunks, dtype=dtype, - transform=[TransformationMeta(type="scale", scale=scale)], + transform=transform, ) else: position = output_plate[position_key_string] @@ -225,7 +226,7 @@ def apply_transform_to_zyx_and_save_v2( kwargs["t_idx"] = t_idx # Process CZYX vs ZYX - if input_channel_indices is not None: + if input_channel_indices is not None and len(input_channel_indices) > 0: click.echo(f"Processing t={t_idx}") czyx_data = position.data.oindex[t_idx, input_channel_indices] @@ -240,15 +241,15 @@ def apply_transform_to_zyx_and_save_v2( else: click.echo(f"Processing c={c_idx}, t={t_idx}") - zyx_data = position.data.oindex[t_idx, c_idx] + czyx_data = position.data.oindex[t_idx, c_idx : c_idx + 1] # Checking if nans or zeros and skip processing - if not _check_nan_n_zeros(zyx_data): + if not _check_nan_n_zeros(czyx_data): # Apply transformation - transformed_zyx = func(zyx_data, **kwargs) + transformed_czyx = func(czyx_data, **kwargs) # Write to file with open_ome_zarr(output_path, mode="r+") as output_dataset: - output_dataset[0][t_idx_out, c_idx] = transformed_zyx + output_dataset[0][t_idx_out, c_idx : c_idx + 1] = transformed_czyx click.echo(f"Finished Writing.. c={c_idx}, t={t_idx}") else: @@ -387,8 +388,8 @@ def process_single_position_v2( func, input_dataset, output_path / Path(*input_data_path.parts[-3:]), - input_channel_indices=None, - output_channel_indices=None, + input_channel_idx, + output_channel_idx, **func_args, ) else: