Skip to content

Commit

Permalink
[operator] add bspmm_sum operator (#210)
Browse files Browse the repository at this point in the history
* add `rspmm_sum` operator

* update

* add other backend's bspmm

* update

---------

Co-authored-by: BuptTab <[email protected]>
  • Loading branch information
gyzhou2000 and gyzhou2000 authored Jul 5, 2024
1 parent e276485 commit 9e04783
Show file tree
Hide file tree
Showing 10 changed files with 219 additions and 17 deletions.
36 changes: 20 additions & 16 deletions gammagl/layers/conv/gat_conv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import tensorlayerx as tlx
from gammagl.layers.conv import MessagePassing
from gammagl.utils import segment_softmax


from gammagl.mpops import bspmm


class GATConv(MessagePassing):
Expand Down Expand Up @@ -79,10 +78,14 @@ def __init__(self,
self.linear = tlx.layers.Linear(out_features=self.out_channels * self.heads,
in_features=self.in_channels,
b_init=None)

init_weight = tlx.initializers.TruncatedNormal()
self.w = tlx.nn.Parameter(
init_weight((in_channels, self.out_channels * self.heads)))

initor = tlx.initializers.TruncatedNormal()
self.att_src = self._get_weights("att_src", shape=(1, self.heads, self.out_channels), init=initor, order=True)
self.att_dst = self._get_weights("att_dst", shape=(1, self.heads, self.out_channels), init=initor, order=True)
self.att = tlx.nn.Parameter(
initor((1, self.heads, self.out_channels * 2)))

self.leaky_relu = tlx.layers.LeakyReLU(negative_slope)
self.dropout = tlx.layers.Dropout(self.dropout_rate)
Expand All @@ -91,22 +94,23 @@ def __init__(self,
self.bias = self._get_weights("bias", shape=(self.heads * self.out_channels,), init=initor)
elif self.add_bias and not concat:
self.bias = self._get_weights("bias", shape=(self.out_channels,), init=initor)

def message(self, x, edge_index, edge_weight=None, num_nodes=None):

def forward(self, x, edge_index, num_nodes=None):
x = tlx.matmul(x, self.w)
x = tlx.reshape(x, shape=(-1, self.heads, self.out_channels))
node_src = edge_index[0, :]
node_dst = edge_index[1, :]
weight_src = tlx.gather(tlx.reduce_sum(x * self.att_src, -1), node_src)
weight_dst = tlx.gather(tlx.reduce_sum(x * self.att_dst, -1), node_dst)
weight = self.leaky_relu(weight_src + weight_dst)
feat_src = tlx.gather(x, node_src)
feat_dst = tlx.gather(x, node_dst)
feat = tlx.concat((feat_src, feat_dst), axis=-1)
feat = tlx.reshape(feat, shape=(-1, self.heads, self.out_channels * 2))
e = tlx.reduce_sum(feat * self.att, axis = -1)

alpha = self.dropout(segment_softmax(weight, node_dst, num_nodes))
x = tlx.gather(x, node_src) * tlx.expand_dims(alpha, -1)
return x * edge_weight if edge_weight else x
e = self.leaky_relu(e)
alpha = self.dropout(segment_softmax(e, node_dst, num_nodes))


def forward(self, x, edge_index, num_nodes=None):
x = tlx.reshape(self.linear(x), shape=(-1, self.heads, self.out_channels))
x = self.propagate(x, edge_index, num_nodes=num_nodes)
x = self.propagate(x, edge_index, num_nodes=num_nodes, edge_weight=alpha)
# x = bspmm(edge_index, weight=alpha, x=x, reduce='sum')

if self.concat:
x = tlx.reshape(x, (-1, self.heads * self.out_channels))
Expand Down
3 changes: 3 additions & 0 deletions gammagl/mpops/mindspore.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,6 @@ def segment_max(x, segment_ids, num_segments=None):

def gspmm(index, weight=None, x=None, reduce='sum'):
pass

def bspmm(index, weight=None, x=None, reduce='sum'):
pass
3 changes: 3 additions & 0 deletions gammagl/mpops/paddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,6 @@ def _scatter(x, index, updates, overwrite=True):

def gspmm(index, weight=None, x=None, reduce='sum'):
pass

def bspmm(index, weight=None, x=None, reduce='sum'):
pass
3 changes: 3 additions & 0 deletions gammagl/mpops/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,6 @@ def segment_min(x, segment_ids, num_segments=None):

def gspmm(index, weight=None, x=None, reduce='sum'):
pass

def bspmm(index, weight=None, x=None, reduce='sum'):
pass
16 changes: 15 additions & 1 deletion gammagl/mpops/torch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
use_ext = False
try:
from .torch_ext._torch_ext import c_segment_sum, c_segment_mean, c_segment_max, c_spmm_sum, c_spmm_mean, c_spmm_max
from .torch_ext._torch_ext import c_segment_sum, c_segment_mean, c_segment_max, c_spmm_sum, c_spmm_mean, c_spmm_max, c_bspmm_sum
use_ext = True
except:
pass
Expand Down Expand Up @@ -297,3 +297,17 @@ def gspmm(index, weight=None, x=None, reduce='sum'):
return c_spmm_max(index, weight, x)
else:
raise Exception("Unsupported reduce type, please choose from ['sum', 'mean', 'max'].")


def bspmm(index, weight=None, x=None, reduce='sum'):
if weight == None:
weight = torch.ones(size=(index.shape[1], ), dtype=torch.float32)
if reduce == 'sum':
return c_bspmm_sum(index, weight, x)
# elif reduce == 'mean':
# return c_spmm_mean(index, weight, x)
# elif reduce == 'max':
# return c_spmm_max(index, weight, x)
else:
# raise Exception("Unsupported reduce type, please choose from ['sum', 'mean', 'max'].")
raise Exception("Unsupported reduce type, please choose from ['sum'].")
102 changes: 102 additions & 0 deletions gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#include "./bspmm_sum_cpu.h"
#include <torch/torch.h>
#include "ATen/core/TensorBody.h"

torch::Tensor bspmm_sum_cpu_forward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x){
if (!x.is_contiguous()) {
x = x.contiguous();
}
if (!weight.is_contiguous()) {
weight = weight.contiguous();
}
if (!index.is_contiguous()) {
index = index.contiguous();
}

int num_nodes = x.size(0);
int heads = x.size(1);
int out_channels = x.size(2);

torch::Tensor out = torch::zeros_like(x, x.options());
auto E = index.size(1);
// auto K = x.numel() / x.size(0);

auto index_data = index.data_ptr<int64_t>();
using scalar_t = float;
auto x_data = x.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();

#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
auto src = index_data[e];
auto dst = index_data[e + E];

for (auto h = 0; h < heads; ++h){
for (auto k = 0; k < out_channels; ++k){
#ifdef COMPILE_WITH_OMP
#pragma omp atomic
#endif
out_data[dst * out_channels * heads + h * out_channels + k] +=
weight_data[e * heads + h] * x_data[src * out_channels * heads + h * out_channels + k];
}
}
}
return out;
}

std::tuple<torch::Tensor, torch::Tensor> bspmm_sum_cpu_backward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x, torch::Tensor &grad) {
if (!grad.is_contiguous()) {
grad = grad.contiguous();
}
if (!weight.is_contiguous()) {
weight = weight.contiguous();
}
if (!index.is_contiguous()) {
index = index.contiguous();
}

int num_nodes = grad.size(0);
int heads = grad.size(1);
int out_channels = grad.size(2);

torch::Tensor grad_x = torch::zeros_like(grad, grad.options());
torch::Tensor grad_weight = torch::zeros_like(weight, weight.options());
auto E = index.size(1);
// auto K = grad.numel() / grad.size(0);

auto index_data = index.data_ptr<int64_t>();
using scalar_t = float;
auto grad_data = grad.data_ptr<scalar_t>();
auto grad_x_data = grad_x.data_ptr<scalar_t>();
auto grad_weight_data = grad_weight.data_ptr<scalar_t>();
auto x_data = x.data_ptr<scalar_t>();
auto weight_data = weight.data_ptr<scalar_t>();

// 计算反向传播的梯度
#ifdef COMPILE_WITH_OMP
#pragma omp parallel for
#endif
for (auto e = 0; e < E; ++e) {
auto src = index_data[e];
auto dst = index_data[e + E];

for (auto h = 0; h < heads; ++h){
for (auto k = 0; k < out_channels; ++k){
#ifdef COMPILE_WITH_OMP
#pragma omp atomic
#endif
grad_x_data[src * out_channels * heads + h * out_channels + k] +=
weight_data[e * heads + h] * grad_data[dst * out_channels * heads + h * out_channels + k];

grad_weight_data[e * heads + h] += x_data[src * out_channels * heads + h * out_channels + k] *
grad_data[dst * out_channels * heads + h * out_channels + k];

}
}
}
// return {grad_x, grad_weight};
return std::make_tuple(grad_x, grad_weight);
}
6 changes: 6 additions & 0 deletions gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#include <torch/torch.h>

