Skip to content

Commit eb7bf78

Browse files
authored
Add RestructureNonConstantAxes pass to address reduce op tests failing on non constant axes (llvm#3600)
1 parent 638ef14 commit eb7bf78

File tree

5 files changed

+308
-0
lines changed

5 files changed

+308
-0
lines changed

include/torch-mlir/Dialect/Torch/Transforms/Passes.h

+6
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,12 @@ StringRef getAbstractInterpLibrary();
149149

150150
static const char kTorchOpPrefix[] = R"(torch.)";
151151

152+
void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns,
153+
MLIRContext *context);
154+
155+
std::unique_ptr<OperationPass<func::FuncOp>>
156+
createRestructureNonConstantAxesPass();
157+
152158
} // namespace Torch
153159

154160
/// Registers all Torch transformation passes.

include/torch-mlir/Dialect/Torch/Transforms/Passes.td

+20
Original file line numberDiff line numberDiff line change
@@ -431,4 +431,24 @@ def VerifyBackendContractNoDecompositions
431431
}];
432432
}
433433

434+
def RestructureNonConstantAxes
435+
: Pass<"torch-restructure-non-constant-axes", "func::FuncOp"> {
436+
let summary = "Ensure that every Reduction.cpp op has a constant reduction axis.";
437+
let constructor = [{
438+
mlir::torch::Torch::createRestructureNonConstantAxesPass()
439+
}];
440+
let description = [{
441+
This pass ensures that every Reduction.cpp op has a constant reduction axis.
442+
443+
It does so using reshapes. For example, a <1,2,3,4,5> tensor will be reshaped to a <?,?,?> tensor
444+
and reduced on axis 1 to produce a <?,1,?> tensor. The resulting tensor will be reshaped back to the original shape.
445+
446+
Then when the axis is supplied at runtime (say axis = -2), the shapes will be computed as so:
447+
<?,?,?> becomes <6,4,5>
448+
which gets reduced to <6,1,5>
449+
and rehsaped back to the original reduction op's output shape,
450+
<1,2,3,1,5>
451+
}];
452+
}
453+
434454
#endif // TORCHMLIR_TORCH_PASSES

