Skip to content

Commit 00af9dd

Browse files
LarsHaalckLars Haalck
authored and
Lars Haalck
committed
Add num_points to fps as an alternative to ratio
1 parent 616704a commit 00af9dd

File tree

8 files changed

+124
-30
lines changed

8 files changed

+124
-30
lines changed

csrc/cluster.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ CLUSTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
1111
} // namespace cluster
1212

1313
CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio,
14-
bool random_start);
14+
int64_t num_points, bool random_start);
1515

1616
CLUSTER_API torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col,
1717
torch::optional<torch::Tensor> optional_weight);

csrc/cpu/fps_cpu.cpp

+13-4
Original file line numberDiff line numberDiff line change
@@ -9,21 +9,30 @@ inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
99
}
1010

1111
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
12-
bool random_start) {
12+
torch::Tensor num_points, bool random_start) {
1313

1414
CHECK_CPU(src);
1515
CHECK_CPU(ptr);
1616
CHECK_CPU(ratio);
17+
CHECK_CPU(num_points);
1718
CHECK_INPUT(ptr.dim() == 1);
1819

1920
src = src.view({src.size(0), -1}).contiguous();
2021
ptr = ptr.contiguous();
2122
auto batch_size = ptr.numel() - 1;
2223

2324
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
24-
auto out_ptr = deg.toType(torch::kFloat) * ratio;
25-
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
26-
25+
torch::Tensor out_ptr;
26+
if (num_points.sum().item<int64_t>() == 0) {
27+
out_ptr = deg.toType(torch::kFloat) * ratio;
28+
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
29+
} else {
30+
TORCH_CHECK((deg.toType(torch::kLong) >= num_points.toType(torch::kLong)).all().item<int64_t>(),
31+
"Passed tensor has fewer elements than requested number of returned points.")
32+
out_ptr = deg.toType(torch::kLong)
33+
.minimum(num_points.toType(torch::kLong))
34+
.cumsum(0);
35+
}
2736
auto out = torch::empty({out_ptr[-1].data_ptr<int64_t>()[0]}, ptr.options());
2837

2938
auto ptr_data = ptr.data_ptr<int64_t>();

csrc/cpu/fps_cpu.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
#include "../extensions.h"
44

55
torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
6-
bool random_start);
6+
torch::Tensor num_points, bool random_start);

csrc/cuda/fps_cuda.cu

+15-3
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,13 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
6565
}
6666

6767
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
68-
torch::Tensor ratio, bool random_start) {
68+
torch::Tensor ratio, torch::Tensor num_points,
69+
bool random_start) {
6970

7071
CHECK_CUDA(src);
7172
CHECK_CUDA(ptr);
7273
CHECK_CUDA(ratio);
74+
CHECK_CUDA(num_points);
7375
CHECK_INPUT(ptr.dim() == 1);
7476
c10::cuda::MaybeSetDevice(src.get_device());
7577

@@ -78,8 +80,18 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
7880
auto batch_size = ptr.numel() - 1;
7981

8082
auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
81-
auto out_ptr = deg.toType(ratio.scalar_type()) * ratio;
82-
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
83+
torch::Tensor out_ptr;
84+
if (num_points.sum().item<int64_t>() == 0) {
85+
out_ptr = deg.toType(ratio.scalar_type()) * ratio;
86+
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
87+
} else {
88+
TORCH_CHECK((deg.toType(torch::kLong) >= num_points.toType(torch::kLong)).all().item<int64_t>(),
89+
"Passed tensor has fewer elements than requested number of returned points.")
90+
out_ptr = deg.toType(torch::kLong)
91+
.minimum(num_points.toType(torch::kLong))
92+
.cumsum(0);
93+
}
94+
8395
out_ptr = torch::cat({torch::zeros({1}, ptr.options()), out_ptr}, 0);
8496

8597
torch::Tensor start;

csrc/cuda/fps_cuda.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,5 @@
33
#include "../extensions.h"
44

55
torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
6-
torch::Tensor ratio, bool random_start);
6+
torch::Tensor ratio, torch::Tensor num_points,
7+
bool random_start);

csrc/fps.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@ PyMODINIT_FUNC PyInit__fps_cpu(void) { return NULL; }
1919
#endif
2020
#endif
2121

