Skip to content

Commit c1a8454

Browse files
authored
Copy DecomposeAggregatedOps pass from tpp-mlir (#215)
1 parent 61361c2 commit c1a8454

File tree

5 files changed

+70
-0
lines changed

5 files changed

+70
-0
lines changed

include/gc/Transforms/Passes.td

+7
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,13 @@ def VerifyTargetDescription : Pass<"verify-target-description", "ModuleOp"> {
105105
];
106106
}
107107

108+
def DecomposeAggregatedOps : Pass<"decompose-aggregated-ops", "func::FuncOp"> {
109+
let summary = "Decompose aggregated operations.";
110+
let description = [{
111+
Decompose operations that implement the `AggregatedOpInterface`.
112+
}];
113+
}
114+
108115
def SinkOpIntoInnerLoop : Pass<"sink-op-into-inner-loop"> {
109116
let summary = "Sink operations into inner loops";
110117
let description = [{The pass tries to sink operations into inner loops as deep as possible to maximize the chance for outer loop optimization.

lib/gc/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ gc_add_mlir_library(GcPasses
1616
IterativeTilingAndFusion.cpp
1717
TilingUsingInterfaceX.cpp
1818
VerifyTargetDescription.cpp
19+
DecomposeAggregatedOps.cpp
1920
DeepTileContractionOp.cpp
2021
TilingUtil.cpp
2122
SinkOpIntoInnerLoop.cpp
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===-- DecomposeAggregatedOps.cpp - Decompose Aggregated Ops ---*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "gc/Transforms/Passes.h"
10+
#include "mlir/Dialect/Func/IR/FuncOps.h"
11+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
12+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
13+
14+
using namespace mlir;
15+
namespace mlir {
16+
namespace gc {
17+
#define GEN_PASS_DEF_DECOMPOSEAGGREGATEDOPS
18+
#include "gc/Transforms/Passes.h.inc"
19+
} // namespace gc
20+
} // namespace mlir
21+
22+
namespace {
23+
24+
struct DecomposeAggregateOpsImpl : public OpRewritePattern<linalg::SoftmaxOp> {
25+
using OpRewritePattern<linalg::SoftmaxOp>::OpRewritePattern;
26+
27+
LogicalResult matchAndRewrite(linalg::SoftmaxOp softmaxOp,
28+
PatternRewriter &rewriter) const override {
29+
auto decomposableOp =
30+
cast<linalg::AggregatedOpInterface>(softmaxOp.getOperation());
31+
FailureOr<SmallVector<Value>> maybeNewResult =
32+
decomposableOp.decomposeOperation(rewriter);
33+
if (failed(maybeNewResult))
34+
return failure();
35+
rewriter.replaceOp(softmaxOp, *maybeNewResult);
36+
return success();
37+
}
38+
};
39+
40+
struct DecomposeAggregatedOps
41+
: public gc::impl::DecomposeAggregatedOpsBase<DecomposeAggregatedOps> {
42+
void runOnOperation() override {
43+
RewritePatternSet patterns(getOperation().getContext());
44+
patterns.add<DecomposeAggregateOpsImpl>(patterns.getContext());
45+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
46+
}
47+
};
48+
49+
} // namespace

lib/gc/Transforms/Pipeline.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ void populateTensorPasses(mlir::OpPassManager &pm) {
6262
// REMOVE this pass after the above passes are added. Currently we add this
6363
// pass to make the pipeline work properly
6464
pm.addNestedPass<func::FuncOp>(createLinalgGeneralizeNamedOpsPass());
65+
// copied from tpp project
66+
pm.addNestedPass<func::FuncOp>(createDecomposeAggregatedOps());
6567
// fold useless tensor operation pass
6668
pm.addPass(createFoldTensorOperation());
6769
pm.addPass(createLoopInvariantCodeMotionPass());
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
// RUN: gc-opt %s -decompose-aggregated-ops | FileCheck %s
2+
3+
// CHECK-LABEL: softmax
4+
func.func @softmax(%arg0: tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32> {
5+
%0 = tensor.empty() : tensor<2x2x2x2xf32>
6+
// CHECK-NOT: linalg.softmax
7+
// CHECK-COUNT-4: linalg.generic
8+
%1 = linalg.softmax dimension(3)
9+
ins(%arg0 : tensor<2x2x2x2xf32>) outs(%0 : tensor<2x2x2x2xf32>) -> tensor<2x2x2x2xf32>
10+
return %1 : tensor<2x2x2x2xf32>
11+
}

0 commit comments

Comments
 (0)