Skip to content

Commit

Permalink
- fixing bug when source and target channels are the same.
Browse files Browse the repository at this point in the history
- removing extra parameters not used
-support for 2D and 3D
-renaming flag for similarity-> similarity
  • Loading branch information
edyoshikun committed Jun 30, 2024
1 parent 0f47090 commit f0ae5ee
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 76 deletions.
17 changes: 7 additions & 10 deletions mantis/cli/apply_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def apply_affine(

with open_ome_zarr(target_position_dirpaths[0]) as target_dataset:
target_channel_names = target_dataset.channel_names
Z_target, Y_target, X_target = target_dataset.data.shape[-3:]
target_shape_zyx = target_dataset.data.shape[-3:]

click.echo('\nREGISTRATION PARAMETERS:')
Expand All @@ -101,24 +100,22 @@ def apply_affine(
source_shape_zyx, target_shape_zyx, matrix
)
# TODO: start or stop may be None
cropped_target_shape_zyx = (
# Overwrite the previous target shape
target_shape_zyx = (
Z_slice.stop - Z_slice.start,
Y_slice.stop - Y_slice.start,
X_slice.stop - X_slice.start,
)
# Overwrite the previous target shape
Z_target, Y_target, X_target = cropped_target_shape_zyx[-3:]
cropped_target_shape_zyx = Z_target, Y_target, X_target
click.echo(f'Shape of cropped output dataset: {cropped_target_shape_zyx}\n')
click.echo(f'Shape of cropped output dataset: {target_shape_zyx}\n')
else:
Z_slice, Y_slice, X_slice = (
slice(0, Z_target),
slice(0, Y_target),
slice(0, X_target),
slice(0, target_shape_zyx[-3]),
slice(0, target_shape_zyx[-2]),
slice(0, target_shape_zyx[-1]),
)

output_metadata = {
"shape": (len(time_indices), len(output_channel_names), Z_target, Y_target, X_target),
"shape": (len(time_indices), len(output_channel_names)) + tuple(target_shape_zyx),
"chunks": None,
"scale": (1,) * 2 + tuple(output_voxel_size),
"channel_names": output_channel_names,
Expand Down
122 changes: 56 additions & 66 deletions mantis/cli/estimate_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,13 @@
@target_position_dirpaths()
@output_filepath()
@click.option(
"--similarity-flag",
"--similarity",
'-x',
is_flag=True,
help='flag to use similarity transform (rotation, translation, scaling) default:Eucledian (rotation, translation)',
help='Flag to use similarity transform (rotation, translation, scaling) default:Eucledian (rotation, translation)',
)
def estimate_affine(
source_position_dirpaths, target_position_dirpaths, output_filepath, similarity_flag
source_position_dirpaths, target_position_dirpaths, output_filepath, similarity
):
"""
Estimate the affine transform between a source (i.e. moving) and a target (i.e.
Expand Down Expand Up @@ -81,55 +81,50 @@ def estimate_affine(
source_channel_Z, source_channel_Y, source_channel_X = source_channel_volume.shape[-3:]
target_channel_Z, target_channel_Y, target_channel_X = target_channel_volume.shape[-3:]

if source_channel_Z < 2 or target_channel_Z < 2:
focus_source_channel_idx = 1
focus_target_channel_idx = 1
else:
focus_source_channel_idx = focus_from_transverse_band(
source_channel_volume[
:,
source_channel_Y // 2
- FOCUS_SLICE_ROI_WIDTH : source_channel_Y // 2
+ FOCUS_SLICE_ROI_WIDTH,
source_channel_X // 2
- FOCUS_SLICE_ROI_WIDTH : source_channel_X // 2
+ FOCUS_SLICE_ROI_WIDTH,
],
NA_det=NA_DETECTION_SOURCE,
lambda_ill=WAVELENGTH_EMISSION_SOURCE_CHANNEL,
pixel_size=source_channel_voxel_size[-1],
)
source_channel_focus_idx = focus_from_transverse_band(
source_channel_volume[
:,
source_channel_Y // 2
- FOCUS_SLICE_ROI_WIDTH : source_channel_Y // 2
+ FOCUS_SLICE_ROI_WIDTH,
source_channel_X // 2
- FOCUS_SLICE_ROI_WIDTH : source_channel_X // 2
+ FOCUS_SLICE_ROI_WIDTH,
],
NA_det=NA_DETECTION_SOURCE,
lambda_ill=WAVELENGTH_EMISSION_SOURCE_CHANNEL,
pixel_size=source_channel_voxel_size[-1],
)

focus_target_channel_idx = focus_from_transverse_band(
target_channel_volume[
:,
target_channel_Y // 2
- FOCUS_SLICE_ROI_WIDTH : target_channel_Y // 2
+ FOCUS_SLICE_ROI_WIDTH,
target_channel_X // 2
- FOCUS_SLICE_ROI_WIDTH : target_channel_X // 2
+ FOCUS_SLICE_ROI_WIDTH,
],
NA_det=NA_DETECTION_TARGET,
lambda_ill=WAVELENGTH_EMISSION_TARGET_CHANNEL,
pixel_size=target_channel_voxel_size[-1],
)
target_channel_focus_idx = focus_from_transverse_band(
target_channel_volume[
:,
target_channel_Y // 2
- FOCUS_SLICE_ROI_WIDTH : target_channel_Y // 2
+ FOCUS_SLICE_ROI_WIDTH,
target_channel_X // 2
- FOCUS_SLICE_ROI_WIDTH : target_channel_X // 2
+ FOCUS_SLICE_ROI_WIDTH,
],
NA_det=NA_DETECTION_TARGET,
lambda_ill=WAVELENGTH_EMISSION_TARGET_CHANNEL,
pixel_size=target_channel_voxel_size[-1],
)

click.echo()
if focus_source_channel_idx not in (0, source_channel_Z - 1):
click.echo(f"Best source channel focus slice: {focus_source_channel_idx}")
if source_channel_focus_idx not in (0, source_channel_Z - 1):
click.echo(f"Best source channel focus slice: {source_channel_focus_idx}")
else:
focus_source_channel_idx = source_channel_Z // 2
source_channel_focus_idx = source_channel_Z // 2
click.echo(
f"Could not determine best source channel focus slice, using {focus_source_channel_idx}"
f"Could not determine best source channel focus slice, using {source_channel_focus_idx}"
)

if focus_target_channel_idx not in (0, target_channel_Z - 1):
click.echo(f"Best target channel focus slice: {focus_target_channel_idx}")
if target_channel_focus_idx not in (0, target_channel_Z - 1):
click.echo(f"Best target channel focus slice: {target_channel_focus_idx}")
else:
focus_target_channel_idx = target_channel_Z // 2
target_channel_focus_idx = target_channel_Z // 2
click.echo(
f"Could not determine best target channel focus slice, using {focus_target_channel_idx}"
f"Could not determine best target channel focus slice, using {target_channel_focus_idx}"
)

# Calculate scaling factors for displaying data
Expand All @@ -138,9 +133,9 @@ def estimate_affine(
click.echo(
f"Z scaling factor: {scaling_factor_z:.3f}; XY scaling factor: {scaling_factor_yx:.3f}\n"
)

# Add layers to napari with and transform
# Rotate the image if needed here

# Convert to ants objects
source_zyx_ants = ants.from_numpy(source_channel_volume.astype(np.float32))
target_zyx_ants = ants.from_numpy(target_channel_volume.astype(np.float32))
Expand Down Expand Up @@ -181,19 +176,19 @@ def estimate_affine(
"magenta",
]

viewer.add_image(target_channel_volume, name=target_channel_name)
viewer.add_image(target_channel_volume, name=f"target_{target_channel_name}")
points_target_channel = viewer.add_points(
ndim=3, name=f"pts_{target_channel_name}", size=50, face_color=COLOR_CYCLE[0]
ndim=3, name=f"pts_target_{target_channel_name}", size=50, face_color=COLOR_CYCLE[0]
)

viewer.add_image(
source_layer = viewer.add_image(
source_zxy_pre_reg.numpy(),
name=source_channel_name,
name=f"source_{source_channel_name}",
blending='additive',
colormap='bop blue',
)
points_source_channel = viewer.add_points(
ndim=3, name=f"pts_{source_channel_name}", size=50, face_color=COLOR_CYCLE[0]
ndim=3, name=f"pts_source_{source_channel_name}", size=50, face_color=COLOR_CYCLE[0]
)

# setup viewer
Expand Down Expand Up @@ -264,7 +259,7 @@ def next_on_click(layer, event, in_focus):
viewer.dims.current_step = prev_step_source_channel

# Bind the mouse click callback to both point layers
in_focus = (focus_source_channel_idx, focus_target_channel_idx)
in_focus = (source_channel_focus_idx, target_channel_focus_idx)

def lambda_callback(layer, event):
return next_on_click(layer=layer, event=event, in_focus=in_focus)
Expand All @@ -285,23 +280,23 @@ def lambda_callback(layer, event):
)

# Get the data from the layers
pts_source_channel = points_source_channel.data
pts_target_channel = points_target_channel.data
pts_source_channel_data = points_source_channel.data
pts_target_channel_data = points_target_channel.data

# Estimate the affine transform between the points xy to make sure registration is good
if similarity_flag:
if similarity:
# Similarity transform (rotation, translation, scaling)
transform = SimilarityTransform()
transform.estimate(pts_source_channel, pts_target_channel)
transform.estimate(pts_source_channel_data, pts_target_channel_data)
manual_estimated_transform = transform.params @ compound_affine

else:
# Euclidean transform (rotation, translation) limiting this dataset's scale and just z-translation
transform = EuclideanTransform()
transform.estimate(pts_source_channel[:, 1:], pts_target_channel[:, 1:])
transform.estimate(pts_source_channel_data[:, 1:], pts_target_channel_data[:, 1:])
yx_points_transformation_matrix = transform.params

z_translation = pts_target_channel[0, 0] - pts_source_channel[0, 0]
z_translation = pts_target_channel_data[0, 0] - pts_source_channel_data[0, 0]

z_scale_translate_matrix = np.array([[1, 0, 0, z_translation]])

Expand All @@ -312,12 +307,6 @@ def lambda_callback(layer, event):
np.insert(yx_points_transformation_matrix, 0, 0, axis=1),
)
) # Insert 0 in the third entry of each row

scaling_affine = get_3D_rescaling_matrix(
(1, target_channel_Y, target_channel_X),
(scaling_factor_z, scaling_factor_yx, scaling_factor_yx),
)

manual_estimated_transform = euclidian_transform @ compound_affine

# NOTE: these two functions are key to pass the function properly to ANTs
Expand All @@ -341,9 +330,10 @@ def lambda_callback(layer, event):
colormap="magenta",
blending='additive',
)
viewer.layers.remove(f"pts_{source_channel_name}")
viewer.layers.remove(f"pts_{target_channel_name}")
viewer.layers[source_channel_name].visible = False
# Cleanup
viewer.layers.remove(points_source_channel)
viewer.layers.remove(points_target_channel)
source_layer.visible = False

# Ants affine transforms
T_manual_numpy = convert_transform_to_numpy(tx_manual)
Expand Down

0 comments on commit f0ae5ee

Please sign in to comment.