Skip to content

Commit

Permalink
[tuner]: address comments and fix CI errors
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <[email protected]>
  • Loading branch information
bangtianliu committed Nov 22, 2024
1 parent 1c3a670 commit d5136dd
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 15 deletions.
4 changes: 2 additions & 2 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from iree.compiler import ir # type: ignore

from iree.compiler.dialects import iree_codegen
from iree.compiler.dialects import iree_codegen # type: ignore

from .common import *
from .dispatch_constraints import *
Expand Down Expand Up @@ -538,7 +538,7 @@ def tune(
walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry)

variant_op_list = iree_codegen.get_executable_variant_ops(mlir_module)
assert len(variant_op_list) == 1, "Support only one op in one disptach"
assert len(variant_op_list) == 1, "Expect one executable variant op"
variant_op = variant_op_list[0]
# Get the MMA intrinisic intructions supported by the target.
mma_list = iree_codegen.query_mma_intrinsics(variant_op)
Expand Down
8 changes: 4 additions & 4 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from iree.compiler import ir # type: ignore

from iree.compiler.dialects import iree_gpu
from iree.compiler.dialects import iree_gpu # type: ignore


class CommonTypes:
Expand Down Expand Up @@ -134,9 +134,9 @@ def all():

def get_compatible_mfma_intrinsics(
problem_size: ProblemSize,
mma_intrinsics: Optional[list[iree_gpu.MMAIntrinsic]] = None,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> list[MfmaIntrinsic]:
mma_list_target = {str(mma) for mma in mma_intrinsics} if mma_intrinsics else None
available_mma_intrinsics = [str(mma) for mma in mma_intrinsics]

def is_compatible(intrinsic: MfmaIntrinsic) -> bool:
if problem_size.res_type.element_type != intrinsic.output_type:
Expand All @@ -147,7 +147,7 @@ def is_compatible(intrinsic: MfmaIntrinsic) -> bool:
if problem_size.rhs_type.element_type != intrinsic.input_type:
return False

if mma_list_target is not None and str(intrinsic) not in mma_list_target:
if available_mma_intrinsics and str(intrinsic) not in available_mma_intrinsics:
return False

return True
Expand Down
9 changes: 6 additions & 3 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([1280, 1280], tuner_ctx.type.f16),
common.ShapedType([2048, 1280], tuner_ctx.type.f32),
common.DispatchKind.mmt,
)
),
[],
) == [
common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
Expand All @@ -122,7 +123,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([1280, 1280], tuner_ctx.type.i8),
common.ShapedType([2048, 1280], tuner_ctx.type.i32),
common.DispatchKind.mmt,
)
),
[],
) == [
common.MfmaIntrinsic.mfma_i32_16x16x32_i8(),
common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
Expand All @@ -135,7 +137,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([64, 640, 320], tuner_ctx.type.f32),
common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
common.DispatchKind.batch_matmul,
)
),
[],
) == [
common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
Expand Down
9 changes: 4 additions & 5 deletions tuner/tuner/dispatch_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from typing import Iterator


from iree.compiler import ir
from iree.compiler.dialects import iree_codegen, iree_gpu
from iree.compiler.dialects import iree_gpu # type: ignore

from .common import *

Expand All @@ -22,7 +21,7 @@ def get_mfma_intrinsic_constraints(
intrinsic_m: z3.ArithRef,
intrinsic_n: z3.ArithRef,
intrinsic_k: z3.ArithRef,
mma_intrinsics: Optional[list[iree_gpu.MMAIntrinsic]] = None,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> z3.BoolRef:
compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size, mma_intrinsics)
assert len(compatible_intrinsics) > 0, "No compatible intrinsics found"
Expand Down Expand Up @@ -73,7 +72,7 @@ def generate_constraints(
subgroup_m_count,
subgroup_n_count,
waves_per_eu,
mma_intrinsics: Optional[list[iree_gpu.MMAIntrinsic]] = None,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
):
M, N, K = (
problem_size.matmul_size.M,
Expand Down Expand Up @@ -139,7 +138,7 @@ def generate_solutions(
logger: logging.Logger,
problem_size: ProblemSize,
num_subgrups: int,
mma_intrinsics: Optional[list[iree_gpu.MMAIntrinsic]] = None,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> Iterator[Configuration]:
M, N, K = problem_size.MNK
logger.info(f"{M},{N},{K}")
Expand Down
6 changes: 5 additions & 1 deletion tuner/tuner/dispatch_constraints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None:
problem_size = common.ProblemSize(
matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
)
configs = dispatch_constraints.generate_solutions(tuner_ctx.logger, problem_size, 4)
configs = dispatch_constraints.generate_solutions(
tuner_ctx.logger, problem_size, 4, []
)
assert configs is not None


Expand Down Expand Up @@ -115,6 +117,7 @@ def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> Non
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
[],
)

solver = z3.Solver()
Expand Down Expand Up @@ -160,6 +163,7 @@ def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> N
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
[],
)
constraints.append(m > 1000) # Adding an additional unsatisfiable constraint

Expand Down

0 comments on commit d5136dd

Please sign in to comment.