Skip to content

Commit 6f3f8c7

Browse files
Replace pad op implementation in conv kernel with torch implementation (#336)
Signed-off-by: nithinsubbiah <[email protected]>
1 parent 1c3ed4b commit 6f3f8c7

File tree

3 files changed

+9
-28
lines changed

3 files changed

+9
-28
lines changed

sharktank/sharktank/ops/qconv_impls.py

+2-22
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import warnings
1313

1414
import torch
15+
import torch.nn.functional as F
1516

1617
from sharktank import kernels
1718

@@ -119,7 +120,7 @@ def qconv2d_tensor_scaled(
119120
padding = _expand_int_to_2_tuple(padding)
120121
dilation = _expand_int_to_2_tuple(dilation)
121122
extended_padding_list = [item for item in padding for _ in range(2)]
122-
padded_input = _pad_last_2d(input_qs, extended_padding_list)
123+
padded_input = F.pad(input_qs, pad=extended_padding_list)
123124
y_qs = _invoke_conv2d_kernel(
124125
padded_input,
125126
weight_qs,
@@ -258,27 +259,6 @@ def _invoke_pooling_sum_kernel(input, kernel_size, stride, dilation, *, accum_dt
258259
return output
259260

260261

261-
def _pad_last_2d(input_tensor, pad_width):
262-
# pad_width should be in the format [pad_left, pad_right, pad_top, pad_bottom]
263-
pad_left, pad_right, pad_top, pad_bottom = pad_width
264-
batch_size, channels, height, width = input_tensor.shape
265-
266-
# Create a new tensor with the desired padded size filled with zeros
267-
padded_height = height + pad_top + pad_bottom
268-
padded_width = width + pad_left + pad_right
269-
padded_tensor = torch.zeros(
270-
(batch_size, channels, padded_height, padded_width),
271-
dtype=input_tensor.dtype,
272-
device=input_tensor.device,
273-
)
274-
275-
# Copy the values from the input tensor to the appropriate location in the padded tensor
276-
padded_tensor[
277-
:, :, pad_top : pad_top + height, pad_left : pad_left + width
278-
] = input_tensor
279-
return padded_tensor
280-
281-
282262
def _flatten_input_scale_offset_channels(d, m):
283263
"""Flattens either a 4d or 0d scale/offset as [N, C, H, W] to 1D.
284264

sharktank/tests/kernels/conv_2d_nchw_fchw_test.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
from parameterized import parameterized
1313

1414
import torch
15+
import torch.nn.functional as F
1516

1617
from iree.turbine import aot
1718
from sharktank import kernels
18-
from sharktank.ops.qconv_impls import _pad_last_2d
1919

2020

2121
class conv_2d_nchw_fchw_test(unittest.TestCase):
@@ -36,7 +36,8 @@ def testBS32(self, input_dtype, output_dtype_name, atol, rtol):
3636
inputs = (torch.rand([2, 4, 64, 64]) * 64).to(input_dtype)
3737
padding = [1, 1]
3838
extended_list = [item for item in padding for _ in range(2)]
39-
inputs_pad = _pad_last_2d(inputs, extended_list)
39+
inputs_pad = F.pad(inputs, pad=extended_list)
40+
4041
weights = (torch.rand([8, 4, 3, 3]) * 64).to(input_dtype)
4142
bias = (torch.rand([8]) * 64).to(dtype=output_dtype)
4243
result = kernels.conv_2d_nchw_fchw(
@@ -68,7 +69,7 @@ def forward(self, a, b, c):
6869
inputs = torch.rand([2, 320, 64, 64]) * 64
6970
padding = [1, 1]
7071
extended_list = [item for item in padding for _ in range(2)]
71-
inputs_pad = _pad_last_2d(inputs, extended_list)
72+
inputs_pad = F.pad(inputs, pad=extended_list)
7273
ep = torch.export.export(
7374
mod,
7475
args=(

sharktank/tests/kernels/pooling_nchw_sum_test.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212
from parameterized import parameterized
1313

1414
import torch
15+
import torch.nn.functional as F
1516

1617
from iree.turbine import aot
1718
from sharktank import kernels
18-
from sharktank.ops.qconv_impls import _pad_last_2d
1919

2020

2121
class pooling_nchw_sum_test(unittest.TestCase):
@@ -34,7 +34,7 @@ def testBS32(self, atol, rtol):
3434
a = (torch.randint(0, 100, (2, 1, 128, 128))).to(torch.float32)
3535
padding = [1, 1]
3636
extended_list = [item for item in padding for _ in range(2)]
37-
inputs_pad = _pad_last_2d(a, extended_list)
37+
inputs_pad = F.pad(a, pad=extended_list)
3838
weight_shape = [3, 3]
3939
stride = [1, 1]
4040
dilations = [1, 1]
@@ -62,7 +62,7 @@ def forward(self, a):
6262
inputs = torch.rand([2, 1, 128, 128]) * 64
6363
padding = [1, 1]
6464
extended_list = [item for item in padding for _ in range(2)]
65-
inputs_pad = _pad_last_2d(inputs, extended_list)
65+
inputs_pad = F.pad(inputs, pad=extended_list)
6666
ep = torch.export.export(
6767
mod,
6868
args=((inputs_pad).to(dtype),),

0 commit comments

Comments
 (0)