Skip to content

Commit

Permalink
revise cpp files format
Browse files Browse the repository at this point in the history
  • Loading branch information
gyzhou2000 committed Jul 18, 2024
1 parent 1df8139 commit 8c6cc5d
Show file tree
Hide file tree
Showing 17 changed files with 683 additions and 608 deletions.
153 changes: 80 additions & 73 deletions gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,106 +1,113 @@
#include "./bspmm_sum_cpu.h"

#include <torch/torch.h>

#include "ATen/core/TensorBody.h"

torch::Tensor bspmm_sum_cpu_forward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x){
if (!x.is_contiguous()) {
x = x.contiguous();
}
if (!weight.is_contiguous()) {
weight = weight.contiguous();
}
if (!index.is_contiguous()) {
index = index.contiguous();
}
torch::Tensor bspmm_sum_cpu_forward(
torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x) {
if (!x.is_contiguous()) {
x = x.contiguous();
}
if (!weight.is_contiguous()) {
weight = weight.contiguous();
}
if (!index.is_contiguous()) {
index = index.contiguous();
}

// int num_nodes = x.size(0);
int heads = x.size(1);
int out_channels = x.size(2);
// int num_nodes = x.size(0);
int heads = x.size(1);
int out_channels = x.size(2);

auto sizes = x.sizes().vec();
// if(sizes[0] == 0)
// sizes[0] = index.max().item<int64_t>();

torch::Tensor out = torch::zeros(sizes, x.options());
auto E = index.size(1);
// auto K = x.numel() / x.size(0);
auto sizes = x.sizes().vec();
// if(sizes[0] == 0)
// sizes[0] = index.max().item<int64_t>();

auto index_data = index.data_ptr<int64_t>();
using scalar_t = float;
auto x_data = x.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
torch::Tensor out = torch::zeros(sizes, x.options());
auto E = index.size(1);
// auto K = x.numel() / x.size(0);

auto index_data = index.data_ptr<int64_t>();
using scalar_t = float;
auto x_data = x.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();

#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
auto src = index_data[e];
auto dst = index_data[e + E];
for (auto e = 0; e < E; ++e) {
auto src = index_data[e];
auto dst = index_data[e + E];

for (auto h = 0; h < heads; ++h){
for (auto k = 0; k < out_channels; ++k){
for (auto h = 0; h < heads; ++h) {
for (auto k = 0; k < out_channels; ++k) {
#ifdef COMPILE_WITH_OMP
#pragma omp atomic
#endif
out_data[dst * out_channels * heads + h * out_channels + k] +=
weight_data[e * heads + h] * x_data[src * out_channels * heads + h * out_channels + k];
}
}
out_data[dst * out_channels * heads + h * out_channels + k] +=
weight_data[e * heads + h] *
x_data[src * out_channels * heads + h * out_channels + k];
}
}
return out;
}
return out;
}

std::tuple<torch::Tensor, torch::Tensor> bspmm_sum_cpu_backward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x, torch::Tensor &grad) {
if (!grad.is_contiguous()) {
grad = grad.contiguous();
}
if (!weight.is_contiguous()) {
weight = weight.contiguous();
}
if (!index.is_contiguous()) {
index = index.contiguous();
}
std::tuple<torch::Tensor, torch::Tensor> bspmm_sum_cpu_backward(
torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x,
torch::Tensor &grad) {
if (!grad.is_contiguous()) {
grad = grad.contiguous();
}
if (!weight.is_contiguous()) {
weight = weight.contiguous();
}
if (!index.is_contiguous()) {
index = index.contiguous();
}

// int num_nodes = grad.size(0);
int heads = grad.size(1);
int out_channels = grad.size(2);
// int num_nodes = grad.size(0);
int heads = grad.size(1);
int out_channels = grad.size(2);

torch::Tensor grad_x = torch::zeros_like(grad, grad.options());
torch::Tensor grad_weight = torch::zeros_like(weight, weight.options());
auto E = index.size(1);
// auto K = grad.numel() / grad.size(0);
torch::Tensor grad_x = torch::zeros_like(grad, grad.options());
torch::Tensor grad_weight = torch::zeros_like(weight, weight.options());
auto E = index.size(1);
// auto K = grad.numel() / grad.size(0);

auto index_data = index.data_ptr<int64_t>();
using scalar_t = float;
auto grad_data = grad.data_ptr<scalar_t>();
auto grad_x_data = grad_x.data_ptr<scalar_t>();
auto grad_weight_data = grad_weight.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();
auto index_data = index.data_ptr<int64_t>();
using scalar_t = float;
auto grad_data = grad.data_ptr<scalar_t>();
auto grad_x_data = grad_x.data_ptr<scalar_t>();
auto grad_weight_data = grad_weight.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();

// 计算反向传播的梯度
#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
auto src = index_data[e];
auto dst = index_data[e + E];
for (auto e = 0; e < E; ++e) {
auto src = index_data[e];
auto dst = index_data[e + E];

for (auto h = 0; h < heads; ++h){
for (auto k = 0; k < out_channels; ++k){
for (auto h = 0; h < heads; ++h) {
for (auto k = 0; k < out_channels; ++k) {
#ifdef COMPILE_WITH_OMP
#pragma omp atomic
#endif
grad_x_data[src * out_channels * heads + h * out_channels + k] +=
weight_data[e * heads + h] * grad_data[dst * out_channels * heads + h * out_channels + k];

grad_weight_data[e * heads + h] += x_data[src * out_channels * heads + h * out_channels + k] *
grad_data[dst * out_channels * heads + h * out_channels + k];
grad_x_data[src * out_channels * heads + h * out_channels + k] +=
weight_data[e * heads + h] *
grad_data[dst * out_channels * heads + h * out_channels + k];

}
}
grad_weight_data[e * heads + h] +=
x_data[src * out_channels * heads + h * out_channels + k] *
grad_data[dst * out_channels * heads + h * out_channels + k];
}
}
// return {grad_x, grad_weight};
return std::make_tuple(grad_x, grad_weight);
}
// return {grad_x, grad_weight};
return std::make_tuple(grad_x, grad_weight);
}
9 changes: 5 additions & 4 deletions gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/torch.h>

torch::Tensor bspmm_sum_cpu_forward(torch::Tensor &index, torch::Tensor &weight,
torch::Tensor &x);
std::tuple<torch::Tensor, torch::Tensor> bspmm_sum_cpu_backward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x,
torch::Tensor &grad);
torch::Tensor bspmm_sum_cpu_forward(
torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x);
std::tuple<torch::Tensor, torch::Tensor> bspmm_sum_cpu_backward(
torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x,
torch::Tensor &grad);
51 changes: 26 additions & 25 deletions gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,35 +34,36 @@ std::tuple<torch::Tensor, torch::Tensor> segment_max_cpu_forward(
auto index_data = index.data_ptr<int64_t>();
auto arg_out_data = arg_out.data_ptr<int64_t>();

AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_max_cpu_forward", [&]() {
out.fill_(std::numeric_limits<scalar_t>::lowest());
auto x_data = x.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(),
"segment_max_cpu_forward", [&]() {
out.fill_(std::numeric_limits<scalar_t>::lowest());
auto x_data = x.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

int64_t idx;
#ifdef COMPILE_WITH_OMP
#pragma omp parallel for private(idx)
#endif
for (auto e = 0; e < E; ++e) {
idx = index_data[e];
TORCH_CHECK_INDEX(idx < N, "Index out of bounds: ", idx, " >= ", N);
for (auto k = 0; k < K; ++k) {
scalar_t current_val = x_data[e * K + k];
scalar_t& max_val = out_data[idx * K + k];
int64_t& max_idx = arg_out_data[idx * K + k];
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
{
if (max_val < current_val) {
max_val = current_val;
max_idx = e;
int64_t idx;
#ifdef COMPILE_WITH_OMP
#pragma omp parallel for private(idx)
#endif
for (auto e = 0; e < E; ++e) {
idx = index_data[e];
TORCH_CHECK_INDEX(idx < N, "Index out of bounds: ", idx, " >= ", N);
for (auto k = 0; k < K; ++k) {
scalar_t current_val = x_data[e * K + k];
scalar_t& max_val = out_data[idx * K + k];
int64_t& max_idx = arg_out_data[idx * K + k];
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
{
if (max_val < current_val) {
max_val = current_val;
max_idx = e;
}
}
}
}
}

});
});

return std::make_tuple(out, arg_out);
}
70 changes: 35 additions & 35 deletions gammagl/mpops/torch_ext/cpu/segment_mean_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#include "segment_mean_cpu.h"

#include <ATen/ATen.h>
#include <assert.h>
#include <torch/extension.h>
#include <torch/script.h>
#include <torch/torch.h>
#include <ATen/ATen.h>

#include <iostream>
#include <vector>
Expand Down Expand Up @@ -35,46 +35,46 @@ torch::Tensor segment_mean_cpu_forward(
auto index_data = index.data_ptr<int64_t>();
auto arg_out_data = arg_out.data_ptr<int64_t>();

AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_mean_cpu_forward", [&]() {
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(),
"segment_mean_cpu_forward", [&]() {
auto x_data = x.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

auto x_data = x.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
torch::Tensor degree = torch::zeros({1, index.size(0)}, x.options());
auto degree_data = degree.data_ptr<scalar_t>();

torch::Tensor degree = torch::zeros({1, index.size(0)}, x.options());
auto degree_data = degree.data_ptr<scalar_t>();

#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
auto idx = index_data[e];
degree_data[idx] += 1;
for (auto k = 0; k < K; ++k) {
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
out_data[idx * K + k] += x_data[e * K + k];
arg_out_data[idx * K + k] = e;
}
}
out = out.contiguous();
degree = degree.contiguous();

#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
if (degree_data[e] > 1) {
#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
auto idx = index_data[e];
degree_data[idx] += 1;
for (auto k = 0; k < K; ++k) {
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
out_data[e * K + k] /= degree_data[e];
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
out_data[idx * K + k] += x_data[e * K + k];
arg_out_data[idx * K + k] = e;
}
}
}
out = out.contiguous();
degree = degree.contiguous();

});
#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
if (degree_data[e] > 1) {
for (auto k = 0; k < K; ++k) {
#ifdef COMPILE_WITH_OMP
#pragma omp critical
#endif
out_data[e * K + k] /= degree_data[e];
}
}
}
});

return out;
}
Loading

0 comments on commit 8c6cc5d

Please sign in to comment.