Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support points 2d #108

Merged
merged 10 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions holonote/annotate/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,6 @@ def points_2d(
cls, data, region_labels, fields_labels, invert_axes=False, groupby: str | None = None
):
"Vectorizes point regions to VLines * HLines. Note does not support hover info"
msg = "2D point regions not supported yet"
raise NotImplementedError(msg)
vdims = [*fields_labels, "__selected__"]
element = hv.Points(data, kdims=region_labels, vdims=vdims)
hover = cls._build_hover_tool(data)
Expand Down Expand Up @@ -236,6 +234,15 @@ class AnnotationDisplay(param.Parameterized):

data = param.DataFrame(doc="Combined dataframe of annotation data", constant=True)

nearest_2d_point_threshold = param.Number(
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
default=0.1,
bounds=(0, None),
doc="""
Threshold for selecting an existing 2D point; anything over
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
this threshold will create a new point instead.
""",
)

invert_axis = param.Boolean(default=False, doc="Switch the annotation axis")

_count = param.Integer(default=0, precedence=-1)
Expand Down Expand Up @@ -425,13 +432,24 @@ def get_indices_by_position(self, **inputs) -> list[Any]:
iter_mask = (
(df[f"start[{k}]"] <= v) & (v < df[f"end[{k}]"]) for k, v in inputs.items()
)
subset = reduce(np.logical_and, iter_mask)
out = list(df[subset].index)
elif self.region_format == "point-point":
xk, yk = list(inputs.keys())
distance = (
(df[f"point[{xk}]"] - inputs[xk]) ** 2 + (df[f"point[{yk}]"] - inputs[yk]) ** 2
) ** 0.5
if (distance > self.nearest_2d_point_threshold).all():
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
return []
out = [df.loc[distance.idxmin()].name] # index == name of series
elif "point" in self.region_format:
iter_mask = ((df[f"point[{k}]"] - v).abs().argmin() for k, v in inputs.items())
out = list(df[reduce(np.logical_and, iter_mask)].index)
else:
msg = f"{self.region_format} not implemented"
raise NotImplementedError(msg)

return list(df[reduce(np.logical_and, iter_mask)].index)
return out

def register_tap_selector(self, element: hv.Element) -> hv.Element:
def tap_selector(x, y) -> None: # Tap tool must be enabled on the element
Expand Down Expand Up @@ -483,7 +501,9 @@ def overlay(self, indicators=True, editor=True) -> hv.Overlay:

def static_indicators(self, **events):
fields_labels = self.annotator.all_fields
region_labels = [k for k in self.data.columns if k not in fields_labels]
region_labels = [
k for k in self.data.columns if k not in fields_labels and k != "__selected__"
ahuang11 marked this conversation as resolved.
Show resolved Hide resolved
]

self.data["__selected__"] = self.data.index.isin(self.annotator.selected_indices)

Expand Down
54 changes: 54 additions & 0 deletions holonote/tests/test_display.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import holoviews as hv

hv.extension("bokeh")


class TestPoint2D:
def test_get_indices_by_position_exact(self, annotator_point2d):
x, y = 0.5, 0.3
description = "A test annotation!"
annotator_point2d.set_regions(x=x, y=y)
annotator_point2d.add_annotation(description=description)
display = annotator_point2d.get_display("x", "y")
indices = display.get_indices_by_position(x=x, y=y)
assert len(indices) == 1

def test_get_indices_by_position_nearest_2d_point_threshold(self, annotator_point2d):
x, y = 0.5, 0.3
description = "A test annotation!"
annotator_point2d.set_regions(x=x, y=y)
annotator_point2d.add_annotation(description=description)
display = annotator_point2d.get_display("x", "y")
indices = display.get_indices_by_position(x=x + 0.5, y=y + 0.5)
assert len(indices) == 0

display.nearest_2d_point_threshold = 5
indices = display.get_indices_by_position(x=x + 0.5, y=y + 0.5)
assert len(indices) == 1

def test_get_indices_by_position_empty(self, annotator_point2d):
display = annotator_point2d.get_display("x", "y")
indices = display.get_indices_by_position(x=0.5, y=0.3)
assert len(indices) == 0

def test_get_indices_by_position_no_position(self, annotator_point2d):
display = annotator_point2d.get_display("x", "y")
indices = display.get_indices_by_position(x=None, y=None)
assert len(indices) == 0

def test_get_indices_by_position_multi_choice(self, annotator_point2d):
x, y = 0.5, 0.3
description = "A test annotation!"
annotator_point2d.set_regions(x=x, y=y)
annotator_point2d.add_annotation(description=description)

x2, y2 = 0.51, 0.31
description = "A test annotation!"
annotator_point2d.set_regions(x=x2, y=y2)
annotator_point2d.add_annotation(description=description)

display = annotator_point2d.get_display("x", "y")
display.nearest_2d_point_threshold = 1000

indices = display.get_indices_by_position(x=x, y=y)
assert len(indices) == 1