torch::Tensor bspmm_sum_cpu_forward(torch::Tensor &index, torch::Tensor &weight,
torch::Tensor &x);
std::tuple<torch::Tensor, torch::Tensor> bspmm_sum_cpu_backward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x,
torch::Tensor &grad);
8 changes: 8 additions & 0 deletions gammagl/mpops/torch_ext/include/gspmm.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,11 @@ class SpMMMax : public torch::autograd::Function<SpMMMax> {
static std::vector<torch::Tensor> backward(torch::autograd::AutogradContext *ctx,
std::vector<torch::Tensor> grad_outs);
};

class BSpMMSum : public torch::autograd::Function<BSpMMSum> {
public:
static torch::Tensor forward(torch::autograd::AutogradContext *ctx, torch::Tensor index,
torch::Tensor weight, torch::Tensor x);
static std::vector<torch::Tensor> backward(torch::autograd::AutogradContext *ctx,
std::vector<torch::Tensor> grad_outs);
};
52 changes: 52 additions & 0 deletions gammagl/mpops/torch_ext/src/gspmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include "../cpu/spmm_sum_cpu.h"
#include "../cpu/spmm_mean_cpu.h"
#include "../cpu/spmm_max_cpu.h"
#include "../cpu/bspmm_sum_cpu.h"

