Skip to content

Commit

Permalink
Distance to point (#36)
Browse files Browse the repository at this point in the history
* Test for distance method

* Implement distance method
  • Loading branch information
faymanns authored Dec 12, 2024
1 parent bf604b8 commit 4fed17a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
44 changes: 44 additions & 0 deletions src/splinebox/spline_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,50 @@ def rotate(self, rotation_matrix, centred=True):
if centred:
self.translate(centroid)

def distance(self, point, return_t=False):
"""
Computes the distance of point from the spline.
Parameters
----------
point : numpy.array
Array with the coordinates of the point.
return_t : bool
Whether to return the paramter t of the spline.
`spline.eval(t)` gives the location on the spline
closest to the point.
Returns
-------
distance : float
The distance between the point and the spline.
t : float
Only returned if `return_t=True`. This is the parameter
corresponding to the location on the spline closest
to the point.
"""
self._check_control_points()
if self.control_points.ndim == 1:
raise RuntimeError("Cannot compute distance for 1D splines.")

max_t = self.M if self.closed else self.M - 1
t = np.linspace(0, max_t, self.M * 10)
points_on_spline = self.eval(t)
distances = np.linalg.norm(points_on_spline - point[np.newaxis], axis=-1)
t_initial = t[np.argmin(distances)]

def _distance(t):
return np.linalg.norm(self.eval(t) - point)

result = scipy.optimize.minimize(_distance, np.array((t_initial,)), bounds=((0, max_t),))

min_distance = np.linalg.norm(self.eval(result.x) - point)

if return_t:
return (min_distance, result.x)

return min_distance


class HermiteSpline(Spline):
"""
Expand Down
23 changes: 23 additions & 0 deletions tests/test_spline_curves.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,3 +710,26 @@ def test_saving_and_loading_of_multiple_splines(is_hermite_basis_function, tmpdi

for spline, loaded_spline in zip(splines, loaded_splines):
assert spline == loaded_spline


def test_distance(initialized_spline_curve):
spline = initialized_spline_curve

if spline.control_points.ndim == 1:
with pytest.raises(RuntimeError):
spline.distance(np.array((1.0,)))
return

point = np.array([1] * spline.control_points.shape[1])
distance = spline.distance(point)

# Check that a set of equality spaced points on the spline are at as far away as the distance
t = np.linspace(0, spline.M if spline.closed else spline.M - 1, spline.M * 10)
points_on_spline = spline.eval(t)
distances = np.linalg.norm(points_on_spline - point[np.newaxis], axis=-1)
assert np.all(distances >= distance)

# Check that the returned t corresponds to the correct distance
distance, t = spline.distance(point, return_t=True)
point_on_spline = spline.eval(t)
assert np.isclose(distance, np.linalg.norm(point - point_on_spline))

0 comments on commit 4fed17a

Please sign in to comment.