Skip to content

Commit

Permalink
add config-based stitching
Browse files Browse the repository at this point in the history
  • Loading branch information
ieivanov committed Apr 2, 2024
1 parent 16edcb5 commit 97c322b
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 129 deletions.
70 changes: 42 additions & 28 deletions examples/slurmkit_example/slurm_stitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,77 +11,90 @@
from slurmkit import SlurmParams, slurm_function, submit_function

from mantis.analysis.stitch import (
get_stitch_output_shape, calculate_shift, shift_image, get_grid_rows_cols, stitch_shifted_store
get_stitch_output_shape, calculate_shift, get_grid_rows_cols
)

from mantis.cli.stitch import (
_preprocess_and_shift, _stitch_shifted_store
)

from mantis.cli.utils import (
create_empty_hcs_zarr,
process_single_position_v2,
)

col_translation = (967.9, -7.45)
row_translation = (7.78, 969)
from mantis.analysis.AnalysisSettings import StitchSettings
from mantis.cli.utils import yaml_to_model


verbose = True

# io parameters
# dataset = 'B3_600k_20x20_timelapse_1'
# channels = ['Nucleus_prediction']
# input_paths = f"/hpc/projects/intracellular_dashboard/ops/2024_03_05_registration_test/live/2-virtual-stain/fcmae-2d/mean_projection/{dataset}.zarr/*/*/*"
# temp_path = Path(f"/hpc/projects/intracellular_dashboard/ops/2024_03_05_registration_test/live/3-stitch/TEMP_{dataset}.zarr")
# dataset = 'A3_600k_38x38_1'

# input_paths = f"/hpc/projects/intracellular_dashboard/ops/2024_03_05_registration_test/live/2-virtual-stain/fcmae-2d/{dataset}.zarr/*/*/*"
# temp_path = Path(f"/hpc/scratch/group.comp.micro/TEMP_{dataset}.zarr")
# output_path = Path(f"/hpc/projects/intracellular_dashboard/ops/2024_03_05_registration_test/live/3-stitch/{dataset}.zarr")
# config_filepath = Path("/hpc/projects/intracellular_dashboard/ops/2024_03_05_registration_test/live/3-stitch/stitch_settings.yml")

dataset = 'grid_test_3'
channels = ['Default']
input_paths = f"/hpc/projects/intracellular_dashboard/ops/2024_03_05_registration_test/fixed/0-convert/{dataset}.zarr/*/*/*"
temp_path = Path(f"/hpc/projects/intracellular_dashboard/ops/2024_03_05_registration_test/fixed/test-register/TEMP_{dataset}.zarr")
output_path = Path(f"/hpc/projects/intracellular_dashboard/ops/2024_03_05_registration_test/fixed/test-register/{dataset}.zarr")
dataset = 'kidney_grid_test_1'

input_paths = f"/hpc/projects/intracellular_dashboard/ops/2024_03_05_registration_test/kidney_tissue/0-convert/dragonfly/{dataset}.zarr/*/*/*"
temp_path = Path(f"/hpc/scratch/group.comp.micro/TEMP_{dataset}.zarr")
output_path = Path(f"/hpc/projects/intracellular_dashboard/ops/2024_03_05_registration_test/kidney_tissue/1-stitch/dragonfly/{dataset}_2.zarr")
config_filepath = Path("/hpc/projects/intracellular_dashboard/ops/2024_03_05_registration_test/kidney_tissue/1-stitch/dragonfly/stitch_settings.yml")

# sbatch and resource parameters
cpus_per_task = 1
mem_per_cpu = "16G"
time = 10 # minutes
partition = 'cpu'


# NOTE: parameters from here and below should not have to be changed
input_paths = [Path(path) for path in natsorted(glob.glob(input_paths))]
slurm_out_path = temp_path.parent / "slurm_output" / "stitch-%j.out"
slurm_out_path = output_path.parent / "slurm_output" / "stitch-%j.out"

settings = yaml_to_model(config_filepath, StitchSettings)

with open_ome_zarr(str(input_paths[0]), mode="r") as input_dataset:
dataset_channel_names = input_dataset.channel_names
input_dataset_channels = input_dataset.channel_names
T, C, Z, Y, X = input_dataset.data.shape
# scale = tuple(input_dataset.scale)

if settings.channels is None:
settings.channels = input_dataset_channels

grid_rows, grid_cols = get_grid_rows_cols(Path(*input_paths[0].parts[:-3]))
n_rows = len(grid_rows)
n_cols = len(grid_cols)