#ifdef COMPILE_WITH_CUDA
#include "../cuda/spmm_sum_cuda.h"
#endif
Expand Down Expand Up @@ -171,3 +173,53 @@ std::vector<torch::Tensor> SpMMMax::backward(torch::autograd::AutogradContext *c

return {torch::Tensor(), torch::Tensor(), grad_x};
}


torch::Tensor BSpMMSum::forward(torch::autograd::AutogradContext *ctx, torch::Tensor index,
torch::Tensor weight, torch::Tensor x) {
ctx->save_for_backward({index, weight, x});
ctx->mark_non_differentiable({index, weight});
torch::Tensor out;
// CUDA
if (x.is_cuda() && index.is_cuda() && weight.is_cuda()) {
// #ifdef COMPILE_WITH_CUDA
// out = bspmm_sum_cuda_forward(index, weight, x);
// #else
AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU.");
// #endif
}
// CPU
else if (x.is_cpu() && index.is_cpu() && weight.is_cpu()) {
out = bspmm_sum_cpu_forward(index, weight, x);
} else {
AT_ERROR("Tensor device inconsistent error.");
}

return out;
}

std::vector<torch::Tensor> BSpMMSum::backward(torch::autograd::AutogradContext *ctx, std::vector<torch::Tensor> grad_outs) {
auto saved = ctx->get_saved_variables();
auto index = saved[0], weight = saved[1], x = saved[2];
auto grad = grad_outs[0];
torch::Tensor grad_x, grad_weight;

// CUDA
if (grad.is_cuda() && index.is_cuda() && weight.is_cuda()) {
// #ifdef COMPILE_WITH_CUDA
// grad_x = bspmm_sum_cuda_backward(index, weight, grad);
// #else
AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU.");
// #endif
}
// CPU
else if (grad.is_cpu() && index.is_cpu() && weight.is_cpu()) {
auto result = bspmm_sum_cpu_backward(index, weight, x, grad);
grad_x = std::get<0>(result);
grad_weight = std::get<1>(result);
} else {
AT_ERROR("Tensor device inconsistent error.");
}

return {torch::Tensor(), grad_weight, grad_x};
}
7 changes: 7 additions & 0 deletions gammagl/mpops/torch_ext/src/operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,18 @@ torch::Tensor spmm_max(torch::Tensor index, torch::Tensor weight, torch::Tensor
return SpMMMax::apply(index, weight, x);
}

torch::Tensor bspmm_sum(torch::Tensor index, torch::Tensor weight,
torch::Tensor x) {
auto result = BSpMMSum::apply(index, weight, x);
return result;
}

PYBIND11_MODULE(_torch_ext, m) {
m.def("c_segment_max", segment_max);
m.def("c_segment_sum", segment_sum);
m.def("c_segment_mean", segment_mean);
m.def("c_spmm_sum", spmm_sum);
m.def("c_spmm_mean", spmm_mean);
m.def("c_spmm_max", spmm_max);
m.def("c_bspmm_sum", bspmm_sum);
}

0 comments on commit 9e04783

Please sign in to comment.