Skip to content

Commit

Permalink
Registration V2 (#102)
Browse files Browse the repository at this point in the history
* initial commit adding tools for cropping and registering with ants

* adding lir crop

* starting cleanup for registration

* converting nan to zero

* ants estimate affine

* adding config file for estimate_transform

* cleaning up some redundancies

* fix edge case of input not being float32

* channel is an int

* typo in napari add image for vs ants optimzer

* adding utils yaml

* separated the manual estimation from the optimizer

* added ants to numpy and viceversa

* adding input at the end to prevent from closing

* replacing the reading of .mat to matrix from yaml

* added changes to ants to numpy from Jordaos input

* cannot have recOrder as dependency given different pycromanager versions

* format

* add test for optimize_affine and update readme

* adding docstring

* removed the config dependency

* removing the focus_finding svg

* removing the estimate_affine.yml and making flags to choose channels and verbose

* added comment and code to use scipy.

* -ls and -lf to -s and -t

* document FOCUS_SLICE_ROI_WIDTH

* `estimate-source-to-target-affine` -> `estimate-affine`

* standardize `optimize-affine` arguments

* cleaner docs

* update RegistrationSettings

* fix tests

* remove duplicate

* write indices to output config

* read channels from file

* test fix

* echo improvments

* black

* style

* remove duplicate cli call

* more intuitive manual registration

* patch for deskewing

* testing slurmkit

* fix apply affine example

* matrix multiplication bugfix

* updating the lir_crop using new registration

* adding lir crop to the apply affine and modifying the manual estimation and the optimizer to output the input zyx shape to yaml

* convert nans to zeros

* adding test for affine

* removed main

* fix `apply_affine` docstring

* fixing formatting

* remove crop-output from cli args

* estimate affine cli docstrings

* reverting changes to AnalysisSettings model.

* optimize_affine cli docstrings

* Revert "reverting changes to AnalysisSettings model."

This reverts commit 0e2f8c4.

* fix points layer grid mode bug

* fix tests

* combining the croping and merging to the apply-affine

* removing unnecessary prints()

* add registration and stabilization modules

* style

* flake8

* Registration refator (#118)

* refactor apply_affine

* clean up

Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]>

* debug and add comments

* updating estimate_affine to match new yml

* don't start parallel pool for 1 iterable

* patch previous commit

* make sure the source_channel_used is first in the list for the optimizer

* make optimize_affine accept the new config parameters

* fix bug calling iohub info on paths with white spaces; rename target_channel_str

* better white space cli bug fix

* cleaner print statements

* don't use itertools.starmap

* clean up apply_affine print statements

* clean up optimize_affine

---------

Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]>
Co-authored-by: Eduardo Hirata-Miyasaki <[email protected]>

* removing the source n target shape zyx and testing keep overhang

* updating estimate affine with previous commit changes

* style

* Update example_apply_affine_settings.yml

* refactor registration and stabilization functions

* fixing pytests

* more intuitive function names

* minor refactor of estimane_affine

* bugfix

---------

Co-authored-by: Talon Chandler <[email protected]>
Co-authored-by: Ivan Ivanov <[email protected]>
  • Loading branch information
3 people authored Jan 19, 2024
1 parent 9155d58 commit 0df6135
Show file tree
Hide file tree
Showing 20 changed files with 1,807 additions and 463 deletions.
209 changes: 157 additions & 52 deletions examples/slurmkit_example/slurm_apply_affine.py
Original file line number Diff line number Diff line change
@@ -1,89 +1,194 @@
import datetime
import os
import glob
from mantis.cli import utils
from slurmkit import SlurmParams, slurm_function, submit_function
from natsort import natsorted
import os

from pathlib import Path

import click
from mantis.cli.apply_affine import registration_params_from_file, rotate_n_affine_transform
import numpy as np

from iohub import open_ome_zarr
from natsort import natsorted
from slurmkit import SlurmParams, slurm_function, submit_function

