diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 41f208216374f..c808e1bedc8de 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -46,6 +46,7 @@ class GreedyRewriteConfig; #define GEN_PASS_DECL_SYMBOLPRIVATIZE #define GEN_PASS_DECL_TOPOLOGICALSORT #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS +#define GEN_PASS_DECL_SELECTPASS #include "mlir/Transforms/Passes.h.inc" /// Creates an instance of the Canonicalizer pass, configured with default @@ -139,6 +140,13 @@ std::unique_ptr createCompositeFixedPointPass( std::string name, llvm::function_ref populateFunc, int maxIterations = 10); +/// Creates select pass, which dynamically selects pass pipeline to run based on +/// root op attribute. +std::unique_ptr createSelectPass( + std::string name, std::string selectCondName, + ArrayRef>> + populateFuncs); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index a39ab77fc8fb3..b707e1edf3d6f 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -586,4 +586,23 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> { ]; } +def SelectPass : Pass<"select-pass"> { + let summary = "Select pass"; + let description = [{ + Select pass dynamically selects pass pipeline to run based on root op + attribute. + }]; + + let options = [ + Option<"name", "name", "std::string", /*default=*/"\"SelectPass\"", + "Select pass display name">, + Option<"selectCondName", "select-cond-name", "std::string", "\"select\"", + "Attribute name used for condition">, + ListOption<"selectValues", "select-values", "std::string", + "Values used to check select condition">, + ListOption<"selectPipelines", "select-pipelines", "std::string", + "Pipelines, assotiated with corresponding select values">, + ]; +} + #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 3a8088bccf299..b94e390b627d6 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_library(MLIRTransforms PrintIR.cpp RemoveDeadValues.cpp SCCP.cpp + SelectPass.cpp SROA.cpp StripDebugInfo.cpp SymbolDCE.cpp diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp new file mode 100644 index 0000000000000..60b956422a9d0 --- /dev/null +++ b/mlir/lib/Transforms/SelectPass.cpp @@ -0,0 +1,133 @@ +//===- SelectPass.cpp - Select pass code ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// SelectPass dynamically selects pass pipeline to run based on root op +// attribute. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" + +namespace mlir { +#define GEN_PASS_DEF_SELECTPASS +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct SelectPass final : public impl::SelectPassBase { + using SelectPassBase::SelectPassBase; + + SelectPass( + std::string name_, std::string selectCondName_, + ArrayRef>> + populateFuncs) { + name = std::move(name_); + selectCondName = std::move(selectCondName_); + + SmallVector selectVals; + SmallVector selectPpls; + selectVals.reserve(populateFuncs.size()); + selectPpls.reserve(populateFuncs.size()); + selectPassManagers.reserve(populateFuncs.size()); + for (auto &&[name, populate] : populateFuncs) { + selectVals.emplace_back(name); + + auto &pm = selectPassManagers.emplace_back(); + populate(pm); + + llvm::raw_string_ostream os(selectPpls.emplace_back()); + pm.printAsTextualPipeline(os); + } + + selectValues = selectVals; + selectPipelines = selectPpls; + } + + LogicalResult initializeOptions( + StringRef options, + function_ref errorHandler) override { + if (failed(SelectPassBase::initializeOptions(options, errorHandler))) + return failure(); + + if (selectCondName.empty()) + return errorHandler("select-cond-name is empty"); + + if (selectValues.size() != selectPipelines.size()) + return errorHandler("values and pipelines size mismatch"); + + selectPassManagers.resize(selectPipelines.size()); + + for (auto &&[i, pipeline] : llvm::enumerate(selectPipelines)) { + if (failed(parsePassPipeline(pipeline, selectPassManagers[i]))) + return errorHandler("failed to parse pipeline"); + } + + return success(); + } + + LogicalResult initialize(MLIRContext *context) override { + condAttrName = StringAttr::get(context, selectCondName); + + selectAttrs.reserve(selectAttrs.size()); + for (StringRef value : selectValues) + selectAttrs.emplace_back(StringAttr::get(context, value)); + + return success(); + } + + void getDependentDialects(DialectRegistry ®istry) const override { + for (const OpPassManager &pipeline : selectPassManagers) + pipeline.getDependentDialects(registry); + } + + void runOnOperation() override { + Operation *op = getOperation(); + Attribute condAttrValue = op->getAttr(condAttrName); + if (!condAttrValue) { + op->emitError("condition attribute not present: ") << condAttrName; + return signalPassFailure(); + } + + for (auto &&[value, pm] : + llvm::zip_equal(selectAttrs, selectPassManagers)) { + if (value != condAttrValue) + continue; + + if (failed(runPipeline(pm, op))) + return signalPassFailure(); + + return; + } + + // TODO: add a default pipeline option. + op->emitError("unhandled condition value: ") << condAttrValue; + return signalPassFailure(); + } + +protected: + StringRef getName() const override { return name; } + +private: + StringAttr condAttrName; + SmallVector selectAttrs; + SmallVector selectPassManagers; +}; +} // namespace + +std::unique_ptr mlir::createSelectPass( + std::string name, std::string selectCondName, + ArrayRef>> + populateFuncs) { + return std::make_unique(std::move(name), + std::move(selectCondName), populateFuncs); +} diff --git a/mlir/test/Transforms/select-pass.mlir b/mlir/test/Transforms/select-pass.mlir new file mode 100644 index 0000000000000..42340f957b4fc --- /dev/null +++ b/mlir/test/Transforms/select-pass.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.module(gpu.module(select-pass{ \ +// RUN: name=TestSelectPass \ +// RUN: select-cond-name=test.attr \ +// RUN: select-values=rocdl,nvvm \ +// RUN: select-pipelines=convert-gpu-to-rocdl,convert-gpu-to-nvvm \ +// RUN: }))' | FileCheck %s + +gpu.module @rocdl_module attributes {test.attr = "rocdl"} { +// CHECK-LABEL: func @foo() +// CHECK: rocdl.workitem.id.x + func.func @foo() -> index { + %0 = gpu.thread_id x + return %0 : index + } +} + +gpu.module @nvvm_module attributes {test.attr = "nvvm"} { +// CHECK-LABEL: func @bar() +// CHECK: nvvm.read.ptx.sreg.tid.x + func.func @bar() -> index { + %0 = gpu.thread_id x + return %0 : index + } +}