forked from crvs/KDTree
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathKDTree.cpp
323 lines (268 loc) · 8.32 KB
/
KDTree.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
/*
* file: KDTree.hpp
* author: J. Frederico Carvalho
*
* This is an adaptation of the KD-tree implementation in rosetta code
* https://rosettacode.org/wiki/K-d_tree
*
* It is a reimplementation of the C code using C++. It also includes a few
* more queries than the original, namely finding all points at a distance
* smaller than some given distance to a point.
*
*/
#include <algorithm>
#include <cmath>
#include <functional>
#include <iterator>
#include <limits>
#include <memory>
#include <vector>
#include "KDTree.hpp"
KDNode::KDNode() = default;
KDNode::KDNode(const point_t &pt, const size_t &idx_, const KDNodePtr &left_,
const KDNodePtr &right_) {
x = pt;
index = idx_;
left = left_;
right = right_;
}
KDNode::KDNode(const pointIndex &pi, const KDNodePtr &left_,
const KDNodePtr &right_) {
x = pi.first;
index = pi.second;
left = left_;
right = right_;
}
KDNode::~KDNode() = default;
double KDNode::coord(const size_t &idx) { return x.at(idx); }
KDNode::operator bool() { return (!x.empty()); }
KDNode::operator point_t() { return x; }
KDNode::operator size_t() { return index; }
KDNode::operator pointIndex() { return pointIndex(x, index); }
KDNodePtr NewKDNodePtr() {
KDNodePtr mynode = std::make_shared< KDNode >();
return mynode;
}
inline double dist2(const point_t &a, const point_t &b) {
double distc = 0;
for (size_t i = 0; i < a.size(); i++) {
double di = a.at(i) - b.at(i);
distc += di * di;
}
return distc;
}
inline double dist2(const KDNodePtr &a, const KDNodePtr &b) {
return dist2(a->x, b->x);
}
inline double dist(const point_t &a, const point_t &b) {
return std::sqrt(dist2(a, b));
}
inline double dist(const KDNodePtr &a, const KDNodePtr &b) {
return std::sqrt(dist2(a, b));
}
comparer::comparer(size_t idx_) : idx{idx_} {};
inline bool comparer::compare_idx(const pointIndex &a, //
const pointIndex &b //
) {
return (a.first.at(idx) < b.first.at(idx)); //
}
inline void sort_on_idx(const pointIndexArr::iterator &begin, //
const pointIndexArr::iterator &end, //
size_t idx) {
comparer comp(idx);
comp.idx = idx;
using std::placeholders::_1;
using std::placeholders::_2;
std::sort(begin, end, std::bind(&comparer::compare_idx, comp, _1, _2));
}
using pointVec = std::vector< point_t >;
KDNodePtr KDTree::make_tree(const pointIndexArr::iterator &begin, //
const pointIndexArr::iterator &end, //
const size_t &length, //
const size_t &level //
) {
if (begin == end) {
return NewKDNodePtr(); // empty tree
}
size_t dim = begin->first.size();
if (length > 1) {
sort_on_idx(begin, end, level);
}
auto middle = begin + (length / 2);
auto l_begin = begin;
auto l_end = middle;
auto r_begin = middle + 1;
auto r_end = end;
size_t l_len = length / 2;
size_t r_len = length - l_len - 1;
KDNodePtr left;
if (l_len > 0 && dim > 0) {
left = make_tree(l_begin, l_end, l_len, (level + 1) % dim);
} else {
left = leaf;
}
KDNodePtr right;
if (r_len > 0 && dim > 0) {
right = make_tree(r_begin, r_end, r_len, (level + 1) % dim);
} else {
right = leaf;
}
// KDNode result = KDNode();
return std::make_shared< KDNode >(*middle, left, right);
}
KDTree::KDTree(pointVec point_array) {
leaf = std::make_shared< KDNode >();
// iterators
pointIndexArr arr;
for (size_t i = 0; i < point_array.size(); i++) {
arr.push_back(pointIndex(point_array.at(i), i));
}
auto begin = arr.begin();
auto end = arr.end();
size_t length = arr.size();
size_t level = 0; // starting
root = KDTree::make_tree(begin, end, length, level);
}
KDNodePtr KDTree::nearest_( //
const KDNodePtr &branch, //
const point_t &pt, //
const size_t &level, //
const KDNodePtr &best, //
const double &best_dist //
) {
double d, dx, dx2;
if (!bool(*branch)) {
return NewKDNodePtr(); // basically, null
}
point_t branch_pt(*branch);
size_t dim = branch_pt.size();
d = dist2(branch_pt, pt);
dx = branch_pt.at(level) - pt.at(level);
dx2 = dx * dx;
KDNodePtr best_l = best;
double best_dist_l = best_dist;
if (d < best_dist) {
best_dist_l = d;
best_l = branch;
}
size_t next_lv = (level + 1) % dim;
KDNodePtr section;
KDNodePtr other;
// select which branch makes sense to check
if (dx > 0) {
section = branch->left;
other = branch->right;
} else {
section = branch->right;
other = branch->left;
}
// keep nearest neighbor from further down the tree
KDNodePtr further = nearest_(section, pt, next_lv, best_l, best_dist_l);
if (!further->x.empty()) {
double dl = dist2(further->x, pt);
if (dl < best_dist_l) {
best_dist_l = dl;
best_l = further;
}
}
// only check the other branch if it makes sense to do so
if (dx2 < best_dist_l) {
further = nearest_(other, pt, next_lv, best_l, best_dist_l);
if (!further->x.empty()) {
double dl = dist2(further->x, pt);
if (dl < best_dist_l) {
best_dist_l = dl;
best_l = further;
}
}
}
return best_l;
};
// default caller
KDNodePtr KDTree::nearest_(const point_t &pt) {
size_t level = 0;
// KDNodePtr best = branch;
double branch_dist = dist2(point_t(*root), pt);
return nearest_(root, // beginning of tree
pt, // point we are querying
level, // start from level 0
root, // best is the root
branch_dist); // best_dist = branch_dist
};
point_t KDTree::nearest_point(const point_t &pt) {
return point_t(*nearest_(pt));
};
size_t KDTree::nearest_index(const point_t &pt) {
return size_t(*nearest_(pt));
};
pointIndex KDTree::nearest_pointIndex(const point_t &pt) {
KDNodePtr Nearest = nearest_(pt);
return pointIndex(point_t(*Nearest), size_t(*Nearest));
}
pointIndexArr KDTree::neighborhood_( //
const KDNodePtr &branch, //
const point_t &pt, //
const double &rad, //
const size_t &level //
) {
double d, dx, dx2;
if (!bool(*branch)) {
// branch has no point, means it is a leaf,
// no points to add
return pointIndexArr();
}
size_t dim = pt.size();
double r2 = rad * rad;
d = dist2(point_t(*branch), pt);
dx = point_t(*branch).at(level) - pt.at(level);
dx2 = dx * dx;
pointIndexArr nbh, nbh_s, nbh_o;
if (d <= r2) {
nbh.push_back(pointIndex(*branch));
}
//
KDNodePtr section;
KDNodePtr other;
if (dx > 0) {
section = branch->left;
other = branch->right;
} else {
section = branch->right;
other = branch->left;
}
nbh_s = neighborhood_(section, pt, rad, (level + 1) % dim);
nbh.insert(nbh.end(), nbh_s.begin(), nbh_s.end());
if (dx2 < r2) {
nbh_o = neighborhood_(other, pt, rad, (level + 1) % dim);
nbh.insert(nbh.end(), nbh_o.begin(), nbh_o.end());
}
return nbh;
};
pointIndexArr KDTree::neighborhood( //
const point_t &pt, //
const double &rad) {
size_t level = 0;
return neighborhood_(root, pt, rad, level);
}
pointVec KDTree::neighborhood_points( //
const point_t &pt, //
const double &rad) {
size_t level = 0;
pointIndexArr nbh = neighborhood_(root, pt, rad, level);
pointVec nbhp;
nbhp.resize(nbh.size());
std::transform(nbh.begin(), nbh.end(), nbhp.begin(),
[](pointIndex x) { return x.first; });
return nbhp;
}
indexArr KDTree::neighborhood_indices( //
const point_t &pt, //
const double &rad) {
size_t level = 0;
pointIndexArr nbh = neighborhood_(root, pt, rad, level);
indexArr nbhi;
nbhi.resize(nbh.size());
std::transform(nbh.begin(), nbh.end(), nbhi.begin(),
[](pointIndex x) { return x.second; });
return nbhi;
}