Skip to content

Commit

Permalink
Add allow_cutlass_sm90 and force_cutlass_sm90 to lowering config (#967)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #967

allow_cutlass_sm90: generate cutlass sm90 kernels along with sm80 kernels

force_cutlass_sm90: only generate cutlass sm90 kernels

Reviewed By: frank-wei, aakhundov

Differential Revision: D51140079

fbshipit-source-id: d0cd9460fd9bd03f6606eba3789cd2c47cb5cf22
  • Loading branch information
chenyang78 authored and facebook-github-bot committed Nov 10, 2023
1 parent cb2dcf9 commit 992e1a0
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 10 deletions.
8 changes: 8 additions & 0 deletions fx2ait/fx2ait/fx2ait.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(
use_tanh_for_sigmoid: bool = False,
profile_timeout: int = 500,
optimize_for_compilation_time: bool = False,
allow_cutlass_sm90: bool = False,
force_cutlass_sm90: bool = False,
):
"""
Args:
Expand All @@ -93,6 +95,8 @@ def __init__(
save_remote_cache: whether to save the updated cache
use_fast_math: whether to use fast math in CUDA kernels
use_tanh_for_sigmoid: whether to use tanh to approximate sigmoid in CUDA kernels
allow_cutlass_sm90: generate cutlass sm90 kernels alongside sm80 kernels on sm90 arch
force_cutlass_sm90: only generate cutlass sm90 kernels on sm90 arch
profile_timeout: timeout in seconds for AIT profilers to complete
optimize_for_compilation_time: we use O1 and disable the ProfileImpl function to reduce compilation time.
"""
Expand All @@ -119,6 +123,8 @@ def __init__(
self.use_fp16_acc = use_fp16_acc
self.use_fast_math = use_fast_math
self.use_tanh_for_sigmoid = use_tanh_for_sigmoid
self.allow_cutlass_sm90 = allow_cutlass_sm90
self.force_cutlass_sm90 = force_cutlass_sm90
self.optimize_for_compilation_time = optimize_for_compilation_time
self.hardware_target = self._create_target()
self.input_specs = input_specs
Expand Down Expand Up @@ -149,6 +155,8 @@ def _create_target(self):
remote_cache_bytes=self.remote_cache_bytes,
use_fast_math=self.use_fast_math,
use_tanh_for_sigmoid=self.use_tanh_for_sigmoid,
allow_cutlass_sm90=self.allow_cutlass_sm90,
force_cutlass_sm90=self.force_cutlass_sm90,
optimize_for_compilation_time=self.optimize_for_compilation_time,
)

Expand Down
4 changes: 4 additions & 0 deletions fx2ait/fx2ait/lower/lower_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,7 @@ class LowerSettings:
optimize_for_compilation_time: bool = False
# If True, use tanh to approximate sigmoid in CUDA kernels
use_tanh_for_sigmoid: bool = False
# generate cutlass sm90 kernels alongside sm80 kernels on sm90 arch
allow_cutlass_sm90: bool = False
# only generate cutlass sm90 kernels on sm90 arch
force_cutlass_sm90: bool = False
12 changes: 11 additions & 1 deletion python/aitemplate/backend/cuda/target_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,17 @@ def __enter__(self):
super().__enter__()
self._gen_cutlass_lib_pkg()
f_gen_ops = registry.get("cuda.gen_cutlass_ops")
self._operators = f_gen_ops(self._arch, self._cuda_version)
allow_cutlass_sm90 = (
self._kwargs.get("allow_cutlass_sm90", False)
or environ.allow_cutlass_sm90_kernels()
)
force_cutlass_sm90 = (
self._kwargs.get("force_cutlass_sm90", False)
or environ.force_cutlass_sm90_kernels()
)
self._operators = f_gen_ops(
self._arch, self._cuda_version, allow_cutlass_sm90, force_cutlass_sm90
)

def __exit__(self, ptype, value, trace):
super().__exit__(ptype, value, trace)
Expand Down
15 changes: 8 additions & 7 deletions python/aitemplate/backend/cuda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@
import logging

from aitemplate.backend import registry
from aitemplate.utils.environ import (
allow_cutlass_sm90_kernels,
force_cutlass_sm90_kernels,
)
from aitemplate.utils.mk_cutlass_lib.mk_cutlass_lib import mk_cutlass_lib

# pylint: disable=C0103,C0415,W0707
Expand Down Expand Up @@ -51,7 +47,12 @@ def __init__(self, arch):


@registry.reg("cuda.gen_cutlass_ops")
def gen_ops(arch, cuda_version):
def gen_ops(
arch,
cuda_version,
allow_cutlass_sm90,
force_cutlass_sm90,
):
import cutlass_lib

args = Args(arch)
Expand All @@ -60,9 +61,9 @@ def gen_ops(arch, cuda_version):
manifest = cutlass_lib.manifest.Manifest(args)

if arch == "90":
if force_cutlass_sm90_kernels():
if force_cutlass_sm90:
cutlass_lib.generator.GenerateSM90(manifest, args.cuda_version)
elif allow_cutlass_sm90_kernels():
elif allow_cutlass_sm90:
cutlass_lib.generator.GenerateSM90(manifest, args.cuda_version)
cutlass_lib.generator.GenerateSM80(manifest, args.cuda_version)
cutlass_lib.extra_operation.GenerateSM80(manifest, args)
Expand Down
28 changes: 26 additions & 2 deletions tests/unittest/ops/test_gemm_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,12 @@ def __init__(self, *args, **kwargs):
super(GEMMBiasTestCase, self).__init__(*args, **kwargs)
self._test_id = 0

def _test_rcr(self, Ms, N, K, test_name, dtype="float16"):
target = detect_target()
def _test_rcr(
self, Ms, N, K, test_name, dtype="float16", allow_sm90=False, force_sm90=False
):
target = detect_target(
allow_cutlass_sm90=allow_sm90, force_cutlass_sm90=force_sm90
)
tolerance_limits = _TOLERANCE_LIMITS[dtype]
MDim = shape_utils.gen_int_var_min_max(Ms, name="m")
X = Tensor(shape=[MDim, IntImm(K)], dtype=dtype, name="input_0", is_input=True)
Expand Down Expand Up @@ -108,6 +112,26 @@ def test_rcr_bfloat16_bf16(self):
)

def test_rcr_sm90(self) -> None:
with env_variables(
INSIDE_RE_WORKER="1",
FORCE_PROFILE="1",
):
self._test_rcr(
Ms=[128],
N=32,
K=32,
test_name="target_fp16_allow_sm90",
dtype="float16",
allow_sm90=True,
)
self._test_rcr(
Ms=[128],
N=32,
K=32,
test_name="target_fp16_force_sm90",
dtype="float16",
force_sm90=True,
)
with env_variables(
AIT_FORCE_CUTLASS_SM90_KERNELS="1",
INSIDE_RE_WORKER="1",
Expand Down

0 comments on commit 992e1a0

Please sign in to comment.