Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add num_points to fps as an alternative to ratio #218

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion csrc/cluster.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ CLUSTER_INLINE_VARIABLE int64_t _cuda_version = cuda_version();
} // namespace cluster

CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, double ratio,
bool random_start);
int64_t num_points, bool random_start);

CLUSTER_API torch::Tensor graclus(torch::Tensor rowptr, torch::Tensor col,
torch::optional<torch::Tensor> optional_weight);
Expand Down
17 changes: 13 additions & 4 deletions csrc/cpu/fps_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,30 @@ inline torch::Tensor get_dist(torch::Tensor x, int64_t idx) {
}

torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) {
torch::Tensor num_points, bool random_start) {

CHECK_CPU(src);
CHECK_CPU(ptr);
CHECK_CPU(ratio);
CHECK_CPU(num_points);
CHECK_INPUT(ptr.dim() == 1);

src = src.view({src.size(0), -1}).contiguous();
ptr = ptr.contiguous();
auto batch_size = ptr.numel() - 1;

auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(torch::kFloat) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);

torch::Tensor out_ptr;
if (num_points.sum().item<int64_t>() == 0) {
out_ptr = deg.toType(torch::kFloat) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
} else {
TORCH_CHECK((deg.toType(torch::kLong) >= num_points.toType(torch::kLong)).all().item<int64_t>(),
"Passed tensor has fewer elements than requested number of returned points.")
out_ptr = deg.toType(torch::kLong)
.minimum(num_points.toType(torch::kLong))
.cumsum(0);
}
auto out = torch::empty({out_ptr[-1].data_ptr<int64_t>()[0]}, ptr.options());

auto ptr_data = ptr.data_ptr<int64_t>();
Expand Down
2 changes: 1 addition & 1 deletion csrc/cpu/fps_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
#include "../extensions.h"

torch::Tensor fps_cpu(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start);
torch::Tensor num_points, bool random_start);
18 changes: 15 additions & 3 deletions csrc/cuda/fps_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,13 @@ __global__ void fps_kernel(const scalar_t *src, const int64_t *ptr,
}

torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, bool random_start) {
torch::Tensor ratio, torch::Tensor num_points,
bool random_start) {

CHECK_CUDA(src);
CHECK_CUDA(ptr);
CHECK_CUDA(ratio);
CHECK_CUDA(num_points);
CHECK_INPUT(ptr.dim() == 1);
c10::cuda::MaybeSetDevice(src.get_device());

Expand All @@ -78,8 +80,18 @@ torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
auto batch_size = ptr.numel() - 1;

auto deg = ptr.narrow(0, 1, batch_size) - ptr.narrow(0, 0, batch_size);
auto out_ptr = deg.toType(ratio.scalar_type()) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
torch::Tensor out_ptr;
if (num_points.sum().item<int64_t>() == 0) {
out_ptr = deg.toType(ratio.scalar_type()) * ratio;
out_ptr = out_ptr.ceil().toType(torch::kLong).cumsum(0);
} else {
TORCH_CHECK((deg.toType(torch::kLong) >= num_points.toType(torch::kLong)).all().item<int64_t>(),
"Passed tensor has fewer elements than requested number of returned points.")
out_ptr = deg.toType(torch::kLong)
.minimum(num_points.toType(torch::kLong))
.cumsum(0);
}

out_ptr = torch::cat({torch::zeros({1}, ptr.options()), out_ptr}, 0);

torch::Tensor start;
Expand Down
3 changes: 2 additions & 1 deletion csrc/cuda/fps_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
#include "../extensions.h"

torch::Tensor fps_cuda(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, bool random_start);
torch::Tensor ratio, torch::Tensor num_points,
bool random_start);
9 changes: 5 additions & 4 deletions csrc/fps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,17 @@ PyMODINIT_FUNC PyInit__fps_cpu(void) { return NULL; }
#endif
#endif

CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr, torch::Tensor ratio,
bool random_start) {
CLUSTER_API torch::Tensor fps(torch::Tensor src, torch::Tensor ptr,
torch::Tensor ratio, torch::Tensor num_points,
bool random_start) {
if (src.device().is_cuda()) {
#ifdef WITH_CUDA
return fps_cuda(src, ptr, ratio, random_start);
return fps_cuda(src, ptr, ratio, num_points, random_start);
#else
AT_ERROR("Not compiled with CUDA support");
#endif
} else {
return fps_cpu(src, ptr, ratio, random_start);
return fps_cpu(src, ptr, ratio, num_points, random_start);
}
}

Expand Down
25 changes: 23 additions & 2 deletions test/test_fps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

@torch.jit.script
def fps2(x: Tensor, ratio: Tensor) -> Tensor:
return fps(x, None, ratio, False)
return fps(x, None, ratio, None, False)


@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
Expand All @@ -33,26 +33,36 @@ def test_fps(dtype, device):

out = fps(x, batch, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]
out = fps(x, batch, num_points=2, random_start=False)
assert out.tolist() == [0, 2, 4, 6]

out = fps(x, batch, num_points=4, random_start=False)
assert out.tolist() == [0, 2, 1, 3, 4, 6, 5, 7]

ratio = torch.tensor(0.5, device=device)
out = fps(x, batch, ratio=ratio, random_start=False)
assert out.tolist() == [0, 2, 4, 6]

out = fps(x, ptr=ptr_list, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]

out = fps(x, ptr=ptr, ratio=0.5, random_start=False)
assert out.tolist() == [0, 2, 4, 6]

ratio = torch.tensor([0.5, 0.5], device=device)
out = fps(x, batch, ratio=ratio, random_start=False)
assert out.tolist() == [0, 2, 4, 6]

num = torch.tensor([2, 2], device=device)
out = fps(x, batch, num_points=num, random_start=False)
assert out.tolist() == [0, 2, 4, 6]

out = fps(x, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]

out = fps(x, ratio=0.5, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
out = fps(x, num_points=4, random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]

out = fps(x, ratio=torch.tensor(0.5, device=device), random_start=False)
assert out.sort()[0].tolist() == [0, 5, 6, 7]
Expand All @@ -63,6 +73,17 @@ def test_fps(dtype, device):
out = fps2(x, torch.tensor([0.5], device=device))
assert out.sort()[0].tolist() == [0, 5, 6, 7]

# requesting too many points
with pytest.raises(RuntimeError):
out = fps(x, batch, num_points=100, random_start=False)

with pytest.raises(RuntimeError):
out = fps(x, batch, num_points=5, random_start=False)

# invalid argument combination
with pytest.raises(ValueError):
out = fps(x, batch, ratio=0.0, num_points=0, random_start=False)


@pytest.mark.parametrize('device', devices)
def test_random_fps(device):
Expand Down
78 changes: 64 additions & 14 deletions torch_cluster/fps.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,57 @@


@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], Optional[int], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[int], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], Optional[int], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch, ratio, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[int], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover

@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[Tensor]) -> Tensor # noqa
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[float], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover


@torch.jit._overload # noqa
def fps(src, batch, ratio, num_points, random_start, batch_size, ptr): # noqa
# type: (Tensor, Optional[Tensor], Optional[Tensor], Optional[Tensor], bool, Optional[int], Optional[List[int]]) -> Tensor # noqa
pass # pragma: no cover


def fps( # noqa
src: torch.Tensor,
batch: Optional[Tensor] = None,
ratio: Optional[Union[Tensor, float]] = None,
num_points: Optional[Union[Tensor, int]] = None,
random_start: bool = True,
batch_size: Optional[int] = None,
ptr: Optional[Union[Tensor, List[int]]] = None,
Expand All @@ -50,7 +74,11 @@ def fps( # noqa
:math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each
node to a specific example. (default: :obj:`None`)
ratio (float or Tensor, optional): Sampling ratio.
Only ratio or num_points can be specified.
(default: :obj:`0.5`)
num_points (int, optional): Number of returned points.
Only ratio or num_points can be specified.
(default: :obj:`None`)
random_start (bool, optional): If set to :obj:`False`, use the first
node in :math:`\mathbf{X}` as starting node. (default: obj:`True`)
batch_size (int, optional): The number of examples :math:`B`.
Expand All @@ -71,25 +99,47 @@ def fps( # noqa
batch = torch.tensor([0, 0, 0, 0])
index = fps(src, batch, ratio=0.5)
"""
# check if only of of ratio or num_points is set
# if no one is set, fallback to ratio = 0.5
if ratio is not None and num_points is not None:
raise ValueError("Only one of ratio and num_points can be specified.")

r: Optional[Tensor] = None
if ratio is None:
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
if num_points is None:
r = torch.tensor(0.5, dtype=src.dtype, device=src.device)
else:
r = torch.tensor(0.0, dtype=src.dtype, device=src.device)
elif isinstance(ratio, float):
r = torch.tensor(ratio, dtype=src.dtype, device=src.device)
else:
r = ratio
assert r is not None

num: Optional[Tensor] = None
if num_points is None:
num = torch.tensor(0, dtype=torch.long, device=src.device)
elif isinstance(num_points, int):
num = torch.tensor(num_points, dtype=torch.long, device=src.device)
else:
num = num_points

assert r is not None and num is not None

if r.sum() == 0 and num.sum() == 0:
raise ValueError("At least one of ratio or num_points should be > 0")

if ptr is not None:
if isinstance(ptr, list) and torch_cluster.typing.WITH_PTR_LIST:
return torch.ops.torch_cluster.fps_ptr_list(
src, ptr, r, random_start)
src, ptr, r, random_start
)

if isinstance(ptr, list):
return torch.ops.torch_cluster.fps(
src, torch.tensor(ptr, device=src.device), r, random_start)
src, torch.tensor(ptr, device=src.device), r, num, random_start
)
else:
return torch.ops.torch_cluster.fps(src, ptr, r, random_start)
return torch.ops.torch_cluster.fps(src, ptr, r, num, random_start)

if batch is not None:
assert src.size(0) == batch.numel()
Expand All @@ -104,4 +154,4 @@ def fps( # noqa
else:
ptr_vec = torch.tensor([0, src.size(0)], device=src.device)

return torch.ops.torch_cluster.fps(src, ptr_vec, r, random_start)
return torch.ops.torch_cluster.fps(src, ptr_vec, r, num, random_start)
Loading