lib/Dialect/Torch/Transforms/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ add_mlir_library(TorchMLIRTorchPasses
1717
ReifyShapeCalculations.cpp
1818
ReifyDtypeCalculations.cpp
1919
ReifyAbstractInterpCalculationsUtils.cpp
20+
RestructureNonConstantAxes.cpp
2021
ScalarizeShapes.cpp
2122
AbstractInterpLibrary.cpp
2223
SimplifyShapeCalculations.cpp
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
//===- RestructureNonConstantAxes.cpp --------------------------------*-
2+
// C++-*-===//
3+
//
4+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// Also available under a BSD-style license. See LICENSE.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#include "PassDetail.h"
12+
13+
#include "mlir/IR/BuiltinOps.h"
14+
#include "mlir/Pass/PassManager.h"
15+
#include "mlir/Transforms/DialectConversion.h"
16+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17+
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
18+
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
19+
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
20+
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"
21+
#include "llvm/ADT/StringSet.h"
22+
#include "llvm/Support/Debug.h"
23+
24+
#define DEBUG_TYPE "torch-lower-to-backend-contract"
25+
26+
using namespace mlir;
27+
using namespace mlir::torch;
28+
using namespace mlir::torch::Torch;
29+
30+
namespace {
31+
32+
template <typename SrcOp>
33+
class ConstantifyDimArgument : public OpRewritePattern<SrcOp> {
34+
public:
35+
using OpRewritePattern<SrcOp>::OpRewritePattern;
36+
37+
bool isDimConstant(SrcOp op) const {
38+
SmallVector<int64_t> dimList;
39+
int64_t dim;
40+
return matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList)) ||
41+
matchPattern(op.getDim(), m_TorchConstantInt(&dim));
42+
}
43+
44+
/*
45+
This function renders the reduction dim constant by reshaping the input tensor
46+
such that the dim argument is the middle dimension.
47+
48+
For example, if the input tensor has shape [3,4,5,6,7] and the dim argument is
49+
-2, the input tensor is reshaped to [3,4,5,6,7] -> [12,5,42], the reduction
50+
operation is applied, and the result is reshaped back to [3,4,1,6,7].
51+
52+
Since we don't know the dim argument at compile time, we need to compute the
53+
arguments to the reshape op at runtime. We do this by computing the new shape
54+
of the tensor by multiplying the shapes of the tensor before and after the dim
55+
argument, and then reshaping the tensor to this new shape.
56+
*/
57+
LogicalResult matchAndRewrite(SrcOp op,
58+
PatternRewriter &rewriter) const override {
59+
Location loc = op->getLoc();
60+
61+
Value self = op.getSelf();
62+
Value dim = op.getDim();
63+
64+
if (isDimConstant(op)) {
65+
return rewriter.notifyMatchFailure(op,
66+
"dim argument is already constant");
67+
}
68+
69+
if (isa<Torch::NoneType>(dim.getType())) {
70+
return rewriter.notifyMatchFailure(
71+
op, "RestructureNonConstantAxes does not support None dim");
72+
}
73+
74+
// when keepdim is not constant, check the ranks of the input and output
75+
// tensors
76+
ValueTensorType selfTy =
77+
llvm::cast<ValueTensorType>(op.getSelf().getType());
78+
ValueTensorType resultTy =
79+
llvm::cast<ValueTensorType>(op.getResult().getType());
80+
if (selfTy.hasSizes() && resultTy.hasSizes() &&
81+
selfTy.getSizes().size() != resultTy.getSizes().size()) {
82+
return rewriter.notifyMatchFailure(
83+
op,
84+
"RestructureNonConstantAxes does not yet support keepdim=false, but "
85+
"the input and output tensors have different ranks");
86+
}
87+
88+
Type intType = rewriter.getType<Torch::IntType>();
89+
Type boolType = rewriter.getType<Torch::BoolType>();
90+
auto createInt = [&](int value) {
91+
return rewriter.create<Torch::ConstantIntOp>(
92+
loc, intType,
93+
rewriter.getIntegerAttr(rewriter.getIntegerType(64), value));
94+
};
95+
Value zero = createInt(0);
96+
Value one = createInt(1);
97+
98+
// handle when dim is a single element list
99+
bool oldDimIsList = isa<Torch::ListType>(dim.getType());
100+
if (oldDimIsList) {
101+
Value len = rewriter.create<Torch::AtenLenTOp>(loc, intType, dim);
102+
Value dimListIsLengthOne =
103+
rewriter.create<Torch::AtenEqIntOp>(loc, boolType, len, one);
104+
rewriter.create<Torch::RuntimeAssertOp>(
105+
loc, dimListIsLengthOne,
106+
rewriter.getStringAttr("RestructureNonConstantAxes does not support "
107+
"dim lists with more than one element"));
108+
dim = rewriter.create<Torch::Aten__Getitem__TOp>(loc, intType, dim, zero);
109+
}
110+
111+
// Normalize negative dim
112+
Value rank = rewriter.create<Torch::AtenDimOp>(loc, intType, self);
113+
Value isNegative = rewriter.create<Torch::AtenLtIntOp>(loc, dim, zero);
114+
Value rankOffset = rewriter.create<Torch::AtenMulIntOp>(
115+
loc, intType,
116+
rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isNegative), rank);
117+
dim = rewriter.create<Torch::AtenAddIntOp>(loc, intType, dim, rankOffset);
118+
119+
auto createConditionalMult = [&](Value self, Value multiplier,
120+
Value condition) {
121+
// compute:
122+
// result = codition ? (self * multiplier) : self
123+
// via
124+
// result = self * (1 + (multiplier - 1) * condition)
125+
// which translates to:
126+
127+
// result = multiplier - 1
128+
Value result = rewriter.create<Torch::AtenSubIntOp>(
129+
loc, intType, multiplier, createInt(1));
130+
// result = result * condition
131+
result =
132+
rewriter.create<Torch::AtenMulIntOp>(loc, intType, result, condition);
133+
// result = result + 1
134+
result = rewriter.create<Torch::AtenAddIntOp>(loc, intType, result,
135+
createInt(1));
136+
// result = self * result
137+
result = rewriter.create<Torch::AtenMulIntOp>(loc, intType, self, result);
138+
return result;
139+
};
140+
141+
// new shape = [beforeDim, dimSize, afterDim]
142+
Value beforeProd = createInt(1);
143+
Value afterProd = createInt(1);
144+
Value dimSize = createInt(1);
145+
146+
for (size_t i = 0; i < selfTy.getSizes().size(); ++i) {
147+
Value idx = createInt(i);
148+
Value size =
149+
rewriter.create<Torch::AtenSizeIntOp>(loc, intType, self, idx);
150+
151+
Value isBeforeDim =
152+
rewriter.create<Torch::AtenLtIntOp>(loc, boolType, idx, dim);
153+
isBeforeDim =
154+
rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isBeforeDim);
155+
Value isAfterDim =
156+
rewriter.create<Torch::AtenGtIntOp>(loc, boolType, idx, dim);
157+
isAfterDim =
158+
rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isAfterDim);
159+
160+
Value isEqualToDim =
161+
rewriter.create<Torch::AtenEqIntOp>(loc, boolType, idx, dim);
162+
isEqualToDim =
163+
rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isEqualToDim);
164+
dimSize = createConditionalMult(dimSize, size, isEqualToDim);
165+
166+
beforeProd = createConditionalMult(beforeProd, size, isBeforeDim);
167+
afterProd = createConditionalMult(afterProd, size, isAfterDim);
168+
}
169+
170+
Value newShape = rewriter.create<Torch::PrimListConstructOp>(
171+
loc, rewriter.getType<Torch::ListType>(intType),
172+
ValueRange{beforeProd, dimSize, afterProd});
173+
174+
// Reshape input
175+
auto newSelfTy = selfTy.getWithSizesAndDtype(
176+
SmallVector<int64_t>{Torch::kUnknownSize, Torch::kUnknownSize,
177+
Torch::kUnknownSize},
178+
selfTy.getDtype());
179+
Value reshapedSelf =
180+
rewriter.create<Torch::AtenViewOp>(loc, newSelfTy, self, newShape);
181+
182+
// construct new operange range where self is replaced with reshapedSelf
183+
// tensor, and dim is replaced with 1
184+
Value newDim;
185+
if (oldDimIsList) {
186+
newDim = rewriter.create<Torch::PrimListConstructOp>(
187+
loc, rewriter.getType<Torch::ListType>(intType), ValueRange{one});
188+
} else {
189+
newDim = one;
190+
}
191+
ValueRange oldOperands = op->getOperands();
192+
SmallVector<Value> newOperandsVect;
193+
for (size_t i = 0; i < oldOperands.size(); ++i) {
194+
if (oldOperands[i] == op.getSelf()) {
195+
newOperandsVect.push_back(reshapedSelf);
196+
} else if (oldOperands[i] == op.getDim()) {
197+
newOperandsVect.push_back(newDim);
198+
} else {
199+
newOperandsVect.push_back(oldOperands[i]);
200+
}
201+
}
202+
ValueRange newOperands = ValueRange(newOperandsVect);
203+
204+
// construct new reduction op result type
205+
ValueTensorType newResultTy =
206+
cast<ValueTensorType>(resultTy.getWithSizesAndDtype(
207+
SmallVector<int64_t>{Torch::kUnknownSize, 1, Torch::kUnknownSize},
208+
resultTy.getDtype()));
209+
210+
Value newReductionOp =
211+
rewriter.create<SrcOp>(loc, newResultTy, newOperands, op->getAttrs());
212+
213+
// Reshape the result back to original shape
214+
ValueTensorType oldResultTy =
215+
cast<ValueTensorType>(op.getResult().getType());
216+
SmallVector<Value> shapeValues;
217+
for (auto dim : oldResultTy.getSizes()) {
218+
shapeValues.push_back(createInt(dim));
219+
}
220+
Value originalShape = rewriter.create<Torch::PrimListConstructOp>(
221+
loc, rewriter.getType<Torch::ListType>(intType), shapeValues);
222+
Value result = rewriter.create<Torch::AtenViewOp>(
223+
loc, op->getResult(0).getType(), newReductionOp, originalShape);
224+
225+
rewriter.replaceOp(op, result);
226+
return success();
227+
};
228+
};
229+
230+
template <typename... OpTypes>
231+
void addConstantifyDimArgumentPatterns(RewritePatternSet &patterns,
232+
MLIRContext *context) {
233+
// simple variadic template to sugar up adding the patterns
234+
(patterns.add<ConstantifyDimArgument<OpTypes>>(context), ...);
235+
}
236+
237+
void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns,
238+
MLIRContext *context) {
239+
// these are the reduction ops with a dim argument
240+
241+
addConstantifyDimArgumentPatterns<
242+
// not supported because they have multiple results
243+
// AtenMaxDimOp,
244+
// AtenMinDimOp,
245+
AtenSumDimIntListOp, AtenAllDimOp, AtenLinalgVectorNormOp,
246+
AtenFrobeniusNormDimOp>(patterns, context);
247+
}
248+
249+
class RestructureNonConstantAxesPass
250+
: public RestructureNonConstantAxesBase<RestructureNonConstantAxesPass> {
251+
public:
252+
RestructureNonConstantAxesPass() = default;
253+
254+
void runOnOperation() override {
255+
MLIRContext *context = &getContext();
256+
257+
RewritePatternSet patterns(context);
258+
259+
populateRestructureNonConstantAxesPattern(patterns, context);
260+
261+
// TODO: Debug visitation order to make this more efficient.
262+
// A single linear scan should suffice.
263+
GreedyRewriteConfig config;
264+
config.useTopDownTraversal = true;
265+
config.maxIterations = GreedyRewriteConfig::kNoLimit;
266+
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns),
267+
config))) {
268+
return signalPassFailure();
269+
}
270+
}
271+
};
272+
} // namespace
273+
274+
std::unique_ptr<OperationPass<func::FuncOp>>
275+
mlir::torch::Torch::createRestructureNonConstantAxesPass() {
276+
return std::make_unique<RestructureNonConstantAxesPass>();
277+
}

lib/Dialect/TorchConversion/Transforms/Passes.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,10 @@ void mlir::torch::registerTorchConversionPasses() {
6464

6565
void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline(
6666
OpPassManager &pm) {
67+
// Fix non constant dims passed to reduction ops
68+
pm.addNestedPass<func::FuncOp>(
69+
torch::Torch::createRestructureNonConstantAxesPass());
70+
6771
// We want to fuse quantized operations together before lowering to linalg.
6872
pm.addNestedPass<func::FuncOp>(Torch::createFuseQuantizedOpsPass());
6973
pm.addNestedPass<func::FuncOp>(Torch::createScalarizeShapesPass());

0 commit comments

Comments
 (0)