From d5136dd7320ea734833c5414d0a6e3f7414df71b Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Thu, 21 Nov 2024 19:34:17 -0600 Subject: [PATCH] [tuner]: address comments and fix CI errors Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 4 ++-- tuner/tuner/common.py | 8 ++++---- tuner/tuner/common_test.py | 9 ++++++--- tuner/tuner/dispatch_constraints.py | 9 ++++----- tuner/tuner/dispatch_constraints_test.py | 6 +++++- 5 files changed, 21 insertions(+), 15 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 495fb031b..38696e6db 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -30,7 +30,7 @@ from iree.compiler import ir # type: ignore -from iree.compiler.dialects import iree_codegen +from iree.compiler.dialects import iree_codegen # type: ignore from .common import * from .dispatch_constraints import * @@ -538,7 +538,7 @@ def tune( walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry) variant_op_list = iree_codegen.get_executable_variant_ops(mlir_module) - assert len(variant_op_list) == 1, "Support only one op in one disptach" + assert len(variant_op_list) == 1, "Expect one executable variant op" variant_op = variant_op_list[0] # Get the MMA intrinisic intructions supported by the target. mma_list = iree_codegen.query_mma_intrinsics(variant_op) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index c07ba231d..407f00b97 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -12,7 +12,7 @@ from iree.compiler import ir # type: ignore -from iree.compiler.dialects import iree_gpu +from iree.compiler.dialects import iree_gpu # type: ignore class CommonTypes: @@ -134,9 +134,9 @@ def all(): def get_compatible_mfma_intrinsics( problem_size: ProblemSize, - mma_intrinsics: Optional[list[iree_gpu.MMAIntrinsic]] = None, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], ) -> list[MfmaIntrinsic]: - mma_list_target = {str(mma) for mma in mma_intrinsics} if mma_intrinsics else None + available_mma_intrinsics = [str(mma) for mma in mma_intrinsics] def is_compatible(intrinsic: MfmaIntrinsic) -> bool: if problem_size.res_type.element_type != intrinsic.output_type: @@ -147,7 +147,7 @@ def is_compatible(intrinsic: MfmaIntrinsic) -> bool: if problem_size.rhs_type.element_type != intrinsic.input_type: return False - if mma_list_target is not None and str(intrinsic) not in mma_list_target: + if available_mma_intrinsics and str(intrinsic) not in available_mma_intrinsics: return False return True diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 891d703e2..733cfcc7f 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -109,7 +109,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([1280, 1280], tuner_ctx.type.f16), common.ShapedType([2048, 1280], tuner_ctx.type.f32), common.DispatchKind.mmt, - ) + ), + [], ) == [ common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), @@ -122,7 +123,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([1280, 1280], tuner_ctx.type.i8), common.ShapedType([2048, 1280], tuner_ctx.type.i32), common.DispatchKind.mmt, - ) + ), + [], ) == [ common.MfmaIntrinsic.mfma_i32_16x16x32_i8(), common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), @@ -135,7 +137,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: common.ShapedType([64, 640, 320], tuner_ctx.type.f32), common.ShapedType([64, 968, 320], tuner_ctx.type.f32), common.DispatchKind.batch_matmul, - ) + ), + [], ) == [ common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index e5eef3f26..85039a1e8 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -11,8 +11,7 @@ from typing import Iterator -from iree.compiler import ir -from iree.compiler.dialects import iree_codegen, iree_gpu +from iree.compiler.dialects import iree_gpu # type: ignore from .common import * @@ -22,7 +21,7 @@ def get_mfma_intrinsic_constraints( intrinsic_m: z3.ArithRef, intrinsic_n: z3.ArithRef, intrinsic_k: z3.ArithRef, - mma_intrinsics: Optional[list[iree_gpu.MMAIntrinsic]] = None, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], ) -> z3.BoolRef: compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size, mma_intrinsics) assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" @@ -73,7 +72,7 @@ def generate_constraints( subgroup_m_count, subgroup_n_count, waves_per_eu, - mma_intrinsics: Optional[list[iree_gpu.MMAIntrinsic]] = None, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], ): M, N, K = ( problem_size.matmul_size.M, @@ -139,7 +138,7 @@ def generate_solutions( logger: logging.Logger, problem_size: ProblemSize, num_subgrups: int, - mma_intrinsics: Optional[list[iree_gpu.MMAIntrinsic]] = None, + mma_intrinsics: list[iree_gpu.MMAIntrinsic], ) -> Iterator[Configuration]: M, N, K = problem_size.MNK logger.info(f"{M},{N},{K}") diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 7e1a5c55d..619076200 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -37,7 +37,9 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) - configs = dispatch_constraints.generate_solutions(tuner_ctx.logger, problem_size, 4) + configs = dispatch_constraints.generate_solutions( + tuner_ctx.logger, problem_size, 4, [] + ) assert configs is not None @@ -115,6 +117,7 @@ def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> Non sg_m_cnt, sg_n_cnt, waves_per_eu, + [], ) solver = z3.Solver() @@ -160,6 +163,7 @@ def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> N sg_m_cnt, sg_n_cnt, waves_per_eu, + [], ) constraints.append(m > 1000) # Adding an additional unsatisfiable constraint