Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import via buffer protocol #57

Closed
wants to merge 14 commits into from
2 changes: 1 addition & 1 deletion python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ geo-index = { path = "../", features = ["rayon"] }
# This is the fork used by polars
# https://github.com/pola-rs/polars/blob/fac700d9670feb57f1df32beaeee38377725fccf/py-polars/Cargo.toml#L33-L35
numpy = { git = "https://github.com/stinodego/rust-numpy.git", rev = "9ba9962ae57ba26e35babdce6f179edf5fe5b9c8", default-features = false }
pyo3 = { version = "0.21", features = ["abi3-py38"] }
pyo3 = { version = "0.21", features = [] }
thiserror = "1"

[profile.release]
Expand Down
2 changes: 2 additions & 0 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ fn ___version() -> &'static str {
fn _rust(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(___version))?;

m.add_wrapped(wrap_pyfunction!(rtree::search_rtree))?;

m.add_class::<rtree::RTree>()?;
m.add_class::<kdtree::KDTree>()?;

Expand Down
228 changes: 226 additions & 2 deletions python/src/rtree.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use geo_index::indices::Indices;
use geo_index::rtree::sort::{HilbertSort, STRSort};
use geo_index::rtree::util::f64_box_to_f32;
use geo_index::rtree::{OwnedRTree, RTreeBuilder, RTreeIndex};
use geo_index::IndexableNum;
use geo_index::rtree::{OwnedRTree, RTreeBuilder, RTreeIndex, TreeMetadata};
use geo_index::{CoordType, IndexableNum};
use numpy::ndarray::{ArrayView1, ArrayView2};
use numpy::{PyArray1, PyArrayMethods, PyReadonlyArray1, PyReadonlyArray2};
use pyo3::buffer::PyBuffer;
use pyo3::exceptions::{PyIndexError, PyTypeError, PyValueError};
use pyo3::intern;
use pyo3::prelude::*;
Expand All @@ -27,6 +29,228 @@ impl<'a> FromPyObject<'a> for RTreeMethod {
}
}

struct PyU8Buffer(PyBuffer<u8>);

impl<'py> FromPyObject<'py> for PyU8Buffer {
fn extract_bound(obj: &Bound<'py, PyAny>) -> PyResult<Self> {
let buffer = PyBuffer::<u8>::get_bound(obj)?;
if !buffer.readonly() {
return Err(PyValueError::new_err("Must be read-only byte buffer."));
}
if buffer.dimensions() != 1 {
return Err(PyValueError::new_err("Expected 1-dimensional array."));
}
// Note: this is probably superfluous for 1D array
if !buffer.is_c_contiguous() {
return Err(PyValueError::new_err("Expected c-contiguous array."));
}
if buffer.len_bytes() == 0 {
return Err(PyValueError::new_err("Buffer has no data."));
}

Ok(Self(buffer))
}
}

impl AsRef<[u8]> for PyU8Buffer {
fn as_ref(&self) -> &[u8] {
let len = self.0.item_count();
let data = self.0.buf_ptr() as *const u8;
unsafe { std::slice::from_raw_parts(data, len) }
kylebarron marked this conversation as resolved.
Show resolved Hide resolved
}
}

struct Pyf64RTreeRef {
buffer: PyU8Buffer,
metadata: TreeMetadata<f64>,
}

impl Pyf64RTreeRef {
fn try_new(buffer: PyU8Buffer) -> PyResult<Self> {
let metadata = TreeMetadata::try_new(buffer.as_ref()).unwrap();
Ok(Self { buffer, metadata })
}
}

impl AsRef<[u8]> for Pyf64RTreeRef {
fn as_ref(&self) -> &[u8] {
self.buffer.as_ref()
}
}

impl RTreeIndex<f64> for Pyf64RTreeRef {
fn boxes(&self) -> &[f64] {
self.metadata.boxes_slice(self.as_ref())
}

fn indices(&self) -> Indices {
self.metadata.indices_slice(self.as_ref())
}

fn level_bounds(&self) -> &[usize] {
self.metadata.level_bounds()
}

fn node_size(&self) -> usize {
self.metadata.node_size()
}

fn num_items(&self) -> usize {
self.metadata.num_items()
}

fn num_nodes(&self) -> usize {
self.metadata.num_nodes()
}
}

struct Pyf32RTreeRef {
buffer: PyU8Buffer,
metadata: TreeMetadata<f32>,
}

impl Pyf32RTreeRef {
fn try_new(buffer: PyU8Buffer) -> PyResult<Self> {
let metadata = TreeMetadata::try_new(buffer.as_ref()).unwrap();
Ok(Self { buffer, metadata })
}
}

impl AsRef<[u8]> for Pyf32RTreeRef {
fn as_ref(&self) -> &[u8] {
self.buffer.as_ref()
}
}

