From 35272e724ef1c48304de0e400478585d816dfeb1 Mon Sep 17 00:00:00 2001
From: Elizabeth <eberrigan@salk.edu>
Date: Thu, 18 Apr 2024 16:24:16 -0700
Subject: [PATCH 1/4] add helper function for getting points from geometry

---
 sleap_roots/convhull.py | 19 ++++++++++-------
 sleap_roots/points.py   | 43 ++++++++++++++++++++++++++++++++++++-
 tests/test_convhull.py  |  1 -
 tests/test_points.py    | 47 +++++++++++++++++++++++++++++++++++++++--
 4 files changed, 98 insertions(+), 12 deletions(-)

diff --git a/sleap_roots/convhull.py b/sleap_roots/convhull.py
index c15a169..12b3df2 100644
--- a/sleap_roots/convhull.py
+++ b/sleap_roots/convhull.py
@@ -4,7 +4,7 @@
 from scipy.spatial import ConvexHull
 from scipy.spatial.distance import pdist
 from typing import Tuple, Optional, Union
-from sleap_roots.points import get_line_equation_from_points
+from sleap_roots.points import extract_points_from_geometry, get_line_equation_from_points
 from shapely import box, LineString, normalize, Polygon
 
 
@@ -382,13 +382,9 @@ def get_chull_areas_via_intersection(
     # Find the intersection between the hull perimeter and the extended line
     intersection = extended_line.intersection(hull_perimeter)
 
-    # Add intersection points to both lists
+    # Compute the intersection points and add to lists
     if not intersection.is_empty:
-        intersect_points = (
-            np.array([[point.x, point.y] for point in intersection.geoms])
-            if intersection.geom_type == "MultiPoint"
-            else np.array([[intersection.x, intersection.y]])
-        )
+        intersect_points = extract_points_from_geometry(intersection)
         above_line.extend(intersect_points)
         below_line.extend(intersect_points)
 
@@ -452,6 +448,10 @@ def get_chull_intersection_vectors(
     Raises:
         ValueError: If pts does not have the expected shape.
     """
+    if r0_pts.ndim == 1 or rn_pts.ndim == 1 or pts.ndim == 2:
+        print("Not enough instances or incorrect format to compute convex hull intersections.")
+        return (np.array([[np.nan, np.nan]]), np.array([[np.nan, np.nan]]))
+
     # Check for valid pts input
     if not isinstance(pts, np.ndarray) or pts.ndim != 3 or pts.shape[-1] != 2:
         raise ValueError("pts must be a numpy array of shape (instances, nodes, 2).")
@@ -460,7 +460,7 @@ def get_chull_intersection_vectors(
         raise ValueError("rn_pts must be a numpy array of shape (instances, 2).")
     # Ensure r0_pts is a numpy array of shape (instances, 2)
     if not isinstance(r0_pts, np.ndarray) or r0_pts.ndim != 2 or r0_pts.shape[-1] != 2:
-        raise ValueError("r0_pts must be a numpy array of shape (instances, 2).")
+        raise ValueError(f"r0_pts must be a numpy array of shape (instances, 2).")
 
     # Flatten pts to 2D array and remove NaN values
     flattened_pts = pts.reshape(-1, 2)
@@ -481,6 +481,9 @@ def get_chull_intersection_vectors(
 
     # Ensuring r0_pts does not contain NaN values
     r0_pts_valid = r0_pts[~np.isnan(r0_pts).any(axis=1)]
+    # Expect two vectors in the end
+    if len(r0_pts_valid) < 2:
+        return (np.array([[np.nan, np.nan]]), np.array([[np.nan, np.nan]]))
 
     # Get the vertices of the convex hull
     hull_vertices = hull.points[hull.vertices]
diff --git a/sleap_roots/points.py b/sleap_roots/points.py
index 479f564..22bdfcf 100644
--- a/sleap_roots/points.py
+++ b/sleap_roots/points.py
@@ -3,11 +3,52 @@
 import numpy as np
 from matplotlib import pyplot as plt
 from matplotlib.lines import Line2D
-from shapely.geometry import LineString
+from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection
 from shapely.ops import nearest_points
 from typing import List, Optional, Tuple
 
 
+def extract_points_from_geometry(geometry):
+    """Extracts coordinates as a list of numpy arrays from any given Shapely geometry object.
+    
+    This function supports Point, MultiPoint, LineString, and GeometryCollection types. 
+    It recursively extracts coordinates from complex geometries and aggregates them into a single list. 
+    For unsupported geometry types, it returns an empty list.
+    
+    Parameters:
+    - geometry (shapely.geometry.base.BaseGeometry): A Shapely geometry object from which to extract points.
+    
+    Returns:
+    - List[np.ndarray]: A list of numpy arrays, where each array represents the coordinates of a point. 
+      The list will be empty if the geometry type is unsupported or contains no coordinates.
+    
+    Raises:
+    - TypeError: If the input is not a recognized Shapely geometry type.
+    
+    Example:
+    >>> from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection
+    >>> point = Point(1, 2)
+    >>> multipoint = MultiPoint([(1, 2), (3, 4)])
+    >>> linestring = LineString([(0, 0), (1, 1), (2, 2)])
+    >>> geom_col = GeometryCollection([point, multipoint, linestring])
+    >>> extract_points_from_geometry(geom_col)
+    [array([1, 2]), array([1, 2]), array([3, 4]), array([0, 0]), array([1, 1]), array([2, 2])]
+    """
+    if isinstance(geometry, Point):
+        return [np.array([geometry.x, geometry.y])]
+    elif isinstance(geometry, MultiPoint):
+        return [np.array([point.x, point.y]) for point in geometry.geoms]
+    elif isinstance(geometry, LineString):
+        return [np.array([x, y]) for x, y in zip(*geometry.xy)]
+    elif isinstance(geometry, GeometryCollection):
+        points = []
+        for geom in geometry.geoms:
+            points.extend(extract_points_from_geometry(geom))
+        return points
+    else:
+        raise TypeError(f"Unsupported geometry type: {type(geometry).__name__}")
+
+
 def get_count(pts: np.ndarray):
     """Get number of roots.
 
diff --git a/tests/test_convhull.py b/tests/test_convhull.py
index a33c279..f506312 100644
--- a/tests/test_convhull.py
+++ b/tests/test_convhull.py
@@ -314,7 +314,6 @@ def test_basic_functionality(pts_shape_3_6_2):
 @pytest.mark.parametrize(
     "invalid_input",
     [
-        (np.array([1, 2]), np.array([3, 4]), np.array([[[1, 2], [3, 4]]]), None),
         (np.array([[1, 2, 3]]), np.array([[3, 4]]), np.array([[[1, 2], [3, 4]]]), None),
         # Add more invalid inputs as needed
     ],
diff --git a/tests/test_points.py b/tests/test_points.py
index 6ac3a1d..ed042a8 100644
--- a/tests/test_points.py
+++ b/tests/test_points.py
@@ -1,9 +1,9 @@
 import numpy as np
 import pytest
-from shapely.geometry import LineString
+from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection
 from sleap_roots import Series
 from sleap_roots.lengths import get_max_length_pts
-from sleap_roots.points import filter_plants_with_unexpected_ct, get_count, join_pts
+from sleap_roots.points import extract_points_from_geometry, filter_plants_with_unexpected_ct, get_count, join_pts
 from sleap_roots.points import (
     get_all_pts_array,
     get_nodes,
@@ -738,3 +738,46 @@ def test_filter_plants_with_unexpected_ct_incorrect_input_types():
     expected_count = "not a float"
     with pytest.raises(ValueError):
         filter_plants_with_unexpected_ct(primary_pts, lateral_pts, expected_count)
+
+
+def test_extract_from_point():
+    point = Point(1, 2)
+    expected = [np.array([1, 2])]
+    assert np.array_equal(extract_points_from_geometry(point), expected)
+
+def test_extract_from_multipoint():
+    multipoint = MultiPoint([(1, 2), (3, 4)])
+    expected = [np.array([1, 2]), np.array([3, 4])]
+    results = extract_points_from_geometry(multipoint)
+    assert all(np.array_equal(result, exp) for result, exp in zip(results, expected))
+
+def test_extract_from_linestring():
+    linestring = LineString([(0, 0), (1, 1), (2, 2)])
+    expected = [np.array([0, 0]), np.array([1, 1]), np.array([2, 2])]
+    results = extract_points_from_geometry(linestring)
+    assert all(np.array_equal(result, exp) for result, exp in zip(results, expected))
+
+def test_extract_from_geometrycollection():
+    geom_collection = GeometryCollection([Point(1, 2), LineString([(0, 0), (1, 1)])])
+    expected = [np.array([1, 2]), np.array([0, 0]), np.array([1, 1])]
+    results = extract_points_from_geometry(geom_collection)
+    assert all(np.array_equal(result, exp) for result, exp in zip(results, expected))
+
+def test_extract_from_empty_multipoint():
+    empty_multipoint = MultiPoint()
+    expected = []
+    assert extract_points_from_geometry(empty_multipoint) == expected
+
+def test_extract_from_empty_linestring():
+    empty_linestring = LineString()
+    expected = []
+    assert extract_points_from_geometry(empty_linestring) == expected
+
+def test_extract_from_unsupported_type():
+    with pytest.raises(NameError):
+        extract_points_from_geometry(Polygon([(0, 0), (1, 1), (1, 0)]))  # Polygon is unsupported
+
+def test_extract_from_empty_geometrycollection():
+    empty_geom_collection = GeometryCollection()
+    expected = []
+    assert extract_points_from_geometry(empty_geom_collection) == expected
\ No newline at end of file

From 93741fcecaf0c2c6dd478b10dde0ed05ff6d8976 Mon Sep 17 00:00:00 2001
From: eberrigan <berri104@gmail.com>
Date: Tue, 23 Apr 2024 10:15:05 -0700
Subject: [PATCH 2/4] lint

---
 sleap_roots/convhull.py |  9 +++++++--
 sleap_roots/points.py   | 16 ++++++++--------
 2 files changed, 15 insertions(+), 10 deletions(-)

diff --git a/sleap_roots/convhull.py b/sleap_roots/convhull.py
index 12b3df2..88ad97c 100644
--- a/sleap_roots/convhull.py
+++ b/sleap_roots/convhull.py
@@ -4,7 +4,10 @@
 from scipy.spatial import ConvexHull
 from scipy.spatial.distance import pdist
 from typing import Tuple, Optional, Union
-from sleap_roots.points import extract_points_from_geometry, get_line_equation_from_points
+from sleap_roots.points import (
+    extract_points_from_geometry,
+    get_line_equation_from_points,
+)
 from shapely import box, LineString, normalize, Polygon
 
 
@@ -449,7 +452,9 @@ def get_chull_intersection_vectors(
         ValueError: If pts does not have the expected shape.
     """
     if r0_pts.ndim == 1 or rn_pts.ndim == 1 or pts.ndim == 2:
-        print("Not enough instances or incorrect format to compute convex hull intersections.")
+        print(
+            "Not enough instances or incorrect format to compute convex hull intersections."
+        )
         return (np.array([[np.nan, np.nan]]), np.array([[np.nan, np.nan]]))
 
     # Check for valid pts input
diff --git a/sleap_roots/points.py b/sleap_roots/points.py
index 22bdfcf..fd9acbc 100644
--- a/sleap_roots/points.py
+++ b/sleap_roots/points.py
@@ -10,21 +10,21 @@
 
 def extract_points_from_geometry(geometry):
     """Extracts coordinates as a list of numpy arrays from any given Shapely geometry object.
-    
-    This function supports Point, MultiPoint, LineString, and GeometryCollection types. 
-    It recursively extracts coordinates from complex geometries and aggregates them into a single list. 
+
+    This function supports Point, MultiPoint, LineString, and GeometryCollection types.
+    It recursively extracts coordinates from complex geometries and aggregates them into a single list.
     For unsupported geometry types, it returns an empty list.
-    
+
     Parameters:
     - geometry (shapely.geometry.base.BaseGeometry): A Shapely geometry object from which to extract points.
-    
+
     Returns:
-    - List[np.ndarray]: A list of numpy arrays, where each array represents the coordinates of a point. 
+    - List[np.ndarray]: A list of numpy arrays, where each array represents the coordinates of a point.
       The list will be empty if the geometry type is unsupported or contains no coordinates.
-    
+
     Raises:
     - TypeError: If the input is not a recognized Shapely geometry type.
-    
+
     Example:
     >>> from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection
     >>> point = Point(1, 2)

From 5033de5979837741bc76f816ae99d04c13bdfae5 Mon Sep 17 00:00:00 2001
From: eberrigan <berri104@gmail.com>
Date: Tue, 23 Apr 2024 10:52:16 -0700
Subject: [PATCH 3/4] Black

---
 tests/test_points.py | 20 +++++++++++++++++---
 1 file changed, 17 insertions(+), 3 deletions(-)

diff --git a/tests/test_points.py b/tests/test_points.py
index ed042a8..54c37d6 100644
--- a/tests/test_points.py
+++ b/tests/test_points.py
@@ -3,7 +3,12 @@
 from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection
 from sleap_roots import Series
 from sleap_roots.lengths import get_max_length_pts
-from sleap_roots.points import extract_points_from_geometry, filter_plants_with_unexpected_ct, get_count, join_pts
+from sleap_roots.points import (
+    extract_points_from_geometry,
+    filter_plants_with_unexpected_ct,
+    get_count,
+    join_pts,
+)
 from sleap_roots.points import (
     get_all_pts_array,
     get_nodes,
@@ -745,39 +750,48 @@ def test_extract_from_point():
     expected = [np.array([1, 2])]
     assert np.array_equal(extract_points_from_geometry(point), expected)
 
+
 def test_extract_from_multipoint():
     multipoint = MultiPoint([(1, 2), (3, 4)])
     expected = [np.array([1, 2]), np.array([3, 4])]
     results = extract_points_from_geometry(multipoint)
     assert all(np.array_equal(result, exp) for result, exp in zip(results, expected))
 
+
 def test_extract_from_linestring():
     linestring = LineString([(0, 0), (1, 1), (2, 2)])
     expected = [np.array([0, 0]), np.array([1, 1]), np.array([2, 2])]
     results = extract_points_from_geometry(linestring)
     assert all(np.array_equal(result, exp) for result, exp in zip(results, expected))
 
+
 def test_extract_from_geometrycollection():
     geom_collection = GeometryCollection([Point(1, 2), LineString([(0, 0), (1, 1)])])
     expected = [np.array([1, 2]), np.array([0, 0]), np.array([1, 1])]
     results = extract_points_from_geometry(geom_collection)
     assert all(np.array_equal(result, exp) for result, exp in zip(results, expected))
 
+
 def test_extract_from_empty_multipoint():
     empty_multipoint = MultiPoint()
     expected = []
     assert extract_points_from_geometry(empty_multipoint) == expected
 
+
 def test_extract_from_empty_linestring():
     empty_linestring = LineString()
     expected = []
     assert extract_points_from_geometry(empty_linestring) == expected
 
+
 def test_extract_from_unsupported_type():
     with pytest.raises(NameError):
-        extract_points_from_geometry(Polygon([(0, 0), (1, 1), (1, 0)]))  # Polygon is unsupported
+        extract_points_from_geometry(
+            Polygon([(0, 0), (1, 1), (1, 0)])
+        )  # Polygon is unsupported
+
 
 def test_extract_from_empty_geometrycollection():
     empty_geom_collection = GeometryCollection()
     expected = []
-    assert extract_points_from_geometry(empty_geom_collection) == expected
\ No newline at end of file
+    assert extract_points_from_geometry(empty_geom_collection) == expected

From 5a51eab3f530fd242eeb5d18a124dd9457b8f411 Mon Sep 17 00:00:00 2001
From: eberrigan <berri104@gmail.com>
Date: Tue, 23 Apr 2024 10:58:53 -0700
Subject: [PATCH 4/4] pydoc style

---
 sleap_roots/points.py | 11 ++++-------
 1 file changed, 4 insertions(+), 7 deletions(-)

diff --git a/sleap_roots/points.py b/sleap_roots/points.py
index fd9acbc..6d5c5c1 100644
--- a/sleap_roots/points.py
+++ b/sleap_roots/points.py
@@ -15,15 +15,12 @@ def extract_points_from_geometry(geometry):
     It recursively extracts coordinates from complex geometries and aggregates them into a single list.
     For unsupported geometry types, it returns an empty list.
 
-    Parameters:
-    - geometry (shapely.geometry.base.BaseGeometry): A Shapely geometry object from which to extract points.
+    Args:
+        geometry (shapely.geometry.base.BaseGeometry): A Shapely geometry object from which to extract points.
 
     Returns:
-    - List[np.ndarray]: A list of numpy arrays, where each array represents the coordinates of a point.
-      The list will be empty if the geometry type is unsupported or contains no coordinates.
-
-    Raises:
-    - TypeError: If the input is not a recognized Shapely geometry type.
+        List[np.ndarray]: A list of numpy arrays, where each array represents the coordinates of a point.
+        The list will be empty if the geometry type is unsupported or contains no coordinates.
 
     Example:
     >>> from shapely.geometry import Point, MultiPoint, LineString, GeometryCollection