|
| 1 | +//===- RestructureNonConstantAxes.cpp --------------------------------*- |
| 2 | +// C++-*-===// |
| 3 | +// |
| 4 | +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. |
| 5 | +// See https://llvm.org/LICENSE.txt for license information. |
| 6 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 7 | +// Also available under a BSD-style license. See LICENSE. |
| 8 | +// |
| 9 | +//===----------------------------------------------------------------------===// |
| 10 | + |
| 11 | +#include "PassDetail.h" |
| 12 | + |
| 13 | +#include "mlir/IR/BuiltinOps.h" |
| 14 | +#include "mlir/Pass/PassManager.h" |
| 15 | +#include "mlir/Transforms/DialectConversion.h" |
| 16 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 17 | +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" |
| 18 | +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" |
| 19 | +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" |
| 20 | +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" |
| 21 | +#include "llvm/ADT/StringSet.h" |
| 22 | +#include "llvm/Support/Debug.h" |
| 23 | + |
| 24 | +#define DEBUG_TYPE "torch-lower-to-backend-contract" |
| 25 | + |
| 26 | +using namespace mlir; |
| 27 | +using namespace mlir::torch; |
| 28 | +using namespace mlir::torch::Torch; |
| 29 | + |
| 30 | +namespace { |
| 31 | + |
| 32 | +template <typename SrcOp> |
| 33 | +class ConstantifyDimArgument : public OpRewritePattern<SrcOp> { |
| 34 | +public: |
| 35 | + using OpRewritePattern<SrcOp>::OpRewritePattern; |
| 36 | + |
| 37 | + bool isDimConstant(SrcOp op) const { |
| 38 | + SmallVector<int64_t> dimList; |
| 39 | + int64_t dim; |
| 40 | + return matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList)) || |
| 41 | + matchPattern(op.getDim(), m_TorchConstantInt(&dim)); |
| 42 | + } |
| 43 | + |
| 44 | + /* |
| 45 | + This function renders the reduction dim constant by reshaping the input tensor |
| 46 | + such that the dim argument is the middle dimension. |
| 47 | +
|
| 48 | + For example, if the input tensor has shape [3,4,5,6,7] and the dim argument is |
| 49 | + -2, the input tensor is reshaped to [3,4,5,6,7] -> [12,5,42], the reduction |
| 50 | + operation is applied, and the result is reshaped back to [3,4,1,6,7]. |
| 51 | +
|
| 52 | + Since we don't know the dim argument at compile time, we need to compute the |
| 53 | + arguments to the reshape op at runtime. We do this by computing the new shape |
| 54 | + of the tensor by multiplying the shapes of the tensor before and after the dim |
| 55 | + argument, and then reshaping the tensor to this new shape. |
| 56 | + */ |
| 57 | + LogicalResult matchAndRewrite(SrcOp op, |
| 58 | + PatternRewriter &rewriter) const override { |
| 59 | + Location loc = op->getLoc(); |
| 60 | + |
| 61 | + Value self = op.getSelf(); |
| 62 | + Value dim = op.getDim(); |
| 63 | + |
| 64 | + if (isDimConstant(op)) { |
| 65 | + return rewriter.notifyMatchFailure(op, |
| 66 | + "dim argument is already constant"); |
| 67 | + } |
| 68 | + |
| 69 | + if (isa<Torch::NoneType>(dim.getType())) { |
| 70 | + return rewriter.notifyMatchFailure( |
| 71 | + op, "RestructureNonConstantAxes does not support None dim"); |
| 72 | + } |
| 73 | + |
| 74 | + // when keepdim is not constant, check the ranks of the input and output |
| 75 | + // tensors |
| 76 | + ValueTensorType selfTy = |
| 77 | + llvm::cast<ValueTensorType>(op.getSelf().getType()); |
| 78 | + ValueTensorType resultTy = |
| 79 | + llvm::cast<ValueTensorType>(op.getResult().getType()); |
| 80 | + if (selfTy.hasSizes() && resultTy.hasSizes() && |
| 81 | + selfTy.getSizes().size() != resultTy.getSizes().size()) { |
| 82 | + return rewriter.notifyMatchFailure( |
| 83 | + op, |
| 84 | + "RestructureNonConstantAxes does not yet support keepdim=false, but " |
| 85 | + "the input and output tensors have different ranks"); |
| 86 | + } |
| 87 | + |
| 88 | + Type intType = rewriter.getType<Torch::IntType>(); |
| 89 | + Type boolType = rewriter.getType<Torch::BoolType>(); |
| 90 | + auto createInt = [&](int value) { |
| 91 | + return rewriter.create<Torch::ConstantIntOp>( |
| 92 | + loc, intType, |
| 93 | + rewriter.getIntegerAttr(rewriter.getIntegerType(64), value)); |
| 94 | + }; |
| 95 | + Value zero = createInt(0); |
| 96 | + Value one = createInt(1); |
| 97 | + |
| 98 | + // handle when dim is a single element list |
| 99 | + bool oldDimIsList = isa<Torch::ListType>(dim.getType()); |
| 100 | + if (oldDimIsList) { |
| 101 | + Value len = rewriter.create<Torch::AtenLenTOp>(loc, intType, dim); |
| 102 | + Value dimListIsLengthOne = |
| 103 | + rewriter.create<Torch::AtenEqIntOp>(loc, boolType, len, one); |
| 104 | + rewriter.create<Torch::RuntimeAssertOp>( |
| 105 | + loc, dimListIsLengthOne, |
| 106 | + rewriter.getStringAttr("RestructureNonConstantAxes does not support " |
| 107 | + "dim lists with more than one element")); |
| 108 | + dim = rewriter.create<Torch::Aten__Getitem__TOp>(loc, intType, dim, zero); |
| 109 | + } |
| 110 | + |
| 111 | + // Normalize negative dim |
| 112 | + Value rank = rewriter.create<Torch::AtenDimOp>(loc, intType, self); |
| 113 | + Value isNegative = rewriter.create<Torch::AtenLtIntOp>(loc, dim, zero); |
| 114 | + Value rankOffset = rewriter.create<Torch::AtenMulIntOp>( |
| 115 | + loc, intType, |
| 116 | + rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isNegative), rank); |
| 117 | + dim = rewriter.create<Torch::AtenAddIntOp>(loc, intType, dim, rankOffset); |
| 118 | + |
| 119 | + auto createConditionalMult = [&](Value self, Value multiplier, |
| 120 | + Value condition) { |
| 121 | + // compute: |
| 122 | + // result = codition ? (self * multiplier) : self |
| 123 | + // via |
| 124 | + // result = self * (1 + (multiplier - 1) * condition) |
| 125 | + // which translates to: |
| 126 | + |
| 127 | + // result = multiplier - 1 |
| 128 | + Value result = rewriter.create<Torch::AtenSubIntOp>( |
| 129 | + loc, intType, multiplier, createInt(1)); |
| 130 | + // result = result * condition |
| 131 | + result = |
| 132 | + rewriter.create<Torch::AtenMulIntOp>(loc, intType, result, condition); |
| 133 | + // result = result + 1 |
| 134 | + result = rewriter.create<Torch::AtenAddIntOp>(loc, intType, result, |
| 135 | + createInt(1)); |
| 136 | + // result = self * result |
| 137 | + result = rewriter.create<Torch::AtenMulIntOp>(loc, intType, self, result); |
| 138 | + return result; |
| 139 | + }; |
| 140 | + |
| 141 | + // new shape = [beforeDim, dimSize, afterDim] |
| 142 | + Value beforeProd = createInt(1); |
| 143 | + Value afterProd = createInt(1); |
| 144 | + Value dimSize = createInt(1); |
| 145 | + |
| 146 | + for (size_t i = 0; i < selfTy.getSizes().size(); ++i) { |
| 147 | + Value idx = createInt(i); |
| 148 | + Value size = |
| 149 | + rewriter.create<Torch::AtenSizeIntOp>(loc, intType, self, idx); |
| 150 | + |
| 151 | + Value isBeforeDim = |
| 152 | + rewriter.create<Torch::AtenLtIntOp>(loc, boolType, idx, dim); |
| 153 | + isBeforeDim = |
| 154 | + rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isBeforeDim); |
| 155 | + Value isAfterDim = |
| 156 | + rewriter.create<Torch::AtenGtIntOp>(loc, boolType, idx, dim); |
| 157 | + isAfterDim = |
| 158 | + rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isAfterDim); |
| 159 | + |
| 160 | + Value isEqualToDim = |
| 161 | + rewriter.create<Torch::AtenEqIntOp>(loc, boolType, idx, dim); |
| 162 | + isEqualToDim = |
| 163 | + rewriter.create<Torch::AtenIntBoolOp>(loc, intType, isEqualToDim); |
| 164 | + dimSize = createConditionalMult(dimSize, size, isEqualToDim); |
| 165 | + |
| 166 | + beforeProd = createConditionalMult(beforeProd, size, isBeforeDim); |
| 167 | + afterProd = createConditionalMult(afterProd, size, isAfterDim); |
| 168 | + } |
| 169 | + |
| 170 | + Value newShape = rewriter.create<Torch::PrimListConstructOp>( |
| 171 | + loc, rewriter.getType<Torch::ListType>(intType), |
| 172 | + ValueRange{beforeProd, dimSize, afterProd}); |
| 173 | + |
| 174 | + // Reshape input |
| 175 | + auto newSelfTy = selfTy.getWithSizesAndDtype( |
| 176 | + SmallVector<int64_t>{Torch::kUnknownSize, Torch::kUnknownSize, |
| 177 | + Torch::kUnknownSize}, |
| 178 | + selfTy.getDtype()); |
| 179 | + Value reshapedSelf = |
| 180 | + rewriter.create<Torch::AtenViewOp>(loc, newSelfTy, self, newShape); |
| 181 | + |
| 182 | + // construct new operange range where self is replaced with reshapedSelf |
| 183 | + // tensor, and dim is replaced with 1 |
| 184 | + Value newDim; |
| 185 | + if (oldDimIsList) { |
| 186 | + newDim = rewriter.create<Torch::PrimListConstructOp>( |
| 187 | + loc, rewriter.getType<Torch::ListType>(intType), ValueRange{one}); |
| 188 | + } else { |
| 189 | + newDim = one; |
| 190 | + } |
| 191 | + ValueRange oldOperands = op->getOperands(); |
| 192 | + SmallVector<Value> newOperandsVect; |
| 193 | + for (size_t i = 0; i < oldOperands.size(); ++i) { |
| 194 | + if (oldOperands[i] == op.getSelf()) { |
| 195 | + newOperandsVect.push_back(reshapedSelf); |
| 196 | + } else if (oldOperands[i] == op.getDim()) { |
| 197 | + newOperandsVect.push_back(newDim); |
| 198 | + } else { |
| 199 | + newOperandsVect.push_back(oldOperands[i]); |
| 200 | + } |
| 201 | + } |
| 202 | + ValueRange newOperands = ValueRange(newOperandsVect); |
| 203 | + |
| 204 | + // construct new reduction op result type |
| 205 | + ValueTensorType newResultTy = |
| 206 | + cast<ValueTensorType>(resultTy.getWithSizesAndDtype( |
| 207 | + SmallVector<int64_t>{Torch::kUnknownSize, 1, Torch::kUnknownSize}, |
| 208 | + resultTy.getDtype())); |
| 209 | + |
| 210 | + Value newReductionOp = |
| 211 | + rewriter.create<SrcOp>(loc, newResultTy, newOperands, op->getAttrs()); |
| 212 | + |
| 213 | + // Reshape the result back to original shape |
| 214 | + ValueTensorType oldResultTy = |
| 215 | + cast<ValueTensorType>(op.getResult().getType()); |
| 216 | + SmallVector<Value> shapeValues; |
| 217 | + for (auto dim : oldResultTy.getSizes()) { |
| 218 | + shapeValues.push_back(createInt(dim)); |
| 219 | + } |
| 220 | + Value originalShape = rewriter.create<Torch::PrimListConstructOp>( |
| 221 | + loc, rewriter.getType<Torch::ListType>(intType), shapeValues); |
| 222 | + Value result = rewriter.create<Torch::AtenViewOp>( |
| 223 | + loc, op->getResult(0).getType(), newReductionOp, originalShape); |
| 224 | + |
| 225 | + rewriter.replaceOp(op, result); |
| 226 | + return success(); |
| 227 | + }; |
| 228 | +}; |
| 229 | + |
| 230 | +template <typename... OpTypes> |
| 231 | +void addConstantifyDimArgumentPatterns(RewritePatternSet &patterns, |
| 232 | + MLIRContext *context) { |
| 233 | + // simple variadic template to sugar up adding the patterns |
| 234 | + (patterns.add<ConstantifyDimArgument<OpTypes>>(context), ...); |
| 235 | +} |
| 236 | + |
| 237 | +void populateRestructureNonConstantAxesPattern(RewritePatternSet &patterns, |
| 238 | + MLIRContext *context) { |
| 239 | + // these are the reduction ops with a dim argument |
| 240 | + |
| 241 | + addConstantifyDimArgumentPatterns< |
| 242 | + // not supported because they have multiple results |
| 243 | + // AtenMaxDimOp, |
| 244 | + // AtenMinDimOp, |
| 245 | + AtenSumDimIntListOp, AtenAllDimOp, AtenLinalgVectorNormOp, |
| 246 | + AtenFrobeniusNormDimOp>(patterns, context); |
| 247 | +} |
| 248 | + |
| 249 | +class RestructureNonConstantAxesPass |
| 250 | + : public RestructureNonConstantAxesBase<RestructureNonConstantAxesPass> { |
| 251 | +public: |
| 252 | + RestructureNonConstantAxesPass() = default; |
| 253 | + |
| 254 | + void runOnOperation() override { |
| 255 | + MLIRContext *context = &getContext(); |
| 256 | + |
| 257 | + RewritePatternSet patterns(context); |
| 258 | + |
| 259 | + populateRestructureNonConstantAxesPattern(patterns, context); |
| 260 | + |
| 261 | + // TODO: Debug visitation order to make this more efficient. |
| 262 | + // A single linear scan should suffice. |
| 263 | + GreedyRewriteConfig config; |
| 264 | + config.useTopDownTraversal = true; |
| 265 | + config.maxIterations = GreedyRewriteConfig::kNoLimit; |
| 266 | + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), |
| 267 | + config))) { |
| 268 | + return signalPassFailure(); |
| 269 | + } |
| 270 | + } |
| 271 | +}; |
| 272 | +} // namespace |
| 273 | + |
| 274 | +std::unique_ptr<OperationPass<func::FuncOp>> |
| 275 | +mlir::torch::Torch::createRestructureNonConstantAxesPass() { |
| 276 | + return std::make_unique<RestructureNonConstantAxesPass>(); |
| 277 | +} |
0 commit comments