Skip to content

Commit

Permalink
[tuner] Add direct TD spec generation for candidates (#606)
Browse files Browse the repository at this point in the history
This PR adds direct transform dialect spec generation for candidate
configurations. This is the first part of the large refactoring
described in #577. The way TD
specs are generated is by matching against certain types of operations,
and then creating a named sequence with
`transform.iree.match.cast_compatible_dag_from_root` based on the
matched operation. This is done for each configuration found, and the
specs are saved to the temporary tuning directory to be used later in
tuning.

One main difference in the flow of candidate generation is that state is
no longer tracked by saving files to a temporary directory. Instead, ir
modules are passed to each function, and only at the very end of
candidate generation are the transform dialect specs written to files.
This makes things cleaner, since there no longer needs to be a
coordination of file paths.

Signed-off-by: Max Dawkins <[email protected]>
  • Loading branch information
Max191 authored and eagarvey-amd committed Jan 8, 2025
1 parent cbca5d9 commit 71902b0
Show file tree
Hide file tree
Showing 12 changed files with 900 additions and 8 deletions.
5 changes: 5 additions & 0 deletions tuner/examples/test/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
9 changes: 9 additions & 0 deletions tuner/examples/test/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from . import tuner_test

tuner_test.main()
40 changes: 40 additions & 0 deletions tuner/examples/test/tuner_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2024 Advanced Micro Devices, Inc
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from tuner import libtuner


def main():
args = libtuner.parse_arguments()

path_config = libtuner.PathConfig()
path_config.base_dir.mkdir(parents=True, exist_ok=True)
path_config.output_unilog.touch()
candidate_trackers: list[libtuner.CandidateTracker] = []
stop_after_phase: str = args.stop_after

print("Setup logging")
libtuner.setup_logging(args, path_config)
print(path_config.run_log, end="\n\n")

if not args.dry_run:
print("Validating devices")
libtuner.validate_devices(args.devices)
print("Validation successful!\n")

print("Generating candidates...")
candidates = libtuner.generate_candidate_specs(
args, path_config, candidate_trackers
)
print(f"Stored candidate specs in {path_config.specs_dir}\n")
if stop_after_phase == libtuner.ExecutionPhases.generate_candidates:
return

print("Check the detailed execution logs in:")
print(path_config.run_log.resolve())

for candidate in candidate_trackers:
libtuner.logging.debug(candidate)
155 changes: 155 additions & 0 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .common import *
from .dispatch_constraints import *
from .dispatch_parser import *
from .spec_builder import *

tune_logger = logging.getLogger("tune")

Expand Down Expand Up @@ -106,6 +107,15 @@ def apply_params(
"""Apply parameter transformations to the operation."""
pass

@abstractmethod
def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> ir.Module:
"""Generate a transform dialect spec that applies the compilation info attr."""
pass


class DispatchTunerRegistry:
def __init__(self):
Expand All @@ -130,6 +140,68 @@ def find_handler(self, op_name: str) -> DispatchTuner:
assert False, "Dispatch kind not supported"


class ContractionOpInterfaceTuner(DispatchTuner, ContractionOpInterfaceParser):
def apply_params(
self,
problem_size: ProblemSize,
template: list[str],
compilation_info: iree_codegen.CompilationInfoAttr,
) -> MLIRTransformation:
raise NotImplementedError

def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> ir.Module:
contraction_op: ir.Operation = self.get_contraction_operation(ir_module)
lhs_type = ir.ShapedType(contraction_op.operands[0].type)
rhs_type = ir.ShapedType(contraction_op.operands[1].type)
acc_type = ir.ShapedType(contraction_op.operands[2].type)
M = acc_type.get_dim_size(0)
N = acc_type.get_dim_size(1)
K = lhs_type.get_dim_size(1)
# TODO(Max191): Get the function name from the func.func in the input module.
func_name = f"match_contraction_{M}x{N}x{K}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}"
return build_td_spec(
ir_module.context, contraction_op, compilation_info, func_name
)


class ConvolutionOpInterfaceTuner(DispatchTuner, ConvolutionOpInterfaceParser):
def apply_params(
self,
problem_size: ProblemSize,
template: list[str],
compilation_info: iree_codegen.CompilationInfoAttr,
) -> MLIRTransformation:
raise NotImplementedError

def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> ir.Module:
conv_op: ir.Operation = self.get_conv_operation(ir_module)
assert (
conv_op.name == "linalg.conv_2d_nhwc_hwcf"
), "expected linalg.conv_2d_nhwc_hwcf"
lhs_type = ir.ShapedType(conv_op.operands[0].type)
rhs_type = ir.ShapedType(conv_op.operands[1].type)
acc_type = ir.ShapedType(conv_op.operands[2].type)
N = acc_type.get_dim_size(0)
H = acc_type.get_dim_size(1)
W = acc_type.get_dim_size(2)
C = rhs_type.get_dim_size(2)
P = rhs_type.get_dim_size(0)
Q = rhs_type.get_dim_size(1)
F = rhs_type.get_dim_size(3)
conv_type = conv_op.name.split(".")[-1]
# TODO(Max191): Get the function name from the func.func in the input module.
func_name = f"match_{conv_type}_{N}x{H}x{W}x{C}x{P}x{Q}x{F}_{lhs_type.element_type}x{rhs_type.element_type}x{acc_type.element_type}"
return build_td_spec(ir_module.context, conv_op, compilation_info, func_name)


class MmtTuner(DispatchTuner, MmtParser):
def get_transform_function_mmt(
self,
Expand Down Expand Up @@ -174,6 +246,13 @@ def apply_params(
)
return MLIRTransformation(template, modified, embeddable)

def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> ir.Module:
raise NotImplementedError


class ConvTuner(DispatchTuner, ConvParser):
def get_transform_function_conv(
Expand Down Expand Up @@ -235,6 +314,13 @@ def apply_params(
)
return MLIRTransformation(template, modified, embeddable)

def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> ir.Module:
raise NotImplementedError


class ContractionTuner(DispatchTuner, ContractionParser):
def get_transform_function_broadcast_rhs_mmt(
Expand Down Expand Up @@ -306,6 +392,13 @@ def apply_params(
"",
)

def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> ir.Module:
raise NotImplementedError


class BatchMmtTuner(DispatchTuner, BatchMmtParser):
def get_transform_function_batch_mmt(
Expand Down Expand Up @@ -353,6 +446,13 @@ def apply_params(
)
return MLIRTransformation(template, modified, embeddable)

def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> ir.Module:
raise NotImplementedError


class BatchMatmulTuner(DispatchTuner, BatchMatmulParser):
def get_transform_function_batch_matmul(
Expand Down Expand Up @@ -409,6 +509,13 @@ def apply_params(
)
return MLIRTransformation(template, modified, embeddable)

def get_td_spec(
self,
ir_module: ir.Module,
compilation_info: iree_codegen.CompilationInfoAttr,
) -> ir.Module:
raise NotImplementedError


@dataclass
class OpWalkResult:
Expand Down Expand Up @@ -452,6 +559,7 @@ def get_default_output_dir() -> str:
return "tuning_" + datetime.now().strftime("%Y_%m_%d_%H_%M")


# TODO(https://github.com/nod-ai/shark-ai/issues/453): Remove in favor of using tune_with_td.
def tune(
input: str, # Path to the mlir file to be tuned
output: str = "", # Path to the output directory, auto creates one if not given
Expand Down Expand Up @@ -527,6 +635,53 @@ def tune(
tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl")


def generate_configs_and_td_specs(
input_module: ir.Module, # Path to the mlir file to be tuned
tuner_context: TunerContext,
limit: int = 4096, # Max candidates to be generated
num_subgroups: int = 4, # GPU spec, used to determine candidate generation constraints
) -> list[ir.Module]:
dispatch_tuner_registry = DispatchTunerRegistry()
dispatch_tuner_registry.register(
[
ContractionOpInterfaceTuner(),
ConvolutionOpInterfaceTuner(),
]
)

walk_result: OpWalkResult = walk_mlir_op(input_module, dispatch_tuner_registry)

dispatch_tuner = walk_result.dispatch_tuner
assert dispatch_tuner, "No suitable dispatch tuner found"
problem_size: ProblemSize = dispatch_tuner.get_shapes(
str(input_module).splitlines()
)
tune_logger.debug(str(problem_size))

# Index 0 is reserved for default config, so it gets no td spec.
with ir.Location.unknown() as loc:
empty_module = ir.Module.create(loc)
config_specs: list[ir.Module] = [empty_module]

# Get the MMA intrinisic intructions supported by the target.
variant_op_list = iree_codegen.get_executable_variant_ops(input_module)
assert len(variant_op_list) == 1, "Expect one executable variant op"
variant_op = variant_op_list[0]
mma_list = iree_codegen.query_mma_intrinsics(variant_op)
for i, config in enumerate(
generate_solutions(tuner_context, problem_size, num_subgroups, mma_list)
):
if i >= limit:
break
tune_logger.info(f"Solution #{i+1}: {config}")
td_spec_module = dispatch_tuner.get_td_spec(input_module, config)
assert td_spec_module, "Failed to generate transform dialect spec"
config_specs.append(td_spec_module)

tune_logger.info(f"Generated {len(config_specs)} tuning specs")
return config_specs


def main():
parser = argparse.ArgumentParser()
parser.add_argument("input", help="Input mlir file", type=str)
Expand Down
Loading

0 comments on commit 71902b0

Please sign in to comment.