Skip to content

Commit

Permalink
K-D tree
Browse files Browse the repository at this point in the history
  • Loading branch information
MKuranowski committed Jul 26, 2024
1 parent 32ade18 commit bbe825f
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 9 deletions.
2 changes: 2 additions & 0 deletions pyroutelib3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from . import nx, osm, protocols
from .distance import euclidean_distance, haversine_earth_distance, taxicab_distance
from .kd import KDTree
from .router import (
DEFAULT_STEP_LIMIT,
StepLimitExceeded,
Expand All @@ -28,6 +29,7 @@
"find_route_without_turn_around",
"find_route",
"haversine_earth_distance",
"KDTree",
"nx",
"osm",
"protocols",
Expand Down
101 changes: 101 additions & 0 deletions pyroutelib3/kd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from dataclasses import dataclass
from typing import Generic, Iterable, List, Optional, Tuple

from typing_extensions import Self

from .distance import haversine_earth_distance
from .protocols import DistanceFunction, Position, WithPositionT


@dataclass(frozen=True)
class KDTree(Generic[WithPositionT]):
"""KDTree implements the `k-d tree data structure <https://en.wikipedia.org/wiki/K-d_tree>`_,
which can be used to speed up nearest-neighbor search for large datasets. Practice shows
that :py:meth:`osm.Graph.find_nearest_neighbor` takes significantly more time than
:py:func:`find_route` when generating multiple routes with ``pyroutelib3``.A k-d tree
can help with that, trading memory usage for CPU time.
This implementation assumes euclidean geometry, even though the default distance function
used is :py:func:`haversine_earth_distance`. This results in undefined behavior when
points are close to the ante meridian (180°/-180° longitude) or poles (90°/-90° latitude),
or when the data spans multiple continents.
"""

pivot: WithPositionT
left: Optional["KDTree[WithPositionT]"] = None
right: Optional["KDTree[WithPositionT]"] = None

def _find_nearest_neighbor_impl(
self,
root: Position,
distance: DistanceFunction = haversine_earth_distance,
axis: int = 0,
) -> Tuple[WithPositionT, float]:
# Start by assuming that pivot is the closest
best = self.pivot
best_distance = distance(root, self.pivot.position)

# Select which branch to recurse into first
first_left = root[0] < best.position[0] if axis == 0 else root[1] < best.position[1]
first = self.left if first_left else self.right
second = self.right if first_left else self.left

# Recurse into the first branch
if first:
alt, alt_distance = first._find_nearest_neighbor_impl(root, distance, axis ^ 1)
if alt_distance < best_distance:
best = alt
best_distance = alt_distance

# (Optionally) recurse into the second branch
if second:
# A closer node is possible in the second branch if and only if
# the splitting axis (as determined by pivot[axis]) is closer than
# the current best candidate
pt_on_axis = (
(self.pivot.position[0], root[1])
if axis == 0
else (root[0], self.pivot.position[1])
)
dist_to_axis = distance(root, pt_on_axis)

if dist_to_axis < best_distance:
alt, alt_distance = second._find_nearest_neighbor_impl(root, distance, axis ^ 1)
if alt_distance < best_distance:
best = alt
best_distance = alt_distance

return best, best_distance

def find_nearest_neighbor(
self,
root: Position,
distance: DistanceFunction = haversine_earth_distance,
) -> WithPositionT:
"""Find the closest node to ``root``, as determined by the provided distance function."""
return self._find_nearest_neighbor_impl(root, distance, 0)[0]

@classmethod
def _build_impl(cls, points: List[WithPositionT], axis: int = 0) -> Optional[Self]:
if not points:
return None
elif len(points) == 1:
return cls(points[0])
else:
points.sort(key=lambda pt: pt.position[axis])
median = len(points) // 2
return cls(
points[median],
cls._build_impl(points[:median], axis ^ 1),
cls._build_impl(points[median + 1 :], axis ^ 1),
)

@classmethod
def build(cls, points: Iterable[WithPositionT]) -> Optional[Self]:
"""Creates a new K-D tree with all of the provided objects with a :py:obj:`Position`.
Note that the type-complaint usage of class methods on generic types requires
explicitly providing the type argument, e.g.::
tree = KDTree[Node].build(nodes)
"""
return cls._build_impl(list(points), 0)
20 changes: 11 additions & 9 deletions pyroutelib3/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,24 @@
"""


class NodeLike(Protocol):
class WithPosition(Protocol):
"""WithPosition describes any object with a ``position`` property of :py:obj:`Position` type."""

@property
def position(self) -> Position: ...


WithPositionT = TypeVar("WithPositionT", bound=WithPosition)


class NodeLike(WithPosition, Protocol):
"""NodeLike describes the protocol of a *node* in a *graph*."""

@property
def id(self) -> int:
"""id property must uniquely identify this *node* in its *graph*."""
...

@property
def position(self) -> Position:
"""position property must describe the position of this *node* in
reference to other nodes in its *graph*. For real-life graphs,
this should be WGS84 degrees, first latitude, then longitude.
"""
...


class ExternalNodeLike(NodeLike, Protocol):
"""ExternalNodeLike is an extension of the :py:class:`NodeLike` protocol
Expand Down
28 changes: 28 additions & 0 deletions pyroutelib3/test_kd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from unittest import TestCase

from .kd import KDTree
from .simple_graph import SimpleNode


class TestKDTree(TestCase):
def test(self) -> None:
tree = KDTree[SimpleNode].build(
[
SimpleNode(1, (1.0, 1.0)),
SimpleNode(2, (1.0, 5.0)),
SimpleNode(3, (3.0, 9.0)),
SimpleNode(4, (4.0, 3.0)),
SimpleNode(5, (4.0, 7.0)),
SimpleNode(6, (6.0, 3.0)),
SimpleNode(7, (7.0, 1.0)),
SimpleNode(8, (8.0, 5.0)),
SimpleNode(9, (8.0, 9.0)),
]
)
self.assertIsNotNone(tree)
assert tree is not None # for type checker

self.assertEqual(tree.find_nearest_neighbor((2.0, 2.0)).id, 1)
self.assertEqual(tree.find_nearest_neighbor((5.0, 3.0)).id, 4)
self.assertEqual(tree.find_nearest_neighbor((5.0, 8.0)).id, 5)
self.assertEqual(tree.find_nearest_neighbor((9.0, 6.0)).id, 8)

0 comments on commit bbe825f

Please sign in to comment.