22-
CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
23-
bool random_start) {
22+
CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr,
23+
torch::Tensor ratio, torch::Tensor num_points,
24+
bool random_start) {
2425
if (src.device().is_cuda()) {
2526
#ifdef WITH_CUDA
26-
return fps_cuda(src, ptr, ratio, random_start);
27+
return fps_cuda(src, ptr, ratio, num_points, random_start);
2728
#else
2829
AT_ERROR("Not compiled with CUDA support");
2930
#endif
3031
} else {
31-
return fps_cpu(src, ptr, ratio, random_start);
32+
return fps_cpu(src, ptr, ratio, num_points, random_start);
3233
}
3334
}
3435

test/test_fps.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
@torch.jit.script
1111
def fps2(x: Tensor, ratio: Tensor) -> Tensor:
12-
return fps(x, None, ratio, False)
12+
return fps(x, None, ratio, None, False)
1313

1414

1515
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
@@ -33,26 +33,36 @@ def test_fps(dtype, device):
3333

3434
out = fps(x, batch, ratio=0.5, random_start=False)
3535
assert out.tolist() == [0, 2, 4, 6]
36+
out = fps(x, batch, num_points=2, random_start=False)
37+
assert out.tolist() == [0, 2, 4, 6]
38+
39+
out = fps(x, batch, num_points=4, random_start=False)
40+
assert out.tolist() == [0, 2, 1, 3, 4, 6, 5, 7]
3641

3742
ratio = torch.tensor(0.5, device=device)
3843
out = fps(x, batch, ratio=ratio, random_start=False)
3944
assert out.tolist() == [0, 2, 4, 6]
4045

4146
out = fps(x, ptr=ptr_list, ratio=0.5, random_start=False)
4247
assert out.tolist() == [0, 2, 4, 6]
43-
4448
out = fps(x, ptr=ptr, ratio=0.5, random_start=False)
4549
assert out.tolist() == [0, 2, 4, 6]
4650

4751
ratio = torch.tensor([0.5, 0.5], device=device)
4852
out = fps(x, batch, ratio=ratio, random_start=False)
4953
assert out.tolist() == [0, 2, 4, 6]
5054

55+
num = torch.tensor([2, 2], device=device)
56+
out = fps(x, batch, num_points=num, random_start=False)
57+
assert out.tolist() == [0, 2, 4, 6]
58+
5159
out = fps(x, random_start=False)
5260
assert out.sort()[0].tolist() == [0, 5, 6, 7]
5361

5462
out = fps(x, ratio=0.5, random_start=False)
5563
assert out.sort()[0].tolist() == [0, 5, 6, 7]
64+
out = fps(x, num_points=4, random_start=False)
65+
assert out.sort()[0].tolist() == [0, 5, 6, 7]
5666

5767
out = fps(x, ratio=torch.tensor(0.5, device=device), random_start=False)
5868
assert out.sort()[0].tolist() == [0, 5, 6, 7]
@@ -63,6 +73,17 @@ def test_fps(dtype, device):
6373
out = fps2(x, torch.tensor([0.5], device=device))
6474
assert out.sort()[0].tolist() == [0, 5, 6, 7]
6575

76+
# requesting too many points
77+
with pytest.raises(RuntimeError):
78+
out = fps(x, batch, num_points=100, random_start=False)
79+
80+
with pytest.raises(RuntimeError):
81+
out = fps(x, batch, num_points=5, random_start=False)
82+
83+
# invalid argument combination
84+
with pytest.raises(ValueError):
85+
out = fps(x, batch, ratio=0.0, num_points=0, random_start=False)
86+
6687

6788
@pytest.mark.parametrize('device', devices)
6889
def test_random_fps(device):

torch_cluster/fps.py

+64-14
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,57 @@
77

88

99
@torch.jit._overload # noqa
10-
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
11-
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
10+
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
11+
# type: (Tensor, Optional[Tensor], Optional[float], Optional[int], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
1212
pass # pragma: no cover
1313

1414

1515
@torch.jit._overload # noqa
16-
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
17-
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
16+
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
17+
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[int], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
1818
pass # pragma: no cover
1919

2020

2121
@torch.jit._overload # noqa
22-
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
23-
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
22+
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
23+
# type: (Tensor, Optional[Tensor], Optional[float], Optional[int], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
2424
pass # pragma: no cover
2525

