From 611f4c01f448545db37bcb2e95f298d1b793177b Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sat, 28 Dec 2024 23:48:03 +0000 Subject: [PATCH] update --- pyg_lib/csrc/ops/autograd/matmul_kernel.cpp | 17 ++++++++--------- pyg_lib/csrc/ops/autograd/sampled_kernel.cpp | 13 ++++++------- pyg_lib/csrc/ops/autograd/softmax_kernel.cpp | 9 ++++----- 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/pyg_lib/csrc/ops/autograd/matmul_kernel.cpp b/pyg_lib/csrc/ops/autograd/matmul_kernel.cpp index d421d076c..1a20d5bb1 100644 --- a/pyg_lib/csrc/ops/autograd/matmul_kernel.cpp +++ b/pyg_lib/csrc/ops/autograd/matmul_kernel.cpp @@ -9,7 +9,6 @@ namespace ops { namespace { -using torch::autograd::Variable; using torch::autograd::variable_list; std::vector concat(std::vector t1, @@ -50,7 +49,7 @@ class GroupedMatmul : public torch::autograd::Function { } } else { for (size_t i = 0; i < other.size(); ++i) - other_grad.push_back(Variable()); + other_grad.push_back(at::Tensor()); } variable_list input_grad; @@ -60,7 +59,7 @@ class GroupedMatmul : public torch::autograd::Function { input_grad = grouped_matmul(input, grad_outs); } else { for (size_t i = 0; i < input.size(); ++i) - input_grad.push_back(Variable()); + input_grad.push_back(at::Tensor()); } return concat(input_grad, other_grad); } @@ -69,11 +68,11 @@ class GroupedMatmul : public torch::autograd::Function { class SegmentMatmul : public torch::autograd::Function { public: static variable_list forward(torch::autograd::AutogradContext* ctx, - const Variable& input, + const at::Tensor& input, const at::Tensor& ptr, - const Variable& other) { + const at::Tensor& other) { at::AutoDispatchBelowADInplaceOrView g; - Variable out = segment_matmul(input, ptr, other); + at::Tensor out = segment_matmul(input, ptr, other); ctx->save_for_backward({input, ptr, other}); return {out}; } @@ -84,13 +83,13 @@ class SegmentMatmul : public torch::autograd::Function { auto saved = ctx->get_saved_variables(); auto input = saved[0], ptr = saved[1], other = saved[2]; - auto input_grad = Variable(); + auto input_grad = at::Tensor(); if (torch::autograd::any_variable_requires_grad({input})) { auto other_t = other.transpose(-2, -1); input_grad = segment_matmul(grad_out, ptr, other_t); } - auto other_grad = Variable(); + auto other_grad = at::Tensor(); if (torch::autograd::any_variable_requires_grad({other})) { auto size = pyg::utils::size_from_ptr(ptr).cpu(); // TODO (matthias) Allow for other types than `int64_t`. @@ -107,7 +106,7 @@ class SegmentMatmul : public torch::autograd::Function { other_grad = at::stack(others_grad); } - return {input_grad, Variable(), other_grad}; + return {input_grad, at::Tensor(), other_grad}; } }; diff --git a/pyg_lib/csrc/ops/autograd/sampled_kernel.cpp b/pyg_lib/csrc/ops/autograd/sampled_kernel.cpp index 7e9a228a6..4dde96d6e 100644 --- a/pyg_lib/csrc/ops/autograd/sampled_kernel.cpp +++ b/pyg_lib/csrc/ops/autograd/sampled_kernel.cpp @@ -7,19 +7,18 @@ namespace ops { namespace { -using torch::autograd::Variable; using torch::autograd::variable_list; class SampledOp : public torch::autograd::Function { public: static variable_list forward(torch::autograd::AutogradContext* ctx, - const Variable& left, - const Variable& right, + const at::Tensor& left, + const at::Tensor& right, const at::optional left_index, const at::optional right_index, const std::string fn) { at::AutoDispatchBelowADInplaceOrView g; - Variable out = sampled_op(left, right, left_index, right_index, fn); + at::Tensor out = sampled_op(left, right, left_index, right_index, fn); ctx->saved_data["has_left_index"] = left_index.has_value(); ctx->saved_data["has_right_index"] = right_index.has_value(); ctx->saved_data["fn"] = fn; @@ -48,7 +47,7 @@ class SampledOp : public torch::autograd::Function { } auto fn = ctx->saved_data["fn"].toStringRef(); - auto grad_left = Variable(); + auto grad_left = at::Tensor(); if (torch::autograd::any_variable_requires_grad({left})) { grad_left = grad_out; @@ -66,7 +65,7 @@ class SampledOp : public torch::autograd::Function { } } - auto grad_right = Variable(); + auto grad_right = at::Tensor(); if (torch::autograd::any_variable_requires_grad({right})) { grad_right = grad_out; @@ -91,7 +90,7 @@ class SampledOp : public torch::autograd::Function { } } - return {grad_left, grad_right, Variable(), Variable(), Variable()}; + return {grad_left, grad_right, at::Tensor(), at::Tensor(), at::Tensor()}; } }; diff --git a/pyg_lib/csrc/ops/autograd/softmax_kernel.cpp b/pyg_lib/csrc/ops/autograd/softmax_kernel.cpp index 52f4696b9..cbc35007f 100644 --- a/pyg_lib/csrc/ops/autograd/softmax_kernel.cpp +++ b/pyg_lib/csrc/ops/autograd/softmax_kernel.cpp @@ -7,18 +7,17 @@ namespace ops { namespace { -using torch::autograd::Variable; using torch::autograd::variable_list; class SoftmaxCSR : public torch::autograd::Function { public: static variable_list forward(torch::autograd::AutogradContext* ctx, - const Variable& src, + const at::Tensor& src, const at::Tensor& ptr, const int64_t dim) { at::AutoDispatchBelowADInplaceOrView g; - Variable out = softmax_csr(src, ptr, dim); + at::Tensor out = softmax_csr(src, ptr, dim); ctx->saved_data["dim"] = dim; ctx->save_for_backward({src, out, ptr}); @@ -34,12 +33,12 @@ class SoftmaxCSR : public torch::autograd::Function { const auto ptr = saved[2]; const auto dim = ctx->saved_data["dim"].toInt(); - auto src_grad = Variable(); + auto src_grad = at::Tensor(); if (torch::autograd::any_variable_requires_grad({src})) { src_grad = softmax_csr_backward(out, out_grad, ptr, dim); } - return {src_grad, Variable(), Variable()}; + return {src_grad, at::Tensor(), at::Tensor()}; } };