Skip to content

Commit

Permalink
[tuner]: use iree_gpu.MMAIntrinsic and iree_gpu.MMAAttr (nod-ai#605)
Browse files Browse the repository at this point in the history
Remove the data class `MfmaIntrinsic` from the codebase, and use IREE
attributes (` iree_gpu.MMAIntrinsic` and `iree_gpu.MMAAttr` ) for MFMA
intrinsics in the tuner.

**Motivation for this PR**: The original MLIR processing relies heavily
on string-based operations, making it fragile and prone to breaking with
updates to the IREE Compiler. To address this, we aim to leverage key
attributes directly through IREE Python bindings, enabled by exposing
these attributes. For more details, refer to [this
issue](nod-ai#453).
  • Loading branch information
bangtianliu authored Nov 26, 2024
1 parent ddc3091 commit 6bb24a3
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 96 deletions.
13 changes: 6 additions & 7 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def apply_configuration(
expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]")
expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>")
expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"")
repl0 = f"<intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>"
repl0 = f"<intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>"
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},'
repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]'
repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}"
Expand Down Expand Up @@ -119,7 +119,6 @@ def get_transform_function_mmt(

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)

return f"""
transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{
%mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
Expand All @@ -132,7 +131,7 @@ def get_transform_function_mmt(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down Expand Up @@ -205,7 +204,7 @@ def get_transform_function_conv(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down Expand Up @@ -266,7 +265,7 @@ def get_transform_function_broadcast_rhs_mmt(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down Expand Up @@ -346,7 +345,7 @@ def get_transform_function_batch_mmt(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down Expand Up @@ -414,7 +413,7 @@ def get_transform_function_batch_matmul(
translation_info = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down
29 changes: 22 additions & 7 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Generator

from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_gpu # type: ignore

from . import candidate_gen
from . import common
Expand Down Expand Up @@ -45,10 +46,12 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None:

M, N, K = 2048, 1280, 1280

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=16,
workgroup_size=[16, 16, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
intrinsic=mma_attr,
tile_sizes=[8, 8, 8],
subgroup_m_count=16,
subgroup_n_count=16,
Expand Down Expand Up @@ -97,10 +100,12 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:

n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[256, 1, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
intrinsic=mma_attr,
tile_sizes=[464, 320, 16],
subgroup_m_count=1,
subgroup_n_count=4,
Expand Down Expand Up @@ -161,10 +166,12 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.contraction,
)

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[256, 1, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
intrinsic=mma_attr,
tile_sizes=[480, 384, 32],
subgroup_m_count=1,
subgroup_n_count=4,
Expand Down Expand Up @@ -208,10 +215,12 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_matmul,
)

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
intrinsic=mma_attr,
tile_sizes=[416, 320, 128],
subgroup_m_count=2,
subgroup_n_count=2,
Expand Down Expand Up @@ -258,10 +267,12 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_mmt,
)

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
Expand Down Expand Up @@ -306,10 +317,12 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_mmt,
)

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
Expand Down Expand Up @@ -377,10 +390,12 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.broadcast_rhs_mmt,
)

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
Expand Down
72 changes: 11 additions & 61 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,74 +85,24 @@ def MNK(self) -> tuple[int, int, int]:
return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K)


@dataclass
class MfmaIntrinsic:
output_type: ir.IntegerType | ir.FloatType
m: int
n: int
k: int
input_type: ir.IntegerType | ir.FloatType

def __str__(self) -> str:
input = str(self.input_type).upper()
output = str(self.output_type).upper()
return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}"

@staticmethod
def mfma_f32_16x16x16_f16():
f16 = ir.F16Type.get()
f32 = ir.F32Type.get()
return MfmaIntrinsic(f32, 16, 16, 16, f16)

@staticmethod
def mfma_f32_32x32x8_f16():
f16 = ir.F16Type.get()
f32 = ir.F32Type.get()
return MfmaIntrinsic(f32, 32, 32, 8, f16)

@staticmethod
def mfma_i32_16x16x32_i8():
i32 = ir.IntegerType.get_signless(32)
i8 = ir.IntegerType.get_signless(8)
return MfmaIntrinsic(i32, 16, 16, 32, i8)

@staticmethod
def mfma_i32_32x32x16_i8():
i32 = ir.IntegerType.get_signless(32)
i8 = ir.IntegerType.get_signless(8)
return MfmaIntrinsic(i32, 32, 32, 16, i8)

@staticmethod
def all():
return [
MfmaIntrinsic.mfma_f32_16x16x16_f16(),
MfmaIntrinsic.mfma_f32_32x32x8_f16(),
MfmaIntrinsic.mfma_i32_16x16x32_i8(),
MfmaIntrinsic.mfma_i32_32x32x16_i8(),
]


def get_compatible_mfma_intrinsics(
problem_size: ProblemSize,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> list[MfmaIntrinsic]:
available_mma_intrinsics = [str(mma) for mma in mma_intrinsics]

def is_compatible(intrinsic: MfmaIntrinsic) -> bool:
if problem_size.res_type.element_type != intrinsic.output_type:
) -> list[iree_gpu.MMAIntrinsic]:
def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool:
mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma
a_type, b_type, c_type = mma_attr.abc_element_types
if problem_size.res_type.element_type != c_type:
return False
if problem_size.dispatch_kind != DispatchKind.batch_matmul:
if problem_size.lhs_type.element_type != intrinsic.input_type:
return False
if problem_size.rhs_type.element_type != intrinsic.input_type:
if (
problem_size.lhs_type.element_type != a_type
or problem_size.rhs_type.element_type != b_type
):
return False

if str(intrinsic) not in available_mma_intrinsics:
return False

return True

return list(filter(is_compatible, MfmaIntrinsic.all()))
return list(filter(is_comptible, mma_intrinsics))


class ReorderWorkgroupsStrategy(Enum):
Expand Down Expand Up @@ -197,7 +147,7 @@ def __str__(self) -> str:
class Configuration:
subgroup_size: int
workgroup_size: list[int]
intrinsic: MfmaIntrinsic
intrinsic: iree_gpu.MMAAttr
tile_sizes: list[int]
subgroup_m_count: int
subgroup_n_count: int
Expand Down
25 changes: 11 additions & 14 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
Usage: python -m pytest candidate_gen_test.py
Usage: python -m pytest common_test.py
"""

import pytest
Expand Down Expand Up @@ -72,10 +72,12 @@ def test_gpu_pipeline_options() -> None:


def test_get_pipeline_config(mlir_ctx: ir.Context) -> None:
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=32,
workgroup_size=[16, 16, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
intrinsic=mma_attr,
tile_sizes=[4, 8, 16],
subgroup_m_count=1,
subgroup_n_count=1,
Expand All @@ -97,11 +99,6 @@ def test_get_pipeline_config(mlir_ctx: ir.Context) -> None:
)


def test_mfma_intrinsic_to_str(mlir_ctx: ir.Context) -> None:
assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16"
assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8"


def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
assert common.get_compatible_mfma_intrinsics(
common.ProblemSize(
Expand All @@ -116,8 +113,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
]

assert common.get_compatible_mfma_intrinsics(
Expand All @@ -133,8 +130,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
],
) == [
common.MfmaIntrinsic.mfma_i32_16x16x32_i8(),
common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
]

assert common.get_compatible_mfma_intrinsics(
Expand All @@ -150,8 +147,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
]

assert common.get_compatible_mfma_intrinsics(
Expand All @@ -166,7 +163,7 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
]

assert (
Expand Down
Loading

0 comments on commit 6bb24a3

Please sign in to comment.