From b7d7e02705deb1dc4942bf39efc19f133e2181f7 Mon Sep 17 00:00:00 2001 From: wms2537 Date: Sun, 28 Nov 2021 21:53:35 +0800 Subject: [PATCH] fix slice_axis --- python/mxnet/gluon/nn/conv_layers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index 65e22d82eded..2cd56c9c3edd 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -1647,8 +1647,8 @@ def forward(self, x): offset = npx.convolution(x, self.offset_weight.data(ctx), self.offset_bias.data(ctx), cudnn_off=True, **self._kwargs_offset) - offset_t = npx.slice_axis(offset, axis=1, begin=0, end=self.offset_split_index) - mask = npx.slice_axis(offset, axis=1, begin=self.offset_split_index, end=None) + offset_t = offset[:,0:self.offset_split_index,:, :] + mask = offset[:,self.offset_split_index:,:, :] mask = npx.sigmoid(mask) * 2 if self.deformable_conv_bias is None: