Skip to content

Commit

Permalink
Fix 3D point adding by passing in dims as well as layer (#48)
Browse files Browse the repository at this point in the history
* Fix 3D point adding by passing in dims as well as layer
* Fix tests
* Fix click test
  • Loading branch information
jni authored Nov 2, 2023
1 parent bcda478 commit 7ecd6f9
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
11 changes: 7 additions & 4 deletions src/zarpaint/_add_3d_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_data_ray(data, start_point, end_point):
return clipped_coords, ray


def find_midpoint_of_first_segment(layer, event):
def find_midpoint_of_first_segment(layer, dims, event):
"""Return the world coordinate of a Labels layer mouse event in 2D or 3D.
In 2D, this is just the event's position.
Expand All @@ -58,14 +58,14 @@ def find_midpoint_of_first_segment(layer, event):
coordinates : array of int
The world coordinates for the mouse event.
"""
ndim = len(layer._dims_displayed)
ndim = len(dims.displayed)
if ndim == 2:
coordinates = event.position
else: # 3d
start, end = layer.get_ray_intersections(
position=event.position,
view_direction=event.view_direction,
dims_displayed=layer._dims_displayed,
dims_displayed=list(dims.displayed),
world=True,
)
coordinates, ray = get_data_ray(layer.data, start, end)
Expand All @@ -85,6 +85,7 @@ def find_midpoint_of_first_segment(layer, event):

@magicgui.magic_factory
def add_points_3d_with_alt_click(
viewer: napari.Viewer,
labels: napari.layers.Labels,
points: napari.layers.Points,
):
Expand All @@ -95,7 +96,9 @@ def click_callback(layer, event):
if not (len(event.modifiers) == 1
and event.modifiers[0].name == 'Alt'):
return
world_click_coordinates = find_midpoint_of_first_segment(layer, event)
world_click_coordinates = find_midpoint_of_first_segment(
layer, viewer.dims, event
)
if world_click_coordinates is not None:
pts_coordinates = pts_world2data(world_click_coordinates)
points.add(pts_coordinates)
16 changes: 10 additions & 6 deletions src/zarpaint/_tests/test_zarpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def test_get_ray_coordinates():

def test_midpoint_2d_empty_ray(make_napari_viewer):
viewer = make_napari_viewer()
dims = viewer.dims

mock_data = np.zeros(shape=(5, 5), dtype="uint8")
layer_data = Labels(mock_data)
Expand All @@ -55,12 +56,13 @@ def test_midpoint_2d_empty_ray(make_napari_viewer):
view_direction = [1, 0]
mouse_event = MockMouseEvent(position, view_direction)

result = find_midpoint_of_first_segment(layer_data, mouse_event)
result = find_midpoint_of_first_segment(layer_data, dims, mouse_event)
assert result == (0, 0)


def test_midpoint_2d_nonempty_ray(make_napari_viewer):
viewer = make_napari_viewer()
dims = viewer.dims

mock_data = np.ones(shape=(5, 5), dtype="uint8")
layer_data = Labels(mock_data)
Expand All @@ -70,21 +72,22 @@ def test_midpoint_2d_nonempty_ray(make_napari_viewer):
view_direction = [1, 0]
mouse_event = MockMouseEvent(position, view_direction)

result = find_midpoint_of_first_segment(layer_data, mouse_event)
result = find_midpoint_of_first_segment(layer_data, dims, mouse_event)
assert result == (3, 0)


def test_midpoint_3d_empty_ray(make_napari_viewer):
viewer = make_napari_viewer()
viewer.dims.ndisplay = 3
dims = viewer.dims

mock_data = np.zeros(shape=(5, 5, 5), dtype="uint8")
layer_data = Labels(mock_data)
mouse_event = MockMouseEvent((2, 2, 0), [1, 0, 0])

viewer.add_layer(layer_data)

result = find_midpoint_of_first_segment(layer_data, mouse_event)
result = find_midpoint_of_first_segment(layer_data, dims, mouse_event)
assert result is None

mock_data[1:4, 1:4, 1:4] = 1
Expand All @@ -93,21 +96,22 @@ def test_midpoint_3d_empty_ray(make_napari_viewer):

viewer.add_layer(layer_data)

result = find_midpoint_of_first_segment(layer_data, mouse_event)
result = find_midpoint_of_first_segment(layer_data, dims, mouse_event)
assert result is None


def test_midpoint_3d_nonempty_ray(make_napari_viewer):
viewer = make_napari_viewer()
viewer.dims.ndisplay = 3
dims = viewer.dims

mock_data = np.zeros(shape=(5, 5, 5), dtype="uint8")
mock_data[1:4, 1:4, 1:4] = 1
layer_data = Labels(mock_data)
viewer.add_layer(layer_data)

mouse_event = MockMouseEvent((2, 2, 0), [0, 1, 1])
result = find_midpoint_of_first_segment(layer_data, mouse_event)
result = find_midpoint_of_first_segment(layer_data, dims, mouse_event)
np.testing.assert_allclose(result, [2., 3.5, 1.5])


Expand All @@ -124,7 +128,7 @@ def test_add_point_3d_alt_click(make_napari_viewer):
viewer.layers.selection.active = label_layer

point_widget = add_points_3d_with_alt_click()
point_widget(label_layer, points_layer)
point_widget(viewer, label_layer, points_layer)

view = viewer.window._qt_viewer
click_coordinates = (view.canvas.size[0] / 2, view.canvas.size[1] / 2, 0)
Expand Down

0 comments on commit 7ecd6f9

Please sign in to comment.