Skip to content

DeformConv2d: SymInt support + meta-implem + opchecks #8063

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,7 @@ def test_is_leaf_node(self, device):
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.opcheck_only_one()
def test_forward(self, device, contiguous, batch_sz, dtype=None):
dtype = dtype or self.dtype
x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype)
Expand Down Expand Up @@ -1071,6 +1072,7 @@ def test_wrong_sizes(self):
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.opcheck_only_one()
def test_backward(self, device, contiguous, batch_sz):
x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args(
device, contiguous, batch_sz, self.dtype
Expand Down Expand Up @@ -1120,6 +1122,7 @@ def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_):

@needs_cuda
@pytest.mark.parametrize("contiguous", (True, False))
@pytest.mark.opcheck_only_one()
def test_compare_cpu_cuda_grads(self, contiguous):
# Test from https://github.com/pytorch/vision/issues/2598
# Run on CUDA only
Expand Down Expand Up @@ -1154,6 +1157,7 @@ def test_compare_cpu_cuda_grads(self, contiguous):
@needs_cuda
@pytest.mark.parametrize("batch_sz", (0, 33))
@pytest.mark.parametrize("dtype", (torch.float, torch.half))
@pytest.mark.opcheck_only_one()
def test_autocast(self, batch_sz, dtype):
with torch.cuda.amp.autocast():
self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype)
Expand All @@ -1163,6 +1167,15 @@ def test_forward_scriptability(self):
torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3))


optests.generate_opcheck_tests(
testcase=TestDeformConv,
namespaces=["torchvision"],
failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"),
additional_decorators=[],
test_utils=OPTESTS,
)


class TestFrozenBNT:
def test_frozenbatchnorm2d_repr(self):
num_features = 32
Expand Down
51 changes: 51 additions & 0 deletions torchvision/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,54 @@ def meta_nms(dets, scores, iou_threshold):
ctx = torch._custom_ops.get_ctx()
num_to_keep = ctx.create_unbacked_symint()
return dets.new_empty(num_to_keep, dtype=torch.long)


@register_meta("deform_conv2d")
def meta_deform_conv2d(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dil_h,
dil_w,
n_weight_grps,
n_offset_grps,
use_mask,
):

out_height, out_width = offset.shape[-2:]
out_channels = weight.shape[0]
batch_size = input.shape[0]
return input.new_empty((batch_size, out_channels, out_height, out_width))


@register_meta("_deform_conv2d_backward")
def meta_deform_conv2d_backward(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask,
):

grad_input = input.new_empty(input.shape)
grad_weight = weight.new_empty(weight.shape)
grad_offset = offset.new_empty(offset.shape)
grad_mask = mask.new_empty(mask.shape)
grad_bias = bias.new_empty(bias.shape)
return grad_input, grad_weight, grad_offset, grad_mask, grad_bias
86 changes: 43 additions & 43 deletions torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@ class DeformConv2dFunction
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
at::AutoDispatchBelowADInplaceOrView g;
auto output = deform_conv2d(
auto output = deform_conv2d_symint(
input,
weight,
offset,
Expand Down Expand Up @@ -70,17 +70,17 @@ class DeformConv2dFunction
auto mask = saved[3];
auto bias = saved[4];

auto stride_h = ctx->saved_data["stride_h"].toInt();
auto stride_w = ctx->saved_data["stride_w"].toInt();
auto pad_h = ctx->saved_data["pad_h"].toInt();
auto pad_w = ctx->saved_data["pad_w"].toInt();
auto dilation_h = ctx->saved_data["dilation_h"].toInt();
auto dilation_w = ctx->saved_data["dilation_w"].toInt();
auto groups = ctx->saved_data["groups"].toInt();
auto offset_groups = ctx->saved_data["offset_groups"].toInt();
auto stride_h = ctx->saved_data["stride_h"].toSymInt();
auto stride_w = ctx->saved_data["stride_w"].toSymInt();
auto pad_h = ctx->saved_data["pad_h"].toSymInt();
auto pad_w = ctx->saved_data["pad_w"].toSymInt();
auto dilation_h = ctx->saved_data["dilation_h"].toSymInt();
auto dilation_w = ctx->saved_data["dilation_w"].toSymInt();
auto groups = ctx->saved_data["groups"].toSymInt();
auto offset_groups = ctx->saved_data["offset_groups"].toSymInt();
auto use_mask = ctx->saved_data["use_mask"].toBool();

auto grads = detail::_deform_conv2d_backward(
auto grads = detail::_deform_conv2d_backward_symint(
grad_output[0],
input,
weight,
Expand Down Expand Up @@ -133,17 +133,17 @@ class DeformConv2dBackwardFunction
const torch::autograd::Variable& offset,
const torch::autograd::Variable& mask,
const torch::autograd::Variable& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
at::AutoDispatchBelowADInplaceOrView g;
auto result = detail::_deform_conv2d_backward(
auto result = detail::_deform_conv2d_backward_symint(
grad,
input,
weight,
Expand Down Expand Up @@ -188,14 +188,14 @@ at::Tensor deform_conv2d_autograd(
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
return DeformConv2dFunction::apply(
input,
Expand All @@ -222,14 +222,14 @@ deform_conv2d_backward_autograd(
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
int64_t stride_h,
int64_t stride_w,
int64_t pad_h,
int64_t pad_w,
int64_t dilation_h,
int64_t dilation_w,
int64_t groups,
int64_t offset_groups,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
auto result = DeformConv2dBackwardFunction::apply(
grad,
Expand Down
79 changes: 77 additions & 2 deletions torchvision/csrc/ops/deform_conv2d.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,42 @@ at::Tensor deform_conv2d(
use_mask);
}

at::Tensor deform_conv2d_symint(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d");
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::deform_conv2d", "")
.typed<decltype(deform_conv2d_symint)>();
return op.call(
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
}

namespace detail {

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
Expand Down Expand Up @@ -84,13 +120,52 @@ _deform_conv2d_backward(
use_mask);
}

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward_symint(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask) {
static auto op =
c10::Dispatcher::singleton()
.findSchemaOrThrow("torchvision::_deform_conv2d_backward", "")
.typed<decltype(_deform_conv2d_backward_symint)>();
return op.call(
grad,
input,
weight,
offset,
mask,
bias,
stride_h,
stride_w,
pad_h,
pad_w,
dilation_h,
dilation_w,
groups,
offset_groups,
use_mask);
}

} // namespace detail

TORCH_LIBRARY_FRAGMENT(torchvision, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor"));
"torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"));
"torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)"));
}

} // namespace ops
Expand Down
34 changes: 34 additions & 0 deletions torchvision/csrc/ops/deform_conv2d.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ VISION_API at::Tensor deform_conv2d(
int64_t offset_groups,
bool use_mask);

VISION_API at::Tensor deform_conv2d_symint(
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask);

namespace detail {

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
Expand All @@ -42,6 +58,24 @@ _deform_conv2d_backward(
int64_t offset_groups,
bool use_mask);

std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor>
_deform_conv2d_backward_symint(
const at::Tensor& grad,
const at::Tensor& input,
const at::Tensor& weight,
const at::Tensor& offset,
const at::Tensor& mask,
const at::Tensor& bias,
c10::SymInt stride_h,
c10::SymInt stride_w,
c10::SymInt pad_h,
c10::SymInt pad_w,
c10::SymInt dilation_h,
c10::SymInt dilation_w,
c10::SymInt groups,
c10::SymInt offset_groups,
bool use_mask);

} // namespace detail

} // namespace ops
Expand Down