diff --git a/mantis/acquisition/AcquisitionSettings.py b/mantis/acquisition/AcquisitionSettings.py index d1225c58..eb485c4a 100644 --- a/mantis/acquisition/AcquisitionSettings.py +++ b/mantis/acquisition/AcquisitionSettings.py @@ -1,3 +1,4 @@ +import copy import warnings from dataclasses import field @@ -36,10 +37,12 @@ class PositionSettings: position_labels: List[str] = field(default_factory=list) num_positions: int = field(init=False, default=0) well_ids: List[str] = field(init=False, default_factory=list) + xyz_positions_shift: list = field(init=False, default_factory=list) def __post_init__(self): assert len(self.xyz_positions) == len(self.position_labels) self.num_positions = len(self.xyz_positions) + self.xyz_positions_shift = copy.deepcopy(self.xyz_positions) try: # Look for "'A1-Site_0', 'H12-Site_1', ... " format @@ -141,6 +144,7 @@ class MicroscopeSettings: use_o3_refocus: bool = False o3_refocus_config: Optional[ConfigSettings] = None o3_refocus_interval_min: Optional[int] = None + autotracker_config: Optional[ConfigSettings] = None @dataclass @@ -193,3 +197,26 @@ def __post_init__(self): attr_val = getattr(self, attr) if attr_val is not None: setattr(self, attr, round(attr_val, 1)) + + +@dataclass +class AutotrackerSettings: + tracking_method: Literal['phase_cross_correlation', 'template_matching', 'multi_otsu'] + tracking_interval: Optional[int] = 1 + scale_yx: Optional[float] = 1.0 + shift_limit: Optional[Union[Tuple[float, float, float], 'None']] = None + device: Optional[str] = 'cpu' + zyx_dampening_factor: Optional[Union[Tuple[float, float, float], None]] = None + # TODO: maybe do the ROI like in the ls_microscope_settings + template_roi_zyx: Optional[Tuple[int, int, int]] = None + template_channel: Optional[str] = None + + @validator("tracking_method") + def check_tracking_method_options(cls, v): + # Check if template matching options are provided and are not None + if v == 'template_matching': + if not all([cls.template_roi_zyx, cls.template_channel]): + raise ValueError( + 'template_roi_zyx and template_channel must be provided for template matching' + ) + return v diff --git a/mantis/acquisition/acq_engine.py b/mantis/acquisition/acq_engine.py index ad9a53d0..a767230f 100644 --- a/mantis/acquisition/acq_engine.py +++ b/mantis/acquisition/acq_engine.py @@ -31,6 +31,7 @@ SliceSettings, MicroscopeSettings, AutoexposureSettings, + AutotrackerSettings, ) from mantis.acquisition.hook_functions.pre_hardware_hook_functions import ( log_preparing_acquisition, @@ -49,7 +50,7 @@ check_ls_acq_finished, ) -# isort: on +from mantis.acquisition.autotracker import autotracker_hook_fn # Define constants @@ -102,6 +103,7 @@ def __init__( self._slice_settings = SliceSettings() self._microscope_settings = MicroscopeSettings() self._autoexposure_settings = None + self._autotracker_settings = None self._z0 = None self.headless = False if mm_app_path is None else True self.type = 'light-sheet' if self.headless else 'label-free' @@ -173,6 +175,10 @@ def microscope_settings(self): def autoexposure_settings(self): return self._autoexposure_settings + @property + def autotracker_settings(self): + return self._autotracker_settings + @channel_settings.setter def channel_settings(self, settings: ChannelSettings): logger.debug( @@ -204,6 +210,16 @@ def autoexposure_settings(self, settings: AutoexposureSettings): ) self._autoexposure_settings = settings + @autotracker_settings.setter + def autotracker_settings(self, settings: AutotrackerSettings): + if settings is None: + logger.debug('Autotracker settings are not provided') + else: + logger.debug( + f"{self.type.capitalize()} acquisition will have the following settings:{asdict(settings)}" + ) + self._autotracker_settings = settings + def setup(self): """ Apply acquisition settings as specified by the class properties @@ -276,6 +292,7 @@ def reset(self): ) +# TODO: check the enable_ls_acq and enable_lf_acq work independently class MantisAcquisition(object): """ Acquisition class for simultaneous label-free and light-sheet acquisition on @@ -336,6 +353,8 @@ def __init__( self._lf_acq_obj = None self._ls_acq_obj = None + globals.demo_run = demo_run + if not enable_lf_acq or not enable_ls_acq: raise Exception('Disabling LF or LS acquisition is not currently supported') @@ -453,6 +472,8 @@ def update_position_settings(self): xyz_positions=xyz_positions, position_labels=position_labels, ) + self.position_settings.xyz_positions_shift = deepcopy(xyz_positions) + else: logger.debug('Position list is already populated and will not be updated') @@ -671,6 +692,12 @@ def setup_autoexposure(self): ) ) + def update_position_autotracker(self): + # Update the position list from the backup + self.position_settings.xyz_positions = deepcopy( + self.position_settings.xyz_positions_shift + ) + def go_to_position(self, position_index: int): # Move slowly for short distances such that autofocus can stay engaged. # Autofocus typically fails when moving long distances, so we can move @@ -938,6 +965,18 @@ def run_autoexposure( f'Autoexposure method {method} is not yet implemented.' ) + def run_autotracker( + self, + acq: BaseChannelSliceAcquisition, + well_id: str, + method: str = 'manual', + ): + logging.debug('running autotracker') + if not any(acq.channel_settings.use_autoexposure): + return + # TODO: implement autotracker + microscope_operations.autotracker(acq.mmc, acq.autoexposure_settings) + def setup(self): """ Setup the mantis acquisition. This method sets up the label-free @@ -984,7 +1023,21 @@ def acquire(self): start_daq_counters, [self._lf_z_ctr_task, self._lf_channel_ctr_task] ) lf_post_hardware_hook_fn = log_acquisition_start - lf_image_saved_fn = check_lf_acq_finished + + # TODO: implement logic for the autotracker_img_saved_hook_fn + if self.lf_acq.microscope_settings.autotracker_config is not None: + lf_image_saved_fn = partial( + autotracker_hook_fn, + 'lf', + self.lf_acq.autotracker_settings, + self._position_settings, + self.lf_acq.microscope_settings.autotracker_config, + self.lf_acq.slice_settings, + self._acq_dir, + ) + else: + logger.info('No autotracker config found. Using default image saved hook') + lf_image_saved_fn = check_lf_acq_finished # define LF acquisition self._lf_acq_obj = Acquisition( @@ -1014,7 +1067,19 @@ def acquire(self): self.ls_acq.channel_settings.channels, ) ls_post_camera_hook_fn = partial(start_daq_counters, [self._ls_z_ctr_task]) - ls_image_saved_fn = check_ls_acq_finished + + # TODO: implement logic for the autotracker_img_saved_hook_fn + if self.ls_acq.microscope_settings.autotracker_config is not None: + ls_image_saved_fn = partial( + autotracker_hook_fn, + 'ls', + self.ls_acq.autotracker_settings, + self.ls_acq.slice_settings, + self._acq_dir, + ) + else: + logger.info('No autotracker config found. Using default image saved hook') + ls_image_saved_fn = check_ls_acq_finished # define LS acquisition self._ls_acq_obj = Acquisition( @@ -1050,8 +1115,23 @@ def acquire(self): # move to the given position if p_label != previous_position_label: + # Check if autotracker is on either arm + if ( + self.ls_acq.microscope_settings.autotracker_config is not None + or self.lf_acq.microscope_settings.autotracker_config is not None + ): + # TODO: Should we get the corods from the csv file or the modified xyz_positions_shifts + logger.debug('Updating the positions for autotracker') + logger.debug( + 'Previous position: %f,%f ', + *self.position_settings.xyz_positions[p_idx][0:2], + ) + self.update_position_autotracker() self.go_to_position(p_idx) + # TODO get the delta shifts + # read the files here and move separately + # autofocus if self.lf_acq.microscope_settings.use_autofocus: autofocus_success = microscope_operations.autofocus( diff --git a/mantis/acquisition/autotracker.py b/mantis/acquisition/autotracker.py new file mode 100644 index 00000000..26739a1f --- /dev/null +++ b/mantis/acquisition/autotracker.py @@ -0,0 +1,579 @@ +# %% +from pathlib import Path +from time import sleep +from typing import Callable, Optional, Tuple, cast + +import numpy as np +import pandas as pd +import skimage + +from numpy.typing import ArrayLike +from scipy.fftpack import next_fast_len +from skimage.exposure import rescale_intensity +from skimage.feature import match_template +from skimage.filters import gaussian +from skimage.measure import label, regionprops + +from mantis import logger +from mantis.acquisition.hook_functions import globals + +# FIXME fix the dependencies so that we can install and import dexpv2 +# from dexpv2.crosscorr import phase_cross_corr +# from dexpv2.utils import center_crop, pad_to_shape, to_cpu + +# TODO: write test functions +# TODO: consider splitting this file into two + + +def calc_weighted_center(labeled_im): + """calculates weighted centroid based on the area of the regions + + Parameters + ---------- + labeled_im : ndarray + labeled image + """ + regions = sorted(regionprops(labeled_im), key=lambda r: r.area) + n_regions = len(regions) + centers = [] + areas = [] + for i in range(n_regions): + centroid = regions[i].centroid + centers.append(centroid) + area = regions[i].area + areas.append(area) + areas_norm = np.array(areas) / np.sum(areas) + centers = np.array(centers) + center_weighted = np.zeros(3) + for j in range(3): + center_weighted[j] = np.sum(centers[:, j] * areas_norm) + + return center_weighted + + +def multiotsu_centroid( + ref_img: ArrayLike, + mov_img: ArrayLike, +) -> list: + """ + Computes the translation shifts_zyx using a multiotsu threshold approach by finding the centroid of the regions + + Parameters + ---------- + moving : ndarray + moving stack ZYX + reference : ndarray + reference image stack ZYX + + Returns + ------- + shifts_zyx : list + list of shifts_zyx in z, y, x order + """ + # Process moving image + mov_img = rescale_intensity(mov_img, in_range='image', out_range=(0, 1.0)) + stack_blur = gaussian(mov_img, sigma=5.0) + thresh = skimage.filters.threshold_multiotsu(stack_blur) + mov_img = stack_blur > thresh[0] + mov_img = label(mov_img) + # Process reference image + ref_img = rescale_intensity(ref_img, in_range='image', out_range=(0, 1.0)) + stack_blur = gaussian(ref_img, sigma=5.0) + thresh = skimage.filters.threshold_multiotsu(stack_blur) + ref_img = stack_blur > thresh[0] + ref_img = label(ref_img) + + # Get the centroids + moving_center = calc_weighted_center(mov_img) + target_center = calc_weighted_center(ref_img) + + # Find the shifts_zyx + shifts_zyx = moving_center - target_center + + logger.debug( + 'moving_center (z,y,x): %f,%f,%f', + moving_center[0], + moving_center[1], + moving_center[2], + ) + logger.debug( + 'target_center (z,y,x): %f,%f,%f', + target_center[0], + target_center[1], + target_center[2], + ) + logger.debug('shifts_zyx (z,y,x): %f,%f,%f', shifts_zyx[0], shifts_zyx[1], shifts_zyx[2]) + + return shifts_zyx + + +def template_matching(ref_img, moving_img, template_slicing_zyx): + """ + Uses template matching to determine shift between two image stacks. + + Parameters: + - ref_img: Reference 3D image stack (numpy array). + - moving_img: Moving 3D image stack (numpy array) to be aligned with the reference. + - template_slicing_zyx: Tuple or list of slice objects defining the region to be used as the template. + + Returns: + - shift: The shift (displacement) needed to align moving_img with ref_img (numpy array). + """ + template = ref_img[template_slicing_zyx] + + result = match_template(moving_img, template) + zyx_1 = np.unravel_index(np.argmax(result), result.shape) + + # Calculate the shift based on template slicing coordinates and match result + # Subtracting the coordinates of the + + shift = np.array(zyx_1) - np.array([s.start for s in template_slicing_zyx]) + + return shift + + +def to_cpu(arr: ArrayLike) -> ArrayLike: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + Moves array to cpu, if it's already there nothing is done. + + """ + if hasattr(arr, "cpu"): + arr = arr.cpu() + elif hasattr(arr, "get"): + arr = arr.get() + return arr + + +def center_crop(arr: ArrayLike, shape: Tuple[int, ...]) -> ArrayLike: + """Crops the center of `arr`""" + assert arr.ndim == len(shape) + + starts = tuple((cur_s - s) // 2 for cur_s, s in zip(arr.shape, shape)) + + assert all(s >= 0 for s in starts) + + slicing = tuple(slice(s, s + d) for s, d in zip(starts, shape)) + + logger.info( + f"center crop: input shape {arr.shape}, output shape {shape}, slicing {slicing}" + ) + + return arr[slicing] + + +def pad_to_shape(arr: ArrayLike, shape: Tuple[int, ...], mode: str, **kwargs) -> ArrayLike: + """Pads array to shape. + + Parameters + ---------- + arr : ArrayLike + Input array. + shape : Tuple[int] + Output shape. + mode : str + Padding mode (see np.pad). + + Returns + ------- + ArrayLike + Padded array. + """ + assert arr.ndim == len(shape) + + dif = tuple(s - a for s, a in zip(shape, arr.shape)) + assert all(d >= 0 for d in dif) + + pad_width = [[s // 2, s - s // 2] for s in dif] + + logger.info(f"padding: input shape {arr.shape}, output shape {shape}, padding {pad_width}") + + return np.pad(arr, pad_width=pad_width, mode=mode, **kwargs) + + +def _match_shape(img: ArrayLike, shape: Tuple[int, ...]) -> ArrayLike: + """Pad or crop array to match provided shape.""" + + if np.any(shape > img.shape): + padded_shape = np.maximum(img.shape, shape) + img = pad_to_shape(img, padded_shape, mode="reflect") + + if np.any(shape < img.shape): + img = center_crop(img, shape) + + return img + + +def phase_cross_corr( + ref_img: ArrayLike, + mov_img: ArrayLike, + maximum_shift: float = 1.0, + to_device: Callable[[ArrayLike], ArrayLike] = lambda x: x, + transform: Optional[Callable[[ArrayLike], ArrayLike]] = np.log1p, +) -> Tuple[int, ...]: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + + Computes translation shift using arg. maximum of phase cross correlation. + Input are padded or cropped for fast FFT computation assuming a maximum translation shift. + + Parameters + ---------- + ref_img : ArrayLike + Reference image. + mov_img : ArrayLike + Moved image. + maximum_shift : float, optional + Maximum location shift normalized by axis size, by default 1.0 + + Returns + ------- + Tuple[int, ...] + Shift between reference and moved image. + """ + shape = tuple( + cast(int, next_fast_len(int(max(s1, s2) * maximum_shift))) + for s1, s2 in zip(ref_img.shape, mov_img.shape) + ) + + logger.debug( + f"phase cross corr. fft shape of {shape} for arrays of shape {ref_img.shape} and {mov_img.shape} " + f"with maximum shift of {maximum_shift}" + ) + + ref_img = _match_shape(ref_img, shape) + mov_img = _match_shape(mov_img, shape) + + ref_img = to_device(ref_img) + mov_img = to_device(mov_img) + + if transform is not None: + ref_img = transform(ref_img) + mov_img = transform(mov_img) + + Fimg1 = np.fft.rfftn(ref_img) + Fimg2 = np.fft.rfftn(mov_img) + eps = np.finfo(Fimg1.dtype).eps + del ref_img, mov_img + + prod = Fimg1 * Fimg2.conj() + del Fimg1, Fimg2 + + norm = np.fmax(np.abs(prod), eps) + corr = np.fft.irfftn(prod / norm) + del prod, norm + + corr = np.fft.fftshift(np.abs(corr)) + + argmax = to_cpu(np.argmax(corr)) + peak = np.unravel_index(argmax, corr.shape) + peak = tuple(s // 2 - p for s, p in zip(corr.shape, peak)) + + logger.debug(f"phase cross corr. peak at {peak}") + + return peak + + +# %% +class Autotracker(object): + _TRACKING_METHODS = { + 'phase_cross_correlation': phase_cross_corr, + 'template_matching': template_matching, + 'multi_otsu': multiotsu_centroid, + } + + def __init__( + self, + tracking_method: str, + shift_limit: Tuple[float, float, float], + scale: ArrayLike, + zyx_dampening_factor: ArrayLike = None, + ): + """ + Autotracker object + + Parameters + ---------- + tracking_method : str + Method to use for autofocus. Options are 'phase_cross_correlation', 'template_matching', 'multi_otsu' + scale : ArrayLike[float, float, float] + Scale factor to convert shifts_zyx from px to um + xy_dampening : tuple[int] + Dampening factor for xy shifts_zyx + """ + self.tracking_method = tracking_method + self.zyx_dampening = zyx_dampening_factor + self.scale = scale + self.shifts_zyx = None + + def estimate_shift(self, ref_img: ArrayLike, mov_img: ArrayLike, **kwargs) -> np.ndarray: + """ + Estimates the shift between two images using the specified autofocus method. + + Parameters + ---------- + ref_img : ArrayLike + Reference image. + mov_img : ArrayLike + Image to be aligned with the reference. + kwargs : dict + Additional keyword arguments to be passed to the autofocus method. + + Returns + ------- + np.ndarray + The estimated shift in scale provided by the user (typically um). + """ + + autofocus_method_func = self._TRACKING_METHODS.get(self.tracking_method) + + if not autofocus_method_func: + raise ValueError(f'Unknown autofocus method: {self.tracking_method}') + + shifts_zyx = autofocus_method_func(ref_img=ref_img, mov_img=mov_img, **kwargs) + + # shifts_zyx in px to shifts_zyx in um + self.shifts_zyx = np.array(shifts_zyx) * self.scale + + if self.zyx_dampening is not None: + self.shifts_zyx = self.shifts_zyx * self.zyx_dampening + logger.info(f'shifts_zyx (z,y,x): {self.shifts_zyx}') + + return self.shifts_zyx + + # Function to log the shifts_zyx to a csv file + def save_shifts_to_file( + self, + output_file: str, + position_id: int, + timepoint_id: int, + shifts_zyx: Tuple[int, int, int] = None, + stage_coords: Tuple[int, int, int] = None, + overwrite: bool = False, + ) -> None: + """ + Saves the computed shifts_zyx to a CSV file. + + Parameters + ---------- + output_file : str + Path to the output CSV file. + shifts_zyx : Tuple[int, int, int] + The computed shifts_zyx (Z, Y, X). + position_id : int + Identifier for the position. + timepoint_id : int + Identifier for the timepoint. + overwrite : bool + If True, the file will be overwritten if it exists. + """ + # Convert output_file to a Path object + output_path = Path(output_file) + if shifts_zyx is None: + shifts_zyx = self.shifts_zyx + if stage_coords is None: + stage_coords = (0, 0, 0) + data = { + "PositionID": [position_id], + "TimepointID": [timepoint_id], + "ShiftZ": [shifts_zyx[-3]], + "ShiftY": [shifts_zyx[-2]], + "ShiftX": [shifts_zyx[-1]], + "StageZ": [stage_coords[-3]], + "StageY": [stage_coords[-2]], + "StageX": [stage_coords[-1]], + } + + df = pd.DataFrame(data) + + if overwrite or not output_path.exists(): + # Write the DataFrame to a new file, including the header + df.to_csv(output_path, mode='w', index=False) + else: + # Append the DataFrame to the existing file, without writing the header + df.to_csv(output_path, mode='a', header=False, index=False) + + def limit_shifts_zyx( + self, shifts_zyx: Tuple[int, int, int], limits: Tuple[int, int, int] = (5, 5, 5) + ) -> Tuple[int, int, int]: + """ + Limits the shifts_zyx to the specified limits. + + Parameters + ---------- + shifts_zyx : Tuple[int, int, int] + The computed shifts_zyx (Z, Y, X). + limits : Tuple[int, int, int] + The limits for the shifts_zyx (Z, Y, X). + + Returns + ------- + Tuple[int, int, int] + The limited shifts_zyx. + """ + shifts_zyx = np.array(shifts_zyx) + limits = np.array(limits) + shifts_zyx = np.where(np.abs(shifts_zyx) > limits, 0, shifts_zyx) + return tuple(shifts_zyx) + + +# TODO: logic for handling which t_idx to grab as reference. If the volume changes +# Drastically, we may need to grab the previous timepoint as reference + + +def get_volume(dataset, axes): + p_idx, t_idx, autotrack_channel, z_range = axes + images = [] + logger.debug( + f"Getting Zstack for p:{p_idx},t:{t_idx},c:{autotrack_channel},z_range:{z_range}" + ) + for z_id in range(z_range): + images.append( + dataset.read_image( + **{'channel': autotrack_channel, 'z': z_id, 'time': t_idx, 'position': p_idx} + ) + ) + return np.stack(images) + + +def autotracker_hook_fn( + arm, + autotracker_settings, + position_settings, + channel_config, + z_slice_settings, + output_shift_path, + axes, + dataset, +) -> None: + """ + Pycromanager hook function that is called when an image is saved. + + Parameters + ---------- + axes : Position, Time, Channel, Z_slice + dataset: Dataset saved in disk + """ + # logger.info('Autotracker hook function called for axes %s', axes) + + # TODO: handle the lf acq or ls_a + if arm == 'lf': + if axes == globals.lf_last_img_idx: + globals.lf_acq_finished = True + elif arm == 'ls': + if axes == globals.ls_last_img_idx: + globals.ls_acq_finished = True + + # Get reference to the acquisition engine and it's settings + # TODO: This is a placeholder, the actual implementation will be different + z_range = z_slice_settings.z_range + num_slices = z_slice_settings.num_slices + scale = autotracker_settings.scale_yx + shift_limit = autotracker_settings.shift_limit + tracking_method = autotracker_settings.tracking_method + tracking_interval = autotracker_settings.tracking_interval + tracking_channel = channel_config.config_name + zyx_dampening_factor = autotracker_settings.zyx_dampening_factor + output_shift_path = Path(output_shift_path) + + # Get axes info + p_label = axes['position'] + p_idx = position_settings.position_labels.index(p_label) + t_idx = axes['time'] + channel = axes['channel'] + z_idx = axes['z'] + + tracker = Autotracker( + tracking_method=tracking_method, + scale=scale, + shift_limit=shift_limit, + zyx_dampening_factor=zyx_dampening_factor, + ) + # Get the z_max + if channel == tracking_channel and z_idx == (num_slices - 1): + # Skip the 1st timepoint + if t_idx > 1: + if t_idx % tracking_interval != 0: + logger.debug('Skipping autotracking t %d', t_idx) + return + logger.debug("WELCOME TO THE FOCUS ZONE") + # logger.debug('Curr axes :P:%s, T:%d, C:%s, Z:%d', p_idx, t_idx, channel, z_idx) + + # Logic to get the volumes + z_volume = z_range + volume_t0_axes = (p_idx, t_idx, tracking_channel, z_volume) + volume_t1_axes = (p_idx, t_idx, tracking_channel, z_volume) + # Compute the shifts_zyx + logger.debug('Instantiating autotracker') + if globals.demo_run: + # Random shifting for demo purposes + shifts_zyx = np.random.randint(-50, 50, 3) + sleep(3) + logger.info( + 'shifts_zyx (z,y,x): %f,%f,%f', shifts_zyx[0], shifts_zyx[1], shifts_zyx[2] + ) + else: + volume_t0 = get_volume(dataset, volume_t0_axes) + volume_t1 = get_volume(dataset, volume_t1_axes) + # Reference and moving volumes + shifts_zyx = tracker.estimate_shifts(volume_t0, volume_t1) + + position_id = str(axes['position']) + '.csv' + shift_coord_output = output_shift_path / position_id + + # Read the previous shifts_zyx and coords + prev_shifts = pd.read_csv(shift_coord_output) + prev_shifts = prev_shifts.iloc[-1] + + # Read the previous shifts_zyx + prev_x = position_settings.xyz_positions_shift[p_idx][0] + prev_y = position_settings.xyz_positions_shift[p_idx][1] + # Update Z shifts_zyx if available + if position_settings.xyz_positions_shift[p_idx][2] is not None: + prev_z = position_settings.xyz_positions_shift[p_idx][2] + logger.info('Previous shifts_zyx: %f,%f,%f', prev_z, prev_y, prev_x) + else: + prev_z = None + logger.info('Previous shifts_yx:,%f,%f', prev_y, prev_x) + # Update the event coordinates + position_settings.xyz_positions_shift[p_idx][0] = prev_x + shifts_zyx[-1] + position_settings.xyz_positions_shift[p_idx][1] = prev_y + shifts_zyx[-2] + # Update Z shifts_zyx if available + if position_settings.xyz_positions_shift[p_idx][2] is not None: + position_settings.xyz_positions_shift[p_idx][2] = prev_z + shifts_zyx[-3] + logger.info( + 'New positions: %f,%f,%f', *position_settings.xyz_positions_shift[p_idx] + ) + else: + logger.info( + 'New positions: %f,%f', *position_settings.xyz_positions_shift[p_idx][0:2] + ) + # Save the shifts_zyx + tracker.save_shifts_to_file( + shift_coord_output, + position_id=p_label, + timepoint_id=t_idx, + shifts_zyx=shifts_zyx, + stage_coords=( + position_settings.xyz_positions_shift[p_idx][2], + position_settings.xyz_positions_shift[p_idx][1], + position_settings.xyz_positions_shift[p_idx][0], + ), + ) + else: + # Save the positions at t=0 + position_id = str(axes['position']) + '.csv' + shift_coord_output = output_shift_path / position_id + prev_y = position_settings.xyz_positions_shift[p_idx][1] + prev_x = position_settings.xyz_positions_shift[p_idx][0] + if position_settings.xyz_positions_shift[p_idx][2] is not None: + prev_z = position_settings.xyz_positions_shift[p_idx][2] + else: + prev_z = None + tracker.save_shifts_to_file( + shift_coord_output, + position_id=p_label, + timepoint_id=t_idx, + shifts_zyx=(0, 0, 0), + stage_coords=(prev_z, prev_y, prev_x), + ) diff --git a/mantis/acquisition/hook_functions/globals.py b/mantis/acquisition/hook_functions/globals.py index 7eb88aba..2af906d5 100644 --- a/mantis/acquisition/hook_functions/globals.py +++ b/mantis/acquisition/hook_functions/globals.py @@ -9,3 +9,5 @@ ls_slice_acquisition_rates = None ls_laser_powers = None new_well = False + +demo_run = False diff --git a/mantis/acquisition/microscope_operations.py b/mantis/acquisition/microscope_operations.py index 324aa679..0b1cb653 100644 --- a/mantis/acquisition/microscope_operations.py +++ b/mantis/acquisition/microscope_operations.py @@ -12,7 +12,7 @@ from pycromanager import Core, Studio from pylablib.devices.Thorlabs import KinesisPiezoMotor -from mantis.acquisition.AcquisitionSettings import AutoexposureSettings +from mantis.acquisition.AcquisitionSettings import AutoexposureSettings, AutotrackerSettings from mantis.acquisition.autoexposure import manual_autoexposure, mean_intensity_autoexposure logger = logging.getLogger(__name__) @@ -670,3 +670,12 @@ def autoexposure( ) return suggested_exposure_time, suggested_light_intensity + + +def autotracker( + mmc: Core, + autotracker_settings: AutotrackerSettings, + **kwargs, +): + logging.debug('microscope operateions autotracker') + pass diff --git a/mantis/acquisition/settings/demo_acquisition_settings.yaml b/mantis/acquisition/settings/demo_acquisition_settings.yaml index a076ad1c..39a4718d 100644 --- a/mantis/acquisition/settings/demo_acquisition_settings.yaml +++ b/mantis/acquisition/settings/demo_acquisition_settings.yaml @@ -6,7 +6,7 @@ time_settings: time_interval_s: 5 lf_channel_settings: - default_exposure_times_ms: [10, 10] + default_exposure_times_ms: [5, 5] channel_group: 'Channel-Multiband' channels: ['DAPI', 'FITC'] use_sequencing: True @@ -30,6 +30,8 @@ lf_microscope_settings: channel_sequencing_settings: - ['LED', 'Sequence', 'On'] use_autofocus: False + autotracker_config: ['Channel-Multiband', 'DAPI'] + ls_channel_settings: default_exposure_times_ms: [20, 30] @@ -37,7 +39,7 @@ ls_channel_settings: channel_group: 'Channel' channels: ['Rhodamine', 'Cy5'] use_sequencing: False - use_autoexposure: [True, False] + use_autoexposure: [False, False] ls_slice_settings: z_stage_name: 'Z' @@ -60,3 +62,10 @@ ls_microscope_settings: ls_autoexposure_settings: autoexposure_method: 'manual' rerun_each_timepoint: True + +autotracker_settings: + tracking_method: 'phase_cross_correlation' + tracking_interval: 1 + shift_limit: [30,200,200] + device: 'cpu' + scale_yx: 0.075 diff --git a/mantis/acquisition/settings/example_acquisition_settings.yaml b/mantis/acquisition/settings/example_acquisition_settings.yaml index c62588f9..a6777b4a 100644 --- a/mantis/acquisition/settings/example_acquisition_settings.yaml +++ b/mantis/acquisition/settings/example_acquisition_settings.yaml @@ -118,6 +118,8 @@ lf_microscope_settings: use_autofocus: True autofocus_stage: 'ZDrive' autofocus_method: 'PFS' + # Autotracker parameters using the label-free channel + autotracker_config: ['Channel - LF', 'BF'] # Microscope settings which will be applied when the light-sheet acquisition is # initialized @@ -142,3 +144,10 @@ ls_microscope_settings: ls_autoexposure_settings: autoexposure_method: 'manual' rerun_each_timepoint: True + +autotracker_settings: + tracking_method: 'phase_cross_correlation' + tracking_interval: 1 + shift_limit: [30,200,200] + device: 'cpu' + scale_yx: 0.075 diff --git a/mantis/cli/run_acquisition.py b/mantis/cli/run_acquisition.py index 14250f8e..194b86df 100644 --- a/mantis/cli/run_acquisition.py +++ b/mantis/cli/run_acquisition.py @@ -53,6 +53,7 @@ def run_acquisition( SliceSettings, MicroscopeSettings, AutoexposureSettings, + AutotrackerSettings, ) # isort: on @@ -77,6 +78,20 @@ def run_acquisition( ls_autoexposure_settings = AutoexposureSettings( **raw_settings.get('ls_autoexposure_settings') ) + autotracker_settings = AutotrackerSettings(**raw_settings.get('autotracker_settings')) + + # Handle logic if autotracker is active in both arms + ls_autotracker_settings = None + lf_autotracker_settings = None + if ( + lf_microscope_settings.autotracker_config is not None + and ls_microscope_settings.autotracker_config is not None + ): + raise ValueError("Autotracker is active in both arms, please specify only one arm") + elif lf_microscope_settings.autotracker_config is not None: + lf_autotracker_settings = autotracker_settings + elif ls_microscope_settings.autotracker_config is not None: + ls_autotracker_settings = autotracker_settings with MantisAcquisition( acquisition_directory=acq_directory, @@ -95,6 +110,8 @@ def run_acquisition( acq.ls_acq.slice_settings = ls_slice_settings acq.ls_acq.microscope_settings = ls_microscope_settings acq.ls_acq.autoexposure_settings = ls_autoexposure_settings + acq.ls_acq.autotracker_settings = ls_autotracker_settings + acq.lf_acq.autotracker_settings = lf_autotracker_settings acq.setup() acq.acquire() diff --git a/mantis/tests/test_autotracker/test_autotracker.py b/mantis/tests/test_autotracker/test_autotracker.py new file mode 100644 index 00000000..dc0aa457 --- /dev/null +++ b/mantis/tests/test_autotracker/test_autotracker.py @@ -0,0 +1,427 @@ +# %% +from typing import Callable, Optional, Tuple, cast + +import matplotlib.pyplot as plt +import numpy as np +import skimage + +from iohub import open_ome_zarr +from numpy.typing import ArrayLike +from scipy.fftpack import next_fast_len +from skimage.exposure import rescale_intensity +from skimage.feature import match_template +from skimage.filters import gaussian +from skimage.measure import label, regionprops + +from mantis import logger + +# FIXME fix the dependencies so that we can install and import dexpv2 +# from dexpv2.crosscorr import phase_cross_corr +# from dexpv2.utils import center_crop, pad_to_shape, to_cpu + +# TODO: Make toy datasets for testing +# TODO: multiotsu should have an offset variable from the center of the image (X,Y) +# TODO: Check why PCC is not working +# TODO: write test functions +# TODO: consider splitting this file into two + + +def calc_weighted_center(labeled_im): + """calculates weighted centroid based on the area of the regions + + Parameters + ---------- + labeled_im : ndarray + labeled image + """ + regions = sorted(regionprops(labeled_im), key=lambda r: r.area) + n_regions = len(regions) + centers = [] + areas = [] + for i in range(n_regions): + centroid = regions[i].centroid + centers.append(centroid) + area = regions[i].area + areas.append(area) + areas_norm = np.array(areas) / np.sum(areas) + centers = np.array(centers) + center_weighted = np.zeros(3) + for j in range(3): + center_weighted[j] = np.sum(centers[:, j] * areas_norm) + + return center_weighted + + +def multiotsu_centroid( + ref_img: ArrayLike, + mov_img: ArrayLike, +) -> list: + """ + Computes the translation shifts using a multiotsu threshold approach by finding the centroid of the regions + + Parameters + ---------- + moving : ndarray + moving stack ZYX + reference : ndarray + reference image stack ZYX + + Returns + ------- + shifts : list + list of shifts in z, y, x order + """ + # Process moving image + mov_img = rescale_intensity(mov_img, in_range='image', out_range=(0, 1.0)) + stack_blur = gaussian(mov_img, sigma=5.0) + thresh = skimage.filters.threshold_multiotsu(stack_blur) + mov_img = stack_blur > thresh[0] + mov_img = label(mov_img) + # Process reference image + ref_img = rescale_intensity(ref_img, in_range='image', out_range=(0, 1.0)) + stack_blur = gaussian(ref_img, sigma=5.0) + thresh = skimage.filters.threshold_multiotsu(stack_blur) + ref_img = stack_blur > thresh[0] + ref_img = label(ref_img) + + # Get the centroids + moving_center = calc_weighted_center(mov_img) + target_center = calc_weighted_center(ref_img) + + # Find the shifts + shifts = moving_center - target_center + + logger.debug( + 'moving_center (z,y,x): %f,%f,%f', + moving_center[0], + moving_center[1], + moving_center[2], + ) + logger.debug( + 'target_center (z,y,x): %f,%f,%f', + target_center[0], + target_center[1], + target_center[2], + ) + logger.debug('shifts (z,y,x): %f,%f,%f', shifts[0], shifts[1], shifts[2]) + + return shifts + + +def template_matching(ref_img, moving_img, template_slicing_zyx): + """ + Uses template matching to determine shift between two image stacks. + + Parameters: + - ref_img: Reference 3D image stack (numpy array). + - moving_img: Moving 3D image stack (numpy array) to be aligned with the reference. + - template_slicing_zyx: Tuple or list of slice objects defining the region to be used as the template. + + Returns: + - shift: The shift (displacement) needed to align moving_img with ref_img (numpy array). + """ + template = ref_img[template_slicing_zyx] + + result = match_template(moving_img, template) + zyx_1 = np.unravel_index(np.argmax(result), result.shape) + + # Calculate the shift based on template slicing coordinates and match result + # Subtracting the coordinates of the + + shift = np.array(zyx_1) - np.array([s.start for s in template_slicing_zyx]) + + return shift + + +def to_cpu(arr: ArrayLike) -> ArrayLike: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + Moves array to cpu, if it's already there nothing is done. + + """ + if hasattr(arr, "cpu"): + arr = arr.cpu() + elif hasattr(arr, "get"): + arr = arr.get() + return arr + + +def _match_shape(img: ArrayLike, shape: Tuple[int, ...]) -> ArrayLike: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + Pad or crop array to match provided shape. + """ + + if np.any(shape > img.shape): + padded_shape = np.maximum(img.shape, shape) + img = pad_to_shape(img, padded_shape, mode="reflect") + + if np.any(shape < img.shape): + img = center_crop(img, shape) + + return img + + +def center_crop(arr: ArrayLike, shape: Tuple[int, ...]) -> ArrayLike: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + Crops the center of `arr` + """ + assert arr.ndim == len(shape) + + starts = tuple((cur_s - s) // 2 for cur_s, s in zip(arr.shape, shape)) + + assert all(s >= 0 for s in starts) + + slicing = tuple(slice(s, s + d) for s, d in zip(starts, shape)) + + logger.info( + f"center crop: input shape {arr.shape}, output shape {shape}, slicing {slicing}" + ) + + return arr[slicing] + + +def pad_to_shape(arr: ArrayLike, shape: Tuple[int, ...], mode: str, **kwargs) -> ArrayLike: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + Pads array to shape. + + Parameters + ---------- + arr : ArrayLike + Input array. + shape : Tuple[int] + Output shape. + mode : str + Padding mode (see np.pad). + + Returns + ------- + ArrayLike + Padded array. + """ + assert arr.ndim == len(shape) + + dif = tuple(s - a for s, a in zip(shape, arr.shape)) + assert all(d >= 0 for d in dif) + + pad_width = [[s // 2, s - s // 2] for s in dif] + + logger.info(f"padding: input shape {arr.shape}, output shape {shape}, padding {pad_width}") + + return np.pad(arr, pad_width=pad_width, mode=mode, **kwargs) + + +def phase_cross_corr( + ref_img: ArrayLike, + mov_img: ArrayLike, + maximum_shift: float = 1.0, + to_device: Callable[[ArrayLike], ArrayLike] = lambda x: x, + transform: Optional[Callable[[ArrayLike], ArrayLike]] = np.log1p, +) -> Tuple[int, ...]: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + + Computes translation shift using arg. maximum of phase cross correlation. + Input are padded or cropped for fast FFT computation assuming a maximum translation shift. + + Parameters + ---------- + ref_img : ArrayLike + Reference image. + mov_img : ArrayLike + Moved image. + maximum_shift : float, optional + Maximum location shift normalized by axis size, by default 1.0 + + Returns + ------- + Tuple[int, ...] + Shift between reference and moved image. + """ + shape = tuple( + cast(int, next_fast_len(int(max(s1, s2) * maximum_shift))) + for s1, s2 in zip(ref_img.shape, mov_img.shape) + ) + + logger.info( + f"phase cross corr. fft shape of {shape} for arrays of shape {ref_img.shape} and {mov_img.shape} " + f"with maximum shift of {maximum_shift}" + ) + + ref_img = _match_shape(ref_img, shape) + mov_img = _match_shape(mov_img, shape) + + ref_img = to_device(ref_img) + mov_img = to_device(mov_img) + + if transform is not None: + ref_img = transform(ref_img) + mov_img = transform(mov_img) + + Fimg1 = np.fft.rfftn(ref_img) + Fimg2 = np.fft.rfftn(mov_img) + eps = np.finfo(Fimg1.dtype).eps + del ref_img, mov_img + + prod = Fimg1 * Fimg2.conj() + del Fimg1, Fimg2 + + norm = np.fmax(np.abs(prod), eps) + corr = np.fft.irfftn(prod / norm) + del prod, norm + + corr = np.fft.fftshift(np.abs(corr)) + + argmax = to_cpu(np.argmax(corr)) + peak = np.unravel_index(argmax, corr.shape) + peak = tuple(s // 2 - p for s, p in zip(corr.shape, peak)) + + logger.info(f"phase cross corr. peak at {peak}") + + return peak + + +# %% +class Autotracker(object): + _AUTOFOCUS_METHODS = { + 'pcc': phase_cross_corr, + 'tm': template_matching, + 'multiotsu': multiotsu_centroid, + } + + def __init__(self, autofocus_method: str, xy_dapening: tuple[int] = None): + self.autofocus_method = autofocus_method + self.xy_dapening = xy_dapening + + def estimate_shift( + self, reference_array: ArrayLike, moving_array: ArrayLike, **kwargs + ) -> np.ndarray: + autofocus_method_func = self._AUTOFOCUS_METHODS.get(self.autofocus_method) + + if not autofocus_method_func: + raise ValueError(f'Unknown autofocus method: {self.autofocus_method}') + + shifts = autofocus_method_func(**kwargs) + + return shifts + + +# %% +def main(): + """ + Toy dataset + translations = [ + (0, 0, 0), # Shift for timepoint 0 + (5, -80, 80), # Shift for timepoint 1 + (9, -50, -50), # Shift for timepoint 2 + (-5, 30, -60), # Shift for timepoint 3 + (0, 30, -80), # Shift for timepoint 4 + ] + """ + # %% + input_data_path = ( + '/home/eduardo.hirata/repos/mantis/mantis/tests/x-ed/toy_translate.zarr/0/0/0' + ) + dataset = open_ome_zarr(input_data_path) + T, C, Z, Y, X = dataset.data.shape + # print(channel_names := dataset.channel_names) + # autofocus_method = 'multiotsu' + # xy_dapening = (10, 10) + + c_idx = 0 + data_t0 = dataset.data[0, c_idx] + data_t1 = dataset.data[1, c_idx] + data_t2 = dataset.data[2, c_idx] + data_t3 = dataset.data[3, c_idx] + + # subplot + fig, ax = plt.subplots(1, 3) + ax[0].imshow(data_t0[10]) + ax[1].imshow(data_t1[10]) + ax[2].imshow(data_t2[10]) + plt.show() + # %% + # Testing Multiotsu + shift_0 = multiotsu_centroid(data_t0, data_t0) + shift_1 = multiotsu_centroid(data_t1, data_t0) + shift_2 = multiotsu_centroid(data_t2, data_t0) + shift_3 = multiotsu_centroid(data_t3, data_t0) + + shifts = [shift_0, shift_1, shift_2, shift_3] + + translations = [ + (0, 0, 0), # Shift for timepoint 0 + (5, -80, 80), # Shift for timepoint 1 + (9, -50, -50), # Shift for timepoint 2 + (-5, 30, -60), # Shift for timepoint 3 + ] # Compare shifts with expected translations + + tolerance = (5, 8, 8) # Define your tolerance level + for i, (calculated, expected) in enumerate(zip(shifts, translations)): + is_similar = np.allclose(calculated, expected, atol=tolerance) + print( + f'Timepoint {i+1} shift: {calculated}, Expected: {expected}, Similar: {is_similar}' + ) + # %% + # Testing PCC + shift_0 = phase_cross_corr(data_t0, data_t0) + shift_1 = phase_cross_corr(data_t0, data_t1) + shift_2 = phase_cross_corr(data_t0, data_t2) + shift_3 = phase_cross_corr(data_t0, data_t3) + + shifts = [shift_0, shift_1, shift_2, shift_3] + + for i, (calculated, expected) in enumerate(zip(shifts, translations)): + is_similar = np.allclose(calculated, expected, atol=tolerance) + print( + f'Timepoint {i+1} shift: {calculated}, Expected: {expected}, Similar: {is_similar}' + ) + + # %% + # Testing template matching + + crop_z = slice(4, 8) + crop_y = slice(200, 300) + crop_x = slice(200, 300) + template = data_t0[crop_z, crop_y, crop_x] + + result = match_template(data_t1, template) + zyx = np.unravel_index(np.argmax(result), result.shape) + + print(zyx) + fig, ax = plt.subplots(1, 2) + ax[0].imshow(template[0]) + ax[0].set_title('template') + ax[1].imshow(data_t1[zyx[0]]) + rect = plt.Rectangle( + (zyx[2], zyx[1]), + template.shape[2], + template.shape[1], + edgecolor='red', + facecolor='none', + ) + ax[1].add_patch(rect) + ax[1].set_title('template matching result') + + # %% + # Calculate the shift, apply and check the result + shift_0 = template_matching(data_t0, data_t0, (crop_z, crop_y, crop_x)) + shift_1 = template_matching(data_t0, data_t1, (crop_z, crop_y, crop_x)) + shift_2 = template_matching(data_t0, data_t2, (crop_z, crop_y, crop_x)) + shift_3 = template_matching(data_t0, data_t3, (crop_z, crop_y, crop_x)) + + shifts = [shift_0, shift_1, shift_2, shift_3] + + for i, (calculated, expected) in enumerate(zip(shifts, translations)): + is_similar = np.allclose(calculated, expected, atol=tolerance) + print( + f'Timepoint {i+1} shift: {calculated}, Expected: {expected}, Similar: {is_similar}' + ) + + +# %% +# %% +if __name__ == "__main__": + main()