impl RTreeIndex<f32> for Pyf32RTreeRef {
fn boxes(&self) -> &[f32] {
self.metadata.boxes_slice(self.as_ref())
}

fn indices(&self) -> Indices {
self.metadata.indices_slice(self.as_ref())
}

fn level_bounds(&self) -> &[usize] {
self.metadata.level_bounds()
}

fn node_size(&self) -> usize {
self.metadata.node_size()
}

fn num_items(&self) -> usize {
self.metadata.num_items()
}

fn num_nodes(&self) -> usize {
self.metadata.num_nodes()
}
}

pub(crate) enum PyRTreeRef {
Float32(Pyf32RTreeRef),
Float64(Pyf64RTreeRef),
}
kylebarron marked this conversation as resolved.
Show resolved Hide resolved

impl<'py> FromPyObject<'py> for PyRTreeRef {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
let buffer = PyU8Buffer::extract_bound(ob)?;
let ct = CoordType::from_buffer(&buffer.as_ref()).unwrap();
match ct {
CoordType::Float32 => Ok(Self::Float32(Pyf32RTreeRef::try_new(buffer)?)),
CoordType::Float64 => Ok(Self::Float64(Pyf64RTreeRef::try_new(buffer)?)),
_ => todo!(),
}
}
}

impl PyRTreeRef {
fn num_items(&self) -> usize {
match self {
Self::Float32(index) => index.num_items(),
Self::Float64(index) => index.num_items(),
}
}

fn num_nodes(&self) -> usize {
match self {
Self::Float32(index) => index.num_nodes(),
Self::Float64(index) => index.num_nodes(),
}
}

fn node_size(&self) -> usize {
match self {
Self::Float32(index) => index.node_size(),
Self::Float64(index) => index.node_size(),
}
}

fn num_levels(&self) -> usize {
match self {
Self::Float32(index) => index.num_levels(),
Self::Float64(index) => index.num_levels(),
}
}

fn num_bytes(&self) -> usize {
match self {
Self::Float32(index) => AsRef::as_ref(index).len(),
Self::Float64(index) => AsRef::as_ref(index).len(),
}
}

fn boxes_at_level<'py>(&'py self, py: Python<'py>, level: usize) -> PyResult<PyObject> {
match self {
Self::Float32(index) => {
let boxes = index
.boxes_at_level(level)
.map_err(|err| PyIndexError::new_err(err.to_string()))?;
let array = PyArray1::from_slice_bound(py, boxes);
Ok(array.reshape([boxes.len() / 4, 4])?.into_py(py))
}
Self::Float64(index) => {
let boxes = index
.boxes_at_level(level)
.map_err(|err| PyIndexError::new_err(err.to_string()))?;
let array = PyArray1::from_slice_bound(py, boxes);
Ok(array.reshape([boxes.len() / 4, 4])?.into_py(py))
}
}
}
}

/// Search an RTree given the provided bounding box.
///
/// Results are the indexes of the inserted objects in insertion order.
///
/// Args:
/// tree: tree or buffer to search
/// min_x: min x coordinate of bounding box
/// min_y: min y coordinate of bounding box
/// max_x: max x coordinate of bounding box
/// max_y: max y coordinate of bounding box
#[pyfunction]
pub(crate) fn search_rtree(
py: Python,
tree: PyRTreeRef,
min_x: f64,
min_y: f64,
max_x: f64,
max_y: f64,
) -> Bound<'_, PyArray1<usize>> {
let result = py.allow_threads(|| match tree {
PyRTreeRef::Float32(tree) => {
let (min_x, min_y, max_x, max_y) = f64_box_to_f32(min_x, min_y, max_x, max_y);
tree.search(min_x, min_y, max_x, max_y)
}
PyRTreeRef::Float64(tree) => tree.search(min_x, min_y, max_x, max_y),
});

PyArray1::from_vec_bound(py, result)
}

enum RTreeInner {
Float32(OwnedRTree<f32>),
Float64(OwnedRTree<f64>),
Expand Down
15 changes: 14 additions & 1 deletion src/rtree/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::rtree::util::compute_num_nodes;

/// Common metadata to describe a tree
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct TreeMetadata<N: IndexableNum> {
pub struct TreeMetadata<N: IndexableNum> {
pub(crate) node_size: usize,
pub(crate) num_items: usize,
pub(crate) num_nodes: usize,
Expand Down Expand Up @@ -112,6 +112,19 @@ impl<N: IndexableNum> TreeMetadata<N> {
[8 + self.nodes_byte_length..8 + self.nodes_byte_length + self.indices_byte_length];
Indices::new(indices_buf, self.num_nodes)
}

pub fn node_size(&self) -> usize {
self.node_size
}
pub fn num_items(&self) -> usize {
self.num_items
}
pub fn num_nodes(&self) -> usize {
self.num_nodes
}
pub fn level_bounds(&self) -> &[usize] {
self.level_bounds.as_slice()
}
}

/// An owned RTree buffer.
Expand Down
2 changes: 1 addition & 1 deletion src/rtree/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ pub mod traversal;
pub mod util;

pub use builder::RTreeBuilder;
pub use index::{OwnedRTree, RTreeRef};
pub use index::{OwnedRTree, RTreeRef, TreeMetadata};
pub use r#trait::RTreeIndex;
Loading