output_shape, global_translation = get_stitch_output_shape(
n_rows, n_cols, Y, X, col_translation, row_translation
n_rows, n_cols, Y, X, settings.column_translation, settings.row_translation
)

# Create the output zarr mirroring input positions
# Takes a while
# Takes a while, 10 minutes ?!
click.echo('Creating output zarr store')
create_empty_hcs_zarr(
store_path=temp_path,
position_keys=[p.parts[-3:] for p in input_paths],
shape=(T, len(channels), Z) + output_shape,
shape=(T, len(settings.channels), Z) + output_shape,
chunks=(1, 1, 1, 4096, 4096),
# scale=scale,
channel_names=channels,
channel_names=settings.channels,
dtype=np.float32,
)

# debug
# process_single_position_v2(
# shift_image,
# _preprocess_and_shift,
# time_indices='all',
# input_data_path=input_paths[0],
# output_path=output_path,
# input_channel_idx=[dataset_channel_names.index(ch) for ch in channels],
# output_channel_idx=list(range(len(channels))),
# output_path=temp_path,
# input_channel_idx=[input_dataset_channels.index(ch) for ch in settings.channels],
# output_channel_idx=list(range(len(settings.channels))),
# num_processes=cpus_per_task,
# settings=settings,
# output_shape=output_shape,
# shift=(0, 0),
# verbose=True,
Expand All @@ -105,15 +118,16 @@
for in_path in input_paths:
col_idx, row_idx = (int(in_path.name[:3]), int(in_path.name[3:]))
shift = calculate_shift(
col_idx, row_idx, col_translation, row_translation, global_translation
col_idx, row_idx, settings.column_translation, settings.row_translation, global_translation
)

func = slurm_func(
shift_image,
_preprocess_and_shift,
time_indices='all',
input_channel_idx=[dataset_channel_names.index(ch) for ch in channels],
output_channel_idx=list(range(len(channels))),
input_channel_idx=[input_dataset_channels.index(ch) for ch in settings.channels],
output_channel_idx=list(range(len(settings.channels))),
num_processes=cpus_per_task,
settings=settings,
output_shape=output_shape,
shift=shift,
verbose=True,
Expand All @@ -129,7 +143,7 @@
)

