Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement nearest neighbor searches on RTree #79

Merged
merged 6 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/kdtree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@
//! slice. If you don't know the coordinate type used in the index, you can use
//! [`CoordType::from_buffer`][crate::CoordType::from_buffer] to infer the coordinate type.
//!
//! ## Coordinate types
//!
//! Supported coordinate types implement [`IndexableNum`][crate::IndexableNum]. Note that float
//! `NaN` is not supported and may panic.
//!
//! ## Example
//!
//! ```
Expand Down
5 changes: 5 additions & 0 deletions src/rtree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
//! slice. If you don't know the coordinate type used in the index, you can use
//! [`CoordType::from_buffer`][crate::CoordType::from_buffer] to infer the coordinate type.
//!
//! ## Coordinate types
//!
//! Supported coordinate types implement [`IndexableNum`][crate::IndexableNum]. Note that float
//! `NaN` is not supported and may panic.
//!
//! ## Example
//!
//! ```
Expand Down
139 changes: 113 additions & 26 deletions src/rtree/trait.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::cmp::Reverse;
use std::collections::BinaryHeap;

use geo_traits::{CoordTrait, RectTrait};

use crate::error::Result;
Expand Down Expand Up @@ -124,39 +127,101 @@ pub trait RTreeIndex<N: IndexableNum>: Sized {
)
}

// #[allow(unused_mut, unused_labels, unused_variables)]
// fn neighbors(&self, x: N, y: N, max_distance: Option<N>) -> Vec<usize> {
// let boxes = self.boxes();
// let indices = self.indices();
// let max_distance = max_distance.unwrap_or(N::max_value());
/// Search items in order of distance from the given point.
///
/// ```
/// use geo_index::rtree::{RTreeBuilder, RTreeIndex, RTreeRef};
/// use geo_index::rtree::sort::HilbertSort;
///
/// // Create an RTree
/// let mut builder = RTreeBuilder::<f64>::new(3);
/// builder.add(0., 0., 2., 2.);
/// builder.add(1., 1., 3., 3.);
/// builder.add(2., 2., 4., 4.);
/// let tree = builder.finish::<HilbertSort>();
///
/// let results = tree.neighbors(5., 5., None, None);
/// assert_eq!(results, vec![2, 1, 0]);
/// ```
fn neighbors(
&self,
x: N,
y: N,
max_results: Option<usize>,
max_distance: Option<N>,
) -> Vec<usize> {
let boxes = self.boxes();
let indices = self.indices();
let max_distance = max_distance.unwrap_or(N::max_value());

// let mut outer_node_index = Some(boxes.len() - 4);
let mut outer_node_index = Some(boxes.len() - 4);
let mut queue = BinaryHeap::new();
let mut results = vec![];
let max_dist_squared = max_distance * max_distance;

// let mut results = vec![];
// let max_dist_squared = max_distance * max_distance;
'outer: while let Some(node_index) = outer_node_index {
// find the end index of the node
let end = (node_index + self.node_size() as usize * 4)
.min(upper_bound(node_index, self.level_bounds()));

// 'outer: while let Some(node_index) = outer_node_index {
// // find the end index of the node
// let end = (node_index + self.node_size() * 4)
// .min(upper_bound(node_index, self.level_bounds()));
// add child nodes to the queue
for pos in (node_index..end).step_by(4) {
let index = indices.get(pos >> 2);

// // add child nodes to the queue
// for pos in (node_index..end).step_by(4) {
// let index = indices.get(pos >> 2);
let dx = axis_dist(x, boxes[pos], boxes[pos + 2]);
let dy = axis_dist(y, boxes[pos + 1], boxes[pos + 3]);
let dist = dx * dx + dy * dy;
if dist > max_dist_squared {
continue;
}

// let dx = axis_dist(x, boxes[pos], boxes[pos + 2]);
// let dy = axis_dist(y, boxes[pos + 1], boxes[pos + 3]);
// let dist = dx * dx + dy * dy;
// if dist > max_dist_squared {
// continue;
// }
// }
if node_index >= self.num_items() as usize * 4 {
// node (use even id)
queue.push(Reverse(NeighborNode {
id: index << 1,
dist,
}));
} else {
// leaf item (use odd id)
queue.push(Reverse(NeighborNode {
id: (index << 1) + 1,
dist,
}));
}
}

// // break 'outer;
// }
// pop items from the queue
while !queue.is_empty() && queue.peek().is_some_and(|val| (val.0.id & 1) != 0) {
let dist = queue.peek().unwrap().0.dist;
if dist > max_dist_squared {
break 'outer;
}
let item = queue.pop().unwrap();
results.push(item.0.id >> 1);
if max_results.is_some_and(|max_results| results.len() == max_results) {
break 'outer;
}
}

if let Some(item) = queue.pop() {
outer_node_index = Some(item.0.id >> 1);
} else {
outer_node_index = None;
}
}

// results
// }
results
}

/// Search items in order of distance from the given coordinate.
fn neighbors_coord(
&self,
coord: &impl CoordTrait<T = N>,
max_results: Option<usize>,
max_distance: Option<N>,
) -> Vec<usize> {
self.neighbors(coord.x(), coord.y(), max_results, max_distance)
}

/// Returns an iterator over the indexes of objects in this and another tree that intersect.
///
Expand All @@ -175,6 +240,28 @@ pub trait RTreeIndex<N: IndexableNum>: Sized {
}
}

/// A wrapper around a node and its distance for use in the priority queue.
#[derive(Debug, Clone, Copy, PartialEq)]
struct NeighborNode<N: IndexableNum> {
id: usize,
dist: N,
}

impl<N: IndexableNum> Eq for NeighborNode<N> {}

impl<N: IndexableNum> Ord for NeighborNode<N> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
// We don't allow NaN. This should only panic on NaN
self.dist.partial_cmp(&other.dist).unwrap()
}
}

impl<N: IndexableNum> PartialOrd for NeighborNode<N> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}

impl<N: IndexableNum> RTreeIndex<N> for OwnedRTree<N> {
fn boxes(&self) -> &[N] {
self.metadata.boxes_slice(&self.buffer)
Expand Down
13 changes: 2 additions & 11 deletions src/type.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::fmt::Debug;

use num_traits::{Bounded, Num, NumCast, ToPrimitive};
use num_traits::{Bounded, Num, NumCast};

use crate::kdtree::constants::KDBUSH_MAGIC;
use crate::GeoIndexError;
Expand All @@ -12,16 +12,7 @@ use crate::GeoIndexError;
/// JavaScript ([rtree](https://github.com/mourner/flatbush),
/// [kdtree](https://github.com/mourner/kdbush))
pub trait IndexableNum:
private::Sealed
+ Num
+ NumCast
+ ToPrimitive
+ PartialOrd
+ Debug
+ Send
+ Sync
+ bytemuck::Pod
+ Bounded
private::Sealed + Num + NumCast + PartialOrd + Debug + Send + Sync + bytemuck::Pod + Bounded
{
/// The type index to match the array order of `ARRAY_TYPES` in flatbush JS
const TYPE_INDEX: u8;
Expand Down
Loading