diff --git a/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.cpp b/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.cpp index 925fcc81..ea00b1cb 100644 --- a/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.cpp +++ b/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.cpp @@ -1,106 +1,113 @@ #include "./bspmm_sum_cpu.h" + #include + #include "ATen/core/TensorBody.h" -torch::Tensor bspmm_sum_cpu_forward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x){ - if (!x.is_contiguous()) { - x = x.contiguous(); - } - if (!weight.is_contiguous()) { - weight = weight.contiguous(); - } - if (!index.is_contiguous()) { - index = index.contiguous(); - } +torch::Tensor bspmm_sum_cpu_forward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x) { + if (!x.is_contiguous()) { + x = x.contiguous(); + } + if (!weight.is_contiguous()) { + weight = weight.contiguous(); + } + if (!index.is_contiguous()) { + index = index.contiguous(); + } - // int num_nodes = x.size(0); - int heads = x.size(1); - int out_channels = x.size(2); + // int num_nodes = x.size(0); + int heads = x.size(1); + int out_channels = x.size(2); - auto sizes = x.sizes().vec(); - // if(sizes[0] == 0) - // sizes[0] = index.max().item(); - - torch::Tensor out = torch::zeros(sizes, x.options()); - auto E = index.size(1); - // auto K = x.numel() / x.size(0); + auto sizes = x.sizes().vec(); + // if(sizes[0] == 0) + // sizes[0] = index.max().item(); - auto index_data = index.data_ptr(); - using scalar_t = float; - auto x_data = x.data_ptr(); - auto out_data = out.data_ptr(); - auto weight_data = weight.data_ptr(); + torch::Tensor out = torch::zeros(sizes, x.options()); + auto E = index.size(1); + // auto K = x.numel() / x.size(0); + + auto index_data = index.data_ptr(); + using scalar_t = float; + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto weight_data = weight.data_ptr(); #ifdef COMPILE_WITH_OMP #pragma omp parallel for #endif - for (auto e = 0; e < E; ++e) { - auto src = index_data[e]; - auto dst = index_data[e + E]; + for (auto e = 0; e < E; ++e) { + auto src = index_data[e]; + auto dst = index_data[e + E]; - for (auto h = 0; h < heads; ++h){ - for (auto k = 0; k < out_channels; ++k){ + for (auto h = 0; h < heads; ++h) { + for (auto k = 0; k < out_channels; ++k) { #ifdef COMPILE_WITH_OMP #pragma omp atomic #endif - out_data[dst * out_channels * heads + h * out_channels + k] += - weight_data[e * heads + h] * x_data[src * out_channels * heads + h * out_channels + k]; - } - } + out_data[dst * out_channels * heads + h * out_channels + k] += + weight_data[e * heads + h] * + x_data[src * out_channels * heads + h * out_channels + k]; + } } - return out; + } + return out; } -std::tuple bspmm_sum_cpu_backward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x, torch::Tensor &grad) { - if (!grad.is_contiguous()) { - grad = grad.contiguous(); - } - if (!weight.is_contiguous()) { - weight = weight.contiguous(); - } - if (!index.is_contiguous()) { - index = index.contiguous(); - } +std::tuple bspmm_sum_cpu_backward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x, + torch::Tensor &grad) { + if (!grad.is_contiguous()) { + grad = grad.contiguous(); + } + if (!weight.is_contiguous()) { + weight = weight.contiguous(); + } + if (!index.is_contiguous()) { + index = index.contiguous(); + } - // int num_nodes = grad.size(0); - int heads = grad.size(1); - int out_channels = grad.size(2); + // int num_nodes = grad.size(0); + int heads = grad.size(1); + int out_channels = grad.size(2); - torch::Tensor grad_x = torch::zeros_like(grad, grad.options()); - torch::Tensor grad_weight = torch::zeros_like(weight, weight.options()); - auto E = index.size(1); - // auto K = grad.numel() / grad.size(0); + torch::Tensor grad_x = torch::zeros_like(grad, grad.options()); + torch::Tensor grad_weight = torch::zeros_like(weight, weight.options()); + auto E = index.size(1); + // auto K = grad.numel() / grad.size(0); - auto index_data = index.data_ptr(); - using scalar_t = float; - auto grad_data = grad.data_ptr(); - auto grad_x_data = grad_x.data_ptr(); - auto grad_weight_data = grad_weight.data_ptr(); - auto x_data = x.data_ptr(); - auto weight_data = weight.data_ptr(); + auto index_data = index.data_ptr(); + using scalar_t = float; + auto grad_data = grad.data_ptr(); + auto grad_x_data = grad_x.data_ptr(); + auto grad_weight_data = grad_weight.data_ptr(); + auto x_data = x.data_ptr(); + auto weight_data = weight.data_ptr(); // 计算反向传播的梯度 #ifdef COMPILE_WITH_OMP #pragma omp parallel for #endif - for (auto e = 0; e < E; ++e) { - auto src = index_data[e]; - auto dst = index_data[e + E]; + for (auto e = 0; e < E; ++e) { + auto src = index_data[e]; + auto dst = index_data[e + E]; - for (auto h = 0; h < heads; ++h){ - for (auto k = 0; k < out_channels; ++k){ + for (auto h = 0; h < heads; ++h) { + for (auto k = 0; k < out_channels; ++k) { #ifdef COMPILE_WITH_OMP #pragma omp atomic #endif - grad_x_data[src * out_channels * heads + h * out_channels + k] += - weight_data[e * heads + h] * grad_data[dst * out_channels * heads + h * out_channels + k]; - - grad_weight_data[e * heads + h] += x_data[src * out_channels * heads + h * out_channels + k] * - grad_data[dst * out_channels * heads + h * out_channels + k]; + grad_x_data[src * out_channels * heads + h * out_channels + k] += + weight_data[e * heads + h] * + grad_data[dst * out_channels * heads + h * out_channels + k]; - } - } + grad_weight_data[e * heads + h] += + x_data[src * out_channels * heads + h * out_channels + k] * + grad_data[dst * out_channels * heads + h * out_channels + k]; + } } - // return {grad_x, grad_weight}; - return std::make_tuple(grad_x, grad_weight); + } + // return {grad_x, grad_weight}; + return std::make_tuple(grad_x, grad_weight); } \ No newline at end of file diff --git a/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.h b/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.h index 478b67b8..e3cbcb5e 100644 --- a/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.h +++ b/gammagl/mpops/torch_ext/cpu/bspmm_sum_cpu.h @@ -1,6 +1,7 @@ #include -torch::Tensor bspmm_sum_cpu_forward(torch::Tensor &index, torch::Tensor &weight, - torch::Tensor &x); -std::tuple bspmm_sum_cpu_backward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x, - torch::Tensor &grad); +torch::Tensor bspmm_sum_cpu_forward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x); +std::tuple bspmm_sum_cpu_backward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x, + torch::Tensor &grad); diff --git a/gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp b/gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp index fade765b..25ceedb6 100644 --- a/gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp +++ b/gammagl/mpops/torch_ext/cpu/segment_max_cpu.cpp @@ -34,35 +34,36 @@ std::tuple segment_max_cpu_forward( auto index_data = index.data_ptr(); auto arg_out_data = arg_out.data_ptr(); - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_max_cpu_forward", [&]() { - out.fill_(std::numeric_limits::lowest()); - auto x_data = x.data_ptr(); - auto out_data = out.data_ptr(); + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), + "segment_max_cpu_forward", [&]() { + out.fill_(std::numeric_limits::lowest()); + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); - int64_t idx; - #ifdef COMPILE_WITH_OMP - #pragma omp parallel for private(idx) - #endif - for (auto e = 0; e < E; ++e) { - idx = index_data[e]; - TORCH_CHECK_INDEX(idx < N, "Index out of bounds: ", idx, " >= ", N); - for (auto k = 0; k < K; ++k) { - scalar_t current_val = x_data[e * K + k]; - scalar_t& max_val = out_data[idx * K + k]; - int64_t& max_idx = arg_out_data[idx * K + k]; - #ifdef COMPILE_WITH_OMP - #pragma omp critical - #endif - { - if (max_val < current_val) { - max_val = current_val; - max_idx = e; + int64_t idx; +#ifdef COMPILE_WITH_OMP +#pragma omp parallel for private(idx) +#endif + for (auto e = 0; e < E; ++e) { + idx = index_data[e]; + TORCH_CHECK_INDEX(idx < N, "Index out of bounds: ", idx, " >= ", N); + for (auto k = 0; k < K; ++k) { + scalar_t current_val = x_data[e * K + k]; + scalar_t& max_val = out_data[idx * K + k]; + int64_t& max_idx = arg_out_data[idx * K + k]; +#ifdef COMPILE_WITH_OMP +#pragma omp critical +#endif + { + if (max_val < current_val) { + max_val = current_val; + max_idx = e; + } } } } - } - - }); + }); return std::make_tuple(out, arg_out); } diff --git a/gammagl/mpops/torch_ext/cpu/segment_mean_cpu.cpp b/gammagl/mpops/torch_ext/cpu/segment_mean_cpu.cpp index 98f2d1da..39455b7b 100644 --- a/gammagl/mpops/torch_ext/cpu/segment_mean_cpu.cpp +++ b/gammagl/mpops/torch_ext/cpu/segment_mean_cpu.cpp @@ -1,10 +1,10 @@ #include "segment_mean_cpu.h" +#include #include #include #include #include -#include #include #include @@ -35,46 +35,46 @@ torch::Tensor segment_mean_cpu_forward( auto index_data = index.data_ptr(); auto arg_out_data = arg_out.data_ptr(); - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_mean_cpu_forward", [&]() { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), + "segment_mean_cpu_forward", [&]() { + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); - auto x_data = x.data_ptr(); - auto out_data = out.data_ptr(); + torch::Tensor degree = torch::zeros({1, index.size(0)}, x.options()); + auto degree_data = degree.data_ptr(); - torch::Tensor degree = torch::zeros({1, index.size(0)}, x.options()); - auto degree_data = degree.data_ptr(); - - #ifdef COMPILE_WITH_OMP - #pragma omp parallel for - #endif - for (auto e = 0; e < E; ++e) { - auto idx = index_data[e]; - degree_data[idx] += 1; - for (auto k = 0; k < K; ++k) { - #ifdef COMPILE_WITH_OMP - #pragma omp critical - #endif - out_data[idx * K + k] += x_data[e * K + k]; - arg_out_data[idx * K + k] = e; - } - } - out = out.contiguous(); - degree = degree.contiguous(); - - #ifdef COMPILE_WITH_OMP - #pragma omp parallel for - #endif - for (auto e = 0; e < E; ++e) { - if (degree_data[e] > 1) { +#ifdef COMPILE_WITH_OMP +#pragma omp parallel for +#endif + for (auto e = 0; e < E; ++e) { + auto idx = index_data[e]; + degree_data[idx] += 1; for (auto k = 0; k < K; ++k) { - #ifdef COMPILE_WITH_OMP - #pragma omp critical - #endif - out_data[e * K + k] /= degree_data[e]; +#ifdef COMPILE_WITH_OMP +#pragma omp critical +#endif + out_data[idx * K + k] += x_data[e * K + k]; + arg_out_data[idx * K + k] = e; } } - } + out = out.contiguous(); + degree = degree.contiguous(); - }); +#ifdef COMPILE_WITH_OMP +#pragma omp parallel for +#endif + for (auto e = 0; e < E; ++e) { + if (degree_data[e] > 1) { + for (auto k = 0; k < K; ++k) { +#ifdef COMPILE_WITH_OMP +#pragma omp critical +#endif + out_data[e * K + k] /= degree_data[e]; + } + } + } + }); return out; } diff --git a/gammagl/mpops/torch_ext/cpu/segment_sum_cpu.cpp b/gammagl/mpops/torch_ext/cpu/segment_sum_cpu.cpp index 5d38478f..52c421f0 100644 --- a/gammagl/mpops/torch_ext/cpu/segment_sum_cpu.cpp +++ b/gammagl/mpops/torch_ext/cpu/segment_sum_cpu.cpp @@ -1,10 +1,10 @@ #include "segment_sum_cpu.h" +#include #include #include #include #include -#include #include #include @@ -29,30 +29,32 @@ torch::Tensor segment_sum_cpu_forward( return out; } - AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), "segment_sum_cpu_forward", [&]() { + AT_DISPATCH_ALL_TYPES_AND2( + at::ScalarType::Half, at::ScalarType::BFloat16, x.scalar_type(), + "segment_sum_cpu_forward", [&]() { // Get data pointers for index, arg_out, and x. auto index_data = index.data_ptr(); auto x_data = x.data_ptr(); // Assuming x is of type float. auto out_data = out.data_ptr(); - auto E = index.size(0); // Number of elements to process. + auto E = index.size(0); // Number of elements to process. auto K = x.numel() / x.size(0); // Size of the inner dimension. - #ifdef COMPILE_WITH_OMP - #pragma omp parallel for - #endif +#ifdef COMPILE_WITH_OMP +#pragma omp parallel for +#endif // Iterate over each element in x. for (auto e = 0; e < E; ++e) { auto idx = index_data[e]; // Handle accumulation for different dimensions. for (auto k = 0; k < K; ++k) { - #ifdef COMPILE_WITH_OMP - #pragma omp critical - #endif +#ifdef COMPILE_WITH_OMP +#pragma omp critical +#endif out_data[idx * K + k] += x_data[e * K + k]; } } - }); + }); return out; } diff --git a/gammagl/mpops/torch_ext/cpu/spmm_max_cpu.cpp b/gammagl/mpops/torch_ext/cpu/spmm_max_cpu.cpp index a89f74e9..6e4af188 100644 --- a/gammagl/mpops/torch_ext/cpu/spmm_max_cpu.cpp +++ b/gammagl/mpops/torch_ext/cpu/spmm_max_cpu.cpp @@ -1,95 +1,99 @@ #include "spmm_max_cpu.h" + #include -std::tuple spmm_max_cpu_forward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x) { - if (!x.is_contiguous()) { - x = x.contiguous(); - } - if (!weight.is_contiguous()) { - weight = weight.contiguous(); - } - if (!index.is_contiguous()) { - index = index.contiguous(); - } - using scalar_t = float; - // 初始化输出张量为最小浮点数 - torch::Tensor out = torch::full_like(x, std::numeric_limits::lowest(), x.options()); - torch::Tensor max_indices = torch::zeros_like(x, torch::kInt64); // 保存最大值索引 +std::tuple spmm_max_cpu_forward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x) { + if (!x.is_contiguous()) { + x = x.contiguous(); + } + if (!weight.is_contiguous()) { + weight = weight.contiguous(); + } + if (!index.is_contiguous()) { + index = index.contiguous(); + } + using scalar_t = float; + // 初始化输出张量为最小浮点数 + torch::Tensor out = + torch::full_like(x, std::numeric_limits::lowest(), x.options()); + torch::Tensor max_indices = + torch::zeros_like(x, torch::kInt64); // 保存最大值索引 - auto E = index.size(1); - auto K = x.numel() / x.size(0); + auto E = index.size(1); + auto K = x.numel() / x.size(0); - auto index_data = index.data_ptr(); - auto x_data = x.data_ptr(); - auto out_data = out.data_ptr(); - auto weight_data = weight.data_ptr(); - auto max_indices_data = max_indices.data_ptr(); + auto index_data = index.data_ptr(); + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto weight_data = weight.data_ptr(); + auto max_indices_data = max_indices.data_ptr(); #ifdef COMPILE_WITH_OMP #pragma omp parallel for #endif - for (auto e = 0; e < E; ++e) { - auto src = index_data[e]; - auto dst = index_data[e + E]; + for (auto e = 0; e < E; ++e) { + auto src = index_data[e]; + auto dst = index_data[e + E]; - for (auto k = 0; k < K; ++k) { - scalar_t weighted_value = weight_data[e] * x_data[src * K + k]; - #ifdef COMPILE_WITH_OMP - #pragma omp critical - #endif - { - if (out_data[dst * K + k] < weighted_value) { - out_data[dst * K + k] = weighted_value; - max_indices_data[dst * K + k] = src; // 保存产生最大值的索引 - } - } + for (auto k = 0; k < K; ++k) { + scalar_t weighted_value = weight_data[e] * x_data[src * K + k]; +#ifdef COMPILE_WITH_OMP +#pragma omp critical +#endif + { + if (out_data[dst * K + k] < weighted_value) { + out_data[dst * K + k] = weighted_value; + max_indices_data[dst * K + k] = src; // 保存产生最大值的索引 } + } } + } - // return out; - return std::make_tuple(out, max_indices); + // return out; + return std::make_tuple(out, max_indices); } -torch::Tensor spmm_max_cpu_backward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &grad, torch::Tensor &max_indices) { - if (!grad.is_contiguous()) { - grad = grad.contiguous(); - } - if (!weight.is_contiguous()) { - weight = weight.contiguous(); - } - if (!index.is_contiguous()) { - index = index.contiguous(); - } - torch::Tensor out = torch::zeros_like(grad, grad.options()); - auto E = index.size(1); - auto K = grad.size(1); +torch::Tensor spmm_max_cpu_backward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &grad, + torch::Tensor &max_indices) { + if (!grad.is_contiguous()) { + grad = grad.contiguous(); + } + if (!weight.is_contiguous()) { + weight = weight.contiguous(); + } + if (!index.is_contiguous()) { + index = index.contiguous(); + } + torch::Tensor out = torch::zeros_like(grad, grad.options()); + auto E = index.size(1); + auto K = grad.size(1); - auto index_data = index.data_ptr(); - using scalar_t = float; - auto grad_data = grad.data_ptr(); - auto out_data = out.data_ptr(); - auto weight_data = weight.data_ptr(); - auto max_indices_data = max_indices.data_ptr(); + auto index_data = index.data_ptr(); + using scalar_t = float; + auto grad_data = grad.data_ptr(); + auto out_data = out.data_ptr(); + auto weight_data = weight.data_ptr(); + auto max_indices_data = max_indices.data_ptr(); #ifdef COMPILE_WITH_OMP #pragma omp parallel for #endif - for (auto e = 0; e < E; ++e) { - auto src = index_data[e]; - auto dst = index_data[e + E]; + for (auto e = 0; e < E; ++e) { + auto src = index_data[e]; + auto dst = index_data[e + E]; - for (auto k = 0; k < K; ++k) { - if (max_indices_data[dst * K + k] == src) { // 检查是否是贡献最大的元素 - scalar_t weighted_value = weight_data[e] * grad_data[dst * K + k]; - #ifdef COMPILE_WITH_OMP - #pragma omp critical - #endif - { - out_data[src * K + k] += weighted_value; - } - } - } + for (auto k = 0; k < K; ++k) { + if (max_indices_data[dst * K + k] == src) { // 检查是否是贡献最大的元素 + scalar_t weighted_value = weight_data[e] * grad_data[dst * K + k]; +#ifdef COMPILE_WITH_OMP +#pragma omp critical +#endif + { out_data[src * K + k] += weighted_value; } + } } + } - return out; + return out; } diff --git a/gammagl/mpops/torch_ext/cpu/spmm_max_cpu.h b/gammagl/mpops/torch_ext/cpu/spmm_max_cpu.h index 68612d82..f3ab1b81 100644 --- a/gammagl/mpops/torch_ext/cpu/spmm_max_cpu.h +++ b/gammagl/mpops/torch_ext/cpu/spmm_max_cpu.h @@ -1,6 +1,7 @@ #include -std::tuple spmm_max_cpu_forward(torch::Tensor &index, torch::Tensor &weight, - torch::Tensor &x); -torch::Tensor spmm_max_cpu_backward(torch::Tensor &index, torch::Tensor &weight, - torch::Tensor &grad, torch::Tensor &max_indices); +std::tuple spmm_max_cpu_forward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x); +torch::Tensor spmm_max_cpu_backward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &grad, + torch::Tensor &max_indices); diff --git a/gammagl/mpops/torch_ext/cpu/spmm_mean_cpu.cpp b/gammagl/mpops/torch_ext/cpu/spmm_mean_cpu.cpp index f9356634..35603551 100644 --- a/gammagl/mpops/torch_ext/cpu/spmm_mean_cpu.cpp +++ b/gammagl/mpops/torch_ext/cpu/spmm_mean_cpu.cpp @@ -1,100 +1,105 @@ #include "spmm_mean_cpu.h" + #include -std::tuple spmm_mean_cpu_forward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x) { - if (!x.is_contiguous()) { - x = x.contiguous(); - } - if (!weight.is_contiguous()) { - weight = weight.contiguous(); - } - if (!index.is_contiguous()) { - index = index.contiguous(); - } - torch::Tensor out = torch::zeros_like(x, x.options()); - auto E = index.size(1); - auto K = x.numel() / x.size(0); +std::tuple spmm_mean_cpu_forward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x) { + if (!x.is_contiguous()) { + x = x.contiguous(); + } + if (!weight.is_contiguous()) { + weight = weight.contiguous(); + } + if (!index.is_contiguous()) { + index = index.contiguous(); + } + torch::Tensor out = torch::zeros_like(x, x.options()); + auto E = index.size(1); + auto K = x.numel() / x.size(0); - auto index_data = index.data_ptr(); - using scalar_t = float; - auto x_data = x.data_ptr(); - auto out_data = out.data_ptr(); - auto weight_data = weight.data_ptr(); + auto index_data = index.data_ptr(); + using scalar_t = float; + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto weight_data = weight.data_ptr(); - // 创建一个张量来存储每个节点的收到的消息数量(入度) - torch::Tensor messages_count = torch::zeros(x.size(0), torch::kInt64); - auto messages_count_data = messages_count.data_ptr(); + // 创建一个张量来存储每个节点的收到的消息数量(入度) + torch::Tensor messages_count = torch::zeros(x.size(0), torch::kInt64); + auto messages_count_data = messages_count.data_ptr(); - // 加权求和 + // 加权求和 #ifdef COMPILE_WITH_OMP #pragma omp parallel for #endif - for (auto e = 0; e < E; ++e) { - auto src = index_data[e]; - auto dst = index_data[e + E]; - messages_count_data[dst]++; + for (auto e = 0; e < E; ++e) { + auto src = index_data[e]; + auto dst = index_data[e + E]; + messages_count_data[dst]++; - for (auto k = 0; k < K; ++k) { + for (auto k = 0; k < K; ++k) { #ifdef COMPILE_WITH_OMP #pragma omp atomic #endif - out_data[dst * K + k] += weight_data[e] * x_data[src * K + k]; - } + out_data[dst * K + k] += weight_data[e] * x_data[src * K + k]; } + } - // 对每个节点的特征进行平均 + // 对每个节点的特征进行平均 #ifdef COMPILE_WITH_OMP #pragma omp parallel for #endif - for (auto n = 0; n < x.size(0); ++n) { - auto msg_count_val = messages_count_data[n]; - if (msg_count_val > 0) { - for (auto k = 0; k < K; ++k) { - out_data[n * K + k] /= static_cast(msg_count_val); - } - } + for (auto n = 0; n < x.size(0); ++n) { + auto msg_count_val = messages_count_data[n]; + if (msg_count_val > 0) { + for (auto k = 0; k < K; ++k) { + out_data[n * K + k] /= static_cast(msg_count_val); + } } + } - return std::make_tuple(out, messages_count); + return std::make_tuple(out, messages_count); } +torch::Tensor spmm_mean_cpu_backward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &grad, + torch::Tensor &messages_count) { + if (!grad.is_contiguous()) { + grad = grad.contiguous(); + } + if (!weight.is_contiguous()) { + weight = weight.contiguous(); + } + if (!index.is_contiguous()) { + index = index.contiguous(); + } -torch::Tensor spmm_mean_cpu_backward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &grad, torch::Tensor &messages_count) { - if (!grad.is_contiguous()) { - grad = grad.contiguous(); - } - if (!weight.is_contiguous()) { - weight = weight.contiguous(); - } - if (!index.is_contiguous()) { - index = index.contiguous(); - } - - torch::Tensor out = torch::zeros_like(grad, grad.options()); - auto E = index.size(1); - auto K = grad.numel() / grad.size(0); + torch::Tensor out = torch::zeros_like(grad, grad.options()); + auto E = index.size(1); + auto K = grad.numel() / grad.size(0); - auto index_data = index.data_ptr(); - using scalar_t = float; - auto grad_data = grad.data_ptr(); - auto out_data = out.data_ptr(); - auto weight_data = weight.data_ptr(); + auto index_data = index.data_ptr(); + using scalar_t = float; + auto grad_data = grad.data_ptr(); + auto out_data = out.data_ptr(); + auto weight_data = weight.data_ptr(); - // 计算反向传播的梯度 + // 计算反向传播的梯度 #ifdef COMPILE_WITH_OMP #pragma omp parallel for #endif - for (auto e = 0; e < E; ++e) { - auto src = index_data[e]; - auto dst = index_data[e + E]; + for (auto e = 0; e < E; ++e) { + auto src = index_data[e]; + auto dst = index_data[e + E]; - for (auto k = 0; k < K; ++k) { - auto grad_contribution = grad_data[dst * K + k] / messages_count[dst].item() * weight_data[e]; + for (auto k = 0; k < K; ++k) { + auto grad_contribution = grad_data[dst * K + k] / + messages_count[dst].item() * + weight_data[e]; #ifdef COMPILE_WITH_OMP #pragma omp atomic #endif - out_data[src * K + k] += grad_contribution; - } + out_data[src * K + k] += grad_contribution; } - return out; + } + return out; } \ No newline at end of file diff --git a/gammagl/mpops/torch_ext/cpu/spmm_mean_cpu.h b/gammagl/mpops/torch_ext/cpu/spmm_mean_cpu.h index 4e3ccdbc..2180020c 100644 --- a/gammagl/mpops/torch_ext/cpu/spmm_mean_cpu.h +++ b/gammagl/mpops/torch_ext/cpu/spmm_mean_cpu.h @@ -1,6 +1,7 @@ #include -std::tuple spmm_mean_cpu_forward(torch::Tensor &index, torch::Tensor &weight, - torch::Tensor &x); -torch::Tensor spmm_mean_cpu_backward(torch::Tensor &index, torch::Tensor &weight, - torch::Tensor &grad, torch::Tensor &messages_count); +std::tuple spmm_mean_cpu_forward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x); +torch::Tensor spmm_mean_cpu_backward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &grad, + torch::Tensor &messages_count); diff --git a/gammagl/mpops/torch_ext/cpu/spmm_sum_cpu.cpp b/gammagl/mpops/torch_ext/cpu/spmm_sum_cpu.cpp index 7e3cab21..4bbd4229 100644 --- a/gammagl/mpops/torch_ext/cpu/spmm_sum_cpu.cpp +++ b/gammagl/mpops/torch_ext/cpu/spmm_sum_cpu.cpp @@ -1,78 +1,80 @@ #include "spmm_sum_cpu.h" + #include -torch::Tensor spmm_sum_cpu_forward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x) { - if (!x.is_contiguous()) { - x = x.contiguous(); - } - if (!weight.is_contiguous()) { - weight = weight.contiguous(); - } - if (!index.is_contiguous()) { - index = index.contiguous(); - } - torch::Tensor out = torch::zeros_like(x, x.options()); - auto E = index.size(1); - auto K = x.numel() / x.size(0); +torch::Tensor spmm_sum_cpu_forward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x) { + if (!x.is_contiguous()) { + x = x.contiguous(); + } + if (!weight.is_contiguous()) { + weight = weight.contiguous(); + } + if (!index.is_contiguous()) { + index = index.contiguous(); + } + torch::Tensor out = torch::zeros_like(x, x.options()); + auto E = index.size(1); + auto K = x.numel() / x.size(0); - auto index_data = index.data_ptr(); - using scalar_t = float; - auto x_data = x.data_ptr(); - auto out_data = out.data_ptr(); - auto weight_data = weight.data_ptr(); + auto index_data = index.data_ptr(); + using scalar_t = float; + auto x_data = x.data_ptr(); + auto out_data = out.data_ptr(); + auto weight_data = weight.data_ptr(); #ifdef COMPILE_WITH_OMP #pragma omp parallel for #endif - for (auto e = 0; e < E; ++e) { - auto src = index_data[e]; - auto dst = index_data[e + E]; + for (auto e = 0; e < E; ++e) { + auto src = index_data[e]; + auto dst = index_data[e + E]; - for (auto k = 0; k < K; ++k) { + for (auto k = 0; k < K; ++k) { #ifdef COMPILE_WITH_OMP #pragma omp atomic #endif - out_data[dst * K + k] += weight_data[e] * x_data[src * K + k]; - } + out_data[dst * K + k] += weight_data[e] * x_data[src * K + k]; } - return out; + } + return out; } +torch::Tensor spmm_sum_cpu_backward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &grad) { + if (!grad.is_contiguous()) { + grad = grad.contiguous(); + } + if (!weight.is_contiguous()) { + weight = weight.contiguous(); + } + if (!index.is_contiguous()) { + index = index.contiguous(); + } + torch::Tensor out = torch::zeros_like(grad, grad.options()); + auto E = index.size(1); + auto K = grad.numel() / grad.size(0); -torch::Tensor spmm_sum_cpu_backward(torch::Tensor &index, torch::Tensor &weight, torch::Tensor &grad) { - if (!grad.is_contiguous()) { - grad = grad.contiguous(); - } - if (!weight.is_contiguous()) { - weight = weight.contiguous(); - } - if (!index.is_contiguous()) { - index = index.contiguous(); - } - torch::Tensor out = torch::zeros_like(grad, grad.options()); - auto E = index.size(1); - auto K = grad.numel() / grad.size(0); - - auto index_data = index.data_ptr(); - using scalar_t = float; - auto grad_data = grad.data_ptr(); - auto out_data = out.data_ptr(); - auto weight_data = weight.data_ptr(); + auto index_data = index.data_ptr(); + using scalar_t = float; + auto grad_data = grad.data_ptr(); + auto out_data = out.data_ptr(); + auto weight_data = weight.data_ptr(); // 计算反向传播的梯度 #ifdef COMPILE_WITH_OMP #pragma omp parallel for #endif - for (auto e = 0; e < E; ++e) { - auto src = index_data[e]; - auto dst = index_data[e + E]; + for (auto e = 0; e < E; ++e) { + auto src = index_data[e]; + auto dst = index_data[e + E]; - for (auto k = 0; k < K; ++k) { + for (auto k = 0; k < K; ++k) { #ifdef COMPILE_WITH_OMP #pragma omp atomic #endif - out_data[src * K + k] += weight_data[e] * grad_data[dst * K + k]; - } + out_data[src * K + k] += weight_data[e] * grad_data[dst * K + k]; } - return out; + } + return out; } \ No newline at end of file diff --git a/gammagl/mpops/torch_ext/cpu/spmm_sum_cpu.h b/gammagl/mpops/torch_ext/cpu/spmm_sum_cpu.h index ca2b71c1..7462951f 100644 --- a/gammagl/mpops/torch_ext/cpu/spmm_sum_cpu.h +++ b/gammagl/mpops/torch_ext/cpu/spmm_sum_cpu.h @@ -1,6 +1,6 @@ #include -torch::Tensor spmm_sum_cpu_forward(torch::Tensor &index, torch::Tensor &weight, - torch::Tensor &x); -torch::Tensor spmm_sum_cpu_backward(torch::Tensor &index, torch::Tensor &weight, - torch::Tensor &grad); +torch::Tensor spmm_sum_cpu_forward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &x); +torch::Tensor spmm_sum_cpu_backward( + torch::Tensor &index, torch::Tensor &weight, torch::Tensor &grad); diff --git a/gammagl/mpops/torch_ext/cuda/segment_max_cuda.cu b/gammagl/mpops/torch_ext/cuda/segment_max_cuda.cu index fff03912..1dcfc018 100644 --- a/gammagl/mpops/torch_ext/cuda/segment_max_cuda.cu +++ b/gammagl/mpops/torch_ext/cuda/segment_max_cuda.cu @@ -31,33 +31,37 @@ using torch::autograd::variable_list; // } template -__device__ void atomic_max(scalar_t* const address, const scalar_t value); +__device__ void atomic_max(scalar_t *const address, const scalar_t value); template <> -__device__ void atomic_max(int32_t* const address, const int32_t value) { - atomicMax(address, value); +__device__ void atomic_max( + int32_t *const address, const int32_t value) { + atomicMax(address, value); } template <> -__device__ void atomic_max(float* const address, const float value) { - int* const address_as_i = (int*)address; - int old = *address_as_i, assumed; - do { - assumed = old; - old = atomicCAS(address_as_i, assumed, - __float_as_int(fmaxf(value, __int_as_float(assumed)))); - } while (assumed != old); +__device__ void atomic_max(float *const address, const float value) { + int *const address_as_i = (int *)address; + int old = *address_as_i, assumed; + do { + assumed = old; + old = atomicCAS( + address_as_i, assumed, + __float_as_int(fmaxf(value, __int_as_float(assumed)))); + } while (assumed != old); } template <> -__device__ void atomic_max(double* const address, const double value) { - unsigned long long int* const address_as_ull = (unsigned long long int*)address; - unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(fmax(value, __longlong_as_double(assumed)))); - } while (assumed != old); +__device__ void atomic_max(double *const address, const double value) { + unsigned long long int *const address_as_ull = + (unsigned long long int *)address; + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS( + address_as_ull, assumed, + __double_as_longlong(fmax(value, __longlong_as_double(assumed)))); + } while (assumed != old); } template @@ -133,19 +137,21 @@ std::tuple segment_max_cuda_forward( auto K = x.numel() / x.size(0); auto stream = at::cuda::getCurrentCUDAStream(); - if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || x.dtype() == torch::kInt32 || x.dtype() == torch::kInt64) { - if (x.dtype() == torch::kInt8){ + if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || + x.dtype() == torch::kInt32 || x.dtype() == torch::kInt64) { + if (x.dtype() == torch::kInt8) { out.fill_(std::numeric_limits::lowest()); - } else if (x.dtype() == torch::kInt16){ + } else if (x.dtype() == torch::kInt16) { out.fill_(std::numeric_limits::lowest()); - } else if (x.dtype() == torch::kInt32){ + } else if (x.dtype() == torch::kInt32) { out.fill_(std::numeric_limits::lowest()); - } else if (x.dtype() == torch::kInt64){ + } else if (x.dtype() == torch::kInt64) { out.fill_(std::numeric_limits::lowest()); } auto type = x.dtype(); using scalar_t = int; - if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || x.dtype() == torch::kInt64) { + if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || + x.dtype() == torch::kInt64) { x = x.to(torch::kInt32); out = out.to(torch::kInt32); } @@ -162,9 +168,9 @@ std::tuple segment_max_cuda_forward( <<>>( x_data, index_data, out_data, arg_out_data, E, K, N, x.numel(), out.size(0)); - + out = out.to(type); - + } else if (x.dtype() == torch::kFloat16 || x.dtype() == torch::kFloat32) { auto type = x.dtype(); using scalar_t = float; @@ -187,7 +193,7 @@ std::tuple segment_max_cuda_forward( <<>>( x_data, index_data, out_data, arg_out_data, E, K, N, x.numel(), out.size(0)); - + out = out.to(type); } else if (x.dtype() == torch::kFloat64) { using scalar_t = double; diff --git a/gammagl/mpops/torch_ext/cuda/segment_mean_cuda.cu b/gammagl/mpops/torch_ext/cuda/segment_mean_cuda.cu index 72fdac13..e158cd3e 100644 --- a/gammagl/mpops/torch_ext/cuda/segment_mean_cuda.cu +++ b/gammagl/mpops/torch_ext/cuda/segment_mean_cuda.cu @@ -19,7 +19,6 @@ using torch::autograd::variable_list; #define THREADS 1024 #define BLOCKS(N) (N + THREADS - 1) / THREADS - template __global__ void segment_mean_cuda_forward_kernel( const scalar_t *x_data, const int64_t *index_data, scalar_t *out_data, @@ -39,8 +38,7 @@ __global__ void segment_mean_cuda_forward_kernel( template __global__ void arg_segment_mean_cuda_forward_kernel( const scalar_t *x_data, const int64_t *index_data, scalar_t *out_data, - scalar_t *count_data, int64_t E, int64_t K, - int64_t N, int64_t numel) { + scalar_t *count_data, int64_t E, int64_t K, int64_t N, int64_t numel) { int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; if (thread_idx < numel) { @@ -81,10 +79,12 @@ torch::Tensor segment_mean_cuda_forward( auto K = x.numel() / x.size(0); auto stream = at::cuda::getCurrentCUDAStream(); - if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || x.dtype() == torch::kInt32 || x.dtype() == torch::kInt64) { + if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || + x.dtype() == torch::kInt32 || x.dtype() == torch::kInt64) { auto type = x.dtype(); using scalar_t = int; - if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || x.dtype() == torch::kInt64) { + if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || + x.dtype() == torch::kInt64) { x = x.to(torch::kInt32); out = out.to(torch::kInt32); } @@ -102,9 +102,9 @@ torch::Tensor segment_mean_cuda_forward( arg_segment_mean_cuda_forward_kernel <<>>( - x_data, index_data, out_data, count_data, E, K, out.sizes().vec()[0], - out.numel()); - + x_data, index_data, out_data, count_data, E, K, + out.sizes().vec()[0], out.numel()); + out = out.to(type); } else if (x.dtype() == torch::kFloat16 || x.dtype() == torch::kFloat32) { auto type = x.dtype(); @@ -127,9 +127,8 @@ torch::Tensor segment_mean_cuda_forward( arg_segment_mean_cuda_forward_kernel <<>>( - x_data, index_data, out_data, count_data, E, K, N, - out.numel()); - + x_data, index_data, out_data, count_data, E, K, N, out.numel()); + out = out.to(type); } else if (x.dtype() == torch::kFloat64) { using scalar_t = double; @@ -146,8 +145,7 @@ torch::Tensor segment_mean_cuda_forward( arg_segment_mean_cuda_forward_kernel <<>>( - x_data, index_data, out_data, count_data, E, K, N, - out.numel()); + x_data, index_data, out_data, count_data, E, K, N, out.numel()); } return out; diff --git a/gammagl/mpops/torch_ext/cuda/segment_sum_cuda.cu b/gammagl/mpops/torch_ext/cuda/segment_sum_cuda.cu index d026d4ed..e57cb465 100644 --- a/gammagl/mpops/torch_ext/cuda/segment_sum_cuda.cu +++ b/gammagl/mpops/torch_ext/cuda/segment_sum_cuda.cu @@ -66,10 +66,12 @@ torch::Tensor segment_sum_cuda_forward( auto K = x.numel() / x.size(0); auto stream = at::cuda::getCurrentCUDAStream(); - if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || x.dtype() == torch::kInt32 || x.dtype() == torch::kInt64) { + if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || + x.dtype() == torch::kInt32 || x.dtype() == torch::kInt64) { auto type = x.dtype(); using scalar_t = int; - if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || x.dtype() == torch::kInt64) { + if (x.dtype() == torch::kInt8 || x.dtype() == torch::kInt16 || + x.dtype() == torch::kInt64) { x = x.to(torch::kInt32); out = out.to(torch::kInt32); } @@ -80,7 +82,7 @@ torch::Tensor segment_sum_cuda_forward( segment_sum_cuda_forward_kernel <<>>( x_data, index_data, out_data, E, K, N, x.numel()); - + out = out.to(type); } else if (x.dtype() == torch::kFloat16 || x.dtype() == torch::kFloat32) { auto type = x.dtype(); @@ -97,7 +99,7 @@ torch::Tensor segment_sum_cuda_forward( segment_sum_cuda_forward_kernel <<>>( x_data, index_data, out_data, E, K, N, x.numel()); - + out = out.to(type); } else if (x.dtype() == torch::kFloat64) { using scalar_t = double; diff --git a/gammagl/mpops/torch_ext/include/gspmm.h b/gammagl/mpops/torch_ext/include/gspmm.h index 10f332f9..0bbdb221 100644 --- a/gammagl/mpops/torch_ext/include/gspmm.h +++ b/gammagl/mpops/torch_ext/include/gspmm.h @@ -1,34 +1,41 @@ #include - class SpMMSum : public torch::autograd::Function { public: - static torch::Tensor forward(torch::autograd::AutogradContext *ctx, torch::Tensor index, - torch::Tensor weight, torch::Tensor x); - static std::vector backward(torch::autograd::AutogradContext *ctx, - std::vector grad_outs); + static torch::Tensor forward( + torch::autograd::AutogradContext *ctx, torch::Tensor index, + torch::Tensor weight, torch::Tensor x); + static std::vector backward( + torch::autograd::AutogradContext *ctx, + std::vector grad_outs); }; class SpMMMean : public torch::autograd::Function { public: - static torch::Tensor forward(torch::autograd::AutogradContext *ctx, torch::Tensor index, - torch::Tensor weight, torch::Tensor x); - static std::vector backward(torch::autograd::AutogradContext *ctx, - std::vector grad_outs); + static torch::Tensor forward( + torch::autograd::AutogradContext *ctx, torch::Tensor index, + torch::Tensor weight, torch::Tensor x); + static std::vector backward( + torch::autograd::AutogradContext *ctx, + std::vector grad_outs); }; class SpMMMax : public torch::autograd::Function { public: - static torch::Tensor forward(torch::autograd::AutogradContext *ctx, torch::Tensor index, - torch::Tensor weight, torch::Tensor x); - static std::vector backward(torch::autograd::AutogradContext *ctx, - std::vector grad_outs); + static torch::Tensor forward( + torch::autograd::AutogradContext *ctx, torch::Tensor index, + torch::Tensor weight, torch::Tensor x); + static std::vector backward( + torch::autograd::AutogradContext *ctx, + std::vector grad_outs); }; class BSpMMSum : public torch::autograd::Function { - public: - static torch::Tensor forward(torch::autograd::AutogradContext *ctx, torch::Tensor index, - torch::Tensor weight, torch::Tensor x); - static std::vector backward(torch::autograd::AutogradContext *ctx, - std::vector grad_outs); + public: + static torch::Tensor forward( + torch::autograd::AutogradContext *ctx, torch::Tensor index, + torch::Tensor weight, torch::Tensor x); + static std::vector backward( + torch::autograd::AutogradContext *ctx, + std::vector grad_outs); }; diff --git a/gammagl/mpops/torch_ext/src/gspmm.cpp b/gammagl/mpops/torch_ext/src/gspmm.cpp index 0840ac7e..1fd35408 100644 --- a/gammagl/mpops/torch_ext/src/gspmm.cpp +++ b/gammagl/mpops/torch_ext/src/gspmm.cpp @@ -1,15 +1,17 @@ #include "../include/gspmm.h" + #include -#include #include #include #include + +#include #include -#include "../cpu/spmm_sum_cpu.h" -#include "../cpu/spmm_mean_cpu.h" -#include "../cpu/spmm_max_cpu.h" #include "../cpu/bspmm_sum_cpu.h" +#include "../cpu/spmm_max_cpu.h" +#include "../cpu/spmm_mean_cpu.h" +#include "../cpu/spmm_sum_cpu.h" #ifdef COMPILE_WITH_CUDA #include "../cuda/spmm_sum_cuda.h" @@ -21,205 +23,238 @@ // 2. generalized operators to support more data // structures, such as csr, csc, etc. -torch::Tensor SpMMSum::forward(torch::autograd::AutogradContext *ctx, torch::Tensor index, - torch::Tensor weight, torch::Tensor x) { - ctx->save_for_backward({index, weight, x}); - ctx->mark_non_differentiable({index, weight}); - torch::Tensor out; - // CUDA - if (x.is_cuda() && index.is_cuda() && weight.is_cuda()) { - #ifdef COMPILE_WITH_CUDA - out = spmm_sum_cuda_forward(index, weight, x); - #else - AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU."); - #endif - } - // CPU - else if (x.is_cpu() && index.is_cpu() && weight.is_cpu()) { - out = spmm_sum_cpu_forward(index, weight, x); - } else { - AT_ERROR("Tensor device inconsistent error."); - } - - return out; - } - -std::vector SpMMSum::backward(torch::autograd::AutogradContext *ctx, std::vector grad_outs) { - auto saved = ctx->get_saved_variables(); - auto index = saved[0], weight = saved[1], x = saved[2]; - auto grad = grad_outs[0]; - torch::Tensor grad_x; - - // CUDA - if (grad.is_cuda() && index.is_cuda() && weight.is_cuda()) { - #ifdef COMPILE_WITH_CUDA - grad_x = spmm_sum_cuda_backward(index, weight, grad); - #else - AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU."); - #endif - } - // CPU - else if (grad.is_cpu() && index.is_cpu() && weight.is_cpu()) { - grad_x = spmm_sum_cpu_backward(index, weight, grad); - } else { - AT_ERROR("Tensor device inconsistent error."); - } - - return {torch::Tensor(), torch::Tensor(), grad_x}; +torch::Tensor SpMMSum::forward( + torch::autograd::AutogradContext *ctx, torch::Tensor index, + torch::Tensor weight, torch::Tensor x) { + ctx->save_for_backward({index, weight, x}); + ctx->mark_non_differentiable({index, weight}); + torch::Tensor out; + // CUDA + if (x.is_cuda() && index.is_cuda() && weight.is_cuda()) { +#ifdef COMPILE_WITH_CUDA + out = spmm_sum_cuda_forward(index, weight, x); +#else + AT_ERROR( + "The program is not compiled with CUDA support, but tensors are " + "located on GPU. Please recompile with CUDA support or move tensors to " + "CPU."); +#endif + } + // CPU + else if (x.is_cpu() && index.is_cpu() && weight.is_cpu()) { + out = spmm_sum_cpu_forward(index, weight, x); + } else { + AT_ERROR("Tensor device inconsistent error."); + } + + return out; +} + +std::vector SpMMSum::backward( + torch::autograd::AutogradContext *ctx, + std::vector grad_outs) { + auto saved = ctx->get_saved_variables(); + auto index = saved[0], weight = saved[1], x = saved[2]; + auto grad = grad_outs[0]; + torch::Tensor grad_x; + + // CUDA + if (grad.is_cuda() && index.is_cuda() && weight.is_cuda()) { +#ifdef COMPILE_WITH_CUDA + grad_x = spmm_sum_cuda_backward(index, weight, grad); +#else + AT_ERROR( + "The program is not compiled with CUDA support, but tensors are " + "located on GPU. Please recompile with CUDA support or move tensors to " + "CPU."); +#endif + } + // CPU + else if (grad.is_cpu() && index.is_cpu() && weight.is_cpu()) { + grad_x = spmm_sum_cpu_backward(index, weight, grad); + } else { + AT_ERROR("Tensor device inconsistent error."); + } + + return {torch::Tensor(), torch::Tensor(), grad_x}; } - -torch::Tensor SpMMMean::forward(torch::autograd::AutogradContext *ctx, torch::Tensor index, - torch::Tensor weight, torch::Tensor x) { - ctx->mark_non_differentiable({index, weight}); - std::tuple result; - - // CUDA - if (x.is_cuda() && index.is_cuda() && weight.is_cuda()) { - AT_ERROR("The program is not support CUDA !"); - // #ifdef COMPILE_WITH_CUDA - // // grad_x = spmm_sum_cuda_backward(index, weight, grad, max_indices); - // grad_x = spmm_sum_cuda_backward(index, weight, grad); - // #else - // AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU."); - // #endif - } - // CPU - else if (x.is_cpu() && index.is_cpu() && weight.is_cpu()) { - result = spmm_mean_cpu_forward(index, weight, x); - } else { - AT_ERROR("Tensor device inconsistent error."); - } - - auto out = std::get<0>(result); - auto arg_out = std::get<1>(result); - ctx->save_for_backward({index, weight, x, arg_out}); - return out; - } - -std::vector SpMMMean::backward(torch::autograd::AutogradContext *ctx, std::vector grad_outs) { - auto saved = ctx->get_saved_variables(); - auto index = saved[0], weight = saved[1], x = saved[2], messages_count = saved[3]; - auto grad = grad_outs[0]; - torch::Tensor grad_x; - - // CUDA - if (grad.is_cuda() && index.is_cuda() && weight.is_cuda()) { - AT_ERROR("The program is not support CUDA !"); + +torch::Tensor SpMMMean::forward( + torch::autograd::AutogradContext *ctx, torch::Tensor index, + torch::Tensor weight, torch::Tensor x) { + ctx->mark_non_differentiable({index, weight}); + std::tuple result; + + // CUDA + if (x.is_cuda() && index.is_cuda() && weight.is_cuda()) { + AT_ERROR("The program is not support CUDA !"); // #ifdef COMPILE_WITH_CUDA - // result = spmm_sum_cuda_forward(index, weight, x); + // // grad_x = spmm_sum_cuda_backward(index, weight, grad, max_indices); + // grad_x = spmm_sum_cuda_backward(index, weight, grad); // #else - // AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU."); + // AT_ERROR("The program is not compiled with CUDA support, but tensors + // are located on GPU. Please recompile with CUDA support or move + // tensors to CPU."); // #endif - } - // CPU - else if (grad.is_cpu() && index.is_cpu() && weight.is_cpu()) { - grad_x = spmm_mean_cpu_backward(index, weight, grad, messages_count); - } else { - AT_ERROR("Tensor device inconsistent error."); - } - - return {torch::Tensor(), torch::Tensor(), grad_x}; + } + // CPU + else if (x.is_cpu() && index.is_cpu() && weight.is_cpu()) { + result = spmm_mean_cpu_forward(index, weight, x); + } else { + AT_ERROR("Tensor device inconsistent error."); + } + + auto out = std::get<0>(result); + auto arg_out = std::get<1>(result); + ctx->save_for_backward({index, weight, x, arg_out}); + return out; } -torch::Tensor SpMMMax::forward(torch::autograd::AutogradContext *ctx, torch::Tensor index, - torch::Tensor weight, torch::Tensor x) { - ctx->mark_non_differentiable({index, weight}); - std::tuple result; +std::vector SpMMMean::backward( + torch::autograd::AutogradContext *ctx, + std::vector grad_outs) { + auto saved = ctx->get_saved_variables(); + auto index = saved[0], weight = saved[1], x = saved[2], + messages_count = saved[3]; + auto grad = grad_outs[0]; + torch::Tensor grad_x; - // CUDA - if (x.is_cuda() && index.is_cuda() && weight.is_cuda()) { - AT_ERROR("The program is not support CUDA !"); + // CUDA + if (grad.is_cuda() && index.is_cuda() && weight.is_cuda()) { + AT_ERROR("The program is not support CUDA !"); // #ifdef COMPILE_WITH_CUDA - // result = spmm_sum_cuda_forward(index, weight, x); + // result = spmm_sum_cuda_forward(index, weight, x); // #else - // AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU."); + // AT_ERROR("The program is not compiled with CUDA support, but tensors + // are located on GPU. Please recompile with CUDA support or move + // tensors to CPU."); // #endif - } - // CPU - else if (x.is_cpu() && index.is_cpu() && weight.is_cpu()) { - result = spmm_max_cpu_forward(index, weight, x); - } else { - AT_ERROR("Tensor device inconsistent error."); - } - - auto out = std::get<0>(result); - auto arg_out = std::get<1>(result); - ctx->save_for_backward({index, weight, x, arg_out}); - return out; + } + // CPU + else if (grad.is_cpu() && index.is_cpu() && weight.is_cpu()) { + grad_x = spmm_mean_cpu_backward(index, weight, grad, messages_count); + } else { + AT_ERROR("Tensor device inconsistent error."); + } + + return {torch::Tensor(), torch::Tensor(), grad_x}; } -std::vector SpMMMax::backward(torch::autograd::AutogradContext *ctx, std::vector grad_outs) { - auto saved = ctx->get_saved_variables(); - auto index = saved[0], weight = saved[1], x = saved[2], max_indices = saved[3]; - auto grad = grad_outs[0]; - torch::Tensor grad_x; - - // CUDA - if (grad.is_cuda() && index.is_cuda() && weight.is_cuda()) { - AT_ERROR("The program is not support CUDA !"); - // #ifdef COMPILE_WITH_CUDA - // // grad_x = spmm_sum_cuda_backward(index, weight, grad, max_indices); - // grad_x = spmm_sum_cuda_backward(index, weight, grad); - // #else - // AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU."); - // #endif - } - // CPU - else if (grad.is_cpu() && index.is_cpu() && weight.is_cpu()) { - grad_x = spmm_max_cpu_backward(index, weight, grad, max_indices); - } else { - AT_ERROR("Tensor device inconsistent error."); - } - - return {torch::Tensor(), torch::Tensor(), grad_x}; +torch::Tensor SpMMMax::forward( + torch::autograd::AutogradContext *ctx, torch::Tensor index, + torch::Tensor weight, torch::Tensor x) { + ctx->mark_non_differentiable({index, weight}); + std::tuple result; + + // CUDA + if (x.is_cuda() && index.is_cuda() && weight.is_cuda()) { + AT_ERROR("The program is not support CUDA !"); + // #ifdef COMPILE_WITH_CUDA + // result = spmm_sum_cuda_forward(index, weight, x); + // #else + // AT_ERROR("The program is not compiled with CUDA support, but tensors + // are located on GPU. Please recompile with CUDA support or move + // tensors to CPU."); + // #endif + } + // CPU + else if (x.is_cpu() && index.is_cpu() && weight.is_cpu()) { + result = spmm_max_cpu_forward(index, weight, x); + } else { + AT_ERROR("Tensor device inconsistent error."); + } + + auto out = std::get<0>(result); + auto arg_out = std::get<1>(result); + ctx->save_for_backward({index, weight, x, arg_out}); + return out; } +std::vector SpMMMax::backward( + torch::autograd::AutogradContext *ctx, + std::vector grad_outs) { + auto saved = ctx->get_saved_variables(); + auto index = saved[0], weight = saved[1], x = saved[2], + max_indices = saved[3]; + auto grad = grad_outs[0]; + torch::Tensor grad_x; + + // CUDA + if (grad.is_cuda() && index.is_cuda() && weight.is_cuda()) { + AT_ERROR("The program is not support CUDA !"); + // #ifdef COMPILE_WITH_CUDA + // // grad_x = spmm_sum_cuda_backward(index, weight, grad, max_indices); + // grad_x = spmm_sum_cuda_backward(index, weight, grad); + // #else + // AT_ERROR("The program is not compiled with CUDA support, but tensors + // are located on GPU. Please recompile with CUDA support or move + // tensors to CPU."); + // #endif + } + // CPU + else if (grad.is_cpu() && index.is_cpu() && weight.is_cpu()) { + grad_x = spmm_max_cpu_backward(index, weight, grad, max_indices); + } else { + AT_ERROR("Tensor device inconsistent error."); + } + + return {torch::Tensor(), torch::Tensor(), grad_x}; +} -torch::Tensor BSpMMSum::forward(torch::autograd::AutogradContext *ctx, torch::Tensor index, - torch::Tensor weight, torch::Tensor x) { - ctx->save_for_backward({index, weight, x}); - ctx->mark_non_differentiable({index, weight}); - torch::Tensor out; - // CUDA - if (x.is_cuda() && index.is_cuda() && weight.is_cuda()) { +torch::Tensor BSpMMSum::forward( + torch::autograd::AutogradContext *ctx, torch::Tensor index, + torch::Tensor weight, torch::Tensor x) { + ctx->save_for_backward({index, weight, x}); + ctx->mark_non_differentiable({index, weight}); + torch::Tensor out; + // CUDA + if (x.is_cuda() && index.is_cuda() && weight.is_cuda()) { // #ifdef COMPILE_WITH_CUDA // out = bspmm_sum_cuda_forward(index, weight, x); // #else - AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU."); + AT_ERROR( + "The program is not compiled with CUDA support, but tensors are " + "located on GPU. Please recompile with CUDA support or move tensors to " + "CPU."); // #endif - } - // CPU - else if (x.is_cpu() && index.is_cpu() && weight.is_cpu()) { - out = bspmm_sum_cpu_forward(index, weight, x); - } else { - AT_ERROR("Tensor device inconsistent error."); - } - - return out; + } + // CPU + else if (x.is_cpu() && index.is_cpu() && weight.is_cpu()) { + out = bspmm_sum_cpu_forward(index, weight, x); + } else { + AT_ERROR("Tensor device inconsistent error."); + } + + return out; } -std::vector BSpMMSum::backward(torch::autograd::AutogradContext *ctx, std::vector grad_outs) { - auto saved = ctx->get_saved_variables(); - auto index = saved[0], weight = saved[1], x = saved[2]; - auto grad = grad_outs[0]; - torch::Tensor grad_x, grad_weight; - - // CUDA - if (grad.is_cuda() && index.is_cuda() && weight.is_cuda()) { - // #ifdef COMPILE_WITH_CUDA - // grad_x = bspmm_sum_cuda_backward(index, weight, grad); - // #else - AT_ERROR("The program is not compiled with CUDA support, but tensors are located on GPU. Please recompile with CUDA support or move tensors to CPU."); - // #endif - } - // CPU - else if (grad.is_cpu() && index.is_cpu() && weight.is_cpu()) { - auto result = bspmm_sum_cpu_backward(index, weight, x, grad); - grad_x = std::get<0>(result); - grad_weight = std::get<1>(result); - } else { - AT_ERROR("Tensor device inconsistent error."); - } - - return {torch::Tensor(), grad_weight, grad_x}; +std::vector BSpMMSum::backward( + torch::autograd::AutogradContext *ctx, + std::vector grad_outs) { + auto saved = ctx->get_saved_variables(); + auto index = saved[0], weight = saved[1], x = saved[2]; + auto grad = grad_outs[0]; + torch::Tensor grad_x, grad_weight; + + // CUDA + if (grad.is_cuda() && index.is_cuda() && weight.is_cuda()) { + // #ifdef COMPILE_WITH_CUDA + // grad_x = bspmm_sum_cuda_backward(index, weight, grad); + // #else + AT_ERROR( + "The program is not compiled with CUDA support, but tensors are " + "located on GPU. Please recompile with CUDA support or move tensors to " + "CPU."); + // #endif + } + // CPU + else if (grad.is_cpu() && index.is_cpu() && weight.is_cpu()) { + auto result = bspmm_sum_cpu_backward(index, weight, x, grad); + grad_x = std::get<0>(result); + grad_weight = std::get<1>(result); + } else { + AT_ERROR("Tensor device inconsistent error."); + } + + return {torch::Tensor(), grad_weight, grad_x}; } diff --git a/gammagl/mpops/torch_ext/src/operators.cpp b/gammagl/mpops/torch_ext/src/operators.cpp index 5f99cd92..dd5c5f70 100644 --- a/gammagl/mpops/torch_ext/src/operators.cpp +++ b/gammagl/mpops/torch_ext/src/operators.cpp @@ -1,7 +1,3 @@ -#include "../include/segment_max.h" -#include "../include/segment_sum.h" -#include "../include/segment_mean.h" -#include "../include/gspmm.h" #include #include #include @@ -10,6 +6,11 @@ #include #include +#include "../include/gspmm.h" +#include "../include/segment_max.h" +#include "../include/segment_mean.h" +#include "../include/segment_sum.h" + torch::Tensor segment_max(torch::Tensor x, torch::Tensor index, int64_t N) { auto result = SegmentMax::apply(x, index, N); return result; @@ -25,24 +26,26 @@ torch::Tensor segment_mean(torch::Tensor x, torch::Tensor index, int64_t N) { return result; } -torch::Tensor spmm_sum(torch::Tensor index, torch::Tensor weight, - torch::Tensor x) { +torch::Tensor spmm_sum( + torch::Tensor index, torch::Tensor weight, torch::Tensor x) { auto result = SpMMSum::apply(index, weight, x); return result; } -torch::Tensor spmm_mean(torch::Tensor index, torch::Tensor weight, torch::Tensor x) { - return SpMMMean::apply(index, weight, x); +torch::Tensor spmm_mean( + torch::Tensor index, torch::Tensor weight, torch::Tensor x) { + return SpMMMean::apply(index, weight, x); } -torch::Tensor spmm_max(torch::Tensor index, torch::Tensor weight, torch::Tensor x) { - return SpMMMax::apply(index, weight, x); +torch::Tensor spmm_max( + torch::Tensor index, torch::Tensor weight, torch::Tensor x) { + return SpMMMax::apply(index, weight, x); } -torch::Tensor bspmm_sum(torch::Tensor index, torch::Tensor weight, - torch::Tensor x) { - auto result = BSpMMSum::apply(index, weight, x); - return result; +torch::Tensor bspmm_sum( + torch::Tensor index, torch::Tensor weight, torch::Tensor x) { + auto result = BSpMMSum::apply(index, weight, x); + return result; } PYBIND11_MODULE(_torch_ext, m) {