Skip to content

Commit

Permalink
Remove usage of torch::autograd::Variable (#369)
Browse files Browse the repository at this point in the history
  • Loading branch information
akihironitta authored Dec 30, 2024
1 parent 706c9ae commit fdec9bd
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 21 deletions.
17 changes: 8 additions & 9 deletions pyg_lib/csrc/ops/autograd/matmul_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ namespace ops {

namespace {

using torch::autograd::Variable;
using torch::autograd::variable_list;

std::vector<at::Tensor> concat(std::vector<at::Tensor> t1,
Expand Down Expand Up @@ -50,7 +49,7 @@ class GroupedMatmul : public torch::autograd::Function<GroupedMatmul> {
}
} else {
for (size_t i = 0; i < other.size(); ++i)
other_grad.push_back(Variable());
other_grad.push_back(at::Tensor());
}

variable_list input_grad;
Expand All @@ -60,7 +59,7 @@ class GroupedMatmul : public torch::autograd::Function<GroupedMatmul> {
input_grad = grouped_matmul(input, grad_outs);
} else {
for (size_t i = 0; i < input.size(); ++i)
input_grad.push_back(Variable());
input_grad.push_back(at::Tensor());
}
return concat(input_grad, other_grad);
}
Expand All @@ -69,11 +68,11 @@ class GroupedMatmul : public torch::autograd::Function<GroupedMatmul> {
class SegmentMatmul : public torch::autograd::Function<SegmentMatmul> {
public:
static variable_list forward(torch::autograd::AutogradContext* ctx,
const Variable& input,
const at::Tensor& input,
const at::Tensor& ptr,
const Variable& other) {
const at::Tensor& other) {
at::AutoDispatchBelowADInplaceOrView g;
Variable out = segment_matmul(input, ptr, other);
at::Tensor out = segment_matmul(input, ptr, other);
ctx->save_for_backward({input, ptr, other});
return {out};
}
Expand All @@ -84,13 +83,13 @@ class SegmentMatmul : public torch::autograd::Function<SegmentMatmul> {
auto saved = ctx->get_saved_variables();
auto input = saved[0], ptr = saved[1], other = saved[2];

auto input_grad = Variable();
auto input_grad = at::Tensor();
if (torch::autograd::any_variable_requires_grad({input})) {
auto other_t = other.transpose(-2, -1);
input_grad = segment_matmul(grad_out, ptr, other_t);
}

auto other_grad = Variable();
auto other_grad = at::Tensor();
if (torch::autograd::any_variable_requires_grad({other})) {
auto size = pyg::utils::size_from_ptr(ptr).cpu();
// TODO (matthias) Allow for other types than `int64_t`.
Expand All @@ -107,7 +106,7 @@ class SegmentMatmul : public torch::autograd::Function<SegmentMatmul> {
other_grad = at::stack(others_grad);
}

return {input_grad, Variable(), other_grad};
return {input_grad, at::Tensor(), other_grad};
}
};

Expand Down
13 changes: 6 additions & 7 deletions pyg_lib/csrc/ops/autograd/sampled_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,18 @@ namespace ops {

namespace {

using torch::autograd::Variable;
using torch::autograd::variable_list;

class SampledOp : public torch::autograd::Function<SampledOp> {
public:
static variable_list forward(torch::autograd::AutogradContext* ctx,
const Variable& left,
const Variable& right,
const at::Tensor& left,
const at::Tensor& right,
const at::optional<at::Tensor> left_index,
const at::optional<at::Tensor> right_index,
const std::string fn) {
at::AutoDispatchBelowADInplaceOrView g;
Variable out = sampled_op(left, right, left_index, right_index, fn);
at::Tensor out = sampled_op(left, right, left_index, right_index, fn);
ctx->saved_data["has_left_index"] = left_index.has_value();
ctx->saved_data["has_right_index"] = right_index.has_value();
ctx->saved_data["fn"] = fn;
Expand Down Expand Up @@ -48,7 +47,7 @@ class SampledOp : public torch::autograd::Function<SampledOp> {
}
auto fn = ctx->saved_data["fn"].toStringRef();

auto grad_left = Variable();
auto grad_left = at::Tensor();
if (torch::autograd::any_variable_requires_grad({left})) {
grad_left = grad_out;

Expand All @@ -66,7 +65,7 @@ class SampledOp : public torch::autograd::Function<SampledOp> {
}
}

auto grad_right = Variable();
auto grad_right = at::Tensor();
if (torch::autograd::any_variable_requires_grad({right})) {
grad_right = grad_out;

Expand All @@ -91,7 +90,7 @@ class SampledOp : public torch::autograd::Function<SampledOp> {
}
}

return {grad_left, grad_right, Variable(), Variable(), Variable()};
return {grad_left, grad_right, at::Tensor(), at::Tensor(), at::Tensor()};
}
};

Expand Down
9 changes: 4 additions & 5 deletions pyg_lib/csrc/ops/autograd/softmax_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,17 @@ namespace ops {

namespace {

using torch::autograd::Variable;
using torch::autograd::variable_list;

class SoftmaxCSR : public torch::autograd::Function<SoftmaxCSR> {
public:
static variable_list forward(torch::autograd::AutogradContext* ctx,
const Variable& src,
const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim) {
at::AutoDispatchBelowADInplaceOrView g;

Variable out = softmax_csr(src, ptr, dim);
at::Tensor out = softmax_csr(src, ptr, dim);
ctx->saved_data["dim"] = dim;
ctx->save_for_backward({src, out, ptr});

Expand All @@ -34,12 +33,12 @@ class SoftmaxCSR : public torch::autograd::Function<SoftmaxCSR> {
const auto ptr = saved[2];
const auto dim = ctx->saved_data["dim"].toInt();

auto src_grad = Variable();
auto src_grad = at::Tensor();
if (torch::autograd::any_variable_requires_grad({src})) {
src_grad = softmax_csr_backward(out, out_grad, ptr, dim);
}

return {src_grad, Variable(), Variable()};
return {src_grad, at::Tensor(), at::Tensor()};
}
};

Expand Down

0 comments on commit fdec9bd

Please sign in to comment.