7
7
8
8
9
9
@torch .jit ._overload # noqa
10
- def fps (src , batch , ratio , random_start , batch_size , ptr ): # noqa
11
- # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
10
+ def fps (src , batch , ratio , num_points , random_start , batch_size , ptr ): # noqa
11
+ # type: (Tensor, Optional[Tensor], Optional[float], Optional[int], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
12
12
pass # pragma: no cover
13
13
14
14
15
15
@torch .jit ._overload # noqa
16
- def fps (src , batch , ratio , random_start , batch_size , ptr ): # noqa
17
- # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
16
+ def fps (src , batch , ratio , num_points , random_start , batch_size , ptr ): # noqa
17
+ # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[int], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
18
18
pass # pragma: no cover
19
19
20
20
21
21
@torch .jit ._overload # noqa
22
- def fps (src , batch , ratio , random_start , batch_size , ptr ): # noqa
23
- # type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
22
+ def fps (src , batch , ratio , num_points , random_start , batch_size , ptr ): # noqa
23
+ # type: (Tensor, Optional[Tensor], Optional[float], Optional[int], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
24
24
pass # pragma: no cover
25
25
26
26
27
27
@torch .jit ._overload # noqa
28
- def fps (src , batch , ratio , random_start , batch_size , ptr ): # noqa
29
- # type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
28
+ def fps (src , batch , ratio , num_points , random_start , batch_size , ptr ): # noqa
29
+ # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[int], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
30
+ pass # pragma: no cover
31
+
32
+ @torch .jit ._overload # noqa
33
+ def fps (src , batch , ratio , num_points , random_start , batch_size , ptr ): # noqa
34
+ # type: (Tensor, Optional[Tensor], Optional[float], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
35
+ pass # pragma: no cover
36
+
37
+
38
+ @torch .jit ._overload # noqa
39
+ def fps (src , batch , ratio , num_points , random_start , batch_size , ptr ): # noqa
40
+ # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
41
+ pass # pragma: no cover
42
+
43
+
44
+ @torch .jit ._overload # noqa
45
+ def fps (src , batch , ratio , num_points , random_start , batch_size , ptr ): # noqa
46
+ # type: (Tensor, Optional[Tensor], Optional[float], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
47
+ pass # pragma: no cover
48
+
49
+
50
+ @torch .jit ._overload # noqa
51
+ def fps (src , batch , ratio , num_points , random_start , batch_size , ptr ): # noqa
52
+ # type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
30
53
pass # pragma: no cover
31
54
32
55
33
56
def fps ( # noqa
34
57
src : torch .Tensor ,
35
58
batch : Optional [Tensor ] = None ,
36
59
ratio : Optional [Union [Tensor , float ]] = None ,
60
+ num_points : Optional [Union [Tensor , int ]] = None ,
37
61
random_start : bool = True ,
38
62
batch_size : Optional [int ] = None ,
39
63
ptr : Optional [Union [Tensor , List [int ]]] = None ,
@@ -50,7 +74,11 @@ def fps( # noqa
50
74
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
51
75
node to a specific example. (default: :obj:`None`)
52
76
ratio (float or Tensor, optional): Sampling ratio.
77
+ Only ratio or num_points can be specified.
53
78
(default: :obj:`0.5`)
79
+ num_points (int, optional): Number of returned points.
80
+ Only ratio or num_points can be specified.
81
+ (default: :obj:`None`)
54
82
random_start (bool, optional): If set to :obj:`False`, use the first
55
83
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
56
84
batch_size (int, optional): The number of examples :math:`B`.
@@ -71,25 +99,47 @@ def fps( # noqa
71
99
batch = torch.tensor([0, 0, 0, 0])
72
100
index = fps(src, batch, ratio=0.5)
73
101
"""
102
+ # check if only of of ratio or num_points is set
103
+ # if no one is set, fallback to ratio = 0.5
104
+ if ratio is not None and num_points is not None :
105
+ raise ValueError ("Only one of ratio and num_points can be specified." )
106
+
74
107
r : Optional [Tensor ] = None
75
108
if ratio is None :
76
- r = torch .tensor (0.5 , dtype = src .dtype , device = src .device )
109
+ if num_points is None :
110
+ r = torch .tensor (0.5 , dtype = src .dtype , device = src .device )
111
+ else :
112
+ r = torch .tensor (0.0 , dtype = src .dtype , device = src .device )
77
113
elif isinstance (ratio , float ):
78
114
r = torch .tensor (ratio , dtype = src .dtype , device = src .device )
79
115
else :
80
116
r = ratio
81
- assert r is not None
117
+
118
+ num : Optional [Tensor ] = None
119
+ if num_points is None :
120
+ num = torch .tensor (0 , dtype = torch .long , device = src .device )
121
+ elif isinstance (num_points , int ):
122
+ num = torch .tensor (num_points , dtype = torch .long , device = src .device )
123
+ else :
124
+ num = num_points
125
+
126
+ assert r is not None and num is not None
127
+
128
+ if r .sum () == 0 and num .sum () == 0 :
129
+ raise ValueError ("At least one of ratio or num_points should be > 0" )
82
130
83
131
if ptr is not None :
84
132
if isinstance (ptr , list ) and torch_cluster .typing .WITH_PTR_LIST :
85
133
return torch .ops .torch_cluster .fps_ptr_list (
86
- src , ptr , r , random_start )
134
+ src , ptr , r , random_start
135
+ )
87
136
88
137
if isinstance (ptr , list ):
89
138
return torch .ops .torch_cluster .fps (
90
- src , torch .tensor (ptr , device = src .device ), r , random_start )
139
+ src , torch .tensor (ptr , device = src .device ), r , num , random_start
140
+ )
91
141
else :
92
- return torch .ops .torch_cluster .fps (src , ptr , r , random_start )
142
+ return torch .ops .torch_cluster .fps (src , ptr , r , num , random_start )
93
143
94
144
if batch is not None :
95
145
assert src .size (0 ) == batch .numel ()
@@ -104,4 +154,4 @@ def fps( # noqa
104
154
else :
105
155
ptr_vec = torch .tensor ([0 , src .size (0 )], device = src .device )
106
156
107
- return torch .ops .torch_cluster .fps (src , ptr_vec , r , random_start )
157
+ return torch .ops .torch_cluster .fps (src , ptr_vec , r , num , random_start )
0 commit comments