Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GraphBolt][CUDA] gb.cat. #7786

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions graphbolt/include/graphbolt/cuda_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@
namespace graphbolt {
namespace ops {

/**
* @brief Is equivalent to `torch.cat(tensors, dim=0)`. Works only for
* contiguous tensors.
*
* @param tensors A vector of tensors to be concatenated.
*
* @return torch::cat(tensors, 0).
*/
torch::Tensor CatImpl(const std::vector<torch::Tensor>& tensors);

/**
* @brief Sorts the given input and optionally returns the original indexes.
*
Expand Down
54 changes: 54 additions & 0 deletions graphbolt/src/cat.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/**
* Copyright (c) 2024, mfbalin (Muhammed Fatih Balin)
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* @file cat.cc
* @brief Concatenation operation.
*/
#include <graphbolt/cuda_ops.h>
#include <torch/autograd.h>

#include "./macro.h"
#include "./utils.h"

namespace graphbolt {
namespace ops {

torch::Tensor Cat(const std::vector<torch::Tensor>& tensors) {
bool all_on_gpu = true;
for (const auto& tensor : tensors) {
all_on_gpu = all_on_gpu && utils::is_on_gpu(tensor);
if (!all_on_gpu) break;
}
if (all_on_gpu) {
GRAPHBOLT_DISPATCH_CUDA_ONLY_DEVICE(
c10::DeviceType::CUDA, "unique_and_compact",
{ return ops::CatImpl(tensors); });
}
return torch::cat(tensors, 0);
}

TORCH_LIBRARY_IMPL(graphbolt, CPU, m) { m.impl("cat", &Cat); }

#ifdef GRAPHBOLT_USE_CUDA
TORCH_LIBRARY_IMPL(graphbolt, CUDA, m) { m.impl("cat", &CatImpl); }
#endif

TORCH_LIBRARY_IMPL(graphbolt, Autograd, m) {
m.impl("cat", torch::autograd::autogradNotImplementedFallback());
}

} // namespace ops
} // namespace graphbolt
80 changes: 80 additions & 0 deletions graphbolt/src/cuda/cat.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/**
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* @file cuda/cat.cu
* @brief ExpandIndptr operator implementation on CUDA.
*/
#include <cub/cub.cuh>
#include <limits>

#include "./common.h"
#include "./expand_indptr.cuh"

namespace graphbolt {
namespace ops {

torch::Tensor CatImpl(const std::vector<torch::Tensor>& tensors) {
const int64_t num_batches = tensors.size();
const int64_t original_feature_size = std::accumulate(
tensors.at(0).sizes().begin() + 1, tensors.at(0).sizes().end(),
tensors.at(0).element_size(), std::multiplies<>());
auto pointers_and_offsets = torch::empty(
num_batches * 2 + 1,
c10::TensorOptions().dtype(torch::kInt64).pinned_memory(true));
auto pointers_ptr =
reinterpret_cast<std::byte**>(pointers_and_offsets.data_ptr());
auto offsets_ptr = pointers_and_offsets.data_ptr<int64_t>() + num_batches;
int64_t i = 0;
offsets_ptr[0] = 0;
for (const auto& tensor : tensors) {
pointers_ptr[i++] = reinterpret_cast<std::byte*>(tensor.data_ptr());
offsets_ptr[i] =
offsets_ptr[i - 1] + tensor.size(0) * original_feature_size;
}
auto pointers_and_offsets_dev = torch::empty_like(
pointers_and_offsets,
tensors[0].options().dtype(pointers_and_offsets.scalar_type()));
CUDA_CALL(cudaMemcpyAsync(
pointers_and_offsets_dev.data_ptr<int64_t>(), pointers_ptr,
sizeof(int64_t) * pointers_and_offsets.numel(), cudaMemcpyHostToDevice,
cuda::GetCurrentStream()));
auto shape = tensors[0].sizes().vec();
shape[0] = offsets_ptr[num_batches] / original_feature_size;
auto output = torch::empty(shape, tensors[0].options());
auto output_ptr = reinterpret_cast<std::byte*>(output.data_ptr());

pointers_ptr =
reinterpret_cast<std::byte**>(pointers_and_offsets_dev.data_ptr());
offsets_ptr = pointers_and_offsets_dev.data_ptr<int64_t>() + num_batches;

thrust::counting_iterator<int64_t> iota(0);
auto output_buffer = thrust::make_transform_iterator(
iota, OutputBufferIndexer<int64_t, std::byte>{offsets_ptr, output_ptr});
auto buffer_sizes = thrust::make_transform_iterator(
iota, AdjacentDifference<int64_t>{offsets_ptr});

constexpr int64_t max_copy_at_once = std::numeric_limits<int32_t>::max();

for (int64_t i = 0; i < num_batches; i += max_copy_at_once) {
CUB_CALL(
DeviceMemcpy::Batched, pointers_ptr + i, output_buffer + i,
buffer_sizes + i, std::min(num_batches - i, max_copy_at_once));
}
return output;
}

} // namespace ops
} // namespace graphbolt
36 changes: 1 addition & 35 deletions graphbolt/src/cuda/expand_indptr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,49 +17,15 @@
* @file cuda/expand_indptr.cu
* @brief ExpandIndptr operator implementation on CUDA.
*/
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>

#include <cub/cub.cuh>
#include <limits>

#include "./common.h"
#include "./expand_indptr.cuh"

namespace graphbolt {
namespace ops {

template <typename indices_t, typename nodes_t>
struct RepeatIndex {
const nodes_t* nodes;
__host__ __device__ auto operator()(indices_t i) {
return thrust::make_constant_iterator(nodes ? nodes[i] : i);
}
};

template <typename indices_t, typename nodes_t>
struct IotaIndex {
const nodes_t* nodes;
__host__ __device__ auto operator()(indices_t i) {
return thrust::make_counting_iterator(nodes ? nodes[i] : 0);
}
};

template <typename indptr_t, typename indices_t>
struct OutputBufferIndexer {
const indptr_t* indptr;
indices_t* buffer;
__host__ __device__ auto operator()(int64_t i) { return buffer + indptr[i]; }
};

template <typename indptr_t>
struct AdjacentDifference {
const indptr_t* indptr;
__host__ __device__ auto operator()(int64_t i) {
return indptr[i + 1] - indptr[i];
}
};

torch::Tensor ExpandIndptrImpl(
torch::Tensor indptr, torch::ScalarType dtype,
torch::optional<torch::Tensor> nodes, torch::optional<int64_t> output_size,
Expand Down
59 changes: 59 additions & 0 deletions graphbolt/src/cuda/expand_indptr.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/**
* Copyright (c) 2023, GT-TDAlab (Muhammed Fatih Balin & Umit V. Catalyurek)
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* @file cuda/expand_indptr.cuh
* @brief ExpandIndptr helper class implementations on CUDA.
*/
#include <thrust/iterator/constant_iterator.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/iterator/transform_iterator.h>

namespace graphbolt {
namespace ops {

template <typename indices_t, typename nodes_t>
struct RepeatIndex {
const nodes_t* nodes;
__host__ __device__ auto operator()(indices_t i) {
return thrust::make_constant_iterator(nodes ? nodes[i] : i);
}
};

template <typename indices_t, typename nodes_t>
struct IotaIndex {
const nodes_t* nodes;
__host__ __device__ auto operator()(indices_t i) {
return thrust::make_counting_iterator(nodes ? nodes[i] : 0);
}
};

template <typename indptr_t, typename indices_t>
struct OutputBufferIndexer {
const indptr_t* indptr;
indices_t* buffer;
__host__ __device__ auto operator()(int64_t i) { return buffer + indptr[i]; }
};

template <typename indptr_t>
struct AdjacentDifference {
const indptr_t* indptr;
__host__ __device__ auto operator()(int64_t i) {
return indptr[i + 1] - indptr[i];
}
};

} // namespace ops
} // namespace graphbolt
7 changes: 7 additions & 0 deletions graphbolt/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,13 @@ TORCH_LIBRARY(graphbolt, m) {
#ifdef HAS_PT2_COMPLIANT_TAG
,
{at::Tag::pt2_compliant_tag}
#endif
);
m.def(
"cat(Tensor[] tensors) -> Tensor"
#ifdef HAS_PT2_COMPLIANT_TAG
,
{at::Tag::pt2_compliant_tag}
#endif
);
}
Expand Down
38 changes: 38 additions & 0 deletions python/dgl/graphbolt/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"EndMarker",
"isin",
"index_select",
"cat",
"expand_indptr",
"indptr_edge_ids",
"CSCFormatBase",
Expand Down Expand Up @@ -98,6 +99,43 @@ def isin(elements, test_elements):
return torch.ops.graphbolt.isin(elements, test_elements)


if TorchVersion(torch.__version__) >= TorchVersion("2.2.0a0"):

torch_fake_decorator = (
torch.library.impl_abstract
if TorchVersion(torch.__version__) < TorchVersion("2.4.0a0")
else torch.library.register_fake
)

@torch_fake_decorator("graphbolt::cat")
def cat_fake(tensors):
"""Fake implementation of cat for torch.compile() support."""
size_0 = sum(t.size(0) for t in tensors)
return tensors[0].new_empty((size_0,) + tensors[0].shape[1:])


def cat(tensors):
"""Concatenates the given tensors along the first dimension.

This is equivalent to

.. code:: python

return torch.cat(tensors, dim=0)

Parameters
----------
tensors : List[torch.Tensor]
A list of tensors to be concatenated

Returns
-------
torch.Tensor
The concatenated tensors.
"""
return torch.ops.graphbolt.cat(tensors)


if TorchVersion(torch.__version__) >= TorchVersion("2.2.0a0"):

torch_fake_decorator = (
Expand Down
32 changes: 32 additions & 0 deletions tests/python/pytorch/graphbolt/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,38 @@ def test_indptr_edge_ids(offset, dtype):
assert explanation.graph_break_count == expected_breaks


@pytest.mark.parametrize("dtype", [torch.float16, torch.float32, torch.float64])
@pytest.mark.parametrize("shape", [tuple(), (13, 17), (5,)])
def test_cat(dtype, shape):
tensors = [
torch.randn((i,) + shape, dtype=dtype, device=F.ctx())
for i in [10, 21, 1]
]
torch_result = torch.cat(tensors, dim=0)
gb_result = gb.cat(tensors)
assert torch.equal(torch_result, gb_result)

if TorchVersion(torch.__version__) >= TorchVersion("2.2.0a0"):
import torch._dynamo as dynamo
from torch.testing._internal.optests import opcheck

# Tests torch.compile compatibility
opcheck(
torch.ops.graphbolt.cat,
(tensors,),
test_utils=[
"test_schema",
"test_autograd_registration",
"test_faketensor",
"test_aot_dispatch_dynamic",
],
raise_exception=True,
)

explanation = dynamo.explain(gb.cat)(tensors)
assert explanation.graph_break_count == 0


def test_csc_format_base_representation():
csc_format_base = gb.CSCFormatBase(
indptr=torch.tensor([0, 2, 4]),
Expand Down
Loading