From 7169c038973e80262a493f0c309a3c2227c18c0a Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 8 Aug 2024 18:08:44 -0700 Subject: [PATCH 1/6] initial test of the multiotsu --- mantis/acquisition/autotracker.py | 245 ++++++++++++++++++++++++++++++ 1 file changed, 245 insertions(+) create mode 100644 mantis/acquisition/autotracker.py diff --git a/mantis/acquisition/autotracker.py b/mantis/acquisition/autotracker.py new file mode 100644 index 00000000..a9051e53 --- /dev/null +++ b/mantis/acquisition/autotracker.py @@ -0,0 +1,245 @@ +# %% + +from typing import Literal + +import matplotlib.pyplot as plt +import numpy as np +import skimage + +from iohub import open_ome_zarr +from numpy.typing import ArrayLike +from skimage.exposure import rescale_intensity +from skimage.feature import match_template +from skimage.filters import gaussian +from skimage.measure import label, regionprops +from skimage.registration import phase_cross_correlation + +from mantis import logger + +# import cv2 + + +# TODO: Make toy datasets for testing +# TODO: multiotsu should have an offset variable from the center of the image (X,Y) +# TODO: Check why PCC is not working +# TODO: write test functions +# TODO: consider splitting this file into two + + +# %% + + +def calc_weighted_center(labeled_im): + """calculates weighted centroid based on the area of the regions + + Parameters + ---------- + labeled_im : ndarray + labeled image + """ + regions = sorted(regionprops(labeled_im), key=lambda r: r.area) + n_regions = len(regions) + centers = [] + areas = [] + for i in range(n_regions): + centroid = regions[i].centroid + centers.append(centroid) + area = regions[i].area + areas.append(area) + areas_norm = np.array(areas) / np.sum(areas) + centers = np.array(centers) + center_weighted = np.zeros(3) + for j in range(3): + center_weighted[j] = np.sum(centers[:, j] * areas_norm) + + return center_weighted + + +def multiotsu_centroid(moving: ArrayLike, reference: ArrayLike) -> list: + """ + Multiotsu centroid method for finding the shift between two volumes (ZYX) + + Parameters + ---------- + moving : ndarray + moving stack ZYX + reference : ndarray + reference image stack ZYX + + Returns + ------- + shifts : list + list of shifts in z, y, x order + """ + # Process moving image + moving = rescale_intensity(moving, in_range='image', out_range=(0, 1.0)) + stack_blur = gaussian(moving, sigma=5.0) + thresh = skimage.filters.threshold_multiotsu(stack_blur) + moving = stack_blur > thresh[0] + moving = label(moving) + # Process reference image + reference = rescale_intensity(reference, in_range='image', out_range=(0, 1.0)) + stack_blur = gaussian(reference, sigma=5.0) + thresh = skimage.filters.threshold_multiotsu(stack_blur) + reference = stack_blur > thresh[0] + reference = label(reference) + + # Get the centroids + moving_center = calc_weighted_center(moving) + target_center = calc_weighted_center(reference) + + # Find the shifts + shifts = moving_center - target_center + + logger.debug( + 'moving_center (z,y,x): %f,%f,%f', moving_center[0], moving_center[1], moving_center[2] + ) + logger.debug( + 'target_center (z,y,x): %f,%f,%f', target_center[0], target_center[1], target_center[2] + ) + logger.debug('shifts (z,y,x): %f,%f,%f', shifts[0], shifts[1], shifts[2]) + + return shifts + + +_AUTOFOCUS_METHODS = { + 'pcc': phase_cross_correlation, + 'tm': match_template, + 'multiotsu': multiotsu_centroid, +} + + +# %% +def main(): + print('AUTOTRACKER') + """ + Toy dataset + translations = [ + (0, 0, 0), # Shift for timepoint 0 + (5, -80, 80), # Shift for timepoint 1 + (9, -50, -50), # Shift for timepoint 2 + (-5, 30, -60), # Shift for timepoint 3 + (0, 30, -80), # Shift for timepoint 4 + ] + """ + # %% + input_data_path = ( + '/home/eduardo.hirata/repos/mantis/mantis/tests/x-ed/toy_translate.zarr/0/0/0' + ) + dataset = open_ome_zarr(input_data_path) + T, C, Z, Y, X = dataset.data.shape + # print(channel_names := dataset.channel_names) + # autofocus_method = 'multiotsu' + # xy_dapening = (10, 10) + + c_idx = 2 + data_t0 = dataset.data[0, c_idx] + data_t1 = dataset.data[1, c_idx] + data_t2 = dataset.data[2, c_idx] + data_t3 = dataset.data[3, c_idx] + + # subplot + fig, ax = plt.subplots(1, 3) + ax[0].imshow(data_t0[10]) + ax[1].imshow(data_t1[10]) + ax[2].imshow(data_t2[10]) + plt.show() + + shift_0 = multiotsu_centroid(data_t0, data_t0) + shift_1 = multiotsu_centroid(data_t1, data_t0) + shift_2 = multiotsu_centroid(data_t2, data_t0) + shift_3 = multiotsu_centroid(data_t3, data_t0) + + shifts = [shift_0, shift_1, shift_2, shift_3] + + translations = [ + (0, 0, 0), # Shift for timepoint 0 + (5, -80, 80), # Shift for timepoint 1 + (9, -50, -50), # Shift for timepoint 2 + (-5, 30, -60), # Shift for timepoint 3 + ] # Compare shifts with expected translations + + tolerance = (5, 8, 8) # Define your tolerance level + for i, (calculated, expected) in enumerate(zip(shifts, translations)): + is_similar = np.allclose(calculated, expected, atol=tolerance) + print( + f'Timepoint {i+1} shift: {calculated}, Expected: {expected}, Similar: {is_similar}' + ) + + +# %% + + +def estimate_shift( + reference_array: ArrayLike, + moving_array: ArrayLike, + autofocus_method: Literal['pcc', 'tm', 'multiotsu'], + xy_dapening: tuple[int] = None, + **kwargs, +) -> np.ndarray: + autofocus_method_func = _AUTOFOCUS_METHODS.get(autofocus_method) + + if not autofocus_method_func: + raise ValueError(f'Unknown autofocus method: {autofocus_method}') + + shifts = autofocus_method_func(**kwargs) + + return shifts + + +def get_shift_centroid(autofocus_method, im_moving, im_ref): + """finds the centroid of the images and returns the shift + between the two volumes in z, y, and x + + Parameters + ---------- + im_moving : ndarray + moving image volume + im_ref : ndarray + reference image volume + + Returns + ------- + shift : ndarray + array of shift in z, y, x order + """ + Z, Y, X = im_moving.shape + if autofocus_method == 'multiotsu': + centroid_moving = multiotsu_centroid(im_moving) + centroid_ref = [int(Z / 2), int((Y) / 2), int((X) / 2)] + if autofocus_method == 'com': + centroid_moving = calc_com(im_moving) + centroid_ref = [int(Z / 2), int((Y) / 2), int((X) / 2)] + + return centroid_moving - centroid_ref + + +def calc_com(im): + """ """ + stack = np.clip(im - np.percentile(im, 10), a_min=0, a_max=None) + Z, Y, X = im.shape + center = [int(Z // 2), int(Y // 2), int(X // 2)] + num_z = 0 + num_y = 0 + num_x = 0 + for z in range(Z): + for y in range(Y): + for x in range(X): + val = stack[z, y, x] + num_z = num_z + (z - center[0]) * val + num_y = num_y + (y - center[1]) * val + num_x = num_x + (x - center[2]) * val + centroid = [ + int(num_z / np.sum(stack)), + int(num_y / np.sum(stack)), + int(num_x / np.sum(stack)), + ] + centroid = np.array(centroid) + np.array(center) + print(f'centroid: {centroid}') + return centroid + + +# %% +# %% +if __name__ == "__main__": + main() From ca341bdf028ea56e04735f1c7361afb498635e55 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 9 Aug 2024 16:45:25 -0700 Subject: [PATCH 2/6] adding working pcc, template matching and multiotsu. borrowing funcs from dexv2 --- mantis/acquisition/autotracker.py | 366 ++++++++++++++++++++++-------- 1 file changed, 270 insertions(+), 96 deletions(-) diff --git a/mantis/acquisition/autotracker.py b/mantis/acquisition/autotracker.py index a9051e53..e403df12 100644 --- a/mantis/acquisition/autotracker.py +++ b/mantis/acquisition/autotracker.py @@ -1,6 +1,5 @@ # %% - -from typing import Literal +from typing import Callable, Optional, Tuple, cast import matplotlib.pyplot as plt import numpy as np @@ -8,16 +7,17 @@ from iohub import open_ome_zarr from numpy.typing import ArrayLike +from scipy.fftpack import next_fast_len from skimage.exposure import rescale_intensity from skimage.feature import match_template from skimage.filters import gaussian from skimage.measure import label, regionprops -from skimage.registration import phase_cross_correlation from mantis import logger -# import cv2 - +# FIXME fix the dependencies so that we can install and import dexpv2 +# from dexpv2.crosscorr import phase_cross_corr +# from dexpv2.utils import center_crop, pad_to_shape, to_cpu # TODO: Make toy datasets for testing # TODO: multiotsu should have an offset variable from the center of the image (X,Y) @@ -26,9 +26,6 @@ # TODO: consider splitting this file into two -# %% - - def calc_weighted_center(labeled_im): """calculates weighted centroid based on the area of the regions @@ -55,9 +52,12 @@ def calc_weighted_center(labeled_im): return center_weighted -def multiotsu_centroid(moving: ArrayLike, reference: ArrayLike) -> list: +def multiotsu_centroid( + ref_img: ArrayLike, + mov_img: ArrayLike, +) -> list: """ - Multiotsu centroid method for finding the shift between two volumes (ZYX) + Computes the translation shifts using a multiotsu threshold approach by finding the centroid of the regions Parameters ---------- @@ -72,46 +72,236 @@ def multiotsu_centroid(moving: ArrayLike, reference: ArrayLike) -> list: list of shifts in z, y, x order """ # Process moving image - moving = rescale_intensity(moving, in_range='image', out_range=(0, 1.0)) - stack_blur = gaussian(moving, sigma=5.0) + mov_img = rescale_intensity(mov_img, in_range='image', out_range=(0, 1.0)) + stack_blur = gaussian(mov_img, sigma=5.0) thresh = skimage.filters.threshold_multiotsu(stack_blur) - moving = stack_blur > thresh[0] - moving = label(moving) + mov_img = stack_blur > thresh[0] + mov_img = label(mov_img) # Process reference image - reference = rescale_intensity(reference, in_range='image', out_range=(0, 1.0)) - stack_blur = gaussian(reference, sigma=5.0) + ref_img = rescale_intensity(ref_img, in_range='image', out_range=(0, 1.0)) + stack_blur = gaussian(ref_img, sigma=5.0) thresh = skimage.filters.threshold_multiotsu(stack_blur) - reference = stack_blur > thresh[0] - reference = label(reference) + ref_img = stack_blur > thresh[0] + ref_img = label(ref_img) # Get the centroids - moving_center = calc_weighted_center(moving) - target_center = calc_weighted_center(reference) + moving_center = calc_weighted_center(mov_img) + target_center = calc_weighted_center(ref_img) # Find the shifts shifts = moving_center - target_center logger.debug( - 'moving_center (z,y,x): %f,%f,%f', moving_center[0], moving_center[1], moving_center[2] + 'moving_center (z,y,x): %f,%f,%f', + moving_center[0], + moving_center[1], + moving_center[2], ) logger.debug( - 'target_center (z,y,x): %f,%f,%f', target_center[0], target_center[1], target_center[2] + 'target_center (z,y,x): %f,%f,%f', + target_center[0], + target_center[1], + target_center[2], ) logger.debug('shifts (z,y,x): %f,%f,%f', shifts[0], shifts[1], shifts[2]) return shifts -_AUTOFOCUS_METHODS = { - 'pcc': phase_cross_correlation, - 'tm': match_template, - 'multiotsu': multiotsu_centroid, -} +def template_matching(ref_img, moving_img, template_slicing_zyx): + """ + Uses template matching to determine shift between two image stacks. + + Parameters: + - ref_img: Reference 3D image stack (numpy array). + - moving_img: Moving 3D image stack (numpy array) to be aligned with the reference. + - template_slicing_zyx: Tuple or list of slice objects defining the region to be used as the template. + + Returns: + - shift: The shift (displacement) needed to align moving_img with ref_img (numpy array). + """ + template = ref_img[template_slicing_zyx] + + result = match_template(moving_img, template) + zyx_1 = np.unravel_index(np.argmax(result), result.shape) + + # Calculate the shift based on template slicing coordinates and match result + # Subtracting the coordinates of the + + shift = np.array(zyx_1) - np.array([s.start for s in template_slicing_zyx]) + + return shift + + +def to_cpu(arr: ArrayLike) -> ArrayLike: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + Moves array to cpu, if it's already there nothing is done. + + """ + if hasattr(arr, "cpu"): + arr = arr.cpu() + elif hasattr(arr, "get"): + arr = arr.get() + return arr + + +def center_crop(arr: ArrayLike, shape: Tuple[int, ...]) -> ArrayLike: + """Crops the center of `arr`""" + assert arr.ndim == len(shape) + + starts = tuple((cur_s - s) // 2 for cur_s, s in zip(arr.shape, shape)) + + assert all(s >= 0 for s in starts) + + slicing = tuple(slice(s, s + d) for s, d in zip(starts, shape)) + + logger.info( + f"center crop: input shape {arr.shape}, output shape {shape}, slicing {slicing}" + ) + + return arr[slicing] + + +def pad_to_shape(arr: ArrayLike, shape: Tuple[int, ...], mode: str, **kwargs) -> ArrayLike: + """Pads array to shape. + + Parameters + ---------- + arr : ArrayLike + Input array. + shape : Tuple[int] + Output shape. + mode : str + Padding mode (see np.pad). + + Returns + ------- + ArrayLike + Padded array. + """ + assert arr.ndim == len(shape) + + dif = tuple(s - a for s, a in zip(shape, arr.shape)) + assert all(d >= 0 for d in dif) + + pad_width = [[s // 2, s - s // 2] for s in dif] + + logger.info(f"padding: input shape {arr.shape}, output shape {shape}, padding {pad_width}") + + return np.pad(arr, pad_width=pad_width, mode=mode, **kwargs) + + +def _match_shape(img: ArrayLike, shape: Tuple[int, ...]) -> ArrayLike: + """Pad or crop array to match provided shape.""" + + if np.any(shape > img.shape): + padded_shape = np.maximum(img.shape, shape) + img = pad_to_shape(img, padded_shape, mode="reflect") + + if np.any(shape < img.shape): + img = center_crop(img, shape) + + return img + + +def phase_cross_corr( + ref_img: ArrayLike, + mov_img: ArrayLike, + maximum_shift: float = 1.0, + to_device: Callable[[ArrayLike], ArrayLike] = lambda x: x, + transform: Optional[Callable[[ArrayLike], ArrayLike]] = np.log1p, +) -> Tuple[int, ...]: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + + Computes translation shift using arg. maximum of phase cross correlation. + Input are padded or cropped for fast FFT computation assuming a maximum translation shift. + + Parameters + ---------- + ref_img : ArrayLike + Reference image. + mov_img : ArrayLike + Moved image. + maximum_shift : float, optional + Maximum location shift normalized by axis size, by default 1.0 + + Returns + ------- + Tuple[int, ...] + Shift between reference and moved image. + """ + shape = tuple( + cast(int, next_fast_len(int(max(s1, s2) * maximum_shift))) + for s1, s2 in zip(ref_img.shape, mov_img.shape) + ) + + logger.info( + f"phase cross corr. fft shape of {shape} for arrays of shape {ref_img.shape} and {mov_img.shape} " + f"with maximum shift of {maximum_shift}" + ) + + ref_img = _match_shape(ref_img, shape) + mov_img = _match_shape(mov_img, shape) + + ref_img = to_device(ref_img) + mov_img = to_device(mov_img) + + if transform is not None: + ref_img = transform(ref_img) + mov_img = transform(mov_img) + + Fimg1 = np.fft.rfftn(ref_img) + Fimg2 = np.fft.rfftn(mov_img) + eps = np.finfo(Fimg1.dtype).eps + del ref_img, mov_img + + prod = Fimg1 * Fimg2.conj() + del Fimg1, Fimg2 + + norm = np.fmax(np.abs(prod), eps) + corr = np.fft.irfftn(prod / norm) + del prod, norm + + corr = np.fft.fftshift(np.abs(corr)) + + argmax = to_cpu(np.argmax(corr)) + peak = np.unravel_index(argmax, corr.shape) + peak = tuple(s // 2 - p for s, p in zip(corr.shape, peak)) + + logger.info(f"phase cross corr. peak at {peak}") + + return peak + + +# %% +class Autotracker(object): + _AUTOFOCUS_METHODS = { + 'pcc': phase_cross_corr, + 'tm': template_matching, + 'multiotsu': multiotsu_centroid, + } + + def __init__(self, autofocus_method: str, xy_dapening: tuple[int] = None): + self.autofocus_method = autofocus_method + self.xy_dapening = xy_dapening + + def estimate_shift( + self, reference_array: ArrayLike, moving_array: ArrayLike, **kwargs + ) -> np.ndarray: + autofocus_method_func = self._AUTOFOCUS_METHODS.get(self.autofocus_method) + + if not autofocus_method_func: + raise ValueError(f'Unknown autofocus method: {self.autofocus_method}') + + shifts = autofocus_method_func(**kwargs) + + return shifts # %% def main(): - print('AUTOTRACKER') """ Toy dataset translations = [ @@ -132,7 +322,7 @@ def main(): # autofocus_method = 'multiotsu' # xy_dapening = (10, 10) - c_idx = 2 + c_idx = 0 data_t0 = dataset.data[0, c_idx] data_t1 = dataset.data[1, c_idx] data_t2 = dataset.data[2, c_idx] @@ -144,7 +334,8 @@ def main(): ax[1].imshow(data_t1[10]) ax[2].imshow(data_t2[10]) plt.show() - + # %% + # Testing Multiotsu shift_0 = multiotsu_centroid(data_t0, data_t0) shift_1 = multiotsu_centroid(data_t1, data_t0) shift_2 = multiotsu_centroid(data_t2, data_t0) @@ -165,78 +356,61 @@ def main(): print( f'Timepoint {i+1} shift: {calculated}, Expected: {expected}, Similar: {is_similar}' ) + # %% + # Testing PCC + shift_0 = phase_cross_corr(data_t0, data_t0) + shift_1 = phase_cross_corr(data_t0, data_t1) + shift_2 = phase_cross_corr(data_t0, data_t2) + shift_3 = phase_cross_corr(data_t0, data_t3) + shifts = [shift_0, shift_1, shift_2, shift_3] -# %% - - -def estimate_shift( - reference_array: ArrayLike, - moving_array: ArrayLike, - autofocus_method: Literal['pcc', 'tm', 'multiotsu'], - xy_dapening: tuple[int] = None, - **kwargs, -) -> np.ndarray: - autofocus_method_func = _AUTOFOCUS_METHODS.get(autofocus_method) - - if not autofocus_method_func: - raise ValueError(f'Unknown autofocus method: {autofocus_method}') - - shifts = autofocus_method_func(**kwargs) - - return shifts + for i, (calculated, expected) in enumerate(zip(shifts, translations)): + is_similar = np.allclose(calculated, expected, atol=tolerance) + print( + f'Timepoint {i+1} shift: {calculated}, Expected: {expected}, Similar: {is_similar}' + ) + # %% + # Testing template matching + + crop_z = slice(4, 8) + crop_y = slice(200, 300) + crop_x = slice(200, 300) + template = data_t0[crop_z, crop_y, crop_x] + + result = match_template(data_t1, template) + zyx = np.unravel_index(np.argmax(result), result.shape) + + print(zyx) + fig, ax = plt.subplots(1, 2) + ax[0].imshow(template[0]) + ax[0].set_title('template') + ax[1].imshow(data_t1[zyx[0]]) + rect = plt.Rectangle( + (zyx[2], zyx[1]), + template.shape[2], + template.shape[1], + edgecolor='red', + facecolor='none', + ) + ax[1].add_patch(rect) + ax[1].set_title('template matching result') -def get_shift_centroid(autofocus_method, im_moving, im_ref): - """finds the centroid of the images and returns the shift - between the two volumes in z, y, and x + # %% + # Calculate the shift, apply and check the result + shift_0 = template_matching(data_t0, data_t0, (crop_z, crop_y, crop_x)) + shift_1 = template_matching(data_t0, data_t1, (crop_z, crop_y, crop_x)) + shift_2 = template_matching(data_t0, data_t2, (crop_z, crop_y, crop_x)) + shift_3 = template_matching(data_t0, data_t3, (crop_z, crop_y, crop_x)) - Parameters - ---------- - im_moving : ndarray - moving image volume - im_ref : ndarray - reference image volume + shifts = [shift_0, shift_1, shift_2, shift_3] - Returns - ------- - shift : ndarray - array of shift in z, y, x order - """ - Z, Y, X = im_moving.shape - if autofocus_method == 'multiotsu': - centroid_moving = multiotsu_centroid(im_moving) - centroid_ref = [int(Z / 2), int((Y) / 2), int((X) / 2)] - if autofocus_method == 'com': - centroid_moving = calc_com(im_moving) - centroid_ref = [int(Z / 2), int((Y) / 2), int((X) / 2)] - - return centroid_moving - centroid_ref - - -def calc_com(im): - """ """ - stack = np.clip(im - np.percentile(im, 10), a_min=0, a_max=None) - Z, Y, X = im.shape - center = [int(Z // 2), int(Y // 2), int(X // 2)] - num_z = 0 - num_y = 0 - num_x = 0 - for z in range(Z): - for y in range(Y): - for x in range(X): - val = stack[z, y, x] - num_z = num_z + (z - center[0]) * val - num_y = num_y + (y - center[1]) * val - num_x = num_x + (x - center[2]) * val - centroid = [ - int(num_z / np.sum(stack)), - int(num_y / np.sum(stack)), - int(num_x / np.sum(stack)), - ] - centroid = np.array(centroid) + np.array(center) - print(f'centroid: {centroid}') - return centroid + for i, (calculated, expected) in enumerate(zip(shifts, translations)): + is_similar = np.allclose(calculated, expected, atol=tolerance) + print( + f'Timepoint {i+1} shift: {calculated}, Expected: {expected}, Similar: {is_similar}' + ) # %% From 321b76ecbf5d46d4936d55c5a834dc942611ea88 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 13 Aug 2024 09:57:48 -0700 Subject: [PATCH 3/6] adding autotracker shrimpy backbone and tracking methods --- mantis/acquisition/AcquisitionSettings.py | 23 + mantis/acquisition/acq_engine.py | 48 ++ mantis/acquisition/autotracker.py | 157 ++++++- .../autotracker_hook_function.py | 48 ++ mantis/acquisition/microscope_operations.py | 11 +- .../settings/demo_acquisition_settings.yaml | 6 + mantis/cli/run_acquisition.py | 11 + .../test_autotracker/test_autotracker.py | 427 ++++++++++++++++++ 8 files changed, 706 insertions(+), 25 deletions(-) create mode 100644 mantis/acquisition/hook_functions/autotracker_hook_function.py create mode 100644 mantis/tests/test_autotracker/test_autotracker.py diff --git a/mantis/acquisition/AcquisitionSettings.py b/mantis/acquisition/AcquisitionSettings.py index d1225c58..e4cbb780 100644 --- a/mantis/acquisition/AcquisitionSettings.py +++ b/mantis/acquisition/AcquisitionSettings.py @@ -193,3 +193,26 @@ def __post_init__(self): attr_val = getattr(self, attr) if attr_val is not None: setattr(self, attr, round(attr_val, 1)) + + +@dataclass +class AutotrackerSettings: + tracking_method: Literal['phase_cross_correlation', 'template_matching', 'multi_otsu'] + device: Optional[str] = 'cpu' + tracking_arm: Literal['lf', 'ls'] = 'lf' + channel_to_track: Optional[str] = None + zyx_dampening_factor: Optional[float, float, float] = None + re_run_every_n_timepoints: Optional[int] = None + # TODO: maybe do the ROI like in the ls_microscope_settings + template_roi_zyx: Optional[Tuple[int, int, int]] = None + template_channel: Optional[str] = None + + @validator("autotracker_method") + def check_autotracker_methods_options(cls, v): + # Check if template matching options are provided and are not None + if v == 'template_matching': + if not all([cls.template_roi_zyx, cls.template_channel]): + raise ValueError( + 'template_roi_zyx and template_channel must be provided for template matching' + ) + return v diff --git a/mantis/acquisition/acq_engine.py b/mantis/acquisition/acq_engine.py index 822fda24..a37985f7 100644 --- a/mantis/acquisition/acq_engine.py +++ b/mantis/acquisition/acq_engine.py @@ -31,6 +31,7 @@ SliceSettings, MicroscopeSettings, AutoexposureSettings, + AutotrackerSettings, ) from mantis.acquisition.hook_functions.pre_hardware_hook_functions import ( log_preparing_acquisition, @@ -102,6 +103,7 @@ def __init__( self._slice_settings = SliceSettings() self._microscope_settings = MicroscopeSettings() self._autoexposure_settings = None + self._autotracker_settings = None self._z0 = None self.headless = False if mm_app_path is None else True self.type = 'light-sheet' if self.headless else 'label-free' @@ -173,6 +175,10 @@ def microscope_settings(self): def autoexposure_settings(self): return self._autoexposure_settings + @property + def autotracker_settings(self): + return self._autotracker_settings + @channel_settings.setter def channel_settings(self, settings: ChannelSettings): logger.debug( @@ -204,6 +210,13 @@ def autoexposure_settings(self, settings: AutoexposureSettings): ) self._autoexposure_settings = settings + @autotracker_settings.setter + def autotracker_settings(self, settings: AutotrackerSettings): + logger.debug( + f"{self.type.capitalize()} acquisition will have the following settings:{asdict(settings)}" + ) + self._autotracker_settings = settings + def setup(self): """ Apply acquisition settings as specified by the class properties @@ -276,6 +289,7 @@ def reset(self): ) +# TODO: check the enable_ls_acq and enable_lf_acq work independently class MantisAcquisition(object): """ Acquisition class for simultaneous label-free and light-sheet acquisition on @@ -671,6 +685,22 @@ def setup_autoexposure(self): ) ) + def setup_autotracker(self): + if self._demo_run: + # TODO: implement autotracker in demo mode + logger.debug('Autotracker is not supported in demo mode') + return + + if self.lf_acq.microscope_settings.use_autotracker: + logger.debug('Setting up autotracker') + microscope_operations.setup_autotracker( + self.lf_acq.mmc, + self.lf_acq.microscope_settings.autotracker_channel, + self.lf_acq.microscope_settings.autotracker_threshold, + ) + else: + logger.debug('Autotracker is not enabled in the microscope settings') + def go_to_position(self, position_index: int): # Move slowly for short distances such that autofocus can stay engaged. # Autofocus typically fails when moving long distances, so we can move @@ -938,6 +968,18 @@ def run_autoexposure( f'Autoexposure method {method} is not yet implemented.' ) + def run_autotracker( + self, + acq: BaseChannelSliceAcquisition, + well_id: str, + method: str = 'manual', + ): + logging.debug('running autotracker') + if not any(acq.channel_settings.use_autoexposure): + return + # TODO: implement autotracker + microscope_operations.autotracker(acq.mmc, acq.autoexposure_settings) + def setup(self): """ Setup the mantis acquisition. This method sets up the label-free @@ -965,6 +1007,9 @@ def setup(self): logger.debug('Setting up autoexposure') self.setup_autoexposure() + logger.debug('Setting up auotracker') + self.setup_autotracker() + def acquire(self): """ Simultaneously acquire label-free and light-sheet data over multiple @@ -1089,6 +1134,9 @@ def acquire(self): well_id=well_id, method=self.ls_acq.autoexposure_settings.autoexposure_method, ) + # TODO: add logic to handle skipping timepoints + if t_idx < 2: + self.run_autotracker(acq=self.lf_acq, well_id=well_id) # Acq rate needs to be updated even if autoexposure was not rerun in this well # Only do that if we are using autoexposure? self.update_ls_acquisition_rates( diff --git a/mantis/acquisition/autotracker.py b/mantis/acquisition/autotracker.py index e403df12..d1cf9866 100644 --- a/mantis/acquisition/autotracker.py +++ b/mantis/acquisition/autotracker.py @@ -1,8 +1,10 @@ # %% +from pathlib import Path from typing import Callable, Optional, Tuple, cast import matplotlib.pyplot as plt import numpy as np +import pandas as pd import skimage from iohub import open_ome_zarr @@ -19,9 +21,6 @@ # from dexpv2.crosscorr import phase_cross_corr # from dexpv2.utils import center_crop, pad_to_shape, to_cpu -# TODO: Make toy datasets for testing -# TODO: multiotsu should have an offset variable from the center of the image (X,Y) -# TODO: Check why PCC is not working # TODO: write test functions # TODO: consider splitting this file into two @@ -237,7 +236,7 @@ def phase_cross_corr( for s1, s2 in zip(ref_img.shape, mov_img.shape) ) - logger.info( + logger.debug( f"phase cross corr. fft shape of {shape} for arrays of shape {ref_img.shape} and {mov_img.shape} " f"with maximum shift of {maximum_shift}" ) @@ -270,34 +269,145 @@ def phase_cross_corr( peak = np.unravel_index(argmax, corr.shape) peak = tuple(s // 2 - p for s, p in zip(corr.shape, peak)) - logger.info(f"phase cross corr. peak at {peak}") + logger.debug(f"phase cross corr. peak at {peak}") return peak # %% class Autotracker(object): - _AUTOFOCUS_METHODS = { - 'pcc': phase_cross_corr, - 'tm': template_matching, - 'multiotsu': multiotsu_centroid, + _TRACKING_METHODS = { + 'phase_cross_correlation': phase_cross_corr, + 'template_matching': template_matching, + 'multi_otsu': multiotsu_centroid, } - def __init__(self, autofocus_method: str, xy_dapening: tuple[int] = None): - self.autofocus_method = autofocus_method - self.xy_dapening = xy_dapening - - def estimate_shift( - self, reference_array: ArrayLike, moving_array: ArrayLike, **kwargs - ) -> np.ndarray: - autofocus_method_func = self._AUTOFOCUS_METHODS.get(self.autofocus_method) + def __init__( + self, + tracking_method: str, + scale: ArrayLike[float, float, float], + zyx_dampening: ArrayLike[float, float, float] = None, + output_shifts_path: Path = './shifts.csv', + ): + """ + Autotracker object + + Parameters + ---------- + tracking_method : str + Method to use for autofocus. Options are 'phase_cross_correlation', 'template_matching', 'multi_otsu' + scale : ArrayLike[float, float, float] + Scale factor to convert shifts from px to um + xy_dampening : tuple[int] + Dampening factor for xy shifts + """ + self.tracking_method = tracking_method + self.zyx_dampening = zyx_dampening + self.scale = scale + self.shifts = None + self.output_shifts_path = output_shifts_path + # TODO: hook to the config logs + + def estimate_shift(self, ref_img: ArrayLike, mov_img: ArrayLike, **kwargs) -> np.ndarray: + """ + Estimates the shift between two images using the specified autofocus method. + + Parameters + ---------- + ref_img : ArrayLike + Reference image. + mov_img : ArrayLike + Image to be aligned with the reference. + kwargs : dict + Additional keyword arguments to be passed to the autofocus method. + + Returns + ------- + np.ndarray + The estimated shift in scale provided by the user (typically um). + """ + + autofocus_method_func = self._TRACKING_METHODS.get(self.tracking_method) if not autofocus_method_func: - raise ValueError(f'Unknown autofocus method: {self.autofocus_method}') - - shifts = autofocus_method_func(**kwargs) - - return shifts + raise ValueError(f'Unknown autofocus method: {self.tracking_method}') + + shifts = autofocus_method_func(ref_img=ref_img, mov_img=mov_img, **kwargs) + + # Shifts in px to shifts in um + self.shifts = np.array(shifts) * self.scale + + if self.zyx_dampening is not None: + self.shifts = self.shifts * self.zyx_dampening + logger.info(f'Shifts (z,y,x): {self.shifts}') + + return self.shifts + + # Function to log the shifts to a csv file + def save_shifts_to_file( + output_file: str, + shifts: Tuple[int, int, int], + position_id: int, + timepoint_id: int, + overwrite: bool = False, + ) -> None: + """ + Saves the computed shifts to a CSV file. + + Parameters + ---------- + output_file : str + Path to the output CSV file. + shifts : Tuple[int, int, int] + The computed shifts (Z, Y, X). + position_id : int + Identifier for the position. + timepoint_id : int + Identifier for the timepoint. + overwrite : bool + If True, the file will be overwritten if it exists. + """ + # Convert output_file to a Path object + output_path = Path(output_file) + data = { + "PositionID": [position_id], + "TimepointID": [timepoint_id], + "ShiftZ": [shifts[0]], + "ShiftY": [shifts[1]], + "ShiftX": [shifts[2]], + } + + df = pd.DataFrame(data) + + if overwrite or not output_path.exists(): + # Write the DataFrame to a new file, including the header + df.to_csv(output_path, mode='w', index=False) + else: + # Append the DataFrame to the existing file, without writing the header + df.to_csv(output_path, mode='a', header=False, index=False) + + def limit_shifts_zyx( + self, shifts: Tuple[int, int, int], limits: Tuple[int, int, int] = (5, 5, 5) + ) -> Tuple[int, int, int]: + """ + Limits the shifts to the specified limits. + + Parameters + ---------- + shifts : Tuple[int, int, int] + The computed shifts (Z, Y, X). + limits : Tuple[int, int, int] + The limits for the shifts (Z, Y, X). + + Returns + ------- + Tuple[int, int, int] + The limited shifts. + """ + shifts = np.array(shifts) + limits = np.array(limits) + shifts = np.where(np.abs(shifts) > limits, 0, shifts) + return tuple(shifts) # %% @@ -319,7 +429,7 @@ def main(): dataset = open_ome_zarr(input_data_path) T, C, Z, Y, X = dataset.data.shape # print(channel_names := dataset.channel_names) - # autofocus_method = 'multiotsu' + # tracking_method = 'multiotsu' # xy_dapening = (10, 10) c_idx = 0 @@ -413,7 +523,6 @@ def main(): ) -# %% # %% if __name__ == "__main__": main() diff --git a/mantis/acquisition/hook_functions/autotracker_hook_function.py b/mantis/acquisition/hook_functions/autotracker_hook_function.py new file mode 100644 index 00000000..7862e55c --- /dev/null +++ b/mantis/acquisition/hook_functions/autotracker_hook_function.py @@ -0,0 +1,48 @@ +from mantis.acquisition.autotracker import Autotracker + + +def autotracker_hook_fn(axes, dataset) -> None: + """ + Pycromanager hook function that is called when an image is saved. + + Parameters + ---------- + axes : Position, Time, Channel, Z_slice + dataset: Dataset saved in disk + """ + + # Get reference to the acquisition engine and it's settings + # TODO: This is a placeholder, the actual implementation will be different + acq = "reference to the acquisition engine" + shift_limit = acq.autofocus_settings.shift_limit + tracking_method = acq.autofocus_settings.tracking_method + output_shift_path = './output.csv' + + # Get axes info + p_idx = axes['position'] + t_idx = axes['time'] + channel = axes['channel'] + z_idx = axes['z'] + + # Logic to get the volumes + # TODO: This is a placeholder, the actual implementation will be different + volume_t0_axes = (p_idx, t_idx, channel, z_idx) + volume_t1_axes = (p_idx, t_idx, channel, z_idx) + + volume_t0 = dataset.read_image(**volume_t0_axes) + volume_t1 = dataset.read_image(**volume_t1_axes) + + # Compute the shifts + tracker = Autotracker( + autofocus_method=tracking_method, + shift_limit=shift_limit, + output_shifts_path=output_shift_path, + ) + # Reference and moving volumes + tracker.estimate_shift(volume_t0, volume_t1) + + # Save the shifts + # TODO: This is a placeholder, the actual implementation will be different + + # Update the event coordinates + # TODO: This is a placeholder, the actual implementation will be different diff --git a/mantis/acquisition/microscope_operations.py b/mantis/acquisition/microscope_operations.py index 324aa679..50370487 100644 --- a/mantis/acquisition/microscope_operations.py +++ b/mantis/acquisition/microscope_operations.py @@ -12,7 +12,7 @@ from pycromanager import Core, Studio from pylablib.devices.Thorlabs import KinesisPiezoMotor -from mantis.acquisition.AcquisitionSettings import AutoexposureSettings +from mantis.acquisition.AcquisitionSettings import AutoexposureSettings, AutotrackerSettings from mantis.acquisition.autoexposure import manual_autoexposure, mean_intensity_autoexposure logger = logging.getLogger(__name__) @@ -670,3 +670,12 @@ def autoexposure( ) return suggested_exposure_time, suggested_light_intensity + + +def autotracker( + mmc: Core, + autotracker_settings: AutotrackerSettings, + **kwargs, +): + logging.debug('Autotracker is not implemented yet') + pass diff --git a/mantis/acquisition/settings/demo_acquisition_settings.yaml b/mantis/acquisition/settings/demo_acquisition_settings.yaml index a076ad1c..417b3206 100644 --- a/mantis/acquisition/settings/demo_acquisition_settings.yaml +++ b/mantis/acquisition/settings/demo_acquisition_settings.yaml @@ -60,3 +60,9 @@ ls_microscope_settings: ls_autoexposure_settings: autoexposure_method: 'manual' rerun_each_timepoint: True + +autotracker_settings: + tracking_method: 'phase_cross_correlation' + tracking_arm: 'lf' + channel_to_track: 'Channel-Multiband' + device: 'cpu' diff --git a/mantis/cli/run_acquisition.py b/mantis/cli/run_acquisition.py index 14250f8e..7a718714 100644 --- a/mantis/cli/run_acquisition.py +++ b/mantis/cli/run_acquisition.py @@ -53,6 +53,7 @@ def run_acquisition( SliceSettings, MicroscopeSettings, AutoexposureSettings, + AutotrackerSettings, ) # isort: on @@ -77,6 +78,14 @@ def run_acquisition( ls_autoexposure_settings = AutoexposureSettings( **raw_settings.get('ls_autoexposure_settings') ) + # TODO: decide ls or lf autoexposure settings + ls_autotracker_settings = None + lf_autotracker_settings = None + autotracker_settings = AutotrackerSettings(**raw_settings.get('autotracker_settings')) + if autotracker_settings.tracking_arm == 'lf': + lf_autotracker_settings = autotracker_settings + elif autotracker_settings.tracking_arm == 'ls': + ls_autotracker_settings = autotracker_settings with MantisAcquisition( acquisition_directory=acq_directory, @@ -95,6 +104,8 @@ def run_acquisition( acq.ls_acq.slice_settings = ls_slice_settings acq.ls_acq.microscope_settings = ls_microscope_settings acq.ls_acq.autoexposure_settings = ls_autoexposure_settings + acq.ls_acq.autotracker_settings = ls_autotracker_settings + acq.lf_acq.autotracker_settings = lf_autotracker_settings acq.setup() acq.acquire() diff --git a/mantis/tests/test_autotracker/test_autotracker.py b/mantis/tests/test_autotracker/test_autotracker.py new file mode 100644 index 00000000..dc0aa457 --- /dev/null +++ b/mantis/tests/test_autotracker/test_autotracker.py @@ -0,0 +1,427 @@ +# %% +from typing import Callable, Optional, Tuple, cast + +import matplotlib.pyplot as plt +import numpy as np +import skimage + +from iohub import open_ome_zarr +from numpy.typing import ArrayLike +from scipy.fftpack import next_fast_len +from skimage.exposure import rescale_intensity +from skimage.feature import match_template +from skimage.filters import gaussian +from skimage.measure import label, regionprops + +from mantis import logger + +# FIXME fix the dependencies so that we can install and import dexpv2 +# from dexpv2.crosscorr import phase_cross_corr +# from dexpv2.utils import center_crop, pad_to_shape, to_cpu + +# TODO: Make toy datasets for testing +# TODO: multiotsu should have an offset variable from the center of the image (X,Y) +# TODO: Check why PCC is not working +# TODO: write test functions +# TODO: consider splitting this file into two + + +def calc_weighted_center(labeled_im): + """calculates weighted centroid based on the area of the regions + + Parameters + ---------- + labeled_im : ndarray + labeled image + """ + regions = sorted(regionprops(labeled_im), key=lambda r: r.area) + n_regions = len(regions) + centers = [] + areas = [] + for i in range(n_regions): + centroid = regions[i].centroid + centers.append(centroid) + area = regions[i].area + areas.append(area) + areas_norm = np.array(areas) / np.sum(areas) + centers = np.array(centers) + center_weighted = np.zeros(3) + for j in range(3): + center_weighted[j] = np.sum(centers[:, j] * areas_norm) + + return center_weighted + + +def multiotsu_centroid( + ref_img: ArrayLike, + mov_img: ArrayLike, +) -> list: + """ + Computes the translation shifts using a multiotsu threshold approach by finding the centroid of the regions + + Parameters + ---------- + moving : ndarray + moving stack ZYX + reference : ndarray + reference image stack ZYX + + Returns + ------- + shifts : list + list of shifts in z, y, x order + """ + # Process moving image + mov_img = rescale_intensity(mov_img, in_range='image', out_range=(0, 1.0)) + stack_blur = gaussian(mov_img, sigma=5.0) + thresh = skimage.filters.threshold_multiotsu(stack_blur) + mov_img = stack_blur > thresh[0] + mov_img = label(mov_img) + # Process reference image + ref_img = rescale_intensity(ref_img, in_range='image', out_range=(0, 1.0)) + stack_blur = gaussian(ref_img, sigma=5.0) + thresh = skimage.filters.threshold_multiotsu(stack_blur) + ref_img = stack_blur > thresh[0] + ref_img = label(ref_img) + + # Get the centroids + moving_center = calc_weighted_center(mov_img) + target_center = calc_weighted_center(ref_img) + + # Find the shifts + shifts = moving_center - target_center + + logger.debug( + 'moving_center (z,y,x): %f,%f,%f', + moving_center[0], + moving_center[1], + moving_center[2], + ) + logger.debug( + 'target_center (z,y,x): %f,%f,%f', + target_center[0], + target_center[1], + target_center[2], + ) + logger.debug('shifts (z,y,x): %f,%f,%f', shifts[0], shifts[1], shifts[2]) + + return shifts + + +def template_matching(ref_img, moving_img, template_slicing_zyx): + """ + Uses template matching to determine shift between two image stacks. + + Parameters: + - ref_img: Reference 3D image stack (numpy array). + - moving_img: Moving 3D image stack (numpy array) to be aligned with the reference. + - template_slicing_zyx: Tuple or list of slice objects defining the region to be used as the template. + + Returns: + - shift: The shift (displacement) needed to align moving_img with ref_img (numpy array). + """ + template = ref_img[template_slicing_zyx] + + result = match_template(moving_img, template) + zyx_1 = np.unravel_index(np.argmax(result), result.shape) + + # Calculate the shift based on template slicing coordinates and match result + # Subtracting the coordinates of the + + shift = np.array(zyx_1) - np.array([s.start for s in template_slicing_zyx]) + + return shift + + +def to_cpu(arr: ArrayLike) -> ArrayLike: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + Moves array to cpu, if it's already there nothing is done. + + """ + if hasattr(arr, "cpu"): + arr = arr.cpu() + elif hasattr(arr, "get"): + arr = arr.get() + return arr + + +def _match_shape(img: ArrayLike, shape: Tuple[int, ...]) -> ArrayLike: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + Pad or crop array to match provided shape. + """ + + if np.any(shape > img.shape): + padded_shape = np.maximum(img.shape, shape) + img = pad_to_shape(img, padded_shape, mode="reflect") + + if np.any(shape < img.shape): + img = center_crop(img, shape) + + return img + + +def center_crop(arr: ArrayLike, shape: Tuple[int, ...]) -> ArrayLike: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + Crops the center of `arr` + """ + assert arr.ndim == len(shape) + + starts = tuple((cur_s - s) // 2 for cur_s, s in zip(arr.shape, shape)) + + assert all(s >= 0 for s in starts) + + slicing = tuple(slice(s, s + d) for s, d in zip(starts, shape)) + + logger.info( + f"center crop: input shape {arr.shape}, output shape {shape}, slicing {slicing}" + ) + + return arr[slicing] + + +def pad_to_shape(arr: ArrayLike, shape: Tuple[int, ...], mode: str, **kwargs) -> ArrayLike: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + Pads array to shape. + + Parameters + ---------- + arr : ArrayLike + Input array. + shape : Tuple[int] + Output shape. + mode : str + Padding mode (see np.pad). + + Returns + ------- + ArrayLike + Padded array. + """ + assert arr.ndim == len(shape) + + dif = tuple(s - a for s, a in zip(shape, arr.shape)) + assert all(d >= 0 for d in dif) + + pad_width = [[s // 2, s - s // 2] for s in dif] + + logger.info(f"padding: input shape {arr.shape}, output shape {shape}, padding {pad_width}") + + return np.pad(arr, pad_width=pad_width, mode=mode, **kwargs) + + +def phase_cross_corr( + ref_img: ArrayLike, + mov_img: ArrayLike, + maximum_shift: float = 1.0, + to_device: Callable[[ArrayLike], ArrayLike] = lambda x: x, + transform: Optional[Callable[[ArrayLike], ArrayLike]] = np.log1p, +) -> Tuple[int, ...]: + """ + Borrowing from Jordao dexpv2.crosscorr https://github.com/royerlab/dexpv2 + + Computes translation shift using arg. maximum of phase cross correlation. + Input are padded or cropped for fast FFT computation assuming a maximum translation shift. + + Parameters + ---------- + ref_img : ArrayLike + Reference image. + mov_img : ArrayLike + Moved image. + maximum_shift : float, optional + Maximum location shift normalized by axis size, by default 1.0 + + Returns + ------- + Tuple[int, ...] + Shift between reference and moved image. + """ + shape = tuple( + cast(int, next_fast_len(int(max(s1, s2) * maximum_shift))) + for s1, s2 in zip(ref_img.shape, mov_img.shape) + ) + + logger.info( + f"phase cross corr. fft shape of {shape} for arrays of shape {ref_img.shape} and {mov_img.shape} " + f"with maximum shift of {maximum_shift}" + ) + + ref_img = _match_shape(ref_img, shape) + mov_img = _match_shape(mov_img, shape) + + ref_img = to_device(ref_img) + mov_img = to_device(mov_img) + + if transform is not None: + ref_img = transform(ref_img) + mov_img = transform(mov_img) + + Fimg1 = np.fft.rfftn(ref_img) + Fimg2 = np.fft.rfftn(mov_img) + eps = np.finfo(Fimg1.dtype).eps + del ref_img, mov_img + + prod = Fimg1 * Fimg2.conj() + del Fimg1, Fimg2 + + norm = np.fmax(np.abs(prod), eps) + corr = np.fft.irfftn(prod / norm) + del prod, norm + + corr = np.fft.fftshift(np.abs(corr)) + + argmax = to_cpu(np.argmax(corr)) + peak = np.unravel_index(argmax, corr.shape) + peak = tuple(s // 2 - p for s, p in zip(corr.shape, peak)) + + logger.info(f"phase cross corr. peak at {peak}") + + return peak + + +# %% +class Autotracker(object): + _AUTOFOCUS_METHODS = { + 'pcc': phase_cross_corr, + 'tm': template_matching, + 'multiotsu': multiotsu_centroid, + } + + def __init__(self, autofocus_method: str, xy_dapening: tuple[int] = None): + self.autofocus_method = autofocus_method + self.xy_dapening = xy_dapening + + def estimate_shift( + self, reference_array: ArrayLike, moving_array: ArrayLike, **kwargs + ) -> np.ndarray: + autofocus_method_func = self._AUTOFOCUS_METHODS.get(self.autofocus_method) + + if not autofocus_method_func: + raise ValueError(f'Unknown autofocus method: {self.autofocus_method}') + + shifts = autofocus_method_func(**kwargs) + + return shifts + + +# %% +def main(): + """ + Toy dataset + translations = [ + (0, 0, 0), # Shift for timepoint 0 + (5, -80, 80), # Shift for timepoint 1 + (9, -50, -50), # Shift for timepoint 2 + (-5, 30, -60), # Shift for timepoint 3 + (0, 30, -80), # Shift for timepoint 4 + ] + """ + # %% + input_data_path = ( + '/home/eduardo.hirata/repos/mantis/mantis/tests/x-ed/toy_translate.zarr/0/0/0' + ) + dataset = open_ome_zarr(input_data_path) + T, C, Z, Y, X = dataset.data.shape + # print(channel_names := dataset.channel_names) + # autofocus_method = 'multiotsu' + # xy_dapening = (10, 10) + + c_idx = 0 + data_t0 = dataset.data[0, c_idx] + data_t1 = dataset.data[1, c_idx] + data_t2 = dataset.data[2, c_idx] + data_t3 = dataset.data[3, c_idx] + + # subplot + fig, ax = plt.subplots(1, 3) + ax[0].imshow(data_t0[10]) + ax[1].imshow(data_t1[10]) + ax[2].imshow(data_t2[10]) + plt.show() + # %% + # Testing Multiotsu + shift_0 = multiotsu_centroid(data_t0, data_t0) + shift_1 = multiotsu_centroid(data_t1, data_t0) + shift_2 = multiotsu_centroid(data_t2, data_t0) + shift_3 = multiotsu_centroid(data_t3, data_t0) + + shifts = [shift_0, shift_1, shift_2, shift_3] + + translations = [ + (0, 0, 0), # Shift for timepoint 0 + (5, -80, 80), # Shift for timepoint 1 + (9, -50, -50), # Shift for timepoint 2 + (-5, 30, -60), # Shift for timepoint 3 + ] # Compare shifts with expected translations + + tolerance = (5, 8, 8) # Define your tolerance level + for i, (calculated, expected) in enumerate(zip(shifts, translations)): + is_similar = np.allclose(calculated, expected, atol=tolerance) + print( + f'Timepoint {i+1} shift: {calculated}, Expected: {expected}, Similar: {is_similar}' + ) + # %% + # Testing PCC + shift_0 = phase_cross_corr(data_t0, data_t0) + shift_1 = phase_cross_corr(data_t0, data_t1) + shift_2 = phase_cross_corr(data_t0, data_t2) + shift_3 = phase_cross_corr(data_t0, data_t3) + + shifts = [shift_0, shift_1, shift_2, shift_3] + + for i, (calculated, expected) in enumerate(zip(shifts, translations)): + is_similar = np.allclose(calculated, expected, atol=tolerance) + print( + f'Timepoint {i+1} shift: {calculated}, Expected: {expected}, Similar: {is_similar}' + ) + + # %% + # Testing template matching + + crop_z = slice(4, 8) + crop_y = slice(200, 300) + crop_x = slice(200, 300) + template = data_t0[crop_z, crop_y, crop_x] + + result = match_template(data_t1, template) + zyx = np.unravel_index(np.argmax(result), result.shape) + + print(zyx) + fig, ax = plt.subplots(1, 2) + ax[0].imshow(template[0]) + ax[0].set_title('template') + ax[1].imshow(data_t1[zyx[0]]) + rect = plt.Rectangle( + (zyx[2], zyx[1]), + template.shape[2], + template.shape[1], + edgecolor='red', + facecolor='none', + ) + ax[1].add_patch(rect) + ax[1].set_title('template matching result') + + # %% + # Calculate the shift, apply and check the result + shift_0 = template_matching(data_t0, data_t0, (crop_z, crop_y, crop_x)) + shift_1 = template_matching(data_t0, data_t1, (crop_z, crop_y, crop_x)) + shift_2 = template_matching(data_t0, data_t2, (crop_z, crop_y, crop_x)) + shift_3 = template_matching(data_t0, data_t3, (crop_z, crop_y, crop_x)) + + shifts = [shift_0, shift_1, shift_2, shift_3] + + for i, (calculated, expected) in enumerate(zip(shifts, translations)): + is_similar = np.allclose(calculated, expected, atol=tolerance) + print( + f'Timepoint {i+1} shift: {calculated}, Expected: {expected}, Similar: {is_similar}' + ) + + +# %% +# %% +if __name__ == "__main__": + main() From 1b0f208dd1b1758221a91883851f0e0d9aa6fd6c Mon Sep 17 00:00:00 2001 From: edhirata Date: Tue, 13 Aug 2024 17:11:47 -0700 Subject: [PATCH 4/6] this runs on demo mode without updating the events for the stage. Assumes that there are shifts computed and returned. --- mantis/acquisition/AcquisitionSettings.py | 13 +- mantis/acquisition/acq_engine.py | 63 +++-- mantis/acquisition/autotracker.py | 238 +++++++++--------- .../autotracker_hook_function.py | 48 ---- mantis/acquisition/hook_functions/globals.py | 2 + mantis/acquisition/microscope_operations.py | 2 +- .../settings/demo_acquisition_settings.yaml | 11 +- mantis/cli/run_acquisition.py | 14 +- 8 files changed, 180 insertions(+), 211 deletions(-) delete mode 100644 mantis/acquisition/hook_functions/autotracker_hook_function.py diff --git a/mantis/acquisition/AcquisitionSettings.py b/mantis/acquisition/AcquisitionSettings.py index e4cbb780..3e73c4c5 100644 --- a/mantis/acquisition/AcquisitionSettings.py +++ b/mantis/acquisition/AcquisitionSettings.py @@ -141,6 +141,7 @@ class MicroscopeSettings: use_o3_refocus: bool = False o3_refocus_config: Optional[ConfigSettings] = None o3_refocus_interval_min: Optional[int] = None + autotracker_config: Optional[ConfigSettings] = None @dataclass @@ -198,17 +199,17 @@ def __post_init__(self): @dataclass class AutotrackerSettings: tracking_method: Literal['phase_cross_correlation', 'template_matching', 'multi_otsu'] + tracking_interval: Optional[int] = 1 + scale_yx: Optional[float] = 1.0 + shift_limit: Optional[Union[Tuple[float, float, float], 'None']] = None device: Optional[str] = 'cpu' - tracking_arm: Literal['lf', 'ls'] = 'lf' - channel_to_track: Optional[str] = None - zyx_dampening_factor: Optional[float, float, float] = None - re_run_every_n_timepoints: Optional[int] = None + zyx_dampening_factor: Optional[Union[Tuple[float, float, float], None]] = None # TODO: maybe do the ROI like in the ls_microscope_settings template_roi_zyx: Optional[Tuple[int, int, int]] = None template_channel: Optional[str] = None - @validator("autotracker_method") - def check_autotracker_methods_options(cls, v): + @validator("tracking_method") + def check_tracking_method_options(cls, v): # Check if template matching options are provided and are not None if v == 'template_matching': if not all([cls.template_roi_zyx, cls.template_channel]): diff --git a/mantis/acquisition/acq_engine.py b/mantis/acquisition/acq_engine.py index 7d927d81..fc4a0b4d 100644 --- a/mantis/acquisition/acq_engine.py +++ b/mantis/acquisition/acq_engine.py @@ -50,7 +50,7 @@ check_ls_acq_finished, ) -# isort: on +from mantis.acquisition.autotracker import autotracker_hook_fn # Define constants @@ -212,9 +212,12 @@ def autoexposure_settings(self, settings: AutoexposureSettings): @autotracker_settings.setter def autotracker_settings(self, settings: AutotrackerSettings): - logger.debug( - f"{self.type.capitalize()} acquisition will have the following settings:{asdict(settings)}" - ) + if settings is None: + logger.debug('Autotracker settings are not provided') + else: + logger.debug( + f"{self.type.capitalize()} acquisition will have the following settings:{asdict(settings)}" + ) self._autotracker_settings = settings def setup(self): @@ -350,6 +353,8 @@ def __init__( self._lf_acq_obj = None self._ls_acq_obj = None + globals.demo_run = demo_run + if not enable_lf_acq or not enable_ls_acq: raise Exception('Disabling LF or LS acquisition is not currently supported') @@ -686,20 +691,8 @@ def setup_autoexposure(self): ) def setup_autotracker(self): - if self._demo_run: - # TODO: implement autotracker in demo mode - logger.debug('Autotracker is not supported in demo mode') - return - - if self.lf_acq.microscope_settings.use_autotracker: - logger.debug('Setting up autotracker') - microscope_operations.setup_autotracker( - self.lf_acq.mmc, - self.lf_acq.microscope_settings.autotracker_channel, - self.lf_acq.microscope_settings.autotracker_threshold, - ) - else: - logger.debug('Autotracker is not enabled in the microscope settings') + logger.info('Setting up autotracker') + # TODO: probably setup the GPU/CPU settings here def go_to_position(self, position_index: int): # Move slowly for short distances such that autofocus can stay engaged. @@ -1007,7 +1000,7 @@ def setup(self): logger.debug('Setting up autoexposure') self.setup_autoexposure() - logger.debug('Setting up auotracker') + logger.debug('Setting up autotracker') self.setup_autotracker() def acquire(self): @@ -1029,7 +1022,20 @@ def acquire(self): start_daq_counters, [self._lf_z_ctr_task, self._lf_channel_ctr_task] ) lf_post_hardware_hook_fn = log_acquisition_start - lf_image_saved_fn = check_lf_acq_finished + + # TODO: implement logic for the autotracker_img_saved_hook_fn + if self.lf_acq.microscope_settings.autotracker_config is not None: + lf_image_saved_fn = partial( + autotracker_hook_fn, + 'lf', + self.lf_acq.autotracker_settings, + self.lf_acq.microscope_settings.autotracker_config, + self.lf_acq.slice_settings, + self._acq_dir, + ) + else: + logger.info('No autotracker config found. Using default image saved hook') + lf_image_saved_fn = check_lf_acq_finished # define LF acquisition self._lf_acq_obj = Acquisition( @@ -1059,7 +1065,19 @@ def acquire(self): self.ls_acq.channel_settings.channels, ) ls_post_camera_hook_fn = partial(start_daq_counters, [self._ls_z_ctr_task]) - ls_image_saved_fn = check_ls_acq_finished + + # TODO: implement logic for the autotracker_img_saved_hook_fn + if self.ls_acq.microscope_settings.autotracker_config is not None: + ls_image_saved_fn = partial( + autotracker_hook_fn, + 'ls', + self.ls_acq.autotracker_settings, + self.ls_acq.slice_settings, + self._acq_dir, + ) + else: + logger.info('No autotracker config found. Using default image saved hook') + ls_image_saved_fn = check_ls_acq_finished # define LS acquisition self._ls_acq_obj = Acquisition( @@ -1134,9 +1152,6 @@ def acquire(self): well_id=well_id, method=self.ls_acq.autoexposure_settings.autoexposure_method, ) - # TODO: add logic to handle skipping timepoints - if t_idx < 2: - self.run_autotracker(acq=self.lf_acq, well_id=well_id) # Acq rate needs to be updated even if autoexposure was not rerun in this well # Only do that if we are using autoexposure? self.update_ls_acquisition_rates( diff --git a/mantis/acquisition/autotracker.py b/mantis/acquisition/autotracker.py index d1cf9866..eda215b2 100644 --- a/mantis/acquisition/autotracker.py +++ b/mantis/acquisition/autotracker.py @@ -2,12 +2,10 @@ from pathlib import Path from typing import Callable, Optional, Tuple, cast -import matplotlib.pyplot as plt import numpy as np import pandas as pd import skimage -from iohub import open_ome_zarr from numpy.typing import ArrayLike from scipy.fftpack import next_fast_len from skimage.exposure import rescale_intensity @@ -16,6 +14,7 @@ from skimage.measure import label, regionprops from mantis import logger +from mantis.acquisition.hook_functions import globals # FIXME fix the dependencies so that we can install and import dexpv2 # from dexpv2.crosscorr import phase_cross_corr @@ -285,9 +284,9 @@ class Autotracker(object): def __init__( self, tracking_method: str, - scale: ArrayLike[float, float, float], - zyx_dampening: ArrayLike[float, float, float] = None, - output_shifts_path: Path = './shifts.csv', + shift_limit: Tuple[float, float, float], + scale: ArrayLike, + zyx_dampening_factor: ArrayLike = None, ): """ Autotracker object @@ -302,11 +301,9 @@ def __init__( Dampening factor for xy shifts """ self.tracking_method = tracking_method - self.zyx_dampening = zyx_dampening + self.zyx_dampening = zyx_dampening_factor self.scale = scale self.shifts = None - self.output_shifts_path = output_shifts_path - # TODO: hook to the config logs def estimate_shift(self, ref_img: ArrayLike, mov_img: ArrayLike, **kwargs) -> np.ndarray: """ @@ -345,10 +342,11 @@ def estimate_shift(self, ref_img: ArrayLike, mov_img: ArrayLike, **kwargs) -> np # Function to log the shifts to a csv file def save_shifts_to_file( + self, output_file: str, - shifts: Tuple[int, int, int], position_id: int, timepoint_id: int, + shifts: Tuple[int, int, int] = None, overwrite: bool = False, ) -> None: """ @@ -369,12 +367,14 @@ def save_shifts_to_file( """ # Convert output_file to a Path object output_path = Path(output_file) + if shifts is None: + shifts = self.shifts data = { "PositionID": [position_id], "TimepointID": [timepoint_id], - "ShiftZ": [shifts[0]], - "ShiftY": [shifts[1]], - "ShiftX": [shifts[2]], + "ShiftZ": [shifts[-3]], + "ShiftY": [shifts[-2]], + "ShiftX": [shifts[-1]], } df = pd.DataFrame(data) @@ -410,119 +410,109 @@ def limit_shifts_zyx( return tuple(shifts) -# %% -def main(): - """ - Toy dataset - translations = [ - (0, 0, 0), # Shift for timepoint 0 - (5, -80, 80), # Shift for timepoint 1 - (9, -50, -50), # Shift for timepoint 2 - (-5, 30, -60), # Shift for timepoint 3 - (0, 30, -80), # Shift for timepoint 4 - ] - """ - # %% - input_data_path = ( - '/home/eduardo.hirata/repos/mantis/mantis/tests/x-ed/toy_translate.zarr/0/0/0' - ) - dataset = open_ome_zarr(input_data_path) - T, C, Z, Y, X = dataset.data.shape - # print(channel_names := dataset.channel_names) - # tracking_method = 'multiotsu' - # xy_dapening = (10, 10) - - c_idx = 0 - data_t0 = dataset.data[0, c_idx] - data_t1 = dataset.data[1, c_idx] - data_t2 = dataset.data[2, c_idx] - data_t3 = dataset.data[3, c_idx] - - # subplot - fig, ax = plt.subplots(1, 3) - ax[0].imshow(data_t0[10]) - ax[1].imshow(data_t1[10]) - ax[2].imshow(data_t2[10]) - plt.show() - # %% - # Testing Multiotsu - shift_0 = multiotsu_centroid(data_t0, data_t0) - shift_1 = multiotsu_centroid(data_t1, data_t0) - shift_2 = multiotsu_centroid(data_t2, data_t0) - shift_3 = multiotsu_centroid(data_t3, data_t0) - - shifts = [shift_0, shift_1, shift_2, shift_3] - - translations = [ - (0, 0, 0), # Shift for timepoint 0 - (5, -80, 80), # Shift for timepoint 1 - (9, -50, -50), # Shift for timepoint 2 - (-5, 30, -60), # Shift for timepoint 3 - ] # Compare shifts with expected translations - - tolerance = (5, 8, 8) # Define your tolerance level - for i, (calculated, expected) in enumerate(zip(shifts, translations)): - is_similar = np.allclose(calculated, expected, atol=tolerance) - print( - f'Timepoint {i+1} shift: {calculated}, Expected: {expected}, Similar: {is_similar}' - ) - # %% - # Testing PCC - shift_0 = phase_cross_corr(data_t0, data_t0) - shift_1 = phase_cross_corr(data_t0, data_t1) - shift_2 = phase_cross_corr(data_t0, data_t2) - shift_3 = phase_cross_corr(data_t0, data_t3) - - shifts = [shift_0, shift_1, shift_2, shift_3] - - for i, (calculated, expected) in enumerate(zip(shifts, translations)): - is_similar = np.allclose(calculated, expected, atol=tolerance) - print( - f'Timepoint {i+1} shift: {calculated}, Expected: {expected}, Similar: {is_similar}' - ) +# TODO: logic for handling which t_idx to grab as reference. If the volume changes +# Drastically, we may need to grab the previous timepoint as reference + - # %% - # Testing template matching - - crop_z = slice(4, 8) - crop_y = slice(200, 300) - crop_x = slice(200, 300) - template = data_t0[crop_z, crop_y, crop_x] - - result = match_template(data_t1, template) - zyx = np.unravel_index(np.argmax(result), result.shape) - - print(zyx) - fig, ax = plt.subplots(1, 2) - ax[0].imshow(template[0]) - ax[0].set_title('template') - ax[1].imshow(data_t1[zyx[0]]) - rect = plt.Rectangle( - (zyx[2], zyx[1]), - template.shape[2], - template.shape[1], - edgecolor='red', - facecolor='none', +def get_volume(dataset, axes): + p_idx, t_idx, autotrack_channel, z_range = axes + images = [] + logger.debug( + f"Getting Zstack for p:{p_idx},t:{t_idx},c:{autotrack_channel},z_range:{z_range}" ) - ax[1].add_patch(rect) - ax[1].set_title('template matching result') - - # %% - # Calculate the shift, apply and check the result - shift_0 = template_matching(data_t0, data_t0, (crop_z, crop_y, crop_x)) - shift_1 = template_matching(data_t0, data_t1, (crop_z, crop_y, crop_x)) - shift_2 = template_matching(data_t0, data_t2, (crop_z, crop_y, crop_x)) - shift_3 = template_matching(data_t0, data_t3, (crop_z, crop_y, crop_x)) - - shifts = [shift_0, shift_1, shift_2, shift_3] - - for i, (calculated, expected) in enumerate(zip(shifts, translations)): - is_similar = np.allclose(calculated, expected, atol=tolerance) - print( - f'Timepoint {i+1} shift: {calculated}, Expected: {expected}, Similar: {is_similar}' + for z_id in range(z_range): + images.append( + dataset.read_image( + **{'channel': autotrack_channel, 'z': z_id, 'time': t_idx, 'position': p_idx} + ) ) + return np.stack(images) + + +def autotracker_hook_fn( + arm, + autotracker_settings, + channel_config, + z_slice_settings, + output_shift_path, + axes, + dataset, +) -> None: + """ + Pycromanager hook function that is called when an image is saved. - -# %% -if __name__ == "__main__": - main() + Parameters + ---------- + axes : Position, Time, Channel, Z_slice + dataset: Dataset saved in disk + """ + # TODO: handle the lf acq or ls_a + if arm == 'lf': + if axes == globals.lf_last_img_idx: + globals.lf_acq_finished = True + elif arm == 'ls': + if axes == globals.ls_last_img_idx: + globals.ls_acq_finished = True + + # Get reference to the acquisition engine and it's settings + # TODO: This is a placeholder, the actual implementation will be different + z_range = z_slice_settings.z_range + num_slices = z_slice_settings.num_slices + scale = autotracker_settings.scale_yx + shift_limit = autotracker_settings.shift_limit + tracking_method = autotracker_settings.tracking_method + tracking_interval = autotracker_settings.tracking_interval + tracking_channel = channel_config.config_name + zyx_dampening_factor = autotracker_settings.zyx_dampening_factor + output_shift_path = Path(output_shift_path) + + # Get axes info + p_idx = axes['position'] + t_idx = axes['time'] + channel = axes['channel'] + z_idx = axes['z'] + + # Skip the 1st timepoint + if t_idx > 0: + if t_idx % tracking_interval != 0: + logger.debug('Skipping autotracking t %d', t_idx) + return + # Get the z_max + if channel == tracking_channel and z_idx == (num_slices - 1): + logger.debug("WELCOME TO THE FOCUS ZONE") + logger.debug('Curr axes :P:%s, T:%d, C:%s, Z:%d', p_idx, t_idx, channel, z_idx) + + # Logic to get the volumes + # TODO: This is a placeholder, the actual implementation will be different + z_volume = z_range + volume_t0_axes = (p_idx, t_idx, tracking_channel, z_volume) + volume_t1_axes = (p_idx, t_idx, tracking_channel, z_volume) + # Compute the shifts + logger.debug('Instantiating autotracker') + tracker = Autotracker( + tracking_method=tracking_method, + scale=scale, + shift_limit=shift_limit, + zyx_dampening_factor=zyx_dampening_factor, + ) + if globals.demo_run: + # Random shifting for demo purposes + shifts = np.random.randint(-50, 50, 3) + logger.info('Shifts (z,y,x): %f,%f,%f', shifts[0], shifts[1], shifts[2]) + else: + volume_t0 = get_volume(dataset, volume_t0_axes) + volume_t1 = get_volume(dataset, volume_t1_axes) + # Reference and moving volumes + shifts = tracker.estimate_shifts(volume_t0, volume_t1) + + # Save the shifts + # TODO: This is a placeholder, the actual implementation will be different + position_id = str(axes['position']) + '.csv' + shift_coord_output = output_shift_path / position_id + tracker.save_shifts_to_file( + shift_coord_output, position_id=p_idx, timepoint_id=t_idx, shifts=shifts + ) + + # Update the event coordinates + # TODO: This is a placeholder, the actual implementation will be different + # event_coords = {'Z': shifts[0], 'Y': shifts[1], 'X': shifts[2]} diff --git a/mantis/acquisition/hook_functions/autotracker_hook_function.py b/mantis/acquisition/hook_functions/autotracker_hook_function.py deleted file mode 100644 index 7862e55c..00000000 --- a/mantis/acquisition/hook_functions/autotracker_hook_function.py +++ /dev/null @@ -1,48 +0,0 @@ -from mantis.acquisition.autotracker import Autotracker - - -def autotracker_hook_fn(axes, dataset) -> None: - """ - Pycromanager hook function that is called when an image is saved. - - Parameters - ---------- - axes : Position, Time, Channel, Z_slice - dataset: Dataset saved in disk - """ - - # Get reference to the acquisition engine and it's settings - # TODO: This is a placeholder, the actual implementation will be different - acq = "reference to the acquisition engine" - shift_limit = acq.autofocus_settings.shift_limit - tracking_method = acq.autofocus_settings.tracking_method - output_shift_path = './output.csv' - - # Get axes info - p_idx = axes['position'] - t_idx = axes['time'] - channel = axes['channel'] - z_idx = axes['z'] - - # Logic to get the volumes - # TODO: This is a placeholder, the actual implementation will be different - volume_t0_axes = (p_idx, t_idx, channel, z_idx) - volume_t1_axes = (p_idx, t_idx, channel, z_idx) - - volume_t0 = dataset.read_image(**volume_t0_axes) - volume_t1 = dataset.read_image(**volume_t1_axes) - - # Compute the shifts - tracker = Autotracker( - autofocus_method=tracking_method, - shift_limit=shift_limit, - output_shifts_path=output_shift_path, - ) - # Reference and moving volumes - tracker.estimate_shift(volume_t0, volume_t1) - - # Save the shifts - # TODO: This is a placeholder, the actual implementation will be different - - # Update the event coordinates - # TODO: This is a placeholder, the actual implementation will be different diff --git a/mantis/acquisition/hook_functions/globals.py b/mantis/acquisition/hook_functions/globals.py index 7eb88aba..2af906d5 100644 --- a/mantis/acquisition/hook_functions/globals.py +++ b/mantis/acquisition/hook_functions/globals.py @@ -9,3 +9,5 @@ ls_slice_acquisition_rates = None ls_laser_powers = None new_well = False + +demo_run = False diff --git a/mantis/acquisition/microscope_operations.py b/mantis/acquisition/microscope_operations.py index 50370487..0b1cb653 100644 --- a/mantis/acquisition/microscope_operations.py +++ b/mantis/acquisition/microscope_operations.py @@ -677,5 +677,5 @@ def autotracker( autotracker_settings: AutotrackerSettings, **kwargs, ): - logging.debug('Autotracker is not implemented yet') + logging.debug('microscope operateions autotracker') pass diff --git a/mantis/acquisition/settings/demo_acquisition_settings.yaml b/mantis/acquisition/settings/demo_acquisition_settings.yaml index 417b3206..39a4718d 100644 --- a/mantis/acquisition/settings/demo_acquisition_settings.yaml +++ b/mantis/acquisition/settings/demo_acquisition_settings.yaml @@ -6,7 +6,7 @@ time_settings: time_interval_s: 5 lf_channel_settings: - default_exposure_times_ms: [10, 10] + default_exposure_times_ms: [5, 5] channel_group: 'Channel-Multiband' channels: ['DAPI', 'FITC'] use_sequencing: True @@ -30,6 +30,8 @@ lf_microscope_settings: channel_sequencing_settings: - ['LED', 'Sequence', 'On'] use_autofocus: False + autotracker_config: ['Channel-Multiband', 'DAPI'] + ls_channel_settings: default_exposure_times_ms: [20, 30] @@ -37,7 +39,7 @@ ls_channel_settings: channel_group: 'Channel' channels: ['Rhodamine', 'Cy5'] use_sequencing: False - use_autoexposure: [True, False] + use_autoexposure: [False, False] ls_slice_settings: z_stage_name: 'Z' @@ -63,6 +65,7 @@ ls_autoexposure_settings: autotracker_settings: tracking_method: 'phase_cross_correlation' - tracking_arm: 'lf' - channel_to_track: 'Channel-Multiband' + tracking_interval: 1 + shift_limit: [30,200,200] device: 'cpu' + scale_yx: 0.075 diff --git a/mantis/cli/run_acquisition.py b/mantis/cli/run_acquisition.py index 7a718714..194b86df 100644 --- a/mantis/cli/run_acquisition.py +++ b/mantis/cli/run_acquisition.py @@ -78,13 +78,19 @@ def run_acquisition( ls_autoexposure_settings = AutoexposureSettings( **raw_settings.get('ls_autoexposure_settings') ) - # TODO: decide ls or lf autoexposure settings + autotracker_settings = AutotrackerSettings(**raw_settings.get('autotracker_settings')) + + # Handle logic if autotracker is active in both arms ls_autotracker_settings = None lf_autotracker_settings = None - autotracker_settings = AutotrackerSettings(**raw_settings.get('autotracker_settings')) - if autotracker_settings.tracking_arm == 'lf': + if ( + lf_microscope_settings.autotracker_config is not None + and ls_microscope_settings.autotracker_config is not None + ): + raise ValueError("Autotracker is active in both arms, please specify only one arm") + elif lf_microscope_settings.autotracker_config is not None: lf_autotracker_settings = autotracker_settings - elif autotracker_settings.tracking_arm == 'ls': + elif ls_microscope_settings.autotracker_config is not None: ls_autotracker_settings = autotracker_settings with MantisAcquisition( From 52574ed34ef96194333148ed46d17e637606fa3a Mon Sep 17 00:00:00 2001 From: edhirata Date: Thu, 15 Aug 2024 17:40:26 -0700 Subject: [PATCH 5/6] autotracker with demo mode working --- mantis/acquisition/AcquisitionSettings.py | 3 + mantis/acquisition/acq_engine.py | 29 +++- mantis/acquisition/autotracker.py | 185 ++++++++++++++-------- 3 files changed, 149 insertions(+), 68 deletions(-) diff --git a/mantis/acquisition/AcquisitionSettings.py b/mantis/acquisition/AcquisitionSettings.py index 3e73c4c5..eb485c4a 100644 --- a/mantis/acquisition/AcquisitionSettings.py +++ b/mantis/acquisition/AcquisitionSettings.py @@ -1,3 +1,4 @@ +import copy import warnings from dataclasses import field @@ -36,10 +37,12 @@ class PositionSettings: position_labels: List[str] = field(default_factory=list) num_positions: int = field(init=False, default=0) well_ids: List[str] = field(init=False, default_factory=list) + xyz_positions_shift: list = field(init=False, default_factory=list) def __post_init__(self): assert len(self.xyz_positions) == len(self.position_labels) self.num_positions = len(self.xyz_positions) + self.xyz_positions_shift = copy.deepcopy(self.xyz_positions) try: # Look for "'A1-Site_0', 'H12-Site_1', ... " format diff --git a/mantis/acquisition/acq_engine.py b/mantis/acquisition/acq_engine.py index fc4a0b4d..a767230f 100644 --- a/mantis/acquisition/acq_engine.py +++ b/mantis/acquisition/acq_engine.py @@ -472,6 +472,8 @@ def update_position_settings(self): xyz_positions=xyz_positions, position_labels=position_labels, ) + self.position_settings.xyz_positions_shift = deepcopy(xyz_positions) + else: logger.debug('Position list is already populated and will not be updated') @@ -690,9 +692,11 @@ def setup_autoexposure(self): ) ) - def setup_autotracker(self): - logger.info('Setting up autotracker') - # TODO: probably setup the GPU/CPU settings here + def update_position_autotracker(self): + # Update the position list from the backup + self.position_settings.xyz_positions = deepcopy( + self.position_settings.xyz_positions_shift + ) def go_to_position(self, position_index: int): # Move slowly for short distances such that autofocus can stay engaged. @@ -1000,9 +1004,6 @@ def setup(self): logger.debug('Setting up autoexposure') self.setup_autoexposure() - logger.debug('Setting up autotracker') - self.setup_autotracker() - def acquire(self): """ Simultaneously acquire label-free and light-sheet data over multiple @@ -1029,6 +1030,7 @@ def acquire(self): autotracker_hook_fn, 'lf', self.lf_acq.autotracker_settings, + self._position_settings, self.lf_acq.microscope_settings.autotracker_config, self.lf_acq.slice_settings, self._acq_dir, @@ -1113,8 +1115,23 @@ def acquire(self): # move to the given position if p_label != previous_position_label: + # Check if autotracker is on either arm + if ( + self.ls_acq.microscope_settings.autotracker_config is not None + or self.lf_acq.microscope_settings.autotracker_config is not None + ): + # TODO: Should we get the corods from the csv file or the modified xyz_positions_shifts + logger.debug('Updating the positions for autotracker') + logger.debug( + 'Previous position: %f,%f ', + *self.position_settings.xyz_positions[p_idx][0:2], + ) + self.update_position_autotracker() self.go_to_position(p_idx) + # TODO get the delta shifts + # read the files here and move separately + # autofocus if self.lf_acq.microscope_settings.use_autofocus: autofocus_success = microscope_operations.autofocus( diff --git a/mantis/acquisition/autotracker.py b/mantis/acquisition/autotracker.py index eda215b2..26739a1f 100644 --- a/mantis/acquisition/autotracker.py +++ b/mantis/acquisition/autotracker.py @@ -1,5 +1,6 @@ # %% from pathlib import Path +from time import sleep from typing import Callable, Optional, Tuple, cast import numpy as np @@ -55,7 +56,7 @@ def multiotsu_centroid( mov_img: ArrayLike, ) -> list: """ - Computes the translation shifts using a multiotsu threshold approach by finding the centroid of the regions + Computes the translation shifts_zyx using a multiotsu threshold approach by finding the centroid of the regions Parameters ---------- @@ -66,8 +67,8 @@ def multiotsu_centroid( Returns ------- - shifts : list - list of shifts in z, y, x order + shifts_zyx : list + list of shifts_zyx in z, y, x order """ # Process moving image mov_img = rescale_intensity(mov_img, in_range='image', out_range=(0, 1.0)) @@ -86,8 +87,8 @@ def multiotsu_centroid( moving_center = calc_weighted_center(mov_img) target_center = calc_weighted_center(ref_img) - # Find the shifts - shifts = moving_center - target_center + # Find the shifts_zyx + shifts_zyx = moving_center - target_center logger.debug( 'moving_center (z,y,x): %f,%f,%f', @@ -101,9 +102,9 @@ def multiotsu_centroid( target_center[1], target_center[2], ) - logger.debug('shifts (z,y,x): %f,%f,%f', shifts[0], shifts[1], shifts[2]) + logger.debug('shifts_zyx (z,y,x): %f,%f,%f', shifts_zyx[0], shifts_zyx[1], shifts_zyx[2]) - return shifts + return shifts_zyx def template_matching(ref_img, moving_img, template_slicing_zyx): @@ -296,14 +297,14 @@ def __init__( tracking_method : str Method to use for autofocus. Options are 'phase_cross_correlation', 'template_matching', 'multi_otsu' scale : ArrayLike[float, float, float] - Scale factor to convert shifts from px to um + Scale factor to convert shifts_zyx from px to um xy_dampening : tuple[int] - Dampening factor for xy shifts + Dampening factor for xy shifts_zyx """ self.tracking_method = tracking_method self.zyx_dampening = zyx_dampening_factor self.scale = scale - self.shifts = None + self.shifts_zyx = None def estimate_shift(self, ref_img: ArrayLike, mov_img: ArrayLike, **kwargs) -> np.ndarray: """ @@ -329,35 +330,36 @@ def estimate_shift(self, ref_img: ArrayLike, mov_img: ArrayLike, **kwargs) -> np if not autofocus_method_func: raise ValueError(f'Unknown autofocus method: {self.tracking_method}') - shifts = autofocus_method_func(ref_img=ref_img, mov_img=mov_img, **kwargs) + shifts_zyx = autofocus_method_func(ref_img=ref_img, mov_img=mov_img, **kwargs) - # Shifts in px to shifts in um - self.shifts = np.array(shifts) * self.scale + # shifts_zyx in px to shifts_zyx in um + self.shifts_zyx = np.array(shifts_zyx) * self.scale if self.zyx_dampening is not None: - self.shifts = self.shifts * self.zyx_dampening - logger.info(f'Shifts (z,y,x): {self.shifts}') + self.shifts_zyx = self.shifts_zyx * self.zyx_dampening + logger.info(f'shifts_zyx (z,y,x): {self.shifts_zyx}') - return self.shifts + return self.shifts_zyx - # Function to log the shifts to a csv file + # Function to log the shifts_zyx to a csv file def save_shifts_to_file( self, output_file: str, position_id: int, timepoint_id: int, - shifts: Tuple[int, int, int] = None, + shifts_zyx: Tuple[int, int, int] = None, + stage_coords: Tuple[int, int, int] = None, overwrite: bool = False, ) -> None: """ - Saves the computed shifts to a CSV file. + Saves the computed shifts_zyx to a CSV file. Parameters ---------- output_file : str Path to the output CSV file. - shifts : Tuple[int, int, int] - The computed shifts (Z, Y, X). + shifts_zyx : Tuple[int, int, int] + The computed shifts_zyx (Z, Y, X). position_id : int Identifier for the position. timepoint_id : int @@ -367,14 +369,19 @@ def save_shifts_to_file( """ # Convert output_file to a Path object output_path = Path(output_file) - if shifts is None: - shifts = self.shifts + if shifts_zyx is None: + shifts_zyx = self.shifts_zyx + if stage_coords is None: + stage_coords = (0, 0, 0) data = { "PositionID": [position_id], "TimepointID": [timepoint_id], - "ShiftZ": [shifts[-3]], - "ShiftY": [shifts[-2]], - "ShiftX": [shifts[-1]], + "ShiftZ": [shifts_zyx[-3]], + "ShiftY": [shifts_zyx[-2]], + "ShiftX": [shifts_zyx[-1]], + "StageZ": [stage_coords[-3]], + "StageY": [stage_coords[-2]], + "StageX": [stage_coords[-1]], } df = pd.DataFrame(data) @@ -387,27 +394,27 @@ def save_shifts_to_file( df.to_csv(output_path, mode='a', header=False, index=False) def limit_shifts_zyx( - self, shifts: Tuple[int, int, int], limits: Tuple[int, int, int] = (5, 5, 5) + self, shifts_zyx: Tuple[int, int, int], limits: Tuple[int, int, int] = (5, 5, 5) ) -> Tuple[int, int, int]: """ - Limits the shifts to the specified limits. + Limits the shifts_zyx to the specified limits. Parameters ---------- - shifts : Tuple[int, int, int] - The computed shifts (Z, Y, X). + shifts_zyx : Tuple[int, int, int] + The computed shifts_zyx (Z, Y, X). limits : Tuple[int, int, int] - The limits for the shifts (Z, Y, X). + The limits for the shifts_zyx (Z, Y, X). Returns ------- Tuple[int, int, int] - The limited shifts. + The limited shifts_zyx. """ - shifts = np.array(shifts) + shifts_zyx = np.array(shifts_zyx) limits = np.array(limits) - shifts = np.where(np.abs(shifts) > limits, 0, shifts) - return tuple(shifts) + shifts_zyx = np.where(np.abs(shifts_zyx) > limits, 0, shifts_zyx) + return tuple(shifts_zyx) # TODO: logic for handling which t_idx to grab as reference. If the volume changes @@ -432,6 +439,7 @@ def get_volume(dataset, axes): def autotracker_hook_fn( arm, autotracker_settings, + position_settings, channel_config, z_slice_settings, output_shift_path, @@ -446,6 +454,8 @@ def autotracker_hook_fn( axes : Position, Time, Channel, Z_slice dataset: Dataset saved in disk """ + # logger.info('Autotracker hook function called for axes %s', axes) + # TODO: handle the lf acq or ls_a if arm == 'lf': if axes == globals.lf_last_img_idx: @@ -467,52 +477,103 @@ def autotracker_hook_fn( output_shift_path = Path(output_shift_path) # Get axes info - p_idx = axes['position'] + p_label = axes['position'] + p_idx = position_settings.position_labels.index(p_label) t_idx = axes['time'] channel = axes['channel'] z_idx = axes['z'] - # Skip the 1st timepoint - if t_idx > 0: - if t_idx % tracking_interval != 0: - logger.debug('Skipping autotracking t %d', t_idx) - return - # Get the z_max - if channel == tracking_channel and z_idx == (num_slices - 1): + tracker = Autotracker( + tracking_method=tracking_method, + scale=scale, + shift_limit=shift_limit, + zyx_dampening_factor=zyx_dampening_factor, + ) + # Get the z_max + if channel == tracking_channel and z_idx == (num_slices - 1): + # Skip the 1st timepoint + if t_idx > 1: + if t_idx % tracking_interval != 0: + logger.debug('Skipping autotracking t %d', t_idx) + return logger.debug("WELCOME TO THE FOCUS ZONE") - logger.debug('Curr axes :P:%s, T:%d, C:%s, Z:%d', p_idx, t_idx, channel, z_idx) + # logger.debug('Curr axes :P:%s, T:%d, C:%s, Z:%d', p_idx, t_idx, channel, z_idx) # Logic to get the volumes - # TODO: This is a placeholder, the actual implementation will be different z_volume = z_range volume_t0_axes = (p_idx, t_idx, tracking_channel, z_volume) volume_t1_axes = (p_idx, t_idx, tracking_channel, z_volume) - # Compute the shifts + # Compute the shifts_zyx logger.debug('Instantiating autotracker') - tracker = Autotracker( - tracking_method=tracking_method, - scale=scale, - shift_limit=shift_limit, - zyx_dampening_factor=zyx_dampening_factor, - ) if globals.demo_run: # Random shifting for demo purposes - shifts = np.random.randint(-50, 50, 3) - logger.info('Shifts (z,y,x): %f,%f,%f', shifts[0], shifts[1], shifts[2]) + shifts_zyx = np.random.randint(-50, 50, 3) + sleep(3) + logger.info( + 'shifts_zyx (z,y,x): %f,%f,%f', shifts_zyx[0], shifts_zyx[1], shifts_zyx[2] + ) else: volume_t0 = get_volume(dataset, volume_t0_axes) volume_t1 = get_volume(dataset, volume_t1_axes) # Reference and moving volumes - shifts = tracker.estimate_shifts(volume_t0, volume_t1) + shifts_zyx = tracker.estimate_shifts(volume_t0, volume_t1) - # Save the shifts - # TODO: This is a placeholder, the actual implementation will be different position_id = str(axes['position']) + '.csv' shift_coord_output = output_shift_path / position_id - tracker.save_shifts_to_file( - shift_coord_output, position_id=p_idx, timepoint_id=t_idx, shifts=shifts - ) + # Read the previous shifts_zyx and coords + prev_shifts = pd.read_csv(shift_coord_output) + prev_shifts = prev_shifts.iloc[-1] + + # Read the previous shifts_zyx + prev_x = position_settings.xyz_positions_shift[p_idx][0] + prev_y = position_settings.xyz_positions_shift[p_idx][1] + # Update Z shifts_zyx if available + if position_settings.xyz_positions_shift[p_idx][2] is not None: + prev_z = position_settings.xyz_positions_shift[p_idx][2] + logger.info('Previous shifts_zyx: %f,%f,%f', prev_z, prev_y, prev_x) + else: + prev_z = None + logger.info('Previous shifts_yx:,%f,%f', prev_y, prev_x) # Update the event coordinates - # TODO: This is a placeholder, the actual implementation will be different - # event_coords = {'Z': shifts[0], 'Y': shifts[1], 'X': shifts[2]} + position_settings.xyz_positions_shift[p_idx][0] = prev_x + shifts_zyx[-1] + position_settings.xyz_positions_shift[p_idx][1] = prev_y + shifts_zyx[-2] + # Update Z shifts_zyx if available + if position_settings.xyz_positions_shift[p_idx][2] is not None: + position_settings.xyz_positions_shift[p_idx][2] = prev_z + shifts_zyx[-3] + logger.info( + 'New positions: %f,%f,%f', *position_settings.xyz_positions_shift[p_idx] + ) + else: + logger.info( + 'New positions: %f,%f', *position_settings.xyz_positions_shift[p_idx][0:2] + ) + # Save the shifts_zyx + tracker.save_shifts_to_file( + shift_coord_output, + position_id=p_label, + timepoint_id=t_idx, + shifts_zyx=shifts_zyx, + stage_coords=( + position_settings.xyz_positions_shift[p_idx][2], + position_settings.xyz_positions_shift[p_idx][1], + position_settings.xyz_positions_shift[p_idx][0], + ), + ) + else: + # Save the positions at t=0 + position_id = str(axes['position']) + '.csv' + shift_coord_output = output_shift_path / position_id + prev_y = position_settings.xyz_positions_shift[p_idx][1] + prev_x = position_settings.xyz_positions_shift[p_idx][0] + if position_settings.xyz_positions_shift[p_idx][2] is not None: + prev_z = position_settings.xyz_positions_shift[p_idx][2] + else: + prev_z = None + tracker.save_shifts_to_file( + shift_coord_output, + position_id=p_label, + timepoint_id=t_idx, + shifts_zyx=(0, 0, 0), + stage_coords=(prev_z, prev_y, prev_x), + ) From 161643f985f0b4572f330d5db282f8795f8e17ac Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Tue, 24 Sep 2024 09:24:24 -0700 Subject: [PATCH 6/6] adding autotracker parameters to the mantis config --- .../settings/example_acquisition_settings.yaml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/mantis/acquisition/settings/example_acquisition_settings.yaml b/mantis/acquisition/settings/example_acquisition_settings.yaml index c62588f9..a6777b4a 100644 --- a/mantis/acquisition/settings/example_acquisition_settings.yaml +++ b/mantis/acquisition/settings/example_acquisition_settings.yaml @@ -118,6 +118,8 @@ lf_microscope_settings: use_autofocus: True autofocus_stage: 'ZDrive' autofocus_method: 'PFS' + # Autotracker parameters using the label-free channel + autotracker_config: ['Channel - LF', 'BF'] # Microscope settings which will be applied when the light-sheet acquisition is # initialized @@ -142,3 +144,10 @@ ls_microscope_settings: ls_autoexposure_settings: autoexposure_method: 'manual' rerun_each_timepoint: True + +autotracker_settings: + tracking_method: 'phase_cross_correlation' + tracking_interval: 1 + shift_limit: [30,200,200] + device: 'cpu' + scale_yx: 0.075