Skip to content

Commit

Permalink
Generalize padding logic in conv2d op (#966)
Browse files Browse the repository at this point in the history
Summary:

Apply padding in Conv2d based on dtype. Padding to multiple of 8bytes.

Differential Revision: D51139833
  • Loading branch information
henryhu6 authored and facebook-github-bot committed Nov 10, 2023
1 parent 992e1a0 commit 9c3e450
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 33 deletions.
30 changes: 19 additions & 11 deletions fx2ait/fx2ait/converters/ait_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@

logger: logging.Logger = logging.getLogger(__name__)
ConverterOutput = Union[AITTensor, Tuple[AITTensor, ...], List[IntVar], IntVar]
REQUIRED_ALIGNMENT = 8


@ait_converter(acc_ops.sigmoid)
Expand Down Expand Up @@ -1289,18 +1290,25 @@ def _choose_conv2d_op(
else:
return transposed_conv2d(stride=stride, pad=pad, dilate=dilate)(x, weight)
last_dim = x._attrs["shape"][-1]._attrs["values"][0]
# CUDA conv channel dim weights need to align w/ a multiple of 2/4/8
# if CI < 4, pad to 4; if 5 < CI < 8, pad to 8;
if last_dim < 4:
weight = pad_last_dim(len(weight._attrs["shape"]), 4)(weight)
x = pad_last_dim(len(x._attrs["shape"]), 4)(x)
elif last_dim > 4 and last_dim < 8:
weight = pad_last_dim(len(weight._attrs["shape"]), 8)(weight)
x = pad_last_dim(len(x._attrs["shape"]), 8)(x)
elif last_dim % 2 != 0:
raise RuntimeError(
f"Conv2d is not implemented for input channel dim {last_dim}: it needs to be aligned to a multiple of 2/4/8"
dtype = x._attrs["dtype"]
if dtype == "float16":
dtype_bytes = 2
elif dtype == "bfloat16":
dtype_bytes = 2
elif dtype == "float32":
dtype_bytes = 4
else:
raise NotImplementedError(f"Unsupported dtype: {dtype}")
last_dim_bytes = last_dim * dtype_bytes
# CUDA conv channel dim weight need to align w/ a multiple of REQUIRED_ALIGNMENT bytes.
if last_dim_bytes % REQUIRED_ALIGNMENT != 0:
new_dim_size = int(
((last_dim_bytes // REQUIRED_ALIGNMENT) + 1)
* REQUIRED_ALIGNMENT
/ dtype_bytes
)
weight = pad_last_dim(len(weight._attrs["shape"]), new_dim_size)(weight)
x = pad_last_dim(len(x._attrs["shape"]), new_dim_size)(x)
if bias:
return conv2d_bias(stride=stride, pad=pad, dilate=dilate)(x, weight, bias)
else:
Expand Down
70 changes: 52 additions & 18 deletions fx2ait/fx2ait/test/converters/test_ait_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,28 @@
#
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import logging

import torch
from aitemplate.testing.test_utils import filter_test_cases_by_params, TestEnv
from aitemplate.utils.torch_utils import string_to_torch_dtype
from fx2ait.acc_tracer import acc_ops
from fx2ait.tools.common_fx2ait import AITTestCase
from parameterized import param, parameterized
from fx2ait.tools.common_fx2ait import AITTestCase, torch_type_to_lower_precision
from parameterized import parameterized


class TestConv2dConverter(AITTestCase):
@parameterized.expand(
[
param("default", 1),
param("no_bias", 1, bias=False),
param("tuple_parameters", 1, (1, 1), (1, 1)),
param("non_zero_padding", 1, padding=1),
param("non_unary_params", 3, 2, padding=1, bias=False),
param("dilation", 1, dilation=2),
param("multi_group", 1, 1, 1, 1, 3, bias=True),
param("in_channel_padding_gt_4_lt_8", 1, in_channel=7),
]
)
def test_conv2d(
def _test_conv2d(
self,
name,
test_name,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
in_channel=3,
bias=True,
ait_dtype="float16",
):
class TestModule(torch.nn.Module):
def __init__(self):
Expand All @@ -55,10 +48,51 @@ def __init__(self):
def forward(self, x):
return self.relu(self.conv(x))

model = TestModule().cuda().half()
inputs = [torch.randn(1, in_channel, 224, 224).cuda().half()]
logging.info(f"Running test {test_name}.")

dtype = string_to_torch_dtype(ait_dtype)
model = TestModule().cuda().to(dtype)
inputs = [torch.randn(1, in_channel, 224, 224).cuda().to(dtype)]
self.run_test(
model,
inputs,
expected_ops={acc_ops.conv2d},
precision=torch_type_to_lower_precision(dtype),
)

@parameterized.expand(
**filter_test_cases_by_params(
{
TestEnv.CUDA_LESS_THAN_SM80: [("float16")],
TestEnv.CUDA_SM80: [("bfloat16"), ("float32")],
TestEnv.ROCM: [("float16")],
}
)
)
def test_conv2d(self, ait_dtype):
self._test_conv2d(f"{ait_dtype}_default", 1, ait_dtype=ait_dtype)
self._test_conv2d(f"{ait_dtype}_no_bias", 1, bias=False, ait_dtype=ait_dtype)
self._test_conv2d(
f"{ait_dtype}_tuple_parameters", 1, (1, 1), (1, 1), ait_dtype=ait_dtype
)
self._test_conv2d(
f"{ait_dtype}_non_zero_padding", 1, padding=1, ait_dtype=ait_dtype
)
self._test_conv2d(
f"{ait_dtype}_non_unary_params",
3,
2,
padding=1,
bias=False,
ait_dtype=ait_dtype,
)
self._test_conv2d(f"{ait_dtype}_dilation", 1, dilation=2, ait_dtype=ait_dtype)
self._test_conv2d(
f"{ait_dtype}_multi_group", 1, 1, 1, 1, 3, bias=True, ait_dtype=ait_dtype
)
self._test_conv2d(
f"{ait_dtype}_padding_3", 1, in_channel=3, ait_dtype=ait_dtype
)
self._test_conv2d(
f"{ait_dtype}_padding_7", 1, in_channel=7, ait_dtype=ait_dtype
)
15 changes: 15 additions & 0 deletions fx2ait/fx2ait/tools/common_fx2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ def lower_precision_to_torch_type(
raise ValueError(f"Unsupported precision: {precision}")


def torch_type_to_lower_precision(
dtype: torch.dtype,
) -> LowerPrecision:
if dtype == torch.float16:
return LowerPrecision.FP16
elif dtype == torch.bfloat16:
return LowerPrecision.BF16
elif dtype == torch.float:
return LowerPrecision.FP32
elif dtype == torch.int8:
return LowerPrecision.INT8
else:
raise ValueError(f"Unsupported dtype: {dtype}")


def fetch_attr(mod, target):
"""
Fetch an attribute from the ``Module`` hierarchy of ``mod.module``.
Expand Down
4 changes: 0 additions & 4 deletions python/aitemplate/backend/cuda/conv2d/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,10 +534,6 @@ def extract_config(
elif "bfloat16" in lib_dtype:
data_type = cutlass_lib.library.DataType.bf16
acc_type = cutlass_lib.library.DataType.f32
# check target use fp16 acc
if "use_fp16_acc" in Target.current()._kwargs:
if Target.current()._kwargs["use_fp16_acc"]:
acc_type = cutlass_lib.library.DataType.bf16
else:
raise RuntimeError(f"Unsupported dtype {lib_dtype}")

Expand Down

0 comments on commit 9c3e450

Please sign in to comment.