from mantis.analysis.AnalysisSettings import RegistrationSettings
from mantis.analysis.register import apply_affine_transform, find_overlapping_volume
from mantis.cli.apply_affine import rescale_voxel_size
from mantis.cli.utils import (
copy_n_paste_czyx,
create_empty_hcs_zarr,
process_single_position_v2,
yaml_to_model,
)

# io parameters
labelfree_data_paths = '/hpc/projects/comp.micro/mantis/2023_08_09_HEK_PCNA_H2B/2-phase3D/pcna_rac1_virtual_staining_b1_redo_1/phase3D.zarr/0/0/0'
lightsheet_data_paths = '/hpc/projects/comp.micro/mantis/2023_08_09_HEK_PCNA_H2B/1-deskew/pcna_rac1_virtual_staining_b1_redo_1/deskewed.zarr/0/0/0'
output_data_path = './registered_output.zarr'
registration_param_path = './register.yml'
source_position_dirpaths = '/input_source.zarr/*/*/*'
target_position_dirpaths = '/input_target.zarr/*/*/*'
config_filepath = (
'../mantis/analysis/settings/example_apply_affine_settings.yml'
)
output_dirpath = './test_output.zarr'

# sbatch and resource parameters
cpus_per_task = 16
cpus_per_task = 4
mem_per_cpu = "16G"
time = 40 # minutes
simultaneous_processes_per_node = 5

# path handling
labelfree_data_paths = natsorted(glob.glob(labelfree_data_paths))
lightsheet_data_paths = natsorted(glob.glob(lightsheet_data_paths))
output_dir = os.path.dirname(output_data_path)
output_paths = utils.get_output_paths(labelfree_data_paths, output_data_path)
click.echo(f"in: {labelfree_data_paths}, out: {output_paths}")
slurm_out_path = str(os.path.join(output_dir, "slurm_output/register-%j.out"))

# Additional registraion arguments
time = 60 # minutes
partition = 'cpu'
simultaneous_processes_per_node = (
8 # number of processes that are run in parallel on a single node
)

# NOTE: parameters from here and below should not have to be changed
source_position_dirpaths = [
Path(path) for path in natsorted(glob.glob(source_position_dirpaths))
]
target_position_dirpaths = [
Path(path) for path in natsorted(glob.glob(target_position_dirpaths))
]
output_dirpath = Path(output_dirpath)
config_filepath = Path(config_filepath)

click.echo(f"in_path: {source_position_dirpaths[0]}, out_path: {output_dirpath}")
slurm_out_path = output_dirpath.parent / "slurm_output" / "register-%j.out"

# Parse from the yaml file
settings = registration_params_from_file(registration_param_path)
settings = yaml_to_model(config_filepath, RegistrationSettings)
matrix = np.array(settings.affine_transform_zyx)
output_shape_zyx = tuple(settings.output_shape_zyx)
keep_overhang = settings.keep_overhang

# Calculate the output voxel size from the input scale and affine transform
with open_ome_zarr(source_position_dirpaths[0]) as source_dataset:
T, C, Z, Y, X = source_dataset.data.shape
source_channel_names = source_dataset.channel_names
source_shape_zyx = source_dataset.data.shape[-3:]
source_voxel_size = source_dataset.scale[-3:]
output_voxel_size = rescale_voxel_size(matrix[:3, :3], source_voxel_size)

with open_ome_zarr(target_position_dirpaths[0]) as target_dataset:
target_channel_names = target_dataset.channel_names
Z_target, Y_target, X_target = target_dataset.data.shape[-3:]
target_shape_zyx = target_dataset.data.shape[-3:]

# Get the output voxel_size
with open_ome_zarr(lightsheet_data_paths[0]) as light_sheet_position:
voxel_size = tuple(light_sheet_position.scale[-3:])
click.echo('\nREGISTRATION PARAMETERS:')
click.echo(f'Transformation matrix:\n{matrix}')
click.echo(f'Voxel size: {output_voxel_size}')

# Logic to parse time indices
if settings.time_indices == "all":
time_indices = list(range(T))
elif isinstance(settings.time_indices, list):
time_indices = settings.time_indices
elif isinstance(settings.time_indices, int):
time_indices = [settings.time_indices]

output_channel_names = target_channel_names
if target_position_dirpaths != source_position_dirpaths:
output_channel_names += source_channel_names

if not keep_overhang:
# Find the largest interior rectangle
click.echo('\nFinding largest overlapping volume between source and target datasets')
Z_slice, Y_slice, X_slice = find_overlapping_volume(
source_shape_zyx, target_shape_zyx, matrix
)
# TODO: start or stop may be None
cropped_target_shape_zyx = (
Z_slice.stop - Z_slice.start,
Y_slice.stop - Y_slice.start,
X_slice.stop - X_slice.start,
)
# Overwrite the previous target shape
Z_target, Y_target, X_target = cropped_target_shape_zyx[-3:]
click.echo(f'Shape of cropped output dataset: {target_shape_zyx}\n')
else:
Z_slice, Y_slice, X_slice = (
slice(0, Z_target),
slice(0, Y_target),
slice(0, X_target),
)

output_metadata = {
"shape": (len(time_indices), len(output_channel_names), Z_target, Y_target, X_target),
"chunks": None,
"scale": (1,) * 2 + tuple(output_voxel_size),
"channel_names": output_channel_names,
"dtype": np.float32,
}

# Create the output zarr mirroring source_position_dirpaths
create_empty_hcs_zarr(
store_path=output_dirpath,
position_keys=[p.parts[-3:] for p in source_position_dirpaths],
**output_metadata,
)