2626

2727
@torch.jit._overload # noqa
28-
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
29-
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
28+
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
29+
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[int], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
30+
pass # pragma: no cover
31+
32+
@torch.jit._overload # noqa
33+
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
34+
# type: (Tensor, Optional[Tensor], Optional[float], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
35+
pass # pragma: no cover
36+
37+
38+
@torch.jit._overload # noqa
39+
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
40+
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
41+
pass # pragma: no cover
42+
43+
44+
@torch.jit._overload # noqa
45+
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
46+
# type: (Tensor, Optional[Tensor], Optional[float], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
47+
pass # pragma: no cover
48+
49+
50+
@torch.jit._overload # noqa
51+
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
52+
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
3053
pass # pragma: no cover
3154

3255

3356
def fps( # noqa
3457
src: torch.Tensor,
3558
batch: Optional[Tensor] = None,
3659
ratio: Optional[Union[Tensor, float]] = None,
60+
num_points: Optional[Union[Tensor, int]] = None,
3761
random_start: bool = True,
3862
batch_size: Optional[int] = None,
3963
ptr: Optional[Union[Tensor, List[int]]] = None,
@@ -50,7 +74,11 @@ def fps( # noqa
5074
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
5175
node to a specific example. (default: :obj:`None`)
5276
ratio (float or Tensor, optional): Sampling ratio.
77+
Only ratio or num_points can be specified.
5378
(default: :obj:`0.5`)
79+
num_points (int, optional): Number of returned points.
80+
Only ratio or num_points can be specified.
81+
(default: :obj:`None`)
5482
random_start (bool, optional): If set to :obj:`False`, use the first
5583
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
5684
batch_size (int, optional): The number of examples :math:`B`.
@@ -71,25 +99,47 @@ def fps( # noqa
7199
batch = torch.tensor([0, 0, 0, 0])
72100
index = fps(src, batch, ratio=0.5)
73101
"""
102+
# check if only of of ratio or num_points is set
103+
# if no one is set, fallback to ratio = 0.5
104+
if ratio is not None and num_points is not None:
105+
raise ValueError("Only one of ratio and num_points can be specified.")
106+
74107
r: Optional[Tensor] = None
75108
if ratio is None:
76-
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
109+
if num_points is None:
110+
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
111+
else:
112+
r = torch.tensor(0.0, dtype=src.dtype, device=src.device)
77113
elif isinstance(ratio, float):
78114
r = torch.tensor(ratio, dtype=src.dtype, device=src.device)
79115
else:
80116
r = ratio
81-
assert r is not None
117+
118+
num: Optional[Tensor] = None
119+
if num_points is None:
120+
num = torch.tensor(0, dtype=torch.long, device=src.device)
121+
elif isinstance(num_points, int):
122+
num = torch.tensor(num_points, dtype=torch.long, device=src.device)
123+
else:
124+
num = num_points
125+
126+
assert r is not None and num is not None
127+
128+
if r.sum() == 0 and num.sum() == 0:
129+
raise ValueError("At least one of ratio or num_points should be > 0")
82130

83131
if ptr is not None:
84132
if isinstance(ptr, list) and torch_cluster.typing.WITH_PTR_LIST:
85133
return torch.ops.torch_cluster.fps_ptr_list(
86-
src, ptr, r, random_start)
134+
src, ptr, r, random_start
135+
)
87136

88137
if isinstance(ptr, list):
89138
return torch.ops.torch_cluster.fps(
90-
src, torch.tensor(ptr, device=src.device), r, random_start)
139+
src, torch.tensor(ptr, device=src.device), r, num, random_start
140+
)
91141
else:
92-
return torch.ops.torch_cluster.fps(src, ptr, r, random_start)
142+
return torch.ops.torch_cluster.fps(src, ptr, r, num, random_start)
93143

94144
if batch is not None:
95145
assert src.size(0) == batch.numel()
@@ -104,4 +154,4 @@ def fps( # noqa
104154
else:
105155
ptr_vec = torch.tensor([0, src.size(0)], device=src.device)
106156

107-
return torch.ops.torch_cluster.fps(src, ptr_vec, r, random_start)
157+
return torch.ops.torch_cluster.fps(src, ptr_vec, r, num, random_start)

0 commit comments

Comments
 (0)