1
+ // ===-- DecomposeAggregatedOps.cpp - Decompose Aggregated Ops ---*- C++ -*-===//
2
+ //
3
+ // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4
+ // See https://llvm.org/LICENSE.txt for license information.
5
+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
+ //
7
+ // ===----------------------------------------------------------------------===//
8
+
9
+ #include " gc/Transforms/Passes.h"
10
+ #include " mlir/Dialect/Func/IR/FuncOps.h"
11
+ #include " mlir/Dialect/Linalg/IR/Linalg.h"
12
+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
13
+
14
+ using namespace mlir ;
15
+ namespace mlir {
16
+ namespace gc {
17
+ #define GEN_PASS_DEF_DECOMPOSEAGGREGATEDOPS
18
+ #include " gc/Transforms/Passes.h.inc"
19
+ } // namespace gc
20
+ } // namespace mlir
21
+
22
+ namespace {
23
+
24
+ struct DecomposeAggregateOpsImpl : public OpRewritePattern <linalg::SoftmaxOp> {
25
+ using OpRewritePattern<linalg::SoftmaxOp>::OpRewritePattern;
26
+
27
+ LogicalResult matchAndRewrite (linalg::SoftmaxOp softmaxOp,
28
+ PatternRewriter &rewriter) const override {
29
+ auto decomposableOp =
30
+ cast<linalg::AggregatedOpInterface>(softmaxOp.getOperation ());
31
+ FailureOr<SmallVector<Value>> maybeNewResult =
32
+ decomposableOp.decomposeOperation (rewriter);
33
+ if (failed (maybeNewResult))
34
+ return failure ();
35
+ rewriter.replaceOp (softmaxOp, *maybeNewResult);
36
+ return success ();
37
+ }
38
+ };
39
+
40
+ struct DecomposeAggregatedOps
41
+ : public gc::impl::DecomposeAggregatedOpsBase<DecomposeAggregatedOps> {
42
+ void runOnOperation () override {
43
+ RewritePatternSet patterns (getOperation ().getContext ());
44
+ patterns.add <DecomposeAggregateOpsImpl>(patterns.getContext ());
45
+ (void )applyPatternsAndFoldGreedily (getOperation (), std::move (patterns));
46
+ }
47
+ };
48
+
49
+ } // namespace
0 commit comments