Skip to content

Commit

Permalink
Revert "Revert "Enabling SymInt in autograd; take 3 (pytorch#81145)""…
Browse files Browse the repository at this point in the history
… ; make sure is_intlist checks for symintnodes (pytorch#82189)

### Description
<!-- What did you change and why was it needed? -->

### Issue
<!-- Link to Issue ticket or RFP -->

### Testing
<!-- How did you test your change? -->

Pull Request resolved: pytorch#82189
Approved by: https://github.com/ezyang
  • Loading branch information
Krovatkin authored and pytorchmergebot committed Jul 26, 2022
1 parent 30e74be commit d2c47d5
Show file tree
Hide file tree
Showing 33 changed files with 373 additions and 64 deletions.
13 changes: 11 additions & 2 deletions aten/src/ATen/BatchingRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,13 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit)
return self_physical.getPhysicalToLogicalMap().apply(result);
}

Tensor expand_batching_rule_symint(const Tensor& self, SymIntArrayRef psize, bool implicit) {
Tensor expand_symint_batching_rule(const Tensor& self, SymIntArrayRef psize, bool implicit) {
return expand_batching_rule(self, asIntArrayRefSlow(psize), implicit);
}

Tensor sum_symint_batching_rule(const Tensor& input_t, c10::SymIntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
return sum_batching_rule(input_t, c10::asIntArrayRefSlow(dim), keepdim, opt_dtype);
}

std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
Expand Down Expand Up @@ -468,6 +471,10 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
return self_physical.getPhysicalToLogicalMap().apply(result);
}

Tensor view_symint_batching_rule(const Tensor& self, c10::SymIntArrayRef size) {
return view_batching_rule(self, asIntArrayRefSlow(size));
}

Tensor view_as_complex_batching_rule(const Tensor& self) {
// guard against the user passing in a batch of scalar tensors with batch
// size equal to 2.
Expand Down Expand Up @@ -1082,6 +1089,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("_new_zeros_with_same_feature_meta", _new_zeros_with_same_feature_meta_batching_rule);

m.impl("sum.dim_IntList", sum_batching_rule);
m.impl("sum.SymInt", sum_symint_batching_rule);
m.impl("is_complex", native::is_complex);

// inplace operations
Expand All @@ -1096,7 +1104,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
m.impl("diagonal", diagonal_batching_rule);
m.impl("expand", expand_batching_rule);
m.impl("expand.SymInt", expand_batching_rule_symint);
m.impl("expand.SymInt", expand_symint_batching_rule);
m.impl("expand_as", native::expand_as); // composite wrt autograd
m.impl("movedim.intlist", movedim_batching_rule);
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
Expand Down Expand Up @@ -1125,6 +1133,7 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
m.impl("unfold", unfold_batching_rule);
m.impl("unsqueeze", unsqueeze_batching_rule);
m.impl("view", view_batching_rule);
m.impl("view.SymInt", view_symint_batching_rule);
m.impl("view_as", native::view_as); // composite wrt autograd

// clamp operations
Expand Down
46 changes: 34 additions & 12 deletions aten/src/ATen/ExpandUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,17 +437,16 @@ inline std::vector<Tensor> expand_outplace(TensorList to_expand) {
return result;
}

// Sums `tensor` repeatedly to produce a tensor of shape `shape`.
// Precondition: is_expandable_to(shape, tensor.sizes()) must be true
static inline Tensor sum_to(
Tensor tensor,
const IntArrayRef shape,
const c10::SymIntArrayRef shape,
bool always_return_non_view = false) {
if (shape.size() == 0) {
return tensor.sum();
}
c10::SmallVector<int64_t, 8> reduce_dims;
const at::IntArrayRef sizes = tensor.sizes();

auto sizes = tensor.sym_sizes();
c10::SmallVector<c10::SymInt, 8> reduce_dims;
const int64_t leading_dims = sizes.size() - shape.size();
for (const auto i : c10::irange(leading_dims)) {
reduce_dims.push_back(i);
Expand All @@ -457,34 +456,57 @@ static inline Tensor sum_to(
reduce_dims.push_back(i);
}
}

if (!reduce_dims.empty()) {
tensor = tensor.sum(reduce_dims, /*keepdim=*/true);
tensor = tensor.sum_symint(reduce_dims, /*keepdim=*/true);
}

if (always_return_non_view) {
// This is only actually used by the functionalization pass.
// We want to be able to guarantee that this function doesn't return a view
// of the input.
return leading_dims > 0 ? at::view_copy(tensor, shape) : tensor.clone();
return leading_dims > 0 ? at::view_copy_symint(tensor, shape)
: tensor.clone();
} else {
return leading_dims > 0 ? tensor.view(shape) : tensor;
return leading_dims > 0 ? tensor.view_symint(shape) : tensor;
}
}

// True if `shape` can be broadcasted to `desired`
static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
// Sums `tensor` repeatedly to produce a tensor of shape `shape`.
// Precondition: is_expandable_to(shape, tensor.sizes()) must be true
static inline Tensor sum_to(
Tensor tensor,
const IntArrayRef shape,
bool always_return_non_view = false) {
auto sym_size = c10::SymIntArrayRef(
reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
return sum_to(tensor, sym_size, always_return_non_view);
}

static inline bool is_expandable_to(
SymIntArrayRef shape,
c10::SymIntArrayRef desired) {
size_t ndim = shape.size();
size_t target_dim = desired.size();
if (ndim > target_dim) {
return false;
}
for (const auto i : c10::irange(ndim)) {
int64_t size = shape[ndim - i - 1];
int64_t target = desired[target_dim - i - 1];
auto size = shape[ndim - i - 1];
auto target = desired[target_dim - i - 1];
if (size != target && size != 1) {
return false;
}
}
return true;
}

static inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) {
auto sym_shape = c10::SymIntArrayRef(
reinterpret_cast<const c10::SymInt*>(shape.data()), shape.size());
auto sym_desired = c10::SymIntArrayRef(
reinterpret_cast<const c10::SymInt*>(desired.data()), desired.size());
return is_expandable_to(sym_shape, sym_desired);
}

} // namespace at
8 changes: 8 additions & 0 deletions aten/src/ATen/FunctionalInverses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,14 @@ Tensor FunctionalInverses::view_copy_inverse(const Tensor& base, const Tensor& m
}
}

