-
Notifications
You must be signed in to change notification settings - Fork 154
/
Copy pathfps.py
157 lines (123 loc) · 6.18 KB
/
fps.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from typing import List, Optional, Union
import torch
from torch import Tensor
import torch_cluster.typing
@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], Optional[int], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover
@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[int], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover
@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], Optional[int], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover
@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[int], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover
@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover
@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover
@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover
@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover
def fps( # noqa
src: torch.Tensor,
batch: Optional[Tensor] = None,
ratio: Optional[Union[Tensor, float]] = None,
num_points: Optional[Union[Tensor, int]] = None,
random_start: bool = True,
batch_size: Optional[int] = None,
ptr: Optional[Union[Tensor, List[int]]] = None,
):
r""""A sampling algorithm from the `"PointNet++: Deep Hierarchical Feature
Learning on Point Sets in a Metric Space"
<https://arxiv.org/abs/1706.02413>`_ paper, which iteratively samples the
most distant point with regard to the rest points.
Args:
src (Tensor): Point feature matrix
:math:`\mathbf{X} \in \mathbb{R}^{N \times F}`.
batch (LongTensor, optional): Batch vector
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
ratio (float or Tensor, optional): Sampling ratio.
Only ratio or num_points can be specified.
(default: :obj:`0.5`)
num_points (int, optional): Number of returned points.
Only ratio or num_points can be specified.
(default: :obj:`None`)
random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
batch_size (int, optional): The number of examples :math:`B`.
Automatically calculated if not given. (default: :obj:`None`)
ptr (torch.Tensor or [int], optional): If given, batch assignment will
be determined based on boundaries in CSR representation, *e.g.*,
:obj:`batch=[0,0,1,1,1,2]` translates to :obj:`ptr=[0,2,5,6]`.
(default: :obj:`None`)
:rtype: :class:`LongTensor`
.. code-block:: python
import torch
from torch_cluster import fps
src = torch.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]])
batch = torch.tensor([0, 0, 0, 0])
index = fps(src, batch, ratio=0.5)
"""
# check if only of of ratio or num_points is set
# if no one is set, fallback to ratio = 0.5
if ratio is not None and num_points is not None:
raise ValueError("Only one of ratio and num_points can be specified.")
r: Optional[Tensor] = None
if ratio is None:
if num_points is None:
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
else:
r = torch.tensor(0.0, dtype=src.dtype, device=src.device)
elif isinstance(ratio, float):
r = torch.tensor(ratio, dtype=src.dtype, device=src.device)
else:
r = ratio
num: Optional[Tensor] = None
if num_points is None:
num = torch.tensor(0, dtype=torch.long, device=src.device)
elif isinstance(num_points, int):
num = torch.tensor(num_points, dtype=torch.long, device=src.device)
else:
num = num_points
assert r is not None and num is not None
if r.sum() == 0 and num.sum() == 0:
raise ValueError("At least one of ratio or num_points should be > 0")
if ptr is not None:
if isinstance(ptr, list) and torch_cluster.typing.WITH_PTR_LIST:
return torch.ops.torch_cluster.fps_ptr_list(
src, ptr, r, random_start
)
if isinstance(ptr, list):
return torch.ops.torch_cluster.fps(
src, torch.tensor(ptr, device=src.device), r, num, random_start
)
else:
return torch.ops.torch_cluster.fps(src, ptr, r, num, random_start)
if batch is not None:
assert src.size(0) == batch.numel()
if batch_size is None:
batch_size = int(batch.max()) + 1
deg = src.new_zeros(batch_size, dtype=torch.long)
deg.scatter_add_(0, batch, torch.ones_like(batch))
ptr_vec = deg.new_zeros(batch_size + 1)
torch.cumsum(deg, 0, out=ptr_vec[1:])
else:
ptr_vec = torch.tensor([0, src.size(0)], device=src.device)
return torch.ops.torch_cluster.fps(src, ptr_vec, r, num, random_start)