From 0303b3c511c4e85e012412142702c4466d3ab01d Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 8 Mar 2025 12:54:27 +0100 Subject: [PATCH 1/6] [mlir] Add `SelectPass` `SelectPass` allows to dynamically select the pass pipeline based on attribute value attached to some top-level op. --- mlir/include/mlir/Transforms/Passes.h | 8 ++ mlir/include/mlir/Transforms/Passes.td | 19 ++++ mlir/lib/Transforms/CMakeLists.txt | 1 + mlir/lib/Transforms/SelectPass.cpp | 132 +++++++++++++++++++++++++ mlir/test/Transforms/select-pass.mlir | 24 +++++ 5 files changed, 184 insertions(+) create mode 100644 mlir/lib/Transforms/SelectPass.cpp create mode 100644 mlir/test/Transforms/select-pass.mlir diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 41f208216374f..e521705371b0b 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -46,6 +46,7 @@ class GreedyRewriteConfig; #define GEN_PASS_DECL_SYMBOLPRIVATIZE #define GEN_PASS_DECL_TOPOLOGICALSORT #define GEN_PASS_DECL_COMPOSITEFIXEDPOINTPASS +#define GEN_PASS_DECL_SELECTPASS #include "mlir/Transforms/Passes.h.inc" /// Creates an instance of the Canonicalizer pass, configured with default @@ -139,6 +140,13 @@ std::unique_ptr createCompositeFixedPointPass( std::string name, llvm::function_ref populateFunc, int maxIterations = 10); +/// Creates select pass which allows to run multiple different set of passes +/// based on attribute value on some top-level op. +std::unique_ptr createSelectPass( + std::string name, std::string selectCondName, + ArrayRef>> + populateFuncs); + //===----------------------------------------------------------------------===// // Registration //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index a39ab77fc8fb3..846dffb89e8f7 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -586,4 +586,23 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> { ]; } +def SelectPass : Pass<"select-pass"> { + let summary = "Select pass"; + let description = [{ + Select pass allows to run multiple different set of passes based on + attribute value on some top-level op. + }]; + + let options = [ + Option<"name", "name", "std::string", /*default=*/"\"SelectPass\"", + "Select pass display name">, + Option<"selectCondName", "select-cond-name", "std::string", "\"select\"", + "Attribute name used for condition">, + ListOption<"selectValues", "select-values", "std::string", + "Values used to check select condition">, + ListOption<"selectPipelines", "select-pipelines", "std::string", + "Pipelines, assotiated with corresponding select values">, + ]; +} + #endif // MLIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Transforms/CMakeLists.txt b/mlir/lib/Transforms/CMakeLists.txt index 3a8088bccf299..b94e390b627d6 100644 --- a/mlir/lib/Transforms/CMakeLists.txt +++ b/mlir/lib/Transforms/CMakeLists.txt @@ -14,6 +14,7 @@ add_mlir_library(MLIRTransforms PrintIR.cpp RemoveDeadValues.cpp SCCP.cpp + SelectPass.cpp SROA.cpp StripDebugInfo.cpp SymbolDCE.cpp diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp new file mode 100644 index 0000000000000..750a6617fe9b9 --- /dev/null +++ b/mlir/lib/Transforms/SelectPass.cpp @@ -0,0 +1,132 @@ +//===- SelectPass.cpp - Select pass code ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// SelectPass allows to run multiple different set of passes based on attribute +// value on some top-level op. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/Passes.h" + +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" + +namespace mlir { +#define GEN_PASS_DEF_SELECTPASS +#include "mlir/Transforms/Passes.h.inc" +} // namespace mlir + +using namespace mlir; + +namespace { +struct SelectPass final : public impl::SelectPassBase { + using SelectPassBase::SelectPassBase; + + SelectPass( + std::string name_, std::string selectCondName_, + ArrayRef>> + populateFuncs) { + name = std::move(name_); + selectCondName = std::move(selectCondName_); + + SmallVector selectVals; + SmallVector selectPpls; + selectVals.reserve(populateFuncs.size()); + selectPpls.reserve(populateFuncs.size()); + selectPassManagers.reserve(populateFuncs.size()); + for (auto &&[name, populate] : populateFuncs) { + selectVals.emplace_back(name); + + auto &pm = selectPassManagers.emplace_back(); + populate(pm); + + llvm::raw_string_ostream os(selectPpls.emplace_back()); + pm.printAsTextualPipeline(os); + } + + selectValues = selectVals; + selectPipelines = selectPpls; + } + + LogicalResult initializeOptions( + StringRef options, + function_ref errorHandler) override { + if (failed(SelectPassBase::initializeOptions(options, errorHandler))) + return failure(); + + if (selectCondName.empty()) + return errorHandler("Invalid select-cond-name"); + + if (selectValues.size() != selectPipelines.size()) + return errorHandler("Values and pipelines size mismatch"); + + selectPassManagers.resize(selectPipelines.size()); + + for (auto &&[i, pipeline] : llvm::enumerate(selectPipelines)) { + if (failed(parsePassPipeline(pipeline, selectPassManagers[i]))) + return errorHandler("Failed to parse pipeline"); + } + + return success(); + } + + LogicalResult initialize(MLIRContext *context) override { + condAttrName = StringAttr::get(context, selectCondName); + + selectAttrs.reserve(selectAttrs.size()); + for (StringRef value : selectValues) + selectAttrs.emplace_back(StringAttr::get(context, value)); + + return success(); + } + + void getDependentDialects(DialectRegistry ®istry) const override { + for (const OpPassManager &pipeline : selectPassManagers) + pipeline.getDependentDialects(registry); + } + + void runOnOperation() override { + Operation *op = getOperation(); + Attribute condAttrValue = op->getAttr(condAttrName); + if (!condAttrValue) { + op->emitError("Condition attribute not present: ") << condAttrName; + return signalPassFailure(); + } + + for (auto &&[value, pm] : + llvm::zip_equal(selectAttrs, selectPassManagers)) { + if (value != condAttrValue) + continue; + + if (failed(runPipeline(pm, op))) + return signalPassFailure(); + + return; + } + + op->emitError("Unhandled condition value: ") << condAttrValue; + return signalPassFailure(); + } + +protected: + StringRef getName() const override { return name; } + +private: + StringAttr condAttrName; + SmallVector selectAttrs; + SmallVector selectPassManagers; +}; +} // namespace + +std::unique_ptr mlir::createSelectPass( + std::string name, std::string selectCondName, + ArrayRef>> + populateFuncs) { + return std::make_unique(std::move(name), + std::move(selectCondName), populateFuncs); +} diff --git a/mlir/test/Transforms/select-pass.mlir b/mlir/test/Transforms/select-pass.mlir new file mode 100644 index 0000000000000..fb93486b94ed7 --- /dev/null +++ b/mlir/test/Transforms/select-pass.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -pass-pipeline='builtin.module(gpu.module(select-pass{ \ +// RUN: name=TestSelectPass \ +// RUN: select-cond-name=test.attr \ +// RUN: select-values=rocdl,nvvm \ +// RUN: select-pipelines=convert-gpu-to-rocdl,convert-gpu-to-nvvm \ +// RUN: }))' -split-input-file | FileCheck %s + +gpu.module @rocdl_module attributes {test.attr = "rocdl"} { +// CHECK-LABEL: func @foo() +// CHECK: rocdl.workitem.id.x + func.func @foo() -> index { + %0 = gpu.thread_id x + return %0 : index + } +} + +gpu.module @nvvm_module attributes {test.attr = "nvvm"} { +// CHECK-LABEL: func @bar() +// CHECK: nvvm.read.ptx.sreg.tid.x + func.func @bar() -> index { + %0 = gpu.thread_id x + return %0 : index + } +} From 6b62cad4fb4929190a9be0df5eba8fae3ae66f43 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 8 Mar 2025 19:46:32 +0100 Subject: [PATCH 2/6] clarify doc --- mlir/include/mlir/Transforms/Passes.h | 4 ++-- mlir/include/mlir/Transforms/Passes.td | 4 ++-- mlir/lib/Transforms/SelectPass.cpp | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index e521705371b0b..c808e1bedc8de 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -140,8 +140,8 @@ std::unique_ptr createCompositeFixedPointPass( std::string name, llvm::function_ref populateFunc, int maxIterations = 10); -/// Creates select pass which allows to run multiple different set of passes -/// based on attribute value on some top-level op. +/// Creates select pass, which dynamically selects pass pipeline to run based on +/// root op attribute. std::unique_ptr createSelectPass( std::string name, std::string selectCondName, ArrayRef>> diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 846dffb89e8f7..b707e1edf3d6f 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -589,8 +589,8 @@ def CompositeFixedPointPass : Pass<"composite-fixed-point-pass"> { def SelectPass : Pass<"select-pass"> { let summary = "Select pass"; let description = [{ - Select pass allows to run multiple different set of passes based on - attribute value on some top-level op. + Select pass dynamically selects pass pipeline to run based on root op + attribute. }]; let options = [ diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp index 750a6617fe9b9..e538a944c97c9 100644 --- a/mlir/lib/Transforms/SelectPass.cpp +++ b/mlir/lib/Transforms/SelectPass.cpp @@ -6,8 +6,8 @@ // //===----------------------------------------------------------------------===// // -// SelectPass allows to run multiple different set of passes based on attribute -// value on some top-level op. +// SelectPass dynamically selects pass pipeline to run based on root op +// attribute. // //===----------------------------------------------------------------------===// From c81b8f2b5f7c142f1e37113cdad3f59f55d13dde Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 8 Mar 2025 19:47:36 +0100 Subject: [PATCH 3/6] remove -split-input-file --- mlir/test/Transforms/select-pass.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/Transforms/select-pass.mlir b/mlir/test/Transforms/select-pass.mlir index fb93486b94ed7..42340f957b4fc 100644 --- a/mlir/test/Transforms/select-pass.mlir +++ b/mlir/test/Transforms/select-pass.mlir @@ -3,7 +3,7 @@ // RUN: select-cond-name=test.attr \ // RUN: select-values=rocdl,nvvm \ // RUN: select-pipelines=convert-gpu-to-rocdl,convert-gpu-to-nvvm \ -// RUN: }))' -split-input-file | FileCheck %s +// RUN: }))' | FileCheck %s gpu.module @rocdl_module attributes {test.attr = "rocdl"} { // CHECK-LABEL: func @foo() From 85586b135f0861e8dab826c0032da7837d18a437 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 8 Mar 2025 19:53:13 +0100 Subject: [PATCH 4/6] update err messages --- mlir/lib/Transforms/SelectPass.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp index e538a944c97c9..c895fb9d0cdd5 100644 --- a/mlir/lib/Transforms/SelectPass.cpp +++ b/mlir/lib/Transforms/SelectPass.cpp @@ -60,16 +60,16 @@ struct SelectPass final : public impl::SelectPassBase { return failure(); if (selectCondName.empty()) - return errorHandler("Invalid select-cond-name"); + return errorHandler("invalid select-cond-name"); if (selectValues.size() != selectPipelines.size()) - return errorHandler("Values and pipelines size mismatch"); + return errorHandler("values and pipelines size mismatch"); selectPassManagers.resize(selectPipelines.size()); for (auto &&[i, pipeline] : llvm::enumerate(selectPipelines)) { if (failed(parsePassPipeline(pipeline, selectPassManagers[i]))) - return errorHandler("Failed to parse pipeline"); + return errorHandler("failed to parse pipeline"); } return success(); @@ -94,7 +94,7 @@ struct SelectPass final : public impl::SelectPassBase { Operation *op = getOperation(); Attribute condAttrValue = op->getAttr(condAttrName); if (!condAttrValue) { - op->emitError("Condition attribute not present: ") << condAttrName; + op->emitError("condition attribute not present: ") << condAttrName; return signalPassFailure(); } @@ -109,7 +109,7 @@ struct SelectPass final : public impl::SelectPassBase { return; } - op->emitError("Unhandled condition value: ") << condAttrValue; + op->emitError("unhandled condition value: ") << condAttrValue; return signalPassFailure(); } From 0d531d591d559c1b3653036cd6fe030586fefd20 Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 8 Mar 2025 20:38:31 +0100 Subject: [PATCH 5/6] TODO --- mlir/lib/Transforms/SelectPass.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp index c895fb9d0cdd5..a025fd8847877 100644 --- a/mlir/lib/Transforms/SelectPass.cpp +++ b/mlir/lib/Transforms/SelectPass.cpp @@ -109,6 +109,7 @@ struct SelectPass final : public impl::SelectPassBase { return; } + // TODO: add a default pipeline option. op->emitError("unhandled condition value: ") << condAttrValue; return signalPassFailure(); } From 1af8b3eea109cf4f6004944050b7946a3f8a7eff Mon Sep 17 00:00:00 2001 From: Ivan Butygin Date: Sat, 8 Mar 2025 21:14:37 +0100 Subject: [PATCH 6/6] improve err msg --- mlir/lib/Transforms/SelectPass.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Transforms/SelectPass.cpp b/mlir/lib/Transforms/SelectPass.cpp index a025fd8847877..60b956422a9d0 100644 --- a/mlir/lib/Transforms/SelectPass.cpp +++ b/mlir/lib/Transforms/SelectPass.cpp @@ -60,7 +60,7 @@ struct SelectPass final : public impl::SelectPassBase { return failure(); if (selectCondName.empty()) - return errorHandler("invalid select-cond-name"); + return errorHandler("select-cond-name is empty"); if (selectValues.size() != selectPipelines.size()) return errorHandler("values and pipelines size mismatch");