Tensor FunctionalInverses::view_copy_SymInt_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, c10::SymIntArrayRef size) {
if (reapply_views) {
return mutated_view.view_symint(base.sym_sizes());
} else {
return at::view_copy_symint(mutated_view, base.sym_sizes());
}
}

Tensor FunctionalInverses::view_copy_dtype_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::ScalarType dtype) {
if (reapply_views) {
return mutated_view.view(base.scalar_type());
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/NamedRegistrations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ TORCH_LIBRARY_IMPL(aten, Named, m) {
m.impl("sum.IntList_out", CppFunction::makeFallthrough());
m.impl("sum.dim_DimnameList", CppFunction::makeFallthrough());
m.impl("sum.dim_IntList", CppFunction::makeFallthrough());
m.impl("sum.SymInt", CppFunction::makeFallthrough());
m.impl("t", CppFunction::makeFallthrough());
m.impl("tan", CppFunction::makeFallthrough());
m.impl("tan.out", CppFunction::makeFallthrough());
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/ReduceOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,10 @@ Tensor sum(const Tensor& self, DimnameList dim, bool keepdim, c10::optional<Scal
return at::sum(self, dimnames_to_positions(self, dim), keepdim, dtype);
}

Tensor sum_symint(const Tensor& input_t, c10::SymIntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
return at::sum(input_t, c10::asIntArrayRefSlow(dim), keepdim, opt_dtype);
}

Tensor& sum_out(const Tensor& self, DimnameList dim,
bool keepdim, optional<ScalarType> opt_dtype, Tensor& result) {
return at::sum_out(result, self, dimnames_to_positions(self, dim), keepdim, opt_dtype);
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,14 @@ Tensor zeros(IntArrayRef size,
return result.zero_();
}

Tensor zeros_symint(c10::SymIntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
return zeros(asIntArrayRefSlow(size), dtype, layout, device, pin_memory);
}

Tensor _efficientzerotensor(IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
Expand Down
6 changes: 6 additions & 0 deletions aten/src/ATen/native/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <algorithm>
#include <cstdint>
#include <vector>
#include <c10/util/StringUtil.h>

namespace at {
namespace meta {
Expand Down Expand Up @@ -3105,6 +3106,11 @@ Tensor view(const Tensor& self,
return view_impl(self, size);
}

Tensor view_symint(const Tensor& self,
c10::SymIntArrayRef size) {
return self.view(c10::asIntArrayRefSlow(size));
}

Tensor alias(const Tensor& self) {
return alias_with_sizes_and_strides(self, self.sizes(), self.strides());
}
Expand Down
13 changes: 13 additions & 0 deletions aten/src/ATen/native/mkldnn/TensorShape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <ATen/Config.h>
#include <ATen/InferSize.h>
#include <ATen/NativeFunctions.h>
#include <c10/core/SymIntArrayRef.h>

#if !AT_MKLDNN_ENABLED()

Expand Down Expand Up @@ -86,3 +87,15 @@ Tensor& mkldnn_transpose_(Tensor& self, int64_t dim0, int64_t dim1) {
} // namespace at

#endif // AT_MKLDNN_ENABLED


namespace at {
namespace native {


Tensor mkldnn_view_symint(const Tensor& self, c10::SymIntArrayRef size) {
return mkldnn_view(self, c10::asIntArrayRefSlow(size));
}

} // namespace native
} // namespace at
3 changes: 3 additions & 0 deletions aten/src/ATen/native/mkldnn/TensorShape.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
#pragma once

#include <ATen/ATen.h>
#include <c10/core/SymIntArrayRef.h>

namespace at {
namespace native {

Tensor mkldnn_view(const Tensor& self, IntArrayRef size);

Tensor mkldnn_view_symint(const Tensor& self, c10::SymIntArrayRef size);

Tensor mkldnn_clone(const Tensor& self);

} // namespace native
Expand Down
23 changes: 23 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4588,6 +4588,12 @@
CompositeExplicitAutograd: sum
SparseCsrCPU, SparseCsrCUDA: sum_csr

- func: sum.SymInt(Tensor self, SymInt[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
device_check: NoCheck # TensorIterator
variants: function, method
dispatch:
CompositeExplicitAutograd: sum_symint

- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
structured_delegate: sum.IntList_out
device_check: NoCheck # TensorIterator
Expand Down Expand Up @@ -5197,6 +5203,8 @@

- func: zeros(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

- func: zeros.SymInt(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor

- func: zeros.out(int[] size, *, Tensor(a!) out) -> Tensor(a!)

- func: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
Expand Down Expand Up @@ -6448,6 +6456,14 @@
CUDA: masked_softmax_backward_cuda
CPU: masked_softmax_backward_cpu

- func: view.SymInt(Tensor(a) self, SymInt[] size) -> Tensor(a)
variants: method
device_check: NoCheck
device_guard: False
dispatch:
CompositeExplicitAutograd: view_symint
MkldnnCPU: mkldnn_view_symint

- func: view(Tensor(a) self, int[] size) -> Tensor(a)
variants: method
device_check: NoCheck
Expand Down Expand Up @@ -12335,6 +12351,13 @@
CompositeExplicitAutograd: _neg_view_copy_out


- func: view_copy.SymInt(Tensor self, SymInt[] size) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: view_copy_SymInt
tags: view_copy


- func: as_strided_copy.out(Tensor self, int[] size, int[] stride, int? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)
variants: function
dispatch:
Expand Down
56 changes: 52 additions & 4 deletions c10/core/SymInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,56 @@ SymInt SymInt::operator+(SymInt sci) const {
return SymInt(data_ + sci.data_);
}

bool SymInt::operator!=(SymInt sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return data_ != sci.data_;
}
// TODO: This is way to much boilerplate
std::shared_ptr<SymbolicIntNode> a =
is_symbolic() ? toSymbolicIntNode() : nullptr;
std::shared_ptr<SymbolicIntNode> b =
sci.is_symbolic() ? sci.toSymbolicIntNode() : nullptr;

SymbolicIntNode* common = a ? a.get() : b.get();
// TODO: technically we need to check that the classes match
if (!a) {
a = common->wrap(data_);
toSymInt(a); //
}
if (!b) {
b = common->wrap(sci.data_);
toSymInt(b);
}

auto c = a->ne(b);
return c->bool_();
}

bool SymInt::operator==(SymInt sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return data_ == sci.data_;
}
// TODO: This is way to much boilerplate
std::shared_ptr<SymbolicIntNode> a =
is_symbolic() ? toSymbolicIntNode() : nullptr;
std::shared_ptr<SymbolicIntNode> b =
sci.is_symbolic() ? sci.toSymbolicIntNode() : nullptr;

SymbolicIntNode* common = a ? a.get() : b.get();
// TODO: technically we need to check that the classes match
if (!a) {
a = common->wrap(data_);
toSymInt(a); //
}
if (!b) {
b = common->wrap(sci.data_);
toSymInt(b);
}

auto c = a->eq(b);
return c->bool_();
}

SymInt SymInt::operator*(SymInt sci) const {
if (!is_symbolic() && !sci.is_symbolic()) {
return SymInt(data_ * sci.data_);
Expand Down Expand Up @@ -68,13 +118,11 @@ bool SymInt::operator<(int64_t sci) const {
}

bool SymInt::operator==(int64_t sci) const {
TORCH_CHECK(!this->is_symbolic(), "Symbolic eq isn't supported yet");
return data_ == sci;
return *this == c10::SymInt(sci);
}

bool SymInt::operator!=(int64_t sci) const {
TORCH_CHECK(!this->is_symbolic(), "Symbolic neq isn't supported yet");
return data_ != sci;
return *this != c10::SymInt(sci);
}

SymInt SymInt::operator*(int64_t sci) const {
Expand Down
10 changes: 2 additions & 8 deletions c10/core/SymInt.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,10 @@ class C10_API SymInt {
return (MASK & static_cast<uint64_t>(this->data_)) == IS_SYM;
}

bool operator==(const SymInt& p2) const {
return data_ == p2.data_;
}

bool operator!=(const SymInt& p2) const {
return data_ != p2.data_;
}

SymInt operator+(SymInt sci) const;
SymInt operator*(SymInt sci) const;
bool operator==(SymInt sci) const;
bool operator!=(SymInt p2) const;
bool operator<(SymInt sci) const;
void operator*=(SymInt sci);

Expand Down
4 changes: 0 additions & 4 deletions c10/core/SymIntArrayRef.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,4 @@ std::ostream& operator<<(std::ostream& os, SymInt s) {
return os;
}

std::ostream& operator<<(std::ostream& out, const c10::SymIntArrayRef& list) {
return out << list.wrapped_symint_array_ref;
}

} // namespace c10
Loading

0 comments on commit d2c47d5

Please sign in to comment.