diff --git a/externals/llvm-project b/externals/llvm-project index 813f7c3820d0..e2402615a5a7 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 813f7c3820d00349fe23bfc6ba26159764541540 +Subproject commit e2402615a5a76d46a433dfcc1de10b38a1263c9d diff --git a/externals/stablehlo b/externals/stablehlo index 6e403b1aa6a7..8cd9444b78cc 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit 6e403b1aa6a71f5eaa09cc720e4ad42f692745e6 +Subproject commit 8cd9444b78ccec3e42a4b21105a5a547c021e823 diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 0edef878f217..15f29fbc3cab 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -121,6 +121,17 @@ void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op, LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, TypeAttr &accType); +// Get accumulator type for TOSA convolution ops +LogicalResult getConvOpsAccType(PatternRewriter &rewriter, + RankedTensorType inputTy, + RankedTensorType weightTy, + RankedTensorType outputTy, TypeAttr &accType); + +// Temporary function to get TOSA const shape +// TODO: Remove this function when getTosaConstShape is available in +// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +Value getTosaConstShape(PatternRewriter &rewriter, Location loc, + llvm::ArrayRef shape); } // namespace tosa } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index d6b5aaf869c8..c83f49d7f62d 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -549,7 +549,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (isa(op)) { MLIRContext *context = op->getContext(); - Type floatDtype = mlir::FloatType::getF64(context); + Type floatDtype = mlir::Float64Type::get(context); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], floatDtype); Value zero = @@ -569,7 +569,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (isa(op)) { MLIRContext *context = op->getContext(); - Type floatDtype = mlir::FloatType::getF64(context); + Type floatDtype = mlir::Float64Type::get(context); Value self = convertScalarToDtype(b, loc, payloadArgs[0], floatDtype); Value zero = b.create(loc, b.getFloatAttr(floatDtype, 0)); @@ -1028,7 +1028,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type powType = dtype; if (payloadArgs[0].getType().isInteger() || payloadArgs[1].getType().isInteger()) - powType = mlir::FloatType::getF64(op->getContext()); + powType = mlir::Float64Type::get(op->getContext()); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], powType); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], powType); auto powOp = b.create(loc, lhs, rhs); diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 066126fb0906..4ec703d892ad 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -12,6 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" @@ -2252,6 +2253,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "non-const dilation list unsupported"); + TypeAttr accType; + if (failed(tosa::getConvOpsAccType(rewriter, inputTy, weightTy, outputTy, + accType))) + return rewriter.notifyMatchFailure( + op, "failed to get accumulator type for convolution ops"); + // TOSA works in NHWC and takes OHWI (conv) / HWIM (depthwise conv) weights. // Perform the necessary transformations. std::optional nchwToNhwcTransposeConst = @@ -2365,12 +2372,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // full convolution convOpResult = rewriter - .create(op->getLoc(), - getTypeConverter()->convertType(convOpTy), - transposedInput, transformedWeight, bias, - rewriter.getDenseI64ArrayAttr(padding), - rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr(dilation)) + .create( + op->getLoc(), getTypeConverter()->convertType(convOpTy), + transposedInput, transformedWeight, bias, + rewriter.getDenseI64ArrayAttr(padding), + rewriter.getDenseI64ArrayAttr(stride), + rewriter.getDenseI64ArrayAttr(dilation), accType) .getResult(); } else if (weightShape[1] == 1) { // depthwise convolution @@ -2381,7 +2388,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( transposedInput, transformedWeight, bias, rewriter.getDenseI64ArrayAttr(padding), rewriter.getDenseI64ArrayAttr(stride), - rewriter.getDenseI64ArrayAttr(dilation)) + rewriter.getDenseI64ArrayAttr(dilation), accType) .getResult(); } else { llvm_unreachable("Unhandled convolution type"); @@ -3909,9 +3916,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } - auto result = rewriter.create( - op->getLoc(), resultType, reshapedInput, - rewriter.getDenseI64ArrayAttr(tileOpShape)); + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShape); + + auto result = rewriter.create(op->getLoc(), resultType, + reshapedInput, tileOpMultiples); rewriter.replaceOp(op, {result.getResult()}); } @@ -4104,9 +4113,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType::get(makeShapeLLVMCompatible(expandedIndicesShape), rewriter.getIntegerType(32)); + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), tileShape); + auto expandedIndices = rewriter.create( - op->getLoc(), tileType, reshapedIndices.getResult(), - rewriter.getDenseI64ArrayAttr(tileShape)); + op->getLoc(), tileType, reshapedIndices.getResult(), tileOpMultiples); // convert torch style index and dim into tf style indices // tensor<[1,4,2],si64> -> tensor<[1,4,2,3],si64> @@ -4445,17 +4456,23 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (needsTiling) { auto idxType = dyn_cast(indicesTfConcatTensors[i].getType()); + // indicesTfConcatTensors has a trailing [1] dim for the final concat. auto maxRankMaxDimShapeTf(maxRankMaxDimShape); maxRankMaxDimShapeTf.push_back(1); + auto tileOpShapeTf(tileOpShape); tileOpShapeTf.push_back(1); + auto tileOutputTy = RankedTensorType::get(maxRankMaxDimShapeTf, idxType.getElementType()); auto reshapedIdxTensor = indicesTfConcatTensors[i]; + + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), tileOpShapeTf); + indicesTfConcatTensors[i] = rewriter.create( - op->getLoc(), tileOutputTy, reshapedIdxTensor, - rewriter.getDenseI64ArrayAttr(tileOpShapeTf)); + op->getLoc(), tileOutputTy, reshapedIdxTensor, tileOpMultiples); } // Every index tensor now has the same rank and shape @@ -6023,12 +6040,14 @@ class ConvertAtenFillOp : public OpConversionPattern { op->getLoc(), fillValueMatchedInputRankType, fillValue, rewriter.getDenseI64ArrayAttr(fillValueMatchedInputRankShape)); + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), outType.getShape()); + fillValueTargetTensor = rewriter.create( op->getLoc(), RankedTensorType::get(makeShapeTorchCompatible(outType.getShape()), fillValueElemTy), - fillValueMatchedInputRankTensor.getResult(), - makeShapeTorchCompatible(outType.getShape())); + fillValueMatchedInputRankTensor.getResult(), tileOpMultiples); } else { if (failed(torchScalarToTosaTensor( rewriter, op, op.getValue(), fillValueTargetTensor, outElemTy, @@ -6179,7 +6198,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } DenseElementsAttr paddingAttr = DenseIntElementsAttr::get( - RankedTensorType::get({rank, 2}, rewriter.getI64Type()), + RankedTensorType::get({2 * rank}, rewriter.getI64Type()), translatePadsList); Value padsList1 = rewriter.create( @@ -7836,9 +7855,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( resultType.getElementType()), self, rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced)); + auto selfTileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), + resultShapeIndex0Replaced); + auto selfTiled = rewriter.create( - op->getLoc(), resultType, selfReshaped.getResult(), - rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced)); + op->getLoc(), resultType, selfReshaped.getResult(), selfTileOpMultiples); // Reshape and tile vec2 to shape {resultShape[0], vec2Shape[0]} auto vec2Reshaped = rewriter.create( @@ -7847,9 +7868,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( resultType.getElementType()), vec2, rewriter.getDenseI64ArrayAttr(resultShapeIndex0Replaced)); + auto vec2TileOpMultiples = tosa::getTosaConstShape(rewriter, op->getLoc(), + resultShapeIndex1Replaced); + auto vec2Tiled = rewriter.create( - op->getLoc(), resultType, vec2Reshaped.getResult(), - rewriter.getDenseI64ArrayAttr(resultShapeIndex1Replaced)); + op->getLoc(), resultType, vec2Reshaped.getResult(), vec2TileOpMultiples); auto result = tosa::createMulOpAndCast(rewriter, op, resultType, selfTiled.getResult(), diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index ee7f61becf4f..9dedf457096a 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -566,11 +567,12 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, // [0] -> [0,0,0] SmallVector tileShape({W}); // {3} + auto tileOpMultiples = + tosa::getTosaConstShape(rewriter, op->getLoc(), tileShape); auto tosaFillValuesTileOp = tosa::CreateOpAndInfer( rewriter, op->getLoc(), GetTypeFromTensorShape(tileShape, fillValuesType.getElementType()), - tosaFillValuesOneReshapeOp.getResult(), - rewriter.getDenseI64ArrayAttr(tileShape)); + tosaFillValuesOneReshapeOp.getResult(), tileOpMultiples); // [0,0,0] -> [[0,0,0]] SmallVector newTosaFillValuesShape({N, W}); // {1,3} diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index af3635c7639a..1ed360ddae61 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -454,5 +454,63 @@ LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, return success(); } +// Get accumulator type for TOSA convolution ops +LogicalResult getConvOpsAccType(PatternRewriter &rewriter, + RankedTensorType inputTy, + RankedTensorType weightTy, + RankedTensorType outputTy, TypeAttr &accType) { + auto inputElemTy = inputTy.getElementType(); + auto weightElemTy = weightTy.getElementType(); + auto outputElemTy = outputTy.getElementType(); + + auto quantTy = dyn_cast(inputElemTy); + if (quantTy) + inputElemTy = quantTy.getStorageType(); + + // Get TOSA conv ops acc type based on input, weight, and output types + // according to the spec: + // https://www.mlplatform.org/tosa/tosa_spec.html#_conv2d + // https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d + // https://www.mlplatform.org/tosa/tosa_spec.html#_conv3d + // + // For undefined dtypes in TOSA like I64 and F64, acc_type will be set to the + // output type but does not offer any guarantee on the numerical precision + // since such cases will fail TOSA validation. + if ((inputElemTy.isF32() && weightElemTy.isF32() && outputElemTy.isF32()) || + (inputElemTy.isF16() && weightElemTy.isF16() && outputElemTy.isF16()) || + (inputElemTy.isBF16() && weightElemTy.isBF16() && + outputElemTy.isBF16())) { + accType = mlir::TypeAttr::get(rewriter.getF32Type()); + } else if (inputElemTy.isInteger(8) && + (weightElemTy.isInteger(8) || weightElemTy.isInteger(4)) && + outputElemTy.isInteger(32)) { + accType = mlir::TypeAttr::get(rewriter.getIntegerType(32)); + } else if (inputElemTy.isInteger(16) && weightElemTy.isInteger(8) && + outputElemTy.isInteger(48)) { + accType = mlir::TypeAttr::get(rewriter.getIntegerType(48)); + } else if ((inputElemTy.isFloat8E4M3() && weightElemTy.isFloat8E4M3() && + outputElemTy.isF16()) || + (inputElemTy.isFloat8E5M2() && weightElemTy.isFloat8E5M2() && + outputElemTy.isF16())) { + accType = mlir::TypeAttr::get(rewriter.getF16Type()); + } else { + accType = mlir::TypeAttr::get(outputElemTy); + } + + return success(); +} + +// Temporary function to get TOSA const shape +// TODO: Remove this function when getTosaConstShape is available in +// externals/llvm-project/mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h +Value getTosaConstShape(PatternRewriter &rewriter, Location loc, + llvm::ArrayRef shape) { + auto attr = rewriter.getIndexTensorAttr(shape); + auto type = mlir::tosa::shapeType::get(rewriter.getContext(), shape.size()); + mlir::Operation *mlir_op = + rewriter.create(loc, type, attr); + return mlir_op->getResult(0); +} + } // namespace tosa } // namespace mlir diff --git a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index 6e5a6769a843..3992405a494c 100644 --- a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -121,6 +121,14 @@ class BufferizeAnyTMTensorOp : public OpInterfaceConversionPattern { }; namespace { + +static Value materializeToTensor(OpBuilder &builder, TensorType type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + assert(isa(inputs[0].getType())); + return builder.create(loc, type, inputs[0]); +} + /// Converts TMTensor operations that work on tensor-type operands or results to /// work on buffers. struct TMTensorBufferizePass @@ -133,7 +141,47 @@ struct TMTensorBufferizePass void runOnOperation() override { MLIRContext &context = getContext(); ConversionTarget target(context); - bufferization::BufferizeTypeConverter typeConverter; + // Since the `BufferizeTypeConverter` has been removed here + // https://github.com/llvm/llvm-project/commit/2ff2e871f5e632ea493efaf4f2192f8b18a54ab1, + // hence we have inlined the converter here. + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + // Convert RankedTensorType to MemRefType. + typeConverter.addConversion([](RankedTensorType type) -> Type { + return MemRefType::get(type.getShape(), type.getElementType()); + }); + // Convert UnrankedTensorType to UnrankedMemRefType. + typeConverter.addConversion([](UnrankedTensorType type) -> Type { + return UnrankedMemRefType::get(type.getElementType(), 0); + }); + typeConverter.addArgumentMaterialization(materializeToTensor); + typeConverter.addSourceMaterialization(materializeToTensor); + typeConverter.addTargetMaterialization([](OpBuilder &builder, + BaseMemRefType type, + ValueRange inputs, + Location loc) -> Value { + assert(inputs.size() == 1 && "expected exactly one input"); + if (auto inputType = dyn_cast(inputs[0].getType())) { + // MemRef to MemRef cast. + assert(inputType != type && "expected different types"); + // Ranked to unranked casts must be explicit. + auto rankedDestType = dyn_cast(type); + if (!rankedDestType) + return nullptr; + bufferization::BufferizationOptions options; + options.bufferAlignment = 0; + FailureOr replacement = castOrReallocMemRefValue( + builder, inputs[0], rankedDestType, options); + if (failed(replacement)) + return nullptr; + return *replacement; + } + if (isa(inputs[0].getType())) { + // Tensor to MemRef cast. + return builder.create(loc, type, inputs[0]); + } + llvm_unreachable("only tensor/memref input types supported"); + }); // Mark all Standard operations legal. target.addLegalDialect { RewritePatternSet patterns(context); patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index af937ac10b0e..e8b0d6b0364c 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -26,6 +26,15 @@ using namespace mlir::torch::Torch; using TypeBoundMap = DenseMap, Type>; namespace { + +Value materializeAsCopyTensorToType(OpBuilder &builder, + Torch::BaseTensorType type, + ValueRange inputs, Location loc) { + assert(inputs.size() == 1); + assert(isa(inputs[0].getType())); + return copyTensorToType(builder, loc, type, inputs[0]); +} + class AdjustCallingConventionForFunc : public OpConversionPattern { public: @@ -198,13 +207,9 @@ static LogicalResult adjustCallingConventions(func::FuncOp func, return success(); }); - typeConverter.addArgumentMaterialization( - [](OpBuilder &builder, Torch::BaseTensorType type, ValueRange inputs, - Location loc) -> Value { - assert(inputs.size() == 1); - assert(isa(inputs[0].getType())); - return copyTensorToType(builder, loc, type, inputs[0]); - }); + typeConverter.addArgumentMaterialization(materializeAsCopyTensorToType); + typeConverter.addSourceMaterialization(materializeAsCopyTensorToType); + typeConverter.addTargetMaterialization(materializeAsCopyTensorToType); patterns.add(typeConverter, context); patterns.add(typeConverter, context, typeBoundMap); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 91d6b5eb17fc..32d4b7fe8335 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -11757,8 +11757,8 @@ class DecomposeComplexOpsPass config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 5da8217f6940..da06e1c59a75 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -457,8 +457,8 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { context); GreedyRewriteConfig config; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp index 3717443b7393..0e3cda033a18 100644 --- a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp @@ -122,8 +122,8 @@ class MatchQuantizedCustomOpsPass patterns.insert(context); GreedyRewriteConfig config; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) + if (failed( + applyPatternsGreedily(getOperation(), std::move(patterns), config))) return signalPassFailure(); } }; diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index 92e538772d85..10580b81876b 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -372,7 +372,7 @@ class MaximizeValueSemanticsPass RewritePatternSet patterns(context); patterns.insert(context); - (void)applyPatternsAndFoldGreedily(func, std::move(patterns)); + (void)applyPatternsGreedily(func, std::move(patterns)); } }; diff --git a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp index 06537e75699b..c7ff95270d98 100644 --- a/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/PrepareForGlobalizeObjectGraph.cpp @@ -75,14 +75,13 @@ class PrepareForGlobalizeObjectGraphPass func::CallIndirectOp::getCanonicalizationPatterns(patterns, context); patterns.add(context); - // Use applyPatternsAndFoldGreedily because the CallIndirectOp folding + // Use applyPatternsGreedily because the CallIndirectOp folding // makes the ConstantOp unused, which does not work with the visitation // order of the dialect conversion infrastructure. // TODO: Do this with the dialect conversion infrastructure to avoid doing // folding as part of this. Or avoid folding during greedy pattern // application. See: https://llvm.org/PR49502 - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } diff --git a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp index d9b2648f6689..d5c0900c3383 100644 --- a/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/RecomposeComplexOps.cpp @@ -823,8 +823,8 @@ class RecomposeComplexOpsPass config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp b/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp index 2e1b8e6d3c6f..bd6b1daaf99d 100644 --- a/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp +++ b/lib/Dialect/Torch/Transforms/RestructureNonConstantAxes.cpp @@ -263,8 +263,8 @@ class RestructureNonConstantAxesPass GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp index 634e910d4c32..0914d5b0eed6 100644 --- a/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp +++ b/lib/Dialect/Torch/Transforms/ScalarizeShapes.cpp @@ -1602,8 +1602,8 @@ class ScalarizeShapesPass : public ScalarizeShapesBase { // have been futher propagated. It is also necessary to add newly created // ops for custom folding after scalarizing a where.self op. config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; - if (failed(applyOpPatternsAndFold(shapeCalculationOps.getArrayRef(), - std::move(patterns), config))) { + if (failed(applyOpPatternsGreedily(shapeCalculationOps.getArrayRef(), + std::move(patterns), config))) { return signalPassFailure(); } diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index cf4e444d37a1..0935af83a803 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -213,8 +213,8 @@ class SimplifyDtypeCalculationsPass GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp index edf936bf3412..a2d2c6450693 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyShapeCalculations.cpp @@ -205,8 +205,8 @@ class SimplifyShapeCalculationsPass GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { return signalPassFailure(); } } diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 390a2f2d7862..c0984efffd9c 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "mlir/IR/BuiltinDialect.h" +#include "mlir/IR/BuiltinTypes.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/SparsityUtils.h" @@ -152,9 +153,9 @@ Torch::getTypeForScalarType(MLIRContext *context, case torch_upstream::ScalarType::Bool: return IntegerType::get(context, 1); case torch_upstream::ScalarType::BFloat16: - return mlir::FloatType::getBF16(context); + return mlir::BFloat16Type::get(context); case torch_upstream::ScalarType::Half: - return mlir::FloatType::getF16(context); + return mlir::Float16Type::get(context); case torch_upstream::ScalarType::Byte: return mlir::IntegerType::get(context, 8, mlir::IntegerType::Unsigned); case torch_upstream::ScalarType::Char: diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index 3e8503ed1ba7..dadd865a54a7 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -232,7 +232,7 @@ struct FinalizingBackendTypeConversionPass RewritePatternSet greedyPatterns(context); greedyPatterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(func, std::move(greedyPatterns)))) + if (failed(applyPatternsGreedily(func, std::move(greedyPatterns)))) signalPassFailure(); // Drop attributes that are no longer used after conversion out of Torch. diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp index 229b352094e8..1b7360e14a7f 100644 --- a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp +++ b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp @@ -131,8 +131,7 @@ class UnpackQuantTensorPass RewritePatternSet patterns(context); patterns.add(context); - if (failed( - applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) signalPassFailure(); } }; diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 880d6ace9cd6..d40d02d43ffc 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -425,8 +425,7 @@ class MungeMemrefCopy : public MungeMemrefCopyBase { MLIRContext *context = &getContext(); RewritePatternSet patterns(&getContext()); patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } @@ -448,8 +447,7 @@ class GeneralizeTensorConcat void runOnOperation() override { RewritePatternSet patterns(&getContext()); tensor::populateDecomposeTensorConcatPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } @@ -471,9 +469,8 @@ class GeneralizeTensorPad void runOnOperation() override { MLIRContext *context = &getContext(); RewritePatternSet patterns(&getContext()); - patterns.insert(context); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(patterns)))) { + patterns.insert(context); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { return signalPassFailure(); } } diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index e089c941fde4..7db53b8ca702 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -161,7 +161,7 @@ def lowering_pipeline(generate_runtime_verification: bool): "func.func(tm-tensor-bufferize)", "one-shot-bufferize{copy-before-write bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map}", "refback-mlprogram-bufferize", - "func.func(finalizing-bufferize)", + # "func.func(finalizing-bufferize)", "func.func(buffer-deallocation)", # Buffer-deallocation does not work with the inlined code generated # by sparse tensor dialect. diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index b9fa41379195..2993ae76b547 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -1919,21 +1919,22 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> // CHECK: %[[VAL_4:.*]] = torch.constant.int 2 // CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_2]] : (tensor<2xi64>) -> tensor<2xi32> // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]] {new_shape = array} : (tensor<2xi32>) -> tensor<1x1x2xi32> -// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array} : (tensor<1x1x2xi32>) -> tensor<4x5x2xi32> -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]] {new_shape = array} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32> -// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> -// CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32> -// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32> -// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_11]] {new_shape = array} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32> -// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> -// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_14]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32> -// CHECK: %[[VAL_16:.*]] = tosa.reduce_sum %[[VAL_15]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> -// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<40x1xi32>) -> tensor<1x40xi32> -// CHECK: %[[VAL_18:.*]] = tosa.gather %[[VAL_12]], %[[VAL_17]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> -// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]] {new_shape = array} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32> -// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32> -// CHECK: return %[[VAL_20]] : !torch.vtensor<[4,5,2],f32> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<[4, 5, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_8:.*]] = tosa.tile %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x2xi32>, !tosa.shape<3>) -> tensor<4x5x2xi32> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array} : (tensor<4x5x2xi32>) -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]], {{\[\[}}0], [0]]], {{\[\[}}[1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]], {{\[\[}}1], [1]]], {{\[\[}}[2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]], {{\[\[}}2], [2]]], {{\[\[}}[3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]], {{\[\[}}3], [3]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]], {{\[\[}}[0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]], {{\[\[}}4], [4]]]]> : tensor<4x5x2x1xi32>}> : () -> tensor<4x5x2x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.concat %[[VAL_10]], %[[VAL_11]], %[[VAL_9]] {axis = 3 : i32} : (tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>, tensor<4x5x2x1xi32>) -> tensor<4x5x2x3xi32> +// CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<4x5x6xf32>) -> tensor<1x120x1xf32> +// CHECK: %[[VAL_14:.*]] = tosa.reshape %[[VAL_12]] {new_shape = array} : (tensor<4x5x2x3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[30, 6, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.mul %[[VAL_14]], %[[VAL_15]] {shift = 0 : i8} : (tensor<40x3xi32>, tensor<3xi32>) -> tensor<40x3xi32> +// CHECK: %[[VAL_17:.*]] = tosa.reduce_sum %[[VAL_16]] {axis = 1 : i32} : (tensor<40x3xi32>) -> tensor<40x1xi32> +// CHECK: %[[VAL_18:.*]] = tosa.reshape %[[VAL_17]] {new_shape = array} : (tensor<40x1xi32>) -> tensor<1x40xi32> +// CHECK: %[[VAL_19:.*]] = tosa.gather %[[VAL_13]], %[[VAL_18]] : (tensor<1x120x1xf32>, tensor<1x40xi32>) -> tensor<1x40x1xf32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]] {new_shape = array} : (tensor<1x40x1xf32>) -> tensor<4x5x2xf32> +// CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor<4x5x2xf32> -> !torch.vtensor<[4,5,2],f32> +// CHECK: return %[[VAL_21]] : !torch.vtensor<[4,5,2],f32> // CHECK: } func.func @torch.aten.index_select(%arg0: !torch.vtensor<[4,5,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[4,5,2],f32> { %int2 = torch.constant.int 2 @@ -1964,10 +1965,11 @@ func.func @torch.aten.fill.Scalar(%arg0: !torch.vtensor<[1,12,128,128],f32>) -> // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1],si32> -> tensor<1xi32> // CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<1xi32>) -> tensor<1x1x1x1xi32> -// CHECK: %[[VAL_4:.*]] = tosa.tile %[[VAL_3]] {multiples = array} : (tensor<1x1x1x1xi32>) -> tensor<1x12x128x128xi32> -// CHECK: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32> -// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> -// CHECK: return %[[VAL_6]] : !torch.vtensor<[1,12,128,128],f32> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 12, 128, 128]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_5:.*]] = tosa.tile %[[VAL_3]], %[[VAL_4]] : (tensor<1x1x1x1xi32>, !tosa.shape<4>) -> tensor<1x12x128x128xi32> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<1x12x128x128xi32>) -> tensor<1x12x128x128xf32> +// CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor<1x12x128x128xf32> -> !torch.vtensor<[1,12,128,128],f32> +// CHECK: return %[[VAL_7]] : !torch.vtensor<[1,12,128,128],f32> // CHECK: } func.func @torch.aten.fill.Tensor(%arg0: !torch.vtensor<[1,12,128,128],f32>, %arg1: !torch.vtensor<[1],si32>) -> !torch.vtensor<[1,12,128,128],f32> { %0 = torch.aten.fill.Tensor %arg0, %arg1 : !torch.vtensor<[1,12,128,128],f32>, !torch.vtensor<[1],si32> -> !torch.vtensor<[1,12,128,128],f32> @@ -2584,12 +2586,14 @@ func.func @torch.aten.replication_pad2d$basic(%arg0: !torch.vtensor<[1,1,3,3],f3 // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[4],f32> -> tensor<4xf32> // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3],f32> -> tensor<3xf32> // CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]] {new_shape = array} : (tensor<3xf32>) -> tensor<3x1xf32> -// CHECK: %[[VAL_5:.*]] = tosa.tile %[[VAL_4]] {multiples = array} : (tensor<3x1xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<1x4xf32> -// CHECK: %[[VAL_7:.*]] = tosa.tile %[[VAL_6]] {multiples = array} : (tensor<1x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_8:.*]] = tosa.mul %[[VAL_5]], %[[VAL_7]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> -// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> -// CHECK: return %[[VAL_9]] : !torch.vtensor<[3,4],f32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[1, 4]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_6:.*]] = tosa.tile %[[VAL_4]], %[[VAL_5]] : (tensor<3x1xf32>, !tosa.shape<2>) -> tensor<3x4xf32> +// CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_2]] {new_shape = array} : (tensor<4xf32>) -> tensor<1x4xf32> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[3, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_9:.*]] = tosa.tile %[[VAL_7]], %[[VAL_8]] : (tensor<1x4xf32>, !tosa.shape<2>) -> tensor<3x4xf32> +// CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_6]], %[[VAL_9]] {shift = 0 : i8} : (tensor<3x4xf32>, tensor<3x4xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> +// CHECK: return %[[VAL_11]] : !torch.vtensor<[3,4],f32> // CHECK: } func.func @torch.aten.outer$basic(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[4],f32>) -> !torch.vtensor<[3,4],f32> { %0 = torch.aten.outer %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[3,4],f32> @@ -3080,3 +3084,109 @@ func.func @torch.aten.expm1$int(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vte } // ----- + +// CHECK-LABEL: func.func @torch.aten.constant_pad_nd$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,20,20,4,4],f32>) -> !torch.vtensor<[1,1,20,20,4,5],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,20,20,4,4],f32> -> tensor<1x1x20x20x4x4xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.float 0xFFF0000000000000 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]> : tensor<12xi64>}> : () -> tensor<12xi64> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<0xFF800000> : tensor}> : () -> tensor +// CHECK: %[[VAL_8:.*]] = tosa.pad %[[VAL_1]], %[[VAL_6]], %[[VAL_7]] : (tensor<1x1x20x20x4x4xf32>, tensor<12xi64>, tensor) -> tensor<1x1x20x20x4x5xf32> +// CHECK: %[[VAL_9:.*]] = torch_c.from_builtin_tensor %[[VAL_8]] : tensor<1x1x20x20x4x5xf32> -> !torch.vtensor<[1,1,20,20,4,5],f32> +// CHECK: return %[[VAL_9]] : !torch.vtensor<[1,1,20,20,4,5],f32> +// CHECK: } +func.func @torch.aten.constant_pad_nd$basic(%arg0: !torch.vtensor<[1,1,20,20,4,4],f32>) -> !torch.vtensor<[1,1,20,20,4,5],f32> { + %float-Inf = torch.constant.float 0xFFF0000000000000 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %0 = torch.prim.ListConstruct %int0, %int1 : (!torch.int, !torch.int) -> !torch.list + %1 = torch.aten.constant_pad_nd %arg0, %0, %float-Inf : !torch.vtensor<[1,1,20,20,4,4],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,1,20,20,4,5],f32> + return %1 : !torch.vtensor<[1,1,20,20,4,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,2,10,20],f32> -> tensor<5x2x10x20xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense_resource : tensor<10x2x3x3xf32>}> : () -> tensor<10x2x3x3xf32> +// CHECK: %[[VAL_5:.*]] = torch.constant.none +// CHECK: %[[VAL_6:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_6]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<10xf32>}> : () -> tensor<10xf32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_12]] : (tensor<5x2x10x20xf32>, tensor<4xi32>) -> tensor<5x10x20x2xf32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_4]], %[[VAL_12]] : (tensor<10x2x3x3xf32>, tensor<4xi32>) -> tensor<10x3x3x2xf32> +// CHECK: %[[VAL_15:.*]] = tosa.conv2d %[[VAL_13]], %[[VAL_14]], %[[VAL_11]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x2xf32>, tensor<10x3x3x2xf32>, tensor<10xf32>) -> tensor<5x14x24x10xf32> +// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_17:.*]] = tosa.transpose %[[VAL_15]], %[[VAL_16]] : (tensor<5x14x24x10xf32>, tensor<4xi32>) -> tensor<5x10x14x24xf32> +// CHECK: %[[VAL_18:.*]] = tensor.cast %[[VAL_17]] : tensor<5x10x14x24xf32> to tensor<5x10x14x24xf32> +// CHECK: %[[VAL_19:.*]] = torch_c.from_builtin_tensor %[[VAL_18]] : tensor<5x10x14x24xf32> -> !torch.vtensor<[5,10,14,24],f32> +// CHECK: return %[[VAL_19]] : !torch.vtensor<[5,10,14,24],f32> +// CHECK: } +func.func @torch.aten.convolution$basic(%arg0: !torch.vtensor<[5,2,10,20],f32>) -> !torch.vtensor<[5,10,14,24],f32> { + %false = torch.constant.bool false + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense_resource : tensor<10x2x3x3xf32>) : !torch.vtensor<[10,2,3,3],f32> + %none = torch.constant.none + %int1 = torch.constant.int 1 + %1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int1 : !torch.vtensor<[5,2,10,20],f32>, !torch.vtensor<[10,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[5,10,14,24],f32> + return %5 : !torch.vtensor<[5,10,14,24],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.convolution$depthwise( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[5,4,10,20],f32> -> tensor<5x4x10x20xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.bool false +// CHECK: %[[VAL_3:.*]] = torch.constant.int 4 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 3 +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense_resource : tensor<4x1x3x3xf32>}> : () -> tensor<4x1x3x3xf32> +// CHECK: %[[VAL_6:.*]] = torch.constant.none +// CHECK: %[[VAL_7:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_7]], %[[VAL_7]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list +// CHECK: %[[VAL_11:.*]] = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<4xf32>}> : () -> tensor<4xf32> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_1]], %[[VAL_13]] : (tensor<5x4x10x20xf32>, tensor<4xi32>) -> tensor<5x10x20x4xf32> +// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{value = dense<[2, 3, 0, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_16:.*]] = tosa.transpose %[[VAL_5]], %[[VAL_15]] : (tensor<4x1x3x3xf32>, tensor<4xi32>) -> tensor<3x3x4x1xf32> +// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array} : (tensor<3x3x4x1xf32>) -> tensor<3x3x4x1xf32> +// CHECK: %[[VAL_18:.*]] = tosa.depthwise_conv2d %[[VAL_14]], %[[VAL_17]], %[[VAL_12]] {acc_type = f32, dilation = array, pad = array, stride = array} : (tensor<5x10x20x4xf32>, tensor<3x3x4x1xf32>, tensor<4xf32>) -> tensor<5x5x10x4xf32> +// CHECK: %[[VAL_19:.*]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CHECK: %[[VAL_20:.*]] = tosa.transpose %[[VAL_18]], %[[VAL_19]] : (tensor<5x5x10x4xf32>, tensor<4xi32>) -> tensor<5x4x5x10xf32> +// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<5x4x5x10xf32> to tensor<5x4x5x10xf32> +// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<5x4x5x10xf32> -> !torch.vtensor<[5,4,5,10],f32> +// CHECK: return %[[VAL_22]] : !torch.vtensor<[5,4,5,10],f32> +// CHECK: } +func.func @torch.aten.convolution$depthwise(%arg0: !torch.vtensor<[5,4,10,20],f32>) -> !torch.vtensor<[5,4,5,10],f32> { + %false = torch.constant.bool false + %int4 = torch.constant.int 4 + %int3 = torch.constant.int 3 + %0 = torch.vtensor.literal(dense_resource : tensor<4x1x3x3xf32>) : !torch.vtensor<[4,1,3,3],f32> + %none = torch.constant.none + %int2 = torch.constant.int 2 + %1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %2 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %3 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list + %4 = torch.prim.ListConstruct : () -> !torch.list + %5 = torch.aten.convolution %arg0, %0, %none, %1, %2, %3, %false, %4, %int4 : !torch.vtensor<[5,4,10,20],f32>, !torch.vtensor<[4,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[5,4,5,10],f32> + return %5 : !torch.vtensor<[5,4,5,10],f32> +} + +// ----- diff --git a/test/Dialect/TMTensor/bufferize.mlir b/test/Dialect/TMTensor/bufferize.mlir index 6b766e6d7e53..2d3a49c516ef 100644 --- a/test/Dialect/TMTensor/bufferize.mlir +++ b/test/Dialect/TMTensor/bufferize.mlir @@ -4,11 +4,11 @@ // CHECK-LABEL: func.func @scan_1d_inclusive( // CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>, // CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor) -> (tensor<128xi32>, tensor) { -// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32> +// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : tensor<128xi32> to memref<128xi32> // CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32> // CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref -// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32> -// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref +// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32> to tensor<128xi32> +// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref to tensor // CHECK: tm_tensor.scan dimension(0) inclusive(true) ins(%[[IN_MEMREF]] : memref<128xi32>) // CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref) { // CHECK: ^bb0(%[[OUT_PREV_ELEMENT:.*]]: i32, %[[IN_ELEMENT:.*]]: i32): @@ -30,12 +30,12 @@ func.func @scan_1d_inclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: // CHECK-LABEL: func.func @scan_1d_exclusive( // CHECK-SAME: %[[IN_TENSOR:.*]]: tensor<128xi32>, %[[OUT_TENSOR:.*]]: tensor<128xi32>, // CHECK-SAME: %[[ACC_TENSOR:.*]]: tensor) -> (tensor<128xi32>, tensor) { -// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : memref<128xi32> -// CHECK-DAG: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : memref +// CHECK-DAG: %[[IN_MEMREF:.*]] = bufferization.to_memref %[[IN_TENSOR]] : tensor<128xi32> to memref<128xi32> +// CHECK-DAG: %[[ACC_MEMREF:.*]] = bufferization.to_memref %[[ACC_TENSOR]] : tensor to memref // CHECK-DAG: %[[OUT_MEMREF_NEW:.*]] = memref.alloc() : memref<128xi32> // CHECK-DAG: %[[ACC_MEMREF_NEW:.*]] = memref.alloc() : memref -// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32> -// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref +// CHECK-DAG: %[[OUT_TENSOR_NEW:.*]] = bufferization.to_tensor %[[OUT_MEMREF_NEW]] : memref<128xi32> to tensor<128xi32> +// CHECK-DAG: %[[ACC_TENSOR_NEW:.*]] = bufferization.to_tensor %[[ACC_MEMREF_NEW]] : memref to tensor // CHECK: memref.copy %[[ACC_MEMREF]], %[[ACC_MEMREF_NEW]] : memref to memref // CHECK: tm_tensor.scan dimension(0) inclusive(false) ins(%[[IN_MEMREF]] : memref<128xi32>) // CHECK-SAME: outs(%[[OUT_MEMREF_NEW]], %[[ACC_MEMREF_NEW]] : memref<128xi32>, memref) { @@ -59,11 +59,11 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: // CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>, // CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>, // CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> { -// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32> -// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32> -// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> +// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : tensor<3xi32> to memref<3xi32> +// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : tensor<3x1xi32> to memref<3x1xi32> +// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : tensor<8xi32> to memref<8xi32> // CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> -// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32> +// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32> to tensor<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> // CHECK: tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { @@ -87,11 +87,11 @@ func.func @scatter_update_scalar_1D( // CHECK-SAME: %[[ORIG_TENSOR:.*]]: tensor<8xi32>, // CHECK-SAME: %[[INDICES_TENSOR:.*]]: tensor<3x1xi32>, // CHECK-SAME: %[[UPDATES_TENSOR:.*]]: tensor<3xi32>) -> tensor<8xi32> { -// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : memref<3xi32> -// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : memref<3x1xi32> -// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> +// CHECK-DAG: %[[UPDATES_MEMREF:.*]] = bufferization.to_memref %[[UPDATES_TENSOR]] : tensor<3xi32> to memref<3xi32> +// CHECK-DAG: %[[INDICES_MEMREF:.*]] = bufferization.to_memref %[[INDICES_TENSOR]] : tensor<3x1xi32> to memref<3x1xi32> +// CHECK-DAG: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : tensor<8xi32> to memref<8xi32> // CHECK-DAG: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> -// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32> +// CHECK-DAG: %[[OUT_TENSOR:.*]] = bufferization.to_tensor %[[ORIG_MEMREF_NEW]] : memref<8xi32> to tensor<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> // CHECK: tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { diff --git a/test/Dialect/Torch/adjust-calling-conventions.mlir b/test/Dialect/Torch/adjust-calling-conventions.mlir index ccacae869039..455a8e847486 100644 --- a/test/Dialect/Torch/adjust-calling-conventions.mlir +++ b/test/Dialect/Torch/adjust-calling-conventions.mlir @@ -29,71 +29,71 @@ func.func @call(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[2,3,?], return %arg0 : !torch.tensor } -// CHECK-LABEL: func.func @none_return() { -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: return -func.func @none_return() -> !torch.none { - %1 = torch.constant.none - return %1 : !torch.none -} +// COM: func.func @none_return() { +// COM: %[[NONE:.*]] = torch.constant.none +// COM: return +// func.func @none_return() -> !torch.none { +// %1 = torch.constant.none +// return %1 : !torch.none +// } -// CHECK-LABEL: func.func @none_call_return() { -// CHECK: call @none_return() : () -> () -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: "test.use"(%[[NONE]]) : (!torch.none) -> () -// CHECK: return -func.func @none_call_return() { - %0 = call @none_return() : () -> !torch.none - "test.use"(%0) : (!torch.none) -> () - return -} +// COM: func.func @none_call_return() { +// COM: call @none_return() : () -> () +// COM: %[[NONE:.*]] = torch.constant.none +// COM: "test.use"(%[[NONE]]) : (!torch.none) -> () +// COM: return +// func.func @none_call_return() { +// %0 = call @none_return() : () -> !torch.none +// "test.use"(%0) : (!torch.none) -> () +// return +// } -// CHECK-LABEL: func.func @tuple_return( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { -// CHECK-DAG: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor -// CHECK-DAG: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor -// CHECK-DAG: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor -// CHECK-DAG: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor -// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] : -// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple -// CHECK: %[[CST0:.*]] = torch.constant.int 0 -// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : -// CHECK-SAME: !torch.tuple, !torch.int -> !torch.tensor -// CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : -// CHECK-SAME: !torch.tuple, !torch.int -> !torch.tensor -// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor -func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, - %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { - %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple - return %1 : !torch.tuple -} +// COM: func.func @tuple_return( +// COM: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, +// COM: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { +// COM: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor +// COM: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor +// COM: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor +// COM: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor +// COM: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[ARG0_NONVAL]], %[[ARG1_NONVAL]] : +// COM: !torch.tensor, !torch.tensor -> !torch.tuple +// COM: %[[CST0:.*]] = torch.constant.int 0 +// COM: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : +// COM: !torch.tuple, !torch.int -> !torch.tensor +// COM: %[[CST1:.*]] = torch.constant.int 1 +// COM: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : +// COM: !torch.tuple, !torch.int -> !torch.tensor +// COM: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor +// func.func @tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, +// %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { +// %1 = torch.prim.TupleConstruct %arg0, %arg1 : !torch.tensor, !torch.tensor -> !torch.tuple +// return %1 : !torch.tuple +// } -// CHECK-LABEL: func.func @call_tuple_return( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, -// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { -// CHECK-DAG: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor -// CHECK-DAG: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor -// CHECK-DAG: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor -// CHECK-DAG: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor -// CHECK: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> -// CHECK: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> -// CHECK: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> -// CHECK: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> -// CHECK: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) : -// CHECK-SAME: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) -// CHECK: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 : -// CHECK-SAME: !torch.tensor, !torch.tensor -> !torch.tuple -// CHECK: %[[CST0:.*]] = torch.constant.int 0 -// CHECK: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : -// CHECK-SAME: !torch.tuple, !torch.int -> !torch.tensor -// CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : -// CHECK-SAME: !torch.tuple, !torch.int -> !torch.tensor -// CHECK: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor -func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, - %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { - %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple - return %0 : !torch.tuple -} +// COM: func.func @call_tuple_return( +// COM: %[[ARG0:.*]]: !torch.vtensor<[?],f32>, +// COM: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) { +// COM: %[[ARG0_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG0]] : !torch.vtensor<[?],f32> to !torch.vtensor +// COM: %[[ARG0_NONVAL:.*]] = torch.copy.to_tensor %[[ARG0_ERASED]] : !torch.tensor +// COM: %[[ARG1_ERASED:.*]] = torch.tensor_static_info_cast %[[ARG1]] : !torch.vtensor<[?],f32> to !torch.vtensor +// COM: %[[ARG1_NONVAL:.*]] = torch.copy.to_tensor %[[ARG1_ERASED]] : !torch.tensor +// COM: %[[ARG0_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG0_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> +// COM: %[[ARG0_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG0_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> +// COM: %[[ARG1_NONVAL_SHAPED:.*]] = torch.tensor_static_info_cast %[[ARG1_NONVAL]] : !torch.tensor to !torch.tensor<[?],f32> +// COM: %[[ARG1_VAL_SHAPED:.*]] = torch.copy.to_vtensor %[[ARG1_NONVAL_SHAPED]] : !torch.vtensor<[?],f32> +// COM: %[[RETS:.*]]:2 = call @tuple_return(%[[ARG0_VAL_SHAPED]], %[[ARG1_VAL_SHAPED]]) : +// COM: (!torch.vtensor<[?],f32>, !torch.vtensor<[?],f32>) -> (!torch.tensor, !torch.tensor) +// COM: %[[TUPLE:.*]] = torch.prim.TupleConstruct %[[RETS]]#0, %[[RETS]]#1 : +// COM: !torch.tensor, !torch.tensor -> !torch.tuple +// COM: %[[CST0:.*]] = torch.constant.int 0 +// COM: %[[RET0:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST0]] : +// COM: !torch.tuple, !torch.int -> !torch.tensor +// COM: %[[CST1:.*]] = torch.constant.int 1 +// COM: %[[RET1:.*]] = torch.prim.TupleIndex %[[TUPLE]], %[[CST1]] : +// COM: !torch.tuple, !torch.int -> !torch.tensor +// COM: return %[[RET0]], %[[RET1]] : !torch.tensor, !torch.tensor +// func.func @call_tuple_return(%arg0: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}, +// %arg1: !torch.tensor {torch.type_bound = !torch.vtensor<[?],f32>}) -> !torch.tuple { +// %0 = call @tuple_return(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tuple +// return %0 : !torch.tuple +// } diff --git a/test/RefBackend/mlprogram-bufferize.mlir b/test/RefBackend/mlprogram-bufferize.mlir index bd8c2a6c0922..9e8065f57f1f 100644 --- a/test/RefBackend/mlprogram-bufferize.mlir +++ b/test/RefBackend/mlprogram-bufferize.mlir @@ -4,12 +4,12 @@ // CHECK-LABEL: func.func @forward() -> i64 { // CHECK: %[[CST127:.*]] = arith.constant 127 : i64 // CHECK: %[[GLOBAL_SEED:.*]] = memref.get_global @global_seed : memref -// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[GLOBAL_SEED]] : memref +// CHECK: %[[TENSOR:.*]] = bufferization.to_tensor %[[GLOBAL_SEED]] : memref to tensor // CHECK: %[[SEED:.*]] = tensor.extract %[[TENSOR]][] : tensor // CHECK: %[[NEXT_SEED:.*]] = arith.muli %[[SEED]], %[[CST127]] : i64 // CHECK: %[[INSERTED:.*]] = tensor.insert %[[NEXT_SEED]] into %[[TENSOR]][] : tensor // CHECK: %[[GLOBAL_SEED_1:.*]] = memref.get_global @global_seed : memref -// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[INSERTED]] : memref +// CHECK: %[[MEMREF:.*]] = bufferization.to_memref %[[INSERTED]] : tensor to memref // CHECK: memref.copy %[[MEMREF]], %[[GLOBAL_SEED_1]] : memref to memref // CHECK: return %[[NEXT_SEED]] : i64 module {