Skip to content

Commit

Permalink
chore: Update pybind11 submodule to commit 3e9dfa2
Browse files Browse the repository at this point in the history
  • Loading branch information
atksh committed Apr 13, 2024
1 parent 4c42f1a commit 4d92cd3
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 15 deletions.
28 changes: 16 additions & 12 deletions cpp/prtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,36 +51,38 @@ namespace py = pybind11;
template <class T>
using vec = std::vector<T>;

template <typename Sequence >
inline py::array_t<typename Sequence::value_type> as_pyarray(Sequence& seq) {
template <typename Sequence>
inline py::array_t<typename Sequence::value_type> as_pyarray(Sequence &seq)
{

auto size = seq.size();
auto data = seq.data();
std::unique_ptr<Sequence> seq_ptr = std::make_unique<Sequence>(std::move(seq));
auto capsule = py::capsule(seq_ptr.get(), [](void *p) { std::unique_ptr<Sequence>(reinterpret_cast<Sequence*>(p)); });
auto capsule = py::capsule(seq_ptr.get(), [](void *p)
{ std::unique_ptr<Sequence>(reinterpret_cast<Sequence *>(p)); });
seq_ptr.release();
return py::array(size, data, capsule);
}

template <typename T>
auto list_list_to_arrays(vec<vec<T>> out_ll){
auto list_list_to_arrays(vec<vec<T>> out_ll)
{
vec<T> out_s;
out_s.reserve(out_ll.size());
std::size_t sum = 0;
for (auto &&i : out_ll) {
for (auto &&i : out_ll)
{
out_s.push_back(i.size());
sum += i.size();
}
vec<T> out;
out.reserve(sum);
for(const auto &v: out_ll)
for (const auto &v : out_ll)
out.insert(out.end(), v.begin(), v.end());

return make_tuple(
std::move(as_pyarray(out_s))
,
std::move(as_pyarray(out))
);
std::move(as_pyarray(out_s)),
std::move(as_pyarray(out)));
}

template <class T, size_t StaticCapacity>
Expand Down Expand Up @@ -242,7 +244,7 @@ class BB
}
for (int i = 0; i < D; ++i)
{
flags[i] = -minima[i] < maxima[i];
flags[i] = -minima[i] <= maxima[i];
}
for (int i = 0; i < D; ++i)
{
Expand Down Expand Up @@ -1291,7 +1293,8 @@ class PRTree
return out;
}

auto find_all_array(const py::array_t<float> &x){
auto find_all_array(const py::array_t<float> &x)
{
return list_list_to_arrays(std::move(find_all(x)));
}

Expand Down Expand Up @@ -1334,6 +1337,7 @@ class PRTree
};

bfs<T, B, D>(std::move(find_func), flat_tree, target);
std::sort(out.begin(), out.end());
return out;
}

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
numpy>=1.16
numpy>=1.16,<2.0
pybind11; platform_machine != "x86_64" and platform_machine != "amd64" and platform_machine != "AMD64" and sys_platform == 'darwin' # for m1 mac
cmake; platform_machine != "x86_64" and platform_machine != "amd64" and platform_machine != "AMD64" and sys_platform == 'darwin' # for m1 mac
6 changes: 5 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from setuptools import Extension, find_packages, setup
from setuptools.command.build_ext import build_ext

version = "v0.6.0"
version = "v0.6.1"

sys.path.append("./tests")

Expand Down Expand Up @@ -109,6 +109,10 @@ def build_extension(self, ext):
classifiers=[
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
],
)
29 changes: 29 additions & 0 deletions tests/test_PRTree.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,32 @@ def test_obj(seed, PRTree, dim, tmp_path):
idx = prtree.query(q)
return_obj = prtree2.query(q, return_obj=True)
assert set(return_obj) == set([obj[i] for i in idx])


def test_readme():
idxes = np.array([1, 2])
rects = np.array([[0.0, 0.0, 1.0, 0.5], [1.0, 1.5, 1.2, 3.0]])
prtree = PRTree2D(idxes, rects)

# batch query
q = np.array([[0.5, 0.2, 0.6, 0.3], [0.8, 0.5, 1.5, 3.5]])
result = prtree.batch_query(q)
assert result == [[1], [1, 2]]

# Insert
prtree.insert(3, [1.0, 1.0, 2.0, 2.0])
q = np.array([[0.5, 0.2, 0.6, 0.3], [0.8, 0.5, 1.5, 3.5]])
result = prtree.batch_query(q)
assert result == [[1], [1, 2, 3]]

# Erase
prtree.erase(2)
result = prtree.batch_query(q)
assert result == [[1], [1, 3]]

# non-batch query
assert prtree.query([0.5, 0.5, 1.0, 1.0]) == [1, 3]

# point query
assert prtree.query([0.5, 0.5]) == [1]
assert prtree.query(0.5, 0.5) == [1]
2 changes: 1 addition & 1 deletion third/pybind11
Submodule pybind11 updated 212 files

0 comments on commit 4d92cd3

Please sign in to comment.