submit_function(
slurm_function(stitch_shifted_store)(temp_path, output_path, verbose),
slurm_function(_stitch_shifted_store)(temp_path, output_path, settings, verbose),
slurm_params=SlurmParams(
partition=partition,
cpus_per_task=8,
Expand Down
13 changes: 13 additions & 0 deletions mantis/analysis/AnalysisSettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ class MyBaseModel(BaseModel, extra=Extra.forbid):
pass


class ProcessingSettings(MyBaseModel):
fliplr: Optional[bool] = False
flipud: Optional[bool] = False


class DeskewSettings(MyBaseModel):
pixel_size_um: PositiveFloat
ls_angle_deg: PositiveFloat
Expand Down Expand Up @@ -67,3 +72,11 @@ def check_affine_transform(cls, v):
raise ValueError("The array must contain valid numerical values.")

return v


class StitchSettings(MyBaseModel):
column_translation: tuple[float, float]
row_translation: tuple[float, float]
channels: Optional[list[str]] = None
preprocessing: Optional[ProcessingSettings] = None
postprocessing: Optional[ProcessingSettings] = None
37 changes: 29 additions & 8 deletions mantis/analysis/scripts/compute_image_translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,16 @@

# %%
data_dir = Path(
'/hpc/projects/intracellular_dashboard/ops/2024_03_05_registration_test/fixed/0-convert/'
'/hpc/projects/intracellular_dashboard/ops/2024_03_05_registration_test/kidney_tissue/0-convert/dragonfly/'
)
dataset = 'grid_test_3.zarr'
dataset = 'kidney_grid_test_1.zarr'
data_path = data_dir / dataset

fliplr = True
flipud = False

rows_limit = 3
cols_limit = 3
percent_overlap = 0.05

dataset = open_ome_zarr(data_path)
Expand All @@ -34,34 +39,50 @@
grid_rows = sorted(grid_rows)
grid_cols = sorted(grid_cols)

if rows_limit:
grid_rows = grid_rows[:rows_limit]
if cols_limit:
grid_cols = grid_cols[:cols_limit]

y_roi = int(sizeY * (percent_overlap + 0.05))


def fetch_image(dataset, well_name, col_name, row_name):
img = dataset[Path(well_name, col_name + row_name)].data[0, 0, 0]
if fliplr:
img = np.fliplr(img)
if flipud:
img = np.flipud(img)
return img


# %%
row_shifts = []
for i in range(len(grid_rows) - 1):
for col_idx, col_name in enumerate(grid_cols):
img0 = dataset[Path(well_name, col_name + grid_rows[i])].data[0, 0, 0]
img1 = dataset[Path(well_name, col_name + grid_rows[i + 1])].data[0, 0, 0]
img0 = fetch_image(dataset, well_name, col_name, grid_rows[i])
img1 = fetch_image(dataset, well_name, col_name, grid_rows[i + 1])

shift, _, _ = phase_cross_correlation(
img0[-y_roi:, :], img1[:y_roi, :], upsample_factor=10
)
shift[0] += sizeX - y_roi
row_shifts.append(shift)
row_translation = np.asarray(row_shifts).mean(axis=0)[::-1]
row_translation = np.median(row_shifts, axis=0)[::-1]

col_shifts = []
for j in range(len(grid_cols) - 1):
for row_idx, row_name in enumerate(grid_rows):
img0 = dataset[Path(well_name, grid_cols[j] + row_name)].data[0, 0, 0]
img1 = dataset[Path(well_name, grid_cols[j + 1] + row_name)].data[0, 0, 0]
img0 = fetch_image(dataset, well_name, grid_cols[j], row_name)
img1 = fetch_image(dataset, well_name, grid_cols[j + 1], row_name)

shift, _, _ = phase_cross_correlation(
img0[:, -y_roi:], img1[:, :y_roi], upsample_factor=10
)
shift[1] += sizeY - y_roi
col_shifts.append(shift)

col_translation = np.asarray(col_shifts).mean(axis=0)[::-1]
col_translation = np.median(col_shifts, axis=0)[::-1]

# %%
print(f'Column translation: {col_translation}, row translation: {row_translation}')
9 changes: 9 additions & 0 deletions mantis/analysis/settings/example_stitch_settings.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
column_translation: [967.9, -7.45] # translation distance in (x, y) in pixels when moving across columns
row_translation: [7.78, 969] # translation distance in (x, y) in pixels when moving across rows
channels: [Default] # may be null in which case all channels will be stitched
preprocessing:
- fliplr: true
- flipud: true
postprocessing:
- fliplr: false
- flipud: false
34 changes: 1 addition & 33 deletions mantis/analysis/stitch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pathlib import Path

import click
import numpy as np
import scipy.ndimage as ndi

Expand Down Expand Up @@ -59,11 +58,10 @@ def shift_image(
verbose: bool = False,
) -> np.ndarray:
ndims = image.ndim
sizeY, sizeX = image.shape[-2:]

if verbose:
print(f"Shifting image by {shift}")

sizeY, sizeX = image.shape[-2:]
output = np.zeros(output_shape, dtype=np.float32)
output[:sizeY, :sizeX] = np.squeeze(image)
output = ndi.shift(output, shift, order=0)
Expand Down Expand Up @@ -134,33 +132,3 @@ def stitch_images(
stitched_array[overlap] /= 2 # average blending in the overlapping region

return stitched_array


def stitch_shifted_store(input_data_path, output_data_path, verbose=True):
click.echo(f'Stitching zarr store: {input_data_path}')
with open_ome_zarr(input_data_path, mode="r") as input_dataset:
well_name, _ = next(input_dataset.wells())
_, sample_position = next(input_dataset.positions())
array_shape = sample_position.data.shape
channels = input_dataset.channel_names

stitched_array = np.zeros(array_shape, dtype=np.float32)
denominator = np.zeros(array_shape, dtype=np.uint8)

j = 0
for _, position in input_dataset.positions():
if verbose:
click.echo(f'Processing position {j}')
stitched_array += position.data
denominator += np.bool_(position.data)
j += 1

denominator[denominator == 0] = 1
stitched_array /= denominator

click.echo(f'Saving stitched array in :{output_data_path}')
with open_ome_zarr(
output_data_path, layout='hcs', channel_names=channels, mode="w-"
) as output_dataset:
position = output_dataset.create_position(*Path(well_name, '0').parts)
position.create_image('0', stitched_array, chunks=(1, 1, 1, 4096, 4096))
Loading

0 comments on commit 97c322b

Please sign in to comment.