From d07ecad3e8ee1dc6fcc629f3d17802b672fa2f3e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 23 Oct 2023 10:41:56 +0100 Subject: [PATCH 1/4] Add opcheck, add partial meta implem --- test/test_ops.py | 13 +++++++++ torchvision/_meta_registrations.py | 45 ++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) diff --git a/test/test_ops.py b/test/test_ops.py index 6d80f037b88..1e4bc9ae473 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -980,6 +980,7 @@ def test_is_leaf_node(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("batch_sz", (0, 33)) + @pytest.mark.opcheck_only_one() def test_forward(self, device, contiguous, batch_sz, dtype=None): dtype = dtype or self.dtype x, _, offset, mask, _, stride, padding, dilation = self.get_fn_args(device, contiguous, batch_sz, dtype) @@ -1032,6 +1033,7 @@ def test_wrong_sizes(self): @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("contiguous", (True, False)) @pytest.mark.parametrize("batch_sz", (0, 33)) + @pytest.mark.opcheck_only_one() def test_backward(self, device, contiguous, batch_sz): x, weight, offset, mask, bias, stride, padding, dilation = self.get_fn_args( device, contiguous, batch_sz, self.dtype @@ -1081,6 +1083,7 @@ def script_func_no_mask(x_, offset_, weight_, bias_, stride_, pad_, dilation_): @needs_cuda @pytest.mark.parametrize("contiguous", (True, False)) + @pytest.mark.opcheck_only_one() def test_compare_cpu_cuda_grads(self, contiguous): # Test from https://github.com/pytorch/vision/issues/2598 # Run on CUDA only @@ -1115,6 +1118,7 @@ def test_compare_cpu_cuda_grads(self, contiguous): @needs_cuda @pytest.mark.parametrize("batch_sz", (0, 33)) @pytest.mark.parametrize("dtype", (torch.float, torch.half)) + @pytest.mark.opcheck_only_one() def test_autocast(self, batch_sz, dtype): with torch.cuda.amp.autocast(): self.test_forward(torch.device("cuda"), contiguous=False, batch_sz=batch_sz, dtype=dtype) @@ -1124,6 +1128,15 @@ def test_forward_scriptability(self): torch.jit.script(ops.DeformConv2d(in_channels=8, out_channels=8, kernel_size=3)) +optests.generate_opcheck_tests( + testcase=TestDeformConv, + namespaces=["torchvision"], + failures_dict_path=os.path.join(os.path.dirname(__file__), "optests_failures_dict.json"), + additional_decorators=[], + test_utils=OPTESTS, +) + + class TestFrozenBNT: def test_frozenbatchnorm2d_repr(self): num_features = 32 diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index 7baece2ae2c..e464dd23234 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -51,6 +51,51 @@ def meta_roi_align_backward( return grad.new_empty((batch_size, channels, height, width)) +@register_meta("deform_conv2d") +def meta_deform_conv2d( + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dil_h, + dil_w, + n_weight_grps, + n_offset_grps, + use_mask, +): + + out_height, out_width = offset.shape[-2:] + out_channels = weight.shape[0] + batch_size = input.shape[0] + return input.new_empty((batch_size, out_channels, out_height, out_width)) + + +@register_meta("_deform_conv2d_backward") +def meta_deform_conv2d_backward( + grad, + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask, +): + return None # TODO + + @torch._custom_ops.impl_abstract("torchvision::nms") def meta_nms(dets, scores, iou_threshold): torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") From 74994d09d57a960a51392f2286be098e243f3bd9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Oct 2023 14:17:22 +0100 Subject: [PATCH 2/4] Add Meta and symint support --- torchvision/_meta_registrations.py | 8 +- .../ops/autograd/deform_conv2d_kernel.cpp | 86 +++++++++---------- torchvision/csrc/ops/deform_conv2d.cpp | 79 ++++++++++++++++- torchvision/csrc/ops/deform_conv2d.h | 35 ++++++++ 4 files changed, 162 insertions(+), 46 deletions(-) diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index e464dd23234..f86c14d1334 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -93,7 +93,13 @@ def meta_deform_conv2d_backward( offset_groups, use_mask, ): - return None # TODO + + grad_input = input.new_empty(input.shape) + grad_weight = input.new_empty(weight.shape) + grad_offset = input.new_empty(offset.shape) + grad_mask = input.new_empty(mask.shape) + grad_bias = input.new_empty(bias.shape) + return grad_input, grad_weight, grad_offset, grad_mask, grad_bias @torch._custom_ops.impl_abstract("torchvision::nms") diff --git a/torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp b/torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp index 801afb6a9bc..0a7bbf9014e 100644 --- a/torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp +++ b/torchvision/csrc/ops/autograd/deform_conv2d_kernel.cpp @@ -18,17 +18,17 @@ class DeformConv2dFunction const torch::autograd::Variable& offset, const torch::autograd::Variable& mask, const torch::autograd::Variable& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, bool use_mask) { at::AutoDispatchBelowADInplaceOrView g; - auto output = deform_conv2d( + auto output = deform_conv2d_symint( input, weight, offset, @@ -70,17 +70,17 @@ class DeformConv2dFunction auto mask = saved[3]; auto bias = saved[4]; - auto stride_h = ctx->saved_data["stride_h"].toInt(); - auto stride_w = ctx->saved_data["stride_w"].toInt(); - auto pad_h = ctx->saved_data["pad_h"].toInt(); - auto pad_w = ctx->saved_data["pad_w"].toInt(); - auto dilation_h = ctx->saved_data["dilation_h"].toInt(); - auto dilation_w = ctx->saved_data["dilation_w"].toInt(); - auto groups = ctx->saved_data["groups"].toInt(); - auto offset_groups = ctx->saved_data["offset_groups"].toInt(); + auto stride_h = ctx->saved_data["stride_h"].toSymInt(); + auto stride_w = ctx->saved_data["stride_w"].toSymInt(); + auto pad_h = ctx->saved_data["pad_h"].toSymInt(); + auto pad_w = ctx->saved_data["pad_w"].toSymInt(); + auto dilation_h = ctx->saved_data["dilation_h"].toSymInt(); + auto dilation_w = ctx->saved_data["dilation_w"].toSymInt(); + auto groups = ctx->saved_data["groups"].toSymInt(); + auto offset_groups = ctx->saved_data["offset_groups"].toSymInt(); auto use_mask = ctx->saved_data["use_mask"].toBool(); - auto grads = detail::_deform_conv2d_backward( + auto grads = detail::_deform_conv2d_backward_symint( grad_output[0], input, weight, @@ -133,17 +133,17 @@ class DeformConv2dBackwardFunction const torch::autograd::Variable& offset, const torch::autograd::Variable& mask, const torch::autograd::Variable& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, bool use_mask) { at::AutoDispatchBelowADInplaceOrView g; - auto result = detail::_deform_conv2d_backward( + auto result = detail::_deform_conv2d_backward_symint( grad, input, weight, @@ -188,14 +188,14 @@ at::Tensor deform_conv2d_autograd( const at::Tensor& offset, const at::Tensor& mask, const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, bool use_mask) { return DeformConv2dFunction::apply( input, @@ -222,14 +222,14 @@ deform_conv2d_backward_autograd( const at::Tensor& offset, const at::Tensor& mask, const at::Tensor& bias, - int64_t stride_h, - int64_t stride_w, - int64_t pad_h, - int64_t pad_w, - int64_t dilation_h, - int64_t dilation_w, - int64_t groups, - int64_t offset_groups, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, bool use_mask) { auto result = DeformConv2dBackwardFunction::apply( grad, diff --git a/torchvision/csrc/ops/deform_conv2d.cpp b/torchvision/csrc/ops/deform_conv2d.cpp index d8f2c9b6ff4..3cda60fe0bc 100644 --- a/torchvision/csrc/ops/deform_conv2d.cpp +++ b/torchvision/csrc/ops/deform_conv2d.cpp @@ -43,6 +43,42 @@ at::Tensor deform_conv2d( use_mask); } +at::Tensor deform_conv2d_symint( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.deform_conv2d.deform_conv2d"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::deform_conv2d", "") + .typed(); + return op.call( + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); +} + namespace detail { std::tuple @@ -84,13 +120,52 @@ _deform_conv2d_backward( use_mask); } +std::tuple +_deform_conv2d_backward_symint( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_deform_conv2d_backward", "") + .typed(); + return op.call( + grad, + input, + weight, + offset, + mask, + bias, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + groups, + offset_groups, + use_mask); +} + } // namespace detail TORCH_LIBRARY_FRAGMENT(torchvision, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> Tensor")); + "torchvision::deform_conv2d(Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> Tensor")); m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, int stride_h, int stride_w, int pad_h, int pad_w, int dilation_h, int dilation_w, int groups, int offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)")); + "torchvision::_deform_conv2d_backward(Tensor grad, Tensor input, Tensor weight, Tensor offset, Tensor mask, Tensor bias, SymInt stride_h, SymInt stride_w, SymInt pad_h, SymInt pad_w, SymInt dilation_h, SymInt dilation_w, SymInt groups, SymInt offset_groups, bool use_mask) -> (Tensor, Tensor, Tensor, Tensor, Tensor)")); } } // namespace ops diff --git a/torchvision/csrc/ops/deform_conv2d.h b/torchvision/csrc/ops/deform_conv2d.h index a35be02aac8..f4651e79fdb 100644 --- a/torchvision/csrc/ops/deform_conv2d.h +++ b/torchvision/csrc/ops/deform_conv2d.h @@ -22,6 +22,23 @@ VISION_API at::Tensor deform_conv2d( int64_t offset_groups, bool use_mask); +VISION_API at::Tensor deform_conv2d_symint( + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask); + + namespace detail { std::tuple @@ -42,6 +59,24 @@ _deform_conv2d_backward( int64_t offset_groups, bool use_mask); +std::tuple +_deform_conv2d_backward_symint( + const at::Tensor& grad, + const at::Tensor& input, + const at::Tensor& weight, + const at::Tensor& offset, + const at::Tensor& mask, + const at::Tensor& bias, + c10::SymInt stride_h, + c10::SymInt stride_w, + c10::SymInt pad_h, + c10::SymInt pad_w, + c10::SymInt dilation_h, + c10::SymInt dilation_w, + c10::SymInt groups, + c10::SymInt offset_groups, + bool use_mask); + } // namespace detail } // namespace ops From 52ab57c4b224cec8c40f8a105b0218176611deb1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Oct 2023 14:18:37 +0100 Subject: [PATCH 3/4] Fix meta --- torchvision/_meta_registrations.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index f86c14d1334..ecac29e4a56 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -95,10 +95,10 @@ def meta_deform_conv2d_backward( ): grad_input = input.new_empty(input.shape) - grad_weight = input.new_empty(weight.shape) - grad_offset = input.new_empty(offset.shape) - grad_mask = input.new_empty(mask.shape) - grad_bias = input.new_empty(bias.shape) + grad_weight = weight.new_empty(weight.shape) + grad_offset = offset.new_empty(offset.shape) + grad_mask = mask.new_empty(mask.shape) + grad_bias = bias.new_empty(bias.shape) return grad_input, grad_weight, grad_offset, grad_mask, grad_bias From ce7dab22aaa6578e092ba4258dce064aa2f1abf0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 25 Oct 2023 14:25:47 +0100 Subject: [PATCH 4/4] fix lint --- torchvision/csrc/ops/deform_conv2d.h | 1 - 1 file changed, 1 deletion(-) diff --git a/torchvision/csrc/ops/deform_conv2d.h b/torchvision/csrc/ops/deform_conv2d.h index f4651e79fdb..cf1f142e648 100644 --- a/torchvision/csrc/ops/deform_conv2d.h +++ b/torchvision/csrc/ops/deform_conv2d.h @@ -38,7 +38,6 @@ VISION_API at::Tensor deform_conv2d_symint( c10::SymInt offset_groups, bool use_mask); - namespace detail { std::tuple