Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tuner]: use python binding to select mma intrinsics #586

Merged
merged 4 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

from iree.compiler import ir # type: ignore

from iree.compiler.dialects import iree_codegen # type: ignore

from .common import *
from .dispatch_constraints import *
from .dispatch_parser import *
Expand Down Expand Up @@ -535,13 +537,19 @@ def tune(

walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry)

variant_op_list = iree_codegen.get_executable_variant_ops(mlir_module)
assert len(variant_op_list) == 1, "Expect one executable variant op"
variant_op = variant_op_list[0]
# Get the MMA intrinisic intructions supported by the target.
mma_list = iree_codegen.query_mma_intrinsics(variant_op)

dispatch_tuner = walk_result.dispatch_tuner
assert dispatch_tuner, "No suitable dispatch tuner found"
problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template)
tune_logger.debug(str(problem_size))
configs = []
for i, config in enumerate(
generate_solutions(tune_logger, problem_size, num_subgroups)
generate_solutions(tune_logger, problem_size, num_subgroups, mma_list)
):
if i >= limit:
break
Expand Down
13 changes: 12 additions & 1 deletion tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from iree.compiler import ir # type: ignore

from iree.compiler.dialects import iree_gpu # type: ignore


class CommonTypes:
def __init__(self, ctx: ir.Context):
Expand Down Expand Up @@ -130,7 +132,12 @@ def all():
]


def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]:
def get_compatible_mfma_intrinsics(
problem_size: ProblemSize,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> list[MfmaIntrinsic]:
available_mma_intrinsics = [str(mma) for mma in mma_intrinsics]

def is_compatible(intrinsic: MfmaIntrinsic) -> bool:
if problem_size.res_type.element_type != intrinsic.output_type:
return False
Expand All @@ -139,6 +146,10 @@ def is_compatible(intrinsic: MfmaIntrinsic) -> bool:
return False
if problem_size.rhs_type.element_type != intrinsic.input_type:
return False

if str(intrinsic) not in available_mma_intrinsics:
return False

return True

return list(filter(is_compatible, MfmaIntrinsic.all()))
Expand Down
51 changes: 48 additions & 3 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Generator

from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_gpu # type: ignore


@pytest.fixture
Expand Down Expand Up @@ -109,7 +110,11 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([1280, 1280], tuner_ctx.type.f16),
common.ShapedType([2048, 1280], tuner_ctx.type.f32),
common.DispatchKind.mmt,
)
),
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
Expand All @@ -122,7 +127,11 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([1280, 1280], tuner_ctx.type.i8),
common.ShapedType([2048, 1280], tuner_ctx.type.i32),
common.DispatchKind.mmt,
)
),
[
iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
],
) == [
common.MfmaIntrinsic.mfma_i32_16x16x32_i8(),
common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
Expand All @@ -135,8 +144,44 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([64, 640, 320], tuner_ctx.type.f32),
common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
common.DispatchKind.batch_matmul,
)
),
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
]

assert common.get_compatible_mfma_intrinsics(
common.ProblemSize(
common.MatmulSize(968, 320, 640, 64),
common.ShapedType([64, 968, 640], tuner_ctx.type.f32),
common.ShapedType([64, 640, 320], tuner_ctx.type.f32),
common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
common.DispatchKind.batch_matmul,
),
[
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
]

assert (
common.get_compatible_mfma_intrinsics(
common.ProblemSize(
common.MatmulSize(968, 320, 640, 64),
common.ShapedType([64, 968, 640], tuner_ctx.type.f32),
common.ShapedType([64, 640, 320], tuner_ctx.type.f32),
common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
common.DispatchKind.batch_matmul,
),
[
iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
],
)
== []
)
15 changes: 12 additions & 3 deletions tuner/tuner/dispatch_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import z3 # type: ignore
from typing import Iterator


from iree.compiler.dialects import iree_gpu # type: ignore

from .common import *


Expand All @@ -18,8 +21,9 @@ def get_mfma_intrinsic_constraints(
intrinsic_m: z3.ArithRef,
intrinsic_n: z3.ArithRef,
intrinsic_k: z3.ArithRef,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> z3.BoolRef:
compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size)
compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size, mma_intrinsics)
assert len(compatible_intrinsics) > 0, "No compatible intrinsics found"
return z3.Or(
*(
Expand Down Expand Up @@ -68,6 +72,7 @@ def generate_constraints(
subgroup_m_count,
subgroup_n_count,
waves_per_eu,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
):
M, N, K = (
problem_size.matmul_size.M,
Expand All @@ -82,7 +87,7 @@ def generate_constraints(
constraints += [subgroup_size == 64, wg_threads <= 1024]
constraints += [
get_mfma_intrinsic_constraints(
problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k
problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k, mma_intrinsics
)
]
subgroup_k_count = 1
Expand Down Expand Up @@ -130,7 +135,10 @@ def generate_constraints(


def generate_solutions(
logger: logging.Logger, problem_size: ProblemSize, num_subgrups: int
logger: logging.Logger,
problem_size: ProblemSize,
num_subgrups: int,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> Iterator[Configuration]:
M, N, K = problem_size.MNK
logger.info(f"{M},{N},{K}")
Expand Down Expand Up @@ -168,6 +176,7 @@ def generate_solutions(
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
mma_intrinsics,
)
solver.add(z3.simplify(z3.And(constraints)))
logger.debug(f"Initial constraints: {solver}")
Expand Down
26 changes: 25 additions & 1 deletion tuner/tuner/dispatch_constraints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Generator

from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_gpu # type: ignore

from . import common
from . import dispatch_constraints
Expand All @@ -37,7 +38,18 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None:
problem_size = common.ProblemSize(
matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
)
configs = dispatch_constraints.generate_solutions(tuner_ctx.logger, problem_size, 4)
configs = dispatch_constraints.generate_solutions(
tuner_ctx.logger,
problem_size,
4,
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
],
)

assert configs is not None


Expand Down Expand Up @@ -115,6 +127,12 @@ def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> Non
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
],
)

solver = z3.Solver()
Expand Down Expand Up @@ -160,6 +178,12 @@ def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> N
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
],
)
constraints.append(m > 1000) # Adding an additional unsatisfiable constraint

Expand Down
Loading