forked from cmuparlay/parlaylib
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathknn.cpp
70 lines (63 loc) · 2.13 KB
/
knn.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
#include <iostream>
#include <string>
#include <random>
#include "parlay/primitives.h"
#include "parlay/random.h"
#include "parlay/io.h"
#include "knn.h"
// **************************************************************
// Driver
// **************************************************************
// checks 10 random points and returns the number of points with errors
long check(const parlay::sequence<coords>& points, const knn_graph& G, int k) {
long n = points.size();
long num_trials = std::min<long>(20, points.size());
parlay::random_generator gen(27);
std::uniform_int_distribution<long> dis(0, n-1);
auto distance_sq = [] (const coords& a, const coords& b) {
double r = 0.0;
for (int i = 0; i < dims; i++) {
double diff = (a[i] - b[i]);
r += diff*diff; }
return r; };
return parlay::reduce(parlay::tabulate(num_trials, [&] (long a) -> long {
auto r = gen[a];
idx i = dis(r);
coords p = points[i];
auto x = parlay::to_sequence(parlay::sort(parlay::map(points, [&] (auto q) {
return distance_sq(p, q);})).cut(1,k+1));
auto y = parlay::reverse(parlay::map(G[i], [&] (long j) {
return distance_sq(p,points[j]);}));
return y != x;}));
}
int main(int argc, char* argv[]) {
auto usage = "Usage: knn <n>";
if (argc != 2) std::cout << usage << std::endl;
else {
long n;
int k = 10;
try { n = std::stol(argv[1]); }
catch (...) { std::cout << usage << std::endl; return 1; }
parlay::random_generator gen(0);
coord box_size = 1000000000;
std::uniform_int_distribution<coord> dis(0, box_size);
// generate n random points in a cube
auto points = parlay::tabulate(n, [&] (long i) {
auto r = gen[i];
coords pnt;
for (coord& c : pnt) c = dis(r);
return pnt;
});
knn_graph r;
parlay::internal::timer t("Time");
for (int i=0; i < 5; i++) {
r = build_knn_graph(points, k);
t.next("knn");
}
if (check(points, r, k) > 0)
std::cout << "found error" << std::endl;
else
std::cout << "generated " << k << " nearest neighbor graph for " << r.size()
<< " points." << std::endl;
}
}