-
Notifications
You must be signed in to change notification settings - Fork 154
/
Copy pathfps_cpu.cpp
67 lines (53 loc) · 2.2 KB
/
fps_cpu.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include "fps_cpu.h"
#include <ATen/Parallel.h>
#include "utils.h"
inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
return (x - x[idx]).pow_(2).sum(1);
}
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
torch::Tensor num_points, bool random_start) {
CHECK_CPU(src);
CHECK_CPU(ptr);
CHECK_CPU(ratio);
CHECK_CPU(num_points);
CHECK_INPUT(ptr.dim() == 1);
src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous();
auto batch_size = ptr.numel() - 1;
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
torch::Tensor out_ptr;
if (num_points.sum().item<int64_t>() == 0) {
out_ptr = deg.toType(torch::kFloat) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
} else {
TORCH_CHECK((deg.toType(torch::kLong) >= num_points.toType(torch::kLong)).all().item<int64_t>(),
"Passed tensor has fewer elements than requested number of returned points.")
out_ptr = deg.toType(torch::kLong)
.minimum(num_points.toType(torch::kLong))
.cumsum(0);
}
auto out = torch::empty({out_ptr[-1].data_ptr<int64_t>()[0]}, ptr.options());
auto ptr_data = ptr.data_ptr<int64_t>();
auto out_ptr_data = out_ptr.data_ptr<int64_t>();
auto out_data = out.data_ptr<int64_t>();
int64_t grain_size = 1; // Always parallelize over batch dimension.
at::parallel_for(0, batch_size, grain_size, [&](int64_t begin, int64_t end) {
int64_t src_start, src_end, out_start, out_end;
for (int64_t b = begin; b < end; b++) {
src_start = ptr_data[b], src_end = ptr_data[b + 1];
out_start = b == 0 ? 0 : out_ptr_data[b - 1], out_end = out_ptr_data[b];
auto y = src.narrow(0, src_start, src_end - src_start);
int64_t start_idx = 0;
if (random_start)
start_idx = rand() % y.size(0);
out_data[out_start] = src_start + start_idx;
auto dist = get_dist(y, start_idx);
for (int64_t i = 1; i < out_end - out_start; i++) {
int64_t argmax = dist.argmax().data_ptr<int64_t>()[0];
out_data[out_start + i] = src_start + argmax;
dist = torch::min(dist, get_dist(y, argmax));
}
}
});
return out;
}