# Get the affine transformation matrix
# NOTE: add any extra metadata if needed:
extra_metadata = {
'registration': {
'affine_matrix': matrix.tolist(),
'pre_affine_90degree_rotations_about_z': settings.pre_affine_90degree_rotations_about_z,
'affine_transformation': {
'transform_matrix': matrix.tolist(),
}
}

affine_transform_args = {
'matrix': matrix,
'output_shape_zyx': settings.output_shape_zyx,
'pre_affine_90degree_rotations_about_z': settings.pre_affine_90degree_rotations_about_z,
'output_shape_zyx': target_shape_zyx, # NOTE: this is the shape of the original target dataset
'crop_output_slicing': ([Z_slice, Y_slice, X_slice] if not keep_overhang else None),
'extra_metadata': extra_metadata,
}
utils.create_empty_zarr(
position_paths=labelfree_data_paths,
output_path=output_data_path,
output_zyx_shape=output_shape_zyx,
chunk_zyx_shape=None,
voxel_size=voxel_size,
)

copy_n_paste_kwargs = {"czyx_slicing_params": ([Z_slice, Y_slice, X_slice])}

# prepare slurm parameters
params = SlurmParams(
partition="cpu",
partition=partition,
cpus_per_task=cpus_per_task,
mem_per_cpu=mem_per_cpu,
time=datetime.timedelta(minutes=time),
output=slurm_out_path,
)

# wrap our utils.process_single_position() function with slurmkit
slurm_process_single_position = slurm_function(utils.process_single_position)
slurm_process_single_position = slurm_function(process_single_position_v2)
register_func = slurm_process_single_position(
func=rotate_n_affine_transform,
func=apply_affine_transform,
output_path=output_dirpath,
time_indices=time_indices,
num_processes=simultaneous_processes_per_node,
**affine_transform_args,
)

# generate an array of jobs by passing the in_path and out_path to slurm wrapped function
register_jobs = [
submit_function(
register_func,
slurm_params=params,
input_data_path=in_path,
output_path=out_path,
)
for in_path, out_path in zip(labelfree_data_paths, output_paths)
]
copy_n_paste_func = slurm_process_single_position(
func=copy_n_paste_czyx,
output_path=output_dirpath,
time_indices=time_indices,
num_processes=simultaneous_processes_per_node,
**copy_n_paste_kwargs,
)

# NOTE: channels will not be processed in parallel
# NOTE: the the source and target datastores may be the same (e.g. Hummingbird datasets)
# apply affine transform to channels in the source datastore that should be registered
# as given in the config file (i.e. settings.source_channel_names)
for input_position_path in source_position_dirpaths:
for channel_name in source_channel_names:
if channel_name in settings.source_channel_names:
submit_function(
register_func,
slurm_params=params,
input_data_path=input_position_path,
input_channel_idx=[source_channel_names.index(channel_name)],
output_channel_idx=[output_channel_names.index(channel_name)],
)

# Copy over the channels that were not processed
for input_position_path in target_position_dirpaths:
for channel_name in target_channel_names:
if channel_name not in settings.source_channel_names:
submit_function(
copy_n_paste_func,
slurm_params=params,
input_data_path=input_position_path,
input_channel_idx=[target_channel_names.index(channel_name)],
output_channel_idx=[output_channel_names.index(channel_name)],
)
45 changes: 23 additions & 22 deletions mantis/analysis/AnalysisSettings.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from typing import Optional
from typing import Literal, Optional, Union

import numpy as np

from pydantic import ConfigDict, PositiveFloat, PositiveInt, validator
from pydantic.dataclasses import dataclass
from pydantic import BaseModel, Extra, NonNegativeInt, PositiveFloat, PositiveInt, validator

config = ConfigDict(extra="forbid")

# All settings classes inherit from MyBaseModel, which forbids extra parameters to guard against typos
class MyBaseModel(BaseModel, extra=Extra.forbid):
pass

@dataclass(config=config)
class DeskewSettings:

class DeskewSettings(MyBaseModel):
pixel_size_um: PositiveFloat
ls_angle_deg: PositiveFloat
px_to_scan_ratio: Optional[PositiveFloat] = None
scan_step_um: Optional[PositiveFloat] = None
keep_overhang: bool = True
keep_overhang: bool = False
average_n_slices: PositiveInt = 3

@validator("ls_angle_deg")
Expand All @@ -28,19 +29,25 @@ def px_to_scan_ratio_check(cls, v):
if v is not None:
return round(float(v), 3)

def __post_init__(self):
if self.px_to_scan_ratio is None:
if self.scan_step_um is not None:
self.px_to_scan_ratio = round(self.pixel_size_um / self.scan_step_um, 3)
def __init__(self, **data):
if data.get("px_to_scan_ratio") is None:
if data.get("scan_step_um") is not None:
data["px_to_scan_ratio"] = round(
data["pixel_size_um"] / data["scan_step_um"], 3
)
else:
raise TypeError("px_to_scan_ratio is not valid")
raise ValueError(
"If px_to_scan_ratio is not provided, both pixel_size_um and scan_step_um must be provided"
)
super().__init__(**data)


@dataclass(config=config)
class RegistrationSettings:
class RegistrationSettings(MyBaseModel):
source_channel_names: list[str]
target_channel_name: str
affine_transform_zyx: list
output_shape_zyx: list
pre_affine_90degree_rotations_about_z: Optional[int] = 1
keep_overhang: bool = False
time_indices: Union[NonNegativeInt, list[NonNegativeInt], Literal["all"]] = "all"

@validator("affine_transform_zyx")
def check_affine_transform(cls, v):
Expand All @@ -60,9 +67,3 @@ def check_affine_transform(cls, v):
raise ValueError("The array must contain valid numerical values.")

return v

@validator("output_shape_zyx")
def check_output_shape_zyx(cls, v):
if not isinstance(v, list) or len(v) != 3:
raise ValueError("The output shape zyx must be a list of length 3.")
return v
Loading

0 comments on commit 0df6135

Please sign in to comment.