diff --git a/docs/Dialects/zhigh.md b/docs/Dialects/zhigh.md index dd87eeecf5..8aa2197c08 100644 --- a/docs/Dialects/zhigh.md +++ b/docs/Dialects/zhigh.md @@ -782,13 +782,6 @@ Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterfac Effects: `MemoryEffects::Effect{}` -#### Attributes: - - - - -
AttributeMLIR TypeDescription
op_type::mlir::StringAttrstring attribute
- #### Operands: | Operand | Description | @@ -814,13 +807,6 @@ Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterfac Effects: `MemoryEffects::Effect{}` -#### Attributes: - - - - -
AttributeMLIR TypeDescription
op_type::mlir::StringAttrstring attribute
- #### Operands: | Operand | Description | @@ -857,6 +843,40 @@ Effects: `MemoryEffects::Effect{}` | :----: | ----------- | | `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH +### `zhigh.Reshape` (::onnx_mlir::zhigh::ZHighReshapeOp) + +_ZHigh Reshape operation for Z Tensors_ + +ZHigh operation to perform a converts a Z Tensor from one type to an equivalent type +with a provided shape. The data is never copied or modified. When no layout is specified, +the output preserve the same layout as the source input. + +Traits: `AlwaysSpeculatableImplTrait` + +Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface` + +Effects: `MemoryEffects::Effect{}` + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `source` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH +| `shape` | tensor of 64-bit signless integer values + +#### Results: + +| Result | Description | +| :----: | ----------- | +| `result` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH + ### `zhigh.Sigmoid` (::onnx_mlir::zhigh::ZHighSigmoidOp) _ZHigh Sigmoid operation_ diff --git a/docs/Dialects/zlow.md b/docs/Dialects/zlow.md index 7be1c6457b..e97ece5313 100644 --- a/docs/Dialects/zlow.md +++ b/docs/Dialects/zlow.md @@ -770,7 +770,6 @@ Traits: `MemRefsNormalizable` -
AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
op_type::mlir::StringAttrstring attribute
#### Operands: @@ -795,7 +794,6 @@ Traits: `MemRefsNormalizable` -
AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
op_type::mlir::StringAttrstring attribute
#### Operands: @@ -832,6 +830,29 @@ Interfaces: `MemoryEffectOpInterface` | `shape` | memref of 64-bit signless integer values | `Out` | memref of dlfloat16 type values +### `zlow.reshape` (::onnx_mlir::zlow::ZLowReshapeOp) + +_ZLow Reshape operation_ + +ZLow operation to perform a reshape (no data movement). + +Traits: `MemRefsNormalizable` + +#### Attributes: + + + + +
AttributeMLIR TypeDescription
layout::mlir::StringAttrstring attribute
+ +#### Operands: + +| Operand | Description | +| :-----: | ----------- | +| `X` | memref of dlfloat16 type values +| `shape` | memref of 64-bit signless integer values +| `Out` | memref of dlfloat16 type values + ### `zlow.sigmoid` (::onnx_mlir::zlow::ZLowSigmoidOp) _ZLow sigmoid operation_ diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td index 2173000382..05a373f9b5 100644 --- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td +++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td @@ -95,17 +95,6 @@ def replaceONNXBatchNormalizationInferenceModePattern : Pattern< // //===----------------------------------------------------------------------===// - -// Create an ONNX Shape Op with type -def CreateShapeOp: NativeCodeCall< - "$_builder.create($_loc, $0, $1, IntegerAttr(), 0)" ->; - -// Get a type for a tensor that stores the shape of another tensor. -def GetShapeTypeOf: NativeCodeCall< - "RankedTensorType::get({mlir::cast($0.getType()).getRank()}, $_builder.getIntegerType(64))" ->; - // Check unidirectional broadcasting from the first to second tensor. def IsUniBroadcastingFromFirstToSecond: Constraint< CPred<"isUniBroadcatableFirstToSecond($0, $1)">, diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp index 2cdc850e02..1ed8f6e85b 100644 --- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp +++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp @@ -1093,6 +1093,47 @@ struct ZHighToZLowUnaryOpLowering : public ConversionPattern { } }; +// Reshape operation. Code similar to unary lowering, except that we use the +// operation's specialized shape here. +struct ZHighToZLowReshapeOpLowering : public ConversionPattern { + ZHighToZLowReshapeOpLowering(TypeConverter &typeConverter, MLIRContext *ctx) + : ConversionPattern(ZHighReshapeOp::getOperationName(), 1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { + Location loc = op->getLoc(); + Value input = operands[0]; + + // Helper builders. + MultiDialectBuilder create(rewriter, loc); + + // Convert ZTensor type to MemRefType. + ZMemRefType zMemRefType = + convertZTensorToMemRefType(*op->result_type_begin()); + + // Shape helper. + ZHighReshapeOpShapeHelper shapeHelper(op, operands, &create.krnlIE); + shapeHelper.computeShapeAndAssertOnFailure(); + SmallVector &dims = shapeHelper.getOutputDims(); + + // Allocate a buffer for the result MemRef. Follow this pattern to be + // similar to all the other zlow patterns. Will remove the alloc when + // lowering zlow.reshape to memref.reinterpret_cast once memrefs are + // normalized. See code in ReshapeToReinterpretCastPattern. + Value alloc = insertAllocForZMemRef(zMemRefType, dims, op, rewriter); + + // Note, we do not need to save the shape of the original operation, as this + // reshape is "no-op" that logically reorganize the shape of the operation + // into 2 equivalent shapes under their given layout. + + // Emit a ZLow operation. + rewriter.create( + loc, input, /* shape,*/ alloc, zMemRefType.layout); + rewriter.replaceOp(op, alloc); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Lower ZHigh ReduceMax/ReduceMin to ZLow ReduceMax/ReduceMin //===----------------------------------------------------------------------===// @@ -1117,8 +1158,6 @@ struct ZHighToZLowReduceOpLowering : public ConversionPattern { : ConversionPattern(OP_TYPE::getOperationName(), 1, ctx) {} LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const final { - MLIRContext *context = rewriter.getContext(); - OP_TYPE reduceOp = mlir::cast(op); Location loc = op->getLoc(); Value data = operands[0]; @@ -2285,6 +2324,8 @@ void populateZHighToZLowConversionPattern(mlir::RewritePatternSet &patterns, patterns.insert>(typeConverter, ctx); patterns.insert>( typeConverter, ctx); + // Reshape operations. + patterns.insert(typeConverter, ctx); // Neural network operations. patterns.insert>( typeConverter, ctx); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt b/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt index 2c3e6ca953..7d8fa855de 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt +++ b/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt @@ -26,6 +26,7 @@ add_onnx_mlir_library(OMZHighOps ZHighOps/QuantizedMatMul/QuantizedMatMul.cpp ZHighOps/QuantizedStick/QuantizedStick.cpp ZHighOps/Reduction/Reduction.cpp + ZHighOps/Reshape/Reshape.cpp ZHighOps/Softmax/Softmax.cpp ZHighOps/Stick/Stick.cpp ZHighOps/StickForGRU/StickForGRU.cpp diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td b/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td index 7bbcd02c87..cbc73dad18 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td @@ -1195,4 +1195,28 @@ def ZHighFixGRUYhOp:ZHigh_Op<"FixGRUYh", [Pure, }]; } +def ZHighReshapeOp:ZHigh_Op<"Reshape", [Pure, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "ZHigh Reshape operation for Z Tensors"; + let description = [{ + ZHigh operation to perform a converts a Z Tensor from one type to an equivalent type + with a provided shape. The data is never copied or modified. When no layout is specified, + the output preserve the same layout as the source input. + }]; + let arguments = (ins AnyTypeOf<[AnyZTensor]>:$source, // Input Z Tensor to be reshaped. + TensorOf<[I64]>:$shape, // Shape of output Z Tensor. + OptionalAttr:$layout); // Layout of output Z Tensor, default same as input. + let results = (outs AnyTypeOf<[AnyZTensor]>:$result); + + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * ZHighReshapeOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new ZHighReshapeOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; +} + #endif // ZHIGH_OPS diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp index c120bc6b44..926d08913c 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp @@ -218,7 +218,7 @@ ZTensorEncodingAttr::QuantizedType getZTensorQuantizedType(Type type) { // Utility functions. Value getMinusBcastConst( - mlir::OpBuilder &builder, Location loc, FloatAttr floatAttr, Value X) { + OpBuilder &builder, Location loc, FloatAttr floatAttr, Value X) { ShapedType xType = mlir::cast(X.getType()); assert(xType.hasStaticShape() && "expected static shape"); float val = floatAttr.getValueAsDouble() * -1.0; @@ -283,14 +283,14 @@ bool isTiling2DTo4D(Value val) { if (!(inputShape.size() == 2 && outputShape.size() == 4)) return false; - // Tiling over each input dimension. + // Tiling over each input dimension. Assume here that the dims are static. return ((inputShape[0] == outputShape[0] * outputShape[1]) && (inputShape[1] == outputShape[2] * outputShape[3])); } /// Check if ONNXReshapeOp is reshaping 3D to 4D by tiling the first input /// dimension. -bool isTiling3DTo4D(Value val) { +bool isLeftmostTiling3DTo4D(Value val) { auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); if (!reshapeOp) return false; @@ -312,12 +312,50 @@ bool isTiling3DTo4D(Value val) { if (!(inputShape.size() == 3 && outputShape.size() == 4)) return false; - // Tiling over each input dimension. + // Tiling over each input dimension. Assume here that the dims are static. return ((inputShape[0] == outputShape[0] * outputShape[1]) && (inputShape[1] == outputShape[2]) && (inputShape[2] == outputShape[3])); } +/// Check if ONNXReshapeOp is reshaping 3D to 4D by tiling the last input +/// dimension. If tilingSize>0, then check that it is tiling by that amount (or +/// a multiple thereof). +bool isRightmostTiling3DTo4D(Value val, int64_t tilingSize) { + auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); + if (!reshapeOp) + return false; + + Value input = reshapeOp.getData(); + Value output = reshapeOp.getReshaped(); + Type inputType = input.getType(); + Type outputType = output.getType(); + + if (!isRankedShapedType(inputType)) + return false; + if (!isRankedShapedType(outputType)) + return false; + + ArrayRef inputShape = getShape(inputType); + ArrayRef outputShape = getShape(outputType); + + // Not reshape from 3D to 4D. + if (!(inputShape.size() == 3 && outputShape.size() == 4)) + return false; + + // Check that the tiling size is given, then the last dim of the output is + // statically determined and is a multiples of tiling size. + if (tilingSize > 0) + if (ShapedType::isDynamic(outputShape[3]) || + (outputShape[3] % tilingSize != 0)) + return false; + + // Tiling over each input dimension. Assume here that the dims are static. + return ((inputShape[0] == outputShape[0]) && + (inputShape[1] == outputShape[1]) && + (inputShape[2] == outputShape[2] * outputShape[3])); +} + /// Check if a 4D tensor is collapsed into 2D by merging the each two /// dimensions. bool isCollapsing4DTo2D(Value val) { @@ -342,14 +380,15 @@ bool isCollapsing4DTo2D(Value val) { if (!(inputShape.size() == 4 && outputShape.size() == 2)) return false; - // Collapsing by merging the first two dimensions. + // Collapsing by merging the first two dimensions. Assume here that the dims + // are static. return ((inputShape[0] * inputShape[1] == outputShape[0]) && (inputShape[2] * inputShape[3] == outputShape[1])); } /// Check if a 4D tensor is collapsed into 3D by merging the first two -/// dimensions. -bool isCollapsing4DTo3D(Value val) { +/// (leftmost) dimensions. +bool isLeftmostCollapsing4DTo3D(Value val) { auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); if (!reshapeOp) return false; @@ -371,12 +410,44 @@ bool isCollapsing4DTo3D(Value val) { if (!(inputShape.size() == 4 && outputShape.size() == 3)) return false; - // Collapsing by merging the first two dimensions. + // Collapsing by merging the first two dimensions. Assume here that the dims + // are static. return ((inputShape[0] * inputShape[1] == outputShape[0]) && (inputShape[2] == outputShape[1]) && (inputShape[3] == outputShape[2])); } +/// Check if a 4D tensor is collapsed into 3D by merging the last two +/// (rightmost) dimensions. +bool isRightmostCollapsing4DTo3D(Value val) { + auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); + if (!reshapeOp) + return false; + + Value input = reshapeOp.getData(); + Value output = reshapeOp.getReshaped(); + Type inputType = input.getType(); + Type outputType = output.getType(); + + if (!isRankedShapedType(inputType)) + return false; + if (!isRankedShapedType(outputType)) + return false; + + ArrayRef inputShape = getShape(inputType); + ArrayRef outputShape = getShape(outputType); + + // Not reshape from 4D to 3D. + if (!(inputShape.size() == 4 && outputShape.size() == 3)) + return false; + + // Collapsing by merging the first two dimensions. Assume here that the dims + // are static. + return ((inputShape[0] == outputShape[0]) && + (inputShape[1] == outputShape[1]) && + (inputShape[2] * inputShape[3] == outputShape[2])); +} + AffineMapAttr getTiling2DTo4DMap(OpBuilder &b, Value val) { assert(isTiling2DTo4D(val) && "ONNXReshapeOp is not suitable for getting a tiling affine map"); @@ -402,8 +473,8 @@ AffineMapAttr getTiling2DTo4DMap(OpBuilder &b, Value val) { return AffineMapAttr::get(map); } -AffineMapAttr getTiling3DTo4DMap(OpBuilder &b, Value val) { - assert(isTiling3DTo4D(val) && +AffineMapAttr getLeftmostTiling3DTo4DMap(OpBuilder &b, Value val) { + assert(isLeftmostTiling3DTo4D(val) && "ONNXReshapeOp is not suitable for getting a tiling affine map"); auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); @@ -450,8 +521,8 @@ AffineMapAttr getCollapsing4DTo2DMap(OpBuilder &b, Value val) { return AffineMapAttr::get(map); } -AffineMapAttr getCollapsing4DTo3DMap(OpBuilder &b, Value val) { - assert(isCollapsing4DTo3D(val) && +AffineMapAttr getLeftmostCollapsing4DTo3DMap(OpBuilder &b, Value val) { + assert(isLeftmostCollapsing4DTo3D(val) && "ONNXReshapeOp is not suitable for getting a collapsing affine map"); auto reshapeOp = mlir::dyn_cast(val.getDefiningOp()); @@ -484,6 +555,44 @@ AffineMapAttr getTransposeMap(OpBuilder &b, ArrayAttr permAttr) { return AffineMapAttr::get(map); } +/// Check the values of a transpose map to be equal to the permVals. +bool isTransposePermutationEqualTo( + ArrayAttr permAttr, mlir::ArrayRef permVals) { + // Check same rank. + int64_t permAttrSize = ArrayAttrSize(permAttr); + int64_t permValSize = permVals.size(); + if (permAttrSize != permValSize) + return false; + // Check same values; abort on failure. + for (int64_t i = 0; i < permAttrSize; ++i) { + int64_t v = ArrayAttrIntVal(permAttr, i); + if (permVals[i] != v) + return false; + } + // Identical, success. + return true; +} + +bool isShapeDimMultipleOf(Value val, int64_t index, int64_t multipleVal) { + // Type must be shaped and ranked. + Type type = val.getType(); + if (!isRankedShapedType(type)) + return false; + // Index must be within bounds of the shape rank; negative is from back. + ArrayRef shape = getShape(type); + int64_t size = shape.size(); + if (index < 0) + index += size; + if (index < 0 || index >= size) + return false; + // At this time, only reason about static shapes. + int64_t dim = shape[index]; + if (ShapedType::isDynamic(dim)) + return false; + // All good now, check if dim is a multiple of "multipleVal." + return dim % multipleVal == 0; +} + IntegerAttr getAxisNHWC(IntegerAttr axisNCHWAttr) { int64_t axisNCHW = axisNCHWAttr.getSInt(); int64_t axisNHWC; diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp index cc346ef17d..c37d2af544 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp @@ -74,22 +74,33 @@ mlir::Value getConstantOfType( bool oneIsOfLayout( mlir::Type t1, mlir::Type t2, ZTensorEncodingAttr::DataLayout layout); -/// Check if ONNXReshapeOp is reshaping 2D/3D to 4D by tiling each input +/// Check if ONNXReshapeOp is reshaping 2D/3D to 4D by tiling an input /// dimension. bool isTiling2DTo4D(mlir::Value val); mlir::AffineMapAttr getTiling2DTo4DMap(mlir::OpBuilder &b, mlir::Value val); -bool isTiling3DTo4D(mlir::Value val); -mlir::AffineMapAttr getTiling3DTo4DMap(mlir::OpBuilder &b, mlir::Value val); -/// Check if ONNXReshapeOp is collapsing 4D into 3D/2D by merging the first two -/// dimensions. -bool isCollapsing4DTo3D(mlir::Value val); -mlir::AffineMapAttr getCollapsing4DTo3DMap(mlir::OpBuilder &b, mlir::Value val); +bool isLeftmostTiling3DTo4D(mlir::Value val); +bool isRightmostTiling3DTo4D(mlir::Value val, int64_t tilingSize); +mlir::AffineMapAttr getLeftmostTiling3DTo4DMap( + mlir::OpBuilder &b, mlir::Value val); +/// Check if ONNXReshapeOp is collapsing 4D into 3D by merging the first two +/// (leftmost) dimensions. +bool isLeftmostCollapsing4DTo3D(mlir::Value val); +/// Check if ONNXReshapeOp is collapsing 4D into 3D by merging the last two +/// (rightmost) dimensions. +bool isRightmostCollapsing4DTo3D(mlir::Value val); +mlir::AffineMapAttr getLeftmostCollapsing4DTo3DMap( + mlir::OpBuilder &b, mlir::Value val); bool isCollapsing4DTo2D(mlir::Value val); mlir::AffineMapAttr getCollapsing4DTo2DMap(mlir::OpBuilder &b, mlir::Value val); /// Get an affine map for the permutation array. mlir::AffineMapAttr getTransposeMap( mlir::OpBuilder &b, mlir::ArrayAttr permAttr); - +/// Check the values of a transpose map to be equal to the permVals. +bool isTransposePermutationEqualTo( + mlir::ArrayAttr permAttr, mlir::ArrayRef permVals); +/// Return true when shape(Value)[index] % multipleVal == 0. +/// Negative indices, count from the back (-1 is last element). +bool isShapeDimMultipleOf(mlir::Value val, int64_t index, int64_t multipleVal); /// Get an axis for NHWC layout given an axis for NCHW layout. mlir::IntegerAttr getAxisNHWC(mlir::IntegerAttr axisNCHWAttr); diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td index d3321a6464..f046029b13 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td @@ -39,6 +39,16 @@ def NotSameLayout: Constraint< def IsNoneType : Constraint(($_self).getType())">>; +// Create an ONNX Shape Op with type +def CreateShapeOp: NativeCodeCall< + "$_builder.create($_loc, $0, $1, IntegerAttr(), 0)" +>; + +// Get a type for a tensor that stores the shape of another tensor. +def GetShapeTypeOf: NativeCodeCall< + "RankedTensorType::get({mlir::cast($0.getType()).getRank()}, $_builder.getIntegerType(64))" +>; + def GetLayout : NativeCodeCall< "::onnx_mlir::zhigh::convertZTensorDataLayoutToStringAttr($_builder, " "::onnx_mlir::zhigh::getZTensorLayout($0.getType()))" @@ -144,10 +154,16 @@ def IsTiling2DTo4D : Constraint< "Is tiling by ONNReshapeOp" >; -/// Check if ONNXReshapeOp is reshaping 3D to 4D by tiling each input dimension. -def IsTiling3DTo4D : Constraint< - CPred<"::onnx_mlir::zhigh::isTiling3DTo4D($0)">, - "Is tiling by ONNReshapeOp" +/// Check if ONNXReshapeOp is reshaping 3D to 4D by tiling leftmost input dimension. +def IsLeftmostTiling3DTo4D : Constraint< + CPred<"::onnx_mlir::zhigh::isLeftmostTiling3DTo4D($0)">, + "Is leftmost tiling by ONNReshapeOp" +>; + +/// Check if ONNXReshapeOp is reshaping 3D to 4D by tiling rightmost input dimension. +def IsRightmostTiling3DTo4DBy64 : Constraint< + CPred<"::onnx_mlir::zhigh::isRightmostTiling3DTo4D($0, 64)">, + "Is rightmost tiling of size 64 by ONNReshapeOp" >; /// Check if ONNXReshapeOp is reshaping 4D to 2D by collapsing the first two input dimensions. @@ -157,9 +173,15 @@ def IsCollapsing4DTo2D : Constraint< >; /// Check if ONNXReshapeOp is reshaping 4D to 3D by collapsing the first two input dimensions. -def IsCollapsing4DTo3D : Constraint< - CPred<"::onnx_mlir::zhigh::isCollapsing4DTo3D($0)">, - "Is collapsing by ONNXReshapeOp" +def IsLeftmostCollapsing4DTo3D : Constraint< + CPred<"::onnx_mlir::zhigh::isLeftmostCollapsing4DTo3D($0)">, + "Is leftmost collapsing by ONNXReshapeOp" +>; + +/// Check if ONNXReshapeOp is reshaping 4D to 3D by collapsing the last two input dimensions. +def IsRightmostCollapsing4DTo3D : Constraint< + CPred<"::onnx_mlir::zhigh::isRightmostCollapsing4DTo3D($0)">, + "Is rightmost collapsing by ONNXReshapeOp" >; def GetResultType : NativeCodeCall< @@ -170,22 +192,37 @@ def GetTiling2DTo4DMap : NativeCodeCall< "::onnx_mlir::zhigh::getTiling2DTo4DMap($_builder, $0)" >; -def GetTiling3DTo4DMap : NativeCodeCall< - "::onnx_mlir::zhigh::getTiling3DTo4DMap($_builder, $0)" +def GetLeftmostTiling3DTo4DMap : NativeCodeCall< + "::onnx_mlir::zhigh::getLeftmostTiling3DTo4DMap($_builder, $0)" >; def GetCollapsing4DTo2DMap: NativeCodeCall< "::onnx_mlir::zhigh::getCollapsing4DTo2DMap($_builder, $0)" >; -def GetCollapsing4DTo3DMap: NativeCodeCall< - "::onnx_mlir::zhigh::getCollapsing4DTo3DMap($_builder, $0)" +def GetLeftmostCollapsing4DTo3DMap: NativeCodeCall< + "::onnx_mlir::zhigh::getLeftmostCollapsing4DTo3DMap($_builder, $0)" >; def GetTransposeMap : NativeCodeCall< "::onnx_mlir::zhigh::getTransposeMap($_builder, $0)" >; +def Is4DTransposePermutationEqualTo0213 : Constraint< + CPred<"::onnx_mlir::zhigh::isTransposePermutationEqualTo($0, {0, 2, 1, 3})">, + "Is 4D Transpose with pattern (0, 2, 1, 3)" +>; + +class IsShapeDimMultipleOf32 : Constraint< + CPred<"::onnx_mlir::zhigh::isShapeDimMultipleOf($0, " # index # ", 32)">, + "The operand shape at given index is a multiple of 32" +>; + +class IsShapeDimMultipleOf64 : Constraint< + CPred<"::onnx_mlir::zhigh::isShapeDimMultipleOf($0, " # index # ", 64)">, + "The operand shape at given index is a multiple of 64" +>; + def IsIdentityAffineMap : Constraint< CPred<"$_self.isIdentity()">, "Is identity AffineMap" diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Reshape/Reshape.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Reshape/Reshape.cpp new file mode 100644 index 0000000000..e1affe1a18 --- /dev/null +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Reshape/Reshape.cpp @@ -0,0 +1,70 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===------------------ Reshape.cpp - ZHigh Operations +//---------------------===// +// +// Copyright 2025 The IBM Research Authors. +// +// ============================================================================= +// +// +//===----------------------------------------------------------------------===// + +#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp" + +using namespace mlir; +using namespace onnx_mlir; + +namespace onnx_mlir { +namespace zhigh { + +//===----------------------------------------------------------------------===// +// ShapeHelper +//===----------------------------------------------------------------------===// + +LogicalResult ZHighReshapeOpShapeHelper::computeShape() { + ZHighReshapeOpAdaptor operandAdaptor(operands); + + // Shape has the dimensions of the output. + DimsExpr outputDims; + createIE->getIntFromArrayAsSymbols(operandAdaptor.getShape(), outputDims); + setOutputDims(outputDims); + return success(); +} + +//===----------------------------------------------------------------------===// +// Shape inference +//===----------------------------------------------------------------------===// + +// Builder + +LogicalResult ZHighReshapeOp::inferShapes( + std::function doShapeInference) { + Value source = getSource(); + if (!hasRankedType(source)) + return success(); + // Output type has the same type as the input/source type. + RankedTensorType sourceType = + mlir::dyn_cast(source.getType()); + Type elementType = sourceType.getElementType(); + // Get encoding + StringAttr layout = getLayoutAttr(); + ZTensorEncodingAttr::DataLayout dataLayout; + Attribute encoding; + if (layout) { + // Operation has an optional output layout, use it. + dataLayout = convertStringAttrToZTensorDataLayout(layout); + encoding = ZTensorEncodingAttr::get(this->getContext(), dataLayout); + } else { + // Operation does not have an optional output layout, reuse it from input. + encoding = sourceType.getEncoding(); + } + + ZHighReshapeOpShapeHelper shapeHelper(getOperation()); + return shapeHelper.computeShapeAndUpdateType(elementType, encoding); +} + +} // namespace zhigh +} // namespace onnx_mlir diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp index f9427116af..84282a9b1b 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp @@ -56,6 +56,7 @@ DECLARE_SHAPE_HELPER_ZHIGH(ZHighStickifiedConstantOfShapeOpShapeHelper) DECLARE_SHAPE_HELPER_ZHIGH(ZHighStickOpShapeHelper) DECLARE_SHAPE_HELPER_ZHIGH(ZHighQuantizedStickOpShapeHelper) DECLARE_SHAPE_HELPER_ZHIGH(ZHighUnstickOpShapeHelper) +DECLARE_SHAPE_HELPER_ZHIGH(ZHighReshapeOpShapeHelper) #undef DECLARE_SHAPE_HELPER_ZHIGH //===----------------------------------------------------------------------===// diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp index 179cfe139b..780411cadb 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp @@ -12,6 +12,7 @@ //===----------------------------------------------------------------------===// #include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp" +#include "src/Compiler/CompilerOptions.hpp" using namespace mlir; using namespace onnx_mlir; @@ -142,6 +143,8 @@ void ZHighStickOp::getCanonicalizationPatterns( results.insert(context); results.insert(context); results.insert(context); + results.insert(context); + results.insert(context); } } // namespace zhigh diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/ZHighStick.td b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/ZHighStick.td index e612f788d6..69fda996e0 100644 --- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/ZHighStick.td +++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/ZHighStick.td @@ -105,7 +105,7 @@ def ReplaceONNXSoftplusPattern: Pattern< (ZHighStickOp:$stickout (ONNXSoftplusOp:$out (ZHighUnstickOp $X)), $layout, $_), [ // Get stickified constant of minus one with input shape. - // Donot saturate since input orignally from NNPA. + // Donot saturate since input originally from NNPA. (ZHighStickOp:$minusOne (GetConstantOfType<"-1.0"> $out), $layout, (NoneIntegerAttr)), // Get minus X with input shape. (ZHighMulOp:$minusX $X, $minusOne, (returnType $X)), @@ -122,7 +122,7 @@ def ReplaceONNXSoftplusPattern: Pattern< [(IsStaticShapeTensor $X), (SameLayout $X, $stickout)] >; -// Calulation of `1/sqrt(X)` or reciprocal square root is often found in +// Calculation of `1/sqrt(X)` or reciprocal square root is often found in // deep learning models, but zDNN does not support it. Thus, we rewrite it into // zDNN-supported operations. // @@ -149,7 +149,94 @@ def ReplaceONNXReciprocalSqrtPattern: Pat< [(IsStaticShapeTensor $X), (SameLayout $X, $stick)] >; -// The folowing pattern was found in bertsquad and GPT models. +// The following pattern was found in Roberta models. +// ``` +// %66 = "zhigh.Unstick"(%65) : (tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<12x384x64xf32> +// %67 = "onnx.Reshape"(%66, %2) {allowzero = 0 : si64} : (tensor<12x384x64xf32>, tensor<4xi64>) -> tensor<1x12x384x64xf32> +// %68 = "onnx.Transpose"(%67) {onnx_node_name = "Transpose_94", perm = [0, 2, 1, 3]} : (tensor<1x12x384x64xf32>) -> tensor<1x384x12x64xf32> +// %69 = "onnx.Reshape"(%68, %9) {allowzero = 0 : si64, onnx_node_name = "Reshape_104"} : (tensor<1x384x12x64xf32>, tensor<3xi64>) -> tensor<1x384x768xf32> +// %70 = "zhigh.Stick"(%69) {layout = "3DS"} : (tensor<1x384x768xf32>) -> tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// ``` +// When the input tensor %65 (here tensor<12x384x64xf16) with dims E3xE2xE1 and 3DS layout satisfies the following properties +// 1) E2 % 32 == 0 +// 2) E1 % 64 == 0 +// Then we can guarantee that input value %65 has the same value +// in the same memory location (*) than output %70. +// (*) does not say the tensor are identical, just that they store +// the same value at the same byte offset in the memory. +// +// High level intuition for this is that the reshape / transpose / reshape +// perform some memory layout operations, which results in no change in the +// 3DS representation as that representation also perform layout changes. +// +// Limitation: current pattern assume static sizes. + +def ReshapeTransposeReshapeRoberta3DSWPattern1 : Pat< + // Input: X -> unstick -> reshape1 -> transpose -> reshape 2 -> stick. + (ZHighStickOp:$stick + (ONNXReshapeOp:$reshape2 + (ONNXTransposeOp:$transpose + (ONNXReshapeOp:$reshape1 + (ZHighUnstickOp:$unstick $X), + $shape1, $_), + $perm), + $shape2, $_), + $layout3DS, $saturation), + // Output: initial X value unchanged, but transformed with the new compatible shape. + (ZHighReshapeOp $X, (CreateShapeOp (GetShapeTypeOf $stick), $stick), (GetLayout $stick)), + // Conditions. + [(TensorHas3DSLayout $X), (Is3DSLayout $layout3DS), // Input/output are 3DS. + (IsStaticShapeTensor $X), (IsStaticShapeTensor $unstick), // Static shapes only. + (IsStaticShapeTensor $reshape1), (IsStaticShapeTensor $transpose), + (IsStaticShapeTensor $reshape2),(IsStaticShapeTensor $stick), + (IsShapeDimMultipleOf32<1> $X), // Second dim of input is a multiple of 32. + (IsShapeDimMultipleOf64<2> $X), // Third dim of input is a multiple of 64. + (Is4DTransposePermutationEqualTo0213 $perm), // Permute middle 2 dims. + (IsLeftmostTiling3DTo4D $reshape1), // 1st reshape is tiling in the leftmost dimension + (IsRightmostCollapsing4DTo3D $reshape2), // 2nd reshape is collapsing the last two dimensions. + ] +>; + +// Second pattern found in roberta +// +// %33 = "zhigh.Unstick"(%32) : (tensor<8x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<8x384x768xf32> // unstick +// %44 = "onnx.Reshape"(%33, %7) {allowzero = 0 : si64, onnx_node_name = "Reshape_64"} : (tensor<8x384x768xf32>, tensor<4xi64>) -> tensor<8x384x12x64xf32> // last goes from 768 to 12, 64 +// %45 = "onnx.Transpose"(%44) {onnx_node_name = "Transpose_65", perm = [0, 2, 1, 3]} : (tensor<8x384x12x64xf32>) -> tensor<8x12x384x64xf32> // same permute +// %50 = "onnx.Reshape"(%45, %2) {allowzero = 0 : si64} : (tensor<8x12x384x64xf32>, tensor<3xi64>) -> tensor<96x384x64xf32> // collapse first 2 dims +// %52 = "zhigh.Stick"(%50) {layout = "3DS"} : (tensor<96x384x64xf32>) -> tensor<96x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// +// Namely unstick -> reshape (by tiling rightmost one to E1/64, 64) -> permute (2nd & 3rd) -> reshape (collapse first 2) -> stick +// Shapes goes from 8 x 384 x (12*64 = 768) to (8*12 = 96) x 384 x 64 +// +// pattern works only when input E2 % 32 and E1 % 64 == 0 + +def ReshapeTransposeReshapeRoberta3DSWPattern2 : Pat< + // Input: X -> unstick -> reshape1 -> transpose -> reshape 2 -> stick. + (ZHighStickOp:$stick + (ONNXReshapeOp:$reshape2 + (ONNXTransposeOp:$transpose + (ONNXReshapeOp:$reshape1 + (ZHighUnstickOp:$unstick $X), + $shape1, $_), + $perm), + $shape2, $_), + $layout3DS, $saturation), + // Output: initial X value unchanged, but transformed with the compatible shape. + (ZHighReshapeOp $X, (CreateShapeOp (GetShapeTypeOf $stick), $stick), (GetLayout $stick)), + // Conditions. + [(TensorHas3DSLayout $X), (Is3DSLayout $layout3DS), // Input/output are 3DS. + (IsStaticShapeTensor $X), (IsStaticShapeTensor $unstick), // Static shapes only. + (IsStaticShapeTensor $reshape1), (IsStaticShapeTensor $transpose), + (IsStaticShapeTensor $reshape2),(IsStaticShapeTensor $stick), + (IsShapeDimMultipleOf32<1> $X), // Second dim of input is a multiple of 32. + (IsShapeDimMultipleOf64<2> $X), // Third dim of input is a multiple of 64. + (Is4DTransposePermutationEqualTo0213 $perm), // Permute middle 2 dims. + (IsRightmostTiling3DTo4DBy64 $reshape1), // 1st reshape is tiling by 64 the rightmost dimension + (IsLeftmostCollapsing4DTo3D $reshape2), // 2nd reshape is collapsing the first two dimensions. + ] +>; + +// The following pattern was found in bertsquad and GPT models. // ``` // %0 = "zhigh.Unstick"(%X) {layout = "2D"} : tensor> -> tensor // %1 = "onnx.Reshape"(%0) : tensor -> tensor @@ -186,7 +273,7 @@ def ReshapeTransposeReshape2DTo3DSPattern : Pat< (returnType (GetResultType $reshape1))), (GetTransposeMap $perm), (returnType (GetResultType $transpose))), - (GetCollapsing4DTo3DMap $reshape2), + (GetLeftmostCollapsing4DTo3DMap $reshape2), (returnType (GetResultType $reshape2))), $layout3DS, $saturation), [(TensorHas2DLayout $X), (Is3DSLayout $layout3DS), @@ -194,7 +281,7 @@ def ReshapeTransposeReshape2DTo3DSPattern : Pat< (IsStaticShapeTensor $reshape1), (IsStaticShapeTensor $transpose), (IsStaticShapeTensor $reshape2),(IsStaticShapeTensor $stick), (IsTiling2DTo4D $reshape1), // 1st reshape is tiling over each input dimension - (IsCollapsing4DTo3D $reshape2), // 2nd reshape is collapsing the first two dimensions. + (IsLeftmostCollapsing4DTo3D $reshape2), // 2nd reshape is collapsing the first two dimensions. ] >; @@ -214,7 +301,7 @@ def ReshapeTransposeReshape3DSTo2DPattern : Pat< (ONNXShapeTransformOp // transpose (ONNXShapeTransformOp // reshape (ZHighUnstickOp $X), - (GetTiling3DTo4DMap $reshape1), + (GetLeftmostTiling3DTo4DMap $reshape1), (returnType (GetResultType $reshape1))), (GetTransposeMap $perm), (returnType (GetResultType $transpose))), @@ -225,7 +312,7 @@ def ReshapeTransposeReshape3DSTo2DPattern : Pat< (IsStaticShapeTensor $X), (IsStaticShapeTensor $unstick), (IsStaticShapeTensor $reshape1), (IsStaticShapeTensor $transpose), (IsStaticShapeTensor $reshape2),(IsStaticShapeTensor $stick), - (IsTiling3DTo4D $reshape1), // 1st reshape is tiling over each input dimension + (IsLeftmostTiling3DTo4D $reshape1), // 1st reshape is tiling over each input dimension (IsCollapsing4DTo2D $reshape2), // 2nd reshape is collapsing the first two dimensions. ] >; diff --git a/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td b/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td index a66cb8273f..1ab8ae958d 100644 --- a/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td +++ b/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td @@ -687,4 +687,15 @@ def ZLowConvertF32ToDLF16VectorOp:ZLow_Op<"vec_f32_to_dlf16", [Pure]> { ]; } +def ZLowReshapeOp:ZLow_Op<"reshape", [MemRefsNormalizable]> { + let summary = "ZLow Reshape operation"; + let description = [{ + ZLow operation to perform a reshape (no data movement). + }]; + // Note that no shape is needed for this operation. + let arguments = (ins ZMemRef:$X, + ZMemRef:$Out, + StrAttr:$layout); +} + #endif // ZLOW_OPS diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp index f97373764a..2328f70265 100644 --- a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp +++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp @@ -25,6 +25,7 @@ #include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp" #include "src/Accelerators/NNPA/Support/LayoutHelper.hpp" #include "src/Dialect/Mlir/DialectBuilder.hpp" +#include "src/Dialect/Mlir/IndexExpr.hpp" #include @@ -33,6 +34,43 @@ using namespace mlir; namespace onnx_mlir { namespace zlow { +/// Transform the zlow.reshape into a memref.reinterpret_cast as this pass +/// operates after all memrefs are fully normalized, which is a requirement +/// here. + +class ReshapeToReinterpretCastPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite( + ZLowReshapeOp reshapeOp, PatternRewriter &rewriter) const override { + Location loc = reshapeOp.getLoc(); + MultiDialectBuilder create(rewriter, loc); + IndexExprScope currScope(&rewriter, loc); + // Here, cannot use the shape found in the reshape op, as it is the original + // shape before memref normalization. + Value input = reshapeOp.getX(); + Value output = reshapeOp.getOut(); + // Input must have no affine layout. In other words, it has been normalized. + if (hasNonIdentityLayout(input.getType()) || + hasNonIdentityLayout(output.getType())) { + return failure(); + } + Operation *outputAllocOp = output.getDefiningOp(); + ShapedType outputType = mlir::cast(output.getType()); + DimsExpr outputDims; + for (int64_t i = 0; i < (int64_t)outputType.getRank(); ++i) { + Value shape = create.mem.dim(output, i); + outputDims.emplace_back(DimIE(shape)); + } + Value reinterpretCast = create.mem.reinterpretCast(input, outputDims); + // Reshape is no longer needed. And instead of allocating data, we simply + // replace the alloc by the reinterpret_cast. + rewriter.eraseOp(reshapeOp); + rewriter.replaceOp(outputAllocOp, reinterpretCast); + return success(); + } +}; + /// Remove unstick if there is no use of its second operand except itself. class UnstickRemovalPattern : public OpRewritePattern { public: @@ -650,6 +688,7 @@ class ZLowRewritePass llvm::SmallDenseSet removableStickOps; ConversionTarget target(getContext()); RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); patterns.insert(&getContext()); diff --git a/src/Dialect/ONNX/DialectBuilder.cpp b/src/Dialect/ONNX/DialectBuilder.cpp index 672b4f6865..c4020a7c1d 100644 --- a/src/Dialect/ONNX/DialectBuilder.cpp +++ b/src/Dialect/ONNX/DialectBuilder.cpp @@ -350,6 +350,13 @@ Value OnnxBuilder::round(Value input, bool scalarType) const { toTensor(input.getType()), toTensor(input)); } +Value OnnxBuilder::shape(Value input) const { + int64_t rank = getRank(input.getType()); + Type outputType = RankedTensorType::get({rank}, b().getI64Type()); + return createTypedOpAndInferShapes( + toTensor(outputType), toTensor(input)); +} + Value OnnxBuilder::shape(Type outputType, Value input) const { return createTypedOpAndInferShapes( toTensor(outputType), toTensor(input)); diff --git a/src/Dialect/ONNX/DialectBuilder.hpp b/src/Dialect/ONNX/DialectBuilder.hpp index 7cec4a5b37..7712175963 100644 --- a/src/Dialect/ONNX/DialectBuilder.hpp +++ b/src/Dialect/ONNX/DialectBuilder.hpp @@ -184,6 +184,7 @@ struct OnnxBuilder : DialectBuilder { // ONNXShapeOp (start is inclusive, default 0; end is exclusive, default // nullptr means all) + mlir::Value shape(mlir::Value input) const; mlir::Value shape(mlir::Type outputType, mlir::Value input) const; mlir::Value shape( mlir::Type outputType, mlir::Value input, int64_t start) const; diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp index 02955d52e5..b82fa129b9 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.cpp @@ -113,19 +113,20 @@ static void refineDims(Operation *op, DimsExpr &inferredDims, Value output) { // inferredDim is different from existingDim. Believe in existingDim. assert(inferredDims[i].isLiteral() && "isLiteral failed"); if (existingDims[i] != inferredDims[i].getLiteral()) { - if (op) + if (op) { llvm::outs() << "\nWarning for operation " << op->getName() << ": [Shape inference, dim " << i << "] the inferred dim (" << inferredDims[i].getLiteral() << ") is different from the existing dim (" << existingDims[i] << "). Use the existing dim instead.\n\n"; - else + } else { llvm::outs() << "\nWarning: [Shape inference, dim " << i << "] the inferred dim (" << inferredDims[i].getLiteral() << ") is different from the existing dim (" << existingDims[i] << "). Use the existing dim instead.\n\n"; + } inferredDims[i] = LitIE(existingDims[i]); } } diff --git a/test/accelerators/NNPA/backend/CMakeLists.txt b/test/accelerators/NNPA/backend/CMakeLists.txt index 272be114d2..600962c21f 100644 --- a/test/accelerators/NNPA/backend/CMakeLists.txt +++ b/test/accelerators/NNPA/backend/CMakeLists.txt @@ -595,10 +595,9 @@ add_dependencies(check-onnx-backend-compilerlib-nnpa CompilerLibTest) add_dependencies(check-onnx-backend-compilerlib-nnpa PyRuntimeC) add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-nnpa) -# If on arch 15 machines then (TODO: enable once avail on test machines): +# TODO arch15: if (avail on test machines): +# In addition to testing arch14, also test arch15. # add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-arch15-nnpa) -# else while on an arch 14 machine: -add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-nnpa) # end if. add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-dynamic-nnpa) add_dependencies(check-onnx-backend-numerical-nnpa check-onnx-backend-constant-nnpa) diff --git a/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/reshape.mlir b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/reshape.mlir new file mode 100644 index 0000000000..d1aae63d56 --- /dev/null +++ b/test/mlir/accelerators/nnpa/conversion/zhigh-to-zlow/reshape.mlir @@ -0,0 +1,20 @@ +// RUN: onnx-mlir-opt --march=z16 --maccel=NNPA --shape-inference --convert-onnx-to-krnl --canonicalize %s -split-input-file | FileCheck %s + +// ----- + + +func.func @should_lower_to_zlow(%arg0: tensor<3x4x50xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<*xf16> { + %0 = onnx.Constant dense<[30, 4, 5]> : tensor<3xi64> + %1 = "zhigh.Reshape"(%arg0, %0) {layout = "3DS"} : (tensor<3x4x50xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<3xi64>) -> tensor<*xf16> + return %1 : tensor<*xf16> + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func.func @should_lower_to_zlow +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<3x4x50xf16, #map>) -> memref<30x4x5xf16, #map> { +// CHECK: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<30x4x5xf16, #map> +// CHECK: "zlow.reshape"([[PARAM_0_]], [[RES_]]) {layout = "3DS"} : (memref<3x4x50xf16, #map>, memref<30x4x5xf16, #map>) -> () +// CHECK: return [[RES_]] : memref<30x4x5xf16, #map> +// CHECK: } +} + diff --git a/test/mlir/accelerators/nnpa/transform/zhigh-combine.mlir b/test/mlir/accelerators/nnpa/transform/zhigh-combine.mlir index 175f228aba..639186e868 100644 --- a/test/mlir/accelerators/nnpa/transform/zhigh-combine.mlir +++ b/test/mlir/accelerators/nnpa/transform/zhigh-combine.mlir @@ -18,7 +18,6 @@ func.func @remove_stick_and_unstick_same_layout(%arg0 : tensor<10x10xf32>) -> te // CHECK: zhigh.Relu // CHECK: zhigh.Unstick } - // ----- func.func @remove_stick_only(%arg0 : tensor<10x10xf32>) -> tensor<10x10xf32> { @@ -89,8 +88,8 @@ func.func @donot_replace_stick_and_unstick_by_layout_transform(%arg0 : tensor<5x // Remove Stick with NoneType input. func.func @remove_nonetype_stick() -> () { - %cst = "onnx.NoValue"() {value} : () -> none - %0 = "zhigh.Stick"(%cst) : (none) -> none + %cst = "onnx.NoValue"() {value} : () -> none + %0 = "zhigh.Stick"(%cst) : (none) -> none return // CHECK-LABEL: remove_nonetype_stick @@ -257,7 +256,7 @@ func.func @reshape_transpose_reshape_3ds_to_2d(%arg0: tensor<48x256x64xf16, #zhi %4 = "zhigh.Stick"(%3) {layout = "2D"} : (tensor<1024x768xf32>) -> tensor<1024x768xf16, #zhigh.layout<{dataLayout = "2D"}>> return %4 : tensor<1024x768xf16, #zhigh.layout<{dataLayout = "2D"}>> -// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> ((d0 floordiv 12) * 256 + d1, (d0 mod 12) * 64 + d2)> +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> ((d0 floordiv 12) * 256 + d1, (d0 mod 12) * 64 + d2)> // CHECK-LABEL: func.func @reshape_transpose_reshape_3ds_to_2d // CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<48x256x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1024x768xf16, #zhigh.layout<{dataLayout = "2D"}>> { // CHECK: [[VAR_0_:%.+]] = "zhigh.Unstick"([[PARAM_0_]]) : (tensor<48x256x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<48x256x64xf32> @@ -333,3 +332,199 @@ func.func @test_delay_dlf16_to_f32(%arg0: tensor<1x3x5x?xf16>, %arg1: tensor<3xi // CHECK: onnx.Return [[VAR_1_]] : tensor<5x3x?xf16> // CHECK: } } + +// ----- + +// COM: Roberta pattern with BS=1 + +func.func @test_Roberta_bs1(%arg0: tensor<12x384x384xf32>, %arg1: tensor<12x384x64xf32>, %arg2: tensor<768x768xf32>) -> tensor<1x384x768xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %2 = onnx.Constant dense<[1, 12, 384, 64]> : tensor<4xi64> + %9 = onnx.Constant dense<[1, 384, 768]> : tensor<3xi64> + %76 = "zhigh.Stick"(%arg0) {layout = "3DS"} : (tensor<12x384x384xf32>) -> tensor<12x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %77 = "zhigh.Stick"(%arg1) {layout = "3DS"} : (tensor<12x384x64xf32>) -> tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %78 = "zhigh.MatMul"(%76, %77, %0) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<12x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %79 = "zhigh.Unstick"(%78) : (tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<12x384x64xf32> + %80 = "onnx.Reshape"(%79, %2) {allowzero = 0 : si64} : (tensor<12x384x64xf32>, tensor<4xi64>) -> tensor<1x12x384x64xf32> + %81 = "onnx.Transpose"(%80) {onnx_node_name = "Transpose_94", perm = [0, 2, 1, 3]} : (tensor<1x12x384x64xf32>) -> tensor<1x384x12x64xf32> + %82 = "onnx.Reshape"(%81, %9) {allowzero = 0 : si64, onnx_node_name = "Reshape_104"} : (tensor<1x384x12x64xf32>, tensor<3xi64>) -> tensor<1x384x768xf32> + %83 = "zhigh.Stick"(%82) {layout = "3DS"} : (tensor<1x384x768xf32>) -> tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %84 = "zhigh.Stick"(%arg2) {layout = "2D"} : (tensor<768x768xf32>) -> tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>> + %85 = "zhigh.MatMul"(%83, %84, %0) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %86 = "zhigh.Unstick"(%85) : (tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x384x768xf32> + onnx.Return %86 : tensor<1x384x768xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_Roberta_bs1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<12x384x384xf32>, [[PARAM_1_:%.+]]: tensor<12x384x64xf32>, [[PARAM_2_:%.+]]: tensor<768x768xf32>) -> tensor<1x384x768xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[1, 384, 768]> : tensor<3xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<12x384x384xf32>) -> tensor<12x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "3DS"} : (tensor<12x384x64xf32>) -> tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_4_:%.+]] = "zhigh.MatMul"([[VAR_2_]], [[VAR_3_]], [[VAR_1_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<12x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.Reshape"([[VAR_4_]], [[VAR_0_]]) {layout = "3DS"} : (tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<3xi64>) -> tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_6_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "2D"} : (tensor<768x768xf32>) -> tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_7_:%.+]] = "zhigh.MatMul"([[VAR_5_]], [[VAR_6_]], [[VAR_1_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_8_:%.+]] = "zhigh.Unstick"([[VAR_7_]]) : (tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x384x768xf32> +// CHECK: onnx.Return [[VAR_8_]] : tensor<1x384x768xf32> +// CHECK: } +} + +// ----- + +// COM: Roberta pattern with BS=8 + +func.func @test_Roberta_bs8(%arg0: tensor<96x384x384xf32>, %arg1: tensor<96x384x64xf32>, %arg2: tensor<768x768xf32>) -> tensor<8x384x768xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %2 = onnx.Constant dense<[8, 12, 384, 64]> : tensor<4xi64> + %9 = onnx.Constant dense<[8, 384, 768]> : tensor<3xi64> + %76 = "zhigh.Stick"(%arg0) {layout = "3DS"} : (tensor<96x384x384xf32>) -> tensor<96x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %77 = "zhigh.Stick"(%arg1) {layout = "3DS"} : (tensor<96x384x64xf32>) -> tensor<96x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %78 = "zhigh.MatMul"(%76, %77, %0) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<96x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<96x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<96x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %79 = "zhigh.Unstick"(%78) : (tensor<96x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<96x384x64xf32> + %80 = "onnx.Reshape"(%79, %2) {allowzero = 0 : si64} : (tensor<96x384x64xf32>, tensor<4xi64>) -> tensor<8x12x384x64xf32> + %81 = "onnx.Transpose"(%80) {onnx_node_name = "Transpose_94", perm = [0, 2, 1, 3]} : (tensor<8x12x384x64xf32>) -> tensor<8x384x12x64xf32> + %82 = "onnx.Reshape"(%81, %9) {allowzero = 0 : si64, onnx_node_name = "Reshape_104"} : (tensor<8x384x12x64xf32>, tensor<3xi64>) -> tensor<8x384x768xf32> + %83 = "zhigh.Stick"(%82) {layout = "3DS"} : (tensor<8x384x768xf32>) -> tensor<8x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %84 = "zhigh.Stick"(%arg2) {layout = "2D"} : (tensor<768x768xf32>) -> tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>> + %85 = "zhigh.MatMul"(%83, %84, %0) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<8x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<8x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %86 = "zhigh.Unstick"(%85) : (tensor<8x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<8x384x768xf32> + onnx.Return %86 : tensor<8x384x768xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_Roberta_bs8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<96x384x384xf32>, [[PARAM_1_:%.+]]: tensor<96x384x64xf32>, [[PARAM_2_:%.+]]: tensor<768x768xf32>) -> tensor<8x384x768xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[8, 384, 768]> : tensor<3xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<96x384x384xf32>) -> tensor<96x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "3DS"} : (tensor<96x384x64xf32>) -> tensor<96x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_4_:%.+]] = "zhigh.MatMul"([[VAR_2_]], [[VAR_3_]], [[VAR_1_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<96x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<96x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<96x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.Reshape"([[VAR_4_]], [[VAR_0_]]) {layout = "3DS"} : (tensor<96x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<3xi64>) -> tensor<8x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_6_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "2D"} : (tensor<768x768xf32>) -> tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_7_:%.+]] = "zhigh.MatMul"([[VAR_5_]], [[VAR_6_]], [[VAR_1_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<8x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<8x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_8_:%.+]] = "zhigh.Unstick"([[VAR_7_]]) : (tensor<8x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<8x384x768xf32> +// CHECK: onnx.Return [[VAR_8_]] : tensor<8x384x768xf32> +// CHECK: } +} + +// ----- + +// COM: Roberta pattern with BS=1 but dim 2 (385) is not mod 32 = 0; should fail to apply pattern + +func.func @test_Roberta_bs1_not_mod32(%arg0: tensor<12x385x385xf32>, %arg1: tensor<12x385x64xf32>, %arg2: tensor<768x768xf32>) -> tensor<1x385x768xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %2 = onnx.Constant dense<[1, 12, 384, 64]> : tensor<4xi64> + %9 = onnx.Constant dense<[1, 384, 768]> : tensor<3xi64> + %76 = "zhigh.Stick"(%arg0) {layout = "3DS"} : (tensor<12x385x385xf32>) -> tensor<12x385x385xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %77 = "zhigh.Stick"(%arg1) {layout = "3DS"} : (tensor<12x385x64xf32>) -> tensor<12x385x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %78 = "zhigh.MatMul"(%76, %77, %0) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<12x385x385xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<12x385x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<12x385x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %79 = "zhigh.Unstick"(%78) : (tensor<12x385x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<12x385x64xf32> + %80 = "onnx.Reshape"(%79, %2) {allowzero = 0 : si64} : (tensor<12x385x64xf32>, tensor<4xi64>) -> tensor<1x12x385x64xf32> + %81 = "onnx.Transpose"(%80) {onnx_node_name = "Transpose_94", perm = [0, 2, 1, 3]} : (tensor<1x12x385x64xf32>) -> tensor<1x385x12x64xf32> + %82 = "onnx.Reshape"(%81, %9) {allowzero = 0 : si64, onnx_node_name = "Reshape_104"} : (tensor<1x385x12x64xf32>, tensor<3xi64>) -> tensor<1x385x768xf32> + %83 = "zhigh.Stick"(%82) {layout = "3DS"} : (tensor<1x385x768xf32>) -> tensor<1x385x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %84 = "zhigh.Stick"(%arg2) {layout = "2D"} : (tensor<768x768xf32>) -> tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>> + %85 = "zhigh.MatMul"(%83, %84, %0) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<1x385x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<1x385x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %86 = "zhigh.Unstick"(%85) : (tensor<1x385x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x385x768xf32> + onnx.Return %86 : tensor<1x385x768xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_Roberta_bs1_not_mod32 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<12x385x385xf32>, [[PARAM_1_:%.+]]: tensor<12x385x64xf32>, [[PARAM_2_:%.+]]: tensor<768x768xf32>) -> tensor<1x385x768xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[1, 12, 384, 64]> : tensor<4xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[1, 384, 768]> : tensor<3xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<12x385x385xf32>) -> tensor<12x385x385xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "3DS"} : (tensor<12x385x64xf32>) -> tensor<12x385x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_5_:%.+]] = "zhigh.MatMul"([[VAR_3_]], [[VAR_4_]], [[VAR_0_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<12x385x385xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<12x385x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<12x385x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_6_:%.+]] = "zhigh.Unstick"([[VAR_5_]]) : (tensor<12x385x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<12x385x64xf32> +// CHECK: [[VAR_7_:%.+]] = "onnx.Reshape"([[VAR_6_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<12x385x64xf32>, tensor<4xi64>) -> tensor<1x12x385x64xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.Transpose"([[VAR_7_]]) {onnx_node_name = "Transpose_94", perm = [0, 2, 1, 3]} : (tensor<1x12x385x64xf32>) -> tensor<1x385x12x64xf32> +// CHECK: [[VAR_9_:%.+]] = "onnx.Reshape"([[VAR_8_]], [[VAR_2_]]) {allowzero = 0 : si64, onnx_node_name = "Reshape_104"} : (tensor<1x385x12x64xf32>, tensor<3xi64>) -> tensor<1x385x768xf32> +// CHECK-DAG: [[VAR_10_:%.+]] = "zhigh.Stick"([[VAR_9_]]) {layout = "3DS"} : (tensor<1x385x768xf32>) -> tensor<1x385x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_11_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "2D"} : (tensor<768x768xf32>) -> tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>> +// CHECK: [[VAR_12_:%.+]] = "zhigh.MatMul"([[VAR_10_]], [[VAR_11_]], [[VAR_0_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<1x385x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<768x768xf16, #zhigh.layout<{dataLayout = "2D"}>>, none) -> tensor<1x385x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_13_:%.+]] = "zhigh.Unstick"([[VAR_12_]]) : (tensor<1x385x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x385x768xf32> +// CHECK: onnx.Return [[VAR_13_]] : tensor<1x385x768xf32> +// CHECK: } +} + +// ----- + +// COM second pattern found in roberta, with BS=1 + +func.func @test_Roberta_pattern2_bs1(%arg0: tensor<1x384x768xf32>, %arg1: tensor<1x384x768xf32>, %arg2: tensor<1x384x768xf32>) -> tensor<12x384x64xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %2 = onnx.Constant dense<[-1, 384, 64]> : tensor<3xi64> + %8 = onnx.Constant dense<[1, 384, 12, 64]> : tensor<4xi64> + %48 = "zhigh.Stick"(%arg0) {layout = "3DS"} : (tensor<1x384x768xf32>) -> tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %49 = "zhigh.Stick"(%arg1) {layout = "3DS"} : (tensor<1x384x768xf32>) -> tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %63 = "zhigh.Stick"(%arg2) {layout = "3DS"} : (tensor<1x384x768xf32>) -> tensor<12x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %50 = "zhigh.Add"(%48, %49) : (tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %51 = "zhigh.Unstick"(%50) : (tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x384x768xf32> + %55 = "onnx.Reshape"(%51, %8) {allowzero = 0 : si64, onnx_node_name = "Reshape_85"} : (tensor<1x384x768xf32>, tensor<4xi64>) -> tensor<1x384x12x64xf32> + %56 = "onnx.Transpose"(%55) {onnx_node_name = "Transpose_86", perm = [0, 2, 1, 3]} : (tensor<1x384x12x64xf32>) -> tensor<1x12x384x64xf32> + %64 = "onnx.Reshape"(%56, %2) {allowzero = 0 : si64} : (tensor<1x12x384x64xf32>, tensor<3xi64>) -> tensor<12x384x64xf32> + %65 = "zhigh.Stick"(%64) {layout = "3DS"} : (tensor<12x384x64xf32>) -> tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %66 = "zhigh.MatMul"(%63, %65, %0) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<12x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %73 = "zhigh.Unstick"(%66) : (tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<12x384x64xf32> + onnx.Return %73 : tensor<12x384x64xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_Roberta_pattern2_bs1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x768xf32>, [[PARAM_1_:%.+]]: tensor<1x384x768xf32>, [[PARAM_2_:%.+]]: tensor<1x384x768xf32>) -> tensor<12x384x64xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<[12, 384, 64]> : tensor<3xi64> +// CHECK-DAG: [[VAR_1_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_2_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<1x384x768xf32>) -> tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "3DS"} : (tensor<1x384x768xf32>) -> tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "3DS"} : (tensor<1x384x768xf32>) -> tensor<12x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_5_:%.+]] = "zhigh.Add"([[VAR_2_]], [[VAR_3_]]) : (tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_6_:%.+]] = "zhigh.Reshape"([[VAR_5_]], [[VAR_0_]]) {layout = "3DS"} : (tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<3xi64>) -> tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_7_:%.+]] = "zhigh.MatMul"([[VAR_4_]], [[VAR_6_]], [[VAR_1_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<12x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_8_:%.+]] = "zhigh.Unstick"([[VAR_7_]]) : (tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<12x384x64xf32> +// CHECK: onnx.Return [[VAR_8_]] : tensor<12x384x64xf32> +// CHECK: } +} + +// ----- + +// COM second pattern found in roberta, with BS=1, not mod 64 + +func.func @test_Roberta_pattern2_bs1_notmod64(%arg0: tensor<1x384x756xf32>, %arg1: tensor<1x384x756xf32>, %arg2: tensor<1x384x756xf32>) -> tensor<12x384x63xf32> { + %0 = "onnx.NoValue"() {value} : () -> none + %2 = onnx.Constant dense<[-1, 384, 64]> : tensor<3xi64> + %8 = onnx.Constant dense<[1, 384, 12, 64]> : tensor<4xi64> + %48 = "zhigh.Stick"(%arg0) {layout = "3DS"} : (tensor<1x384x756xf32>) -> tensor<1x384x756xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %49 = "zhigh.Stick"(%arg1) {layout = "3DS"} : (tensor<1x384x756xf32>) -> tensor<1x384x756xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %63 = "zhigh.Stick"(%arg2) {layout = "3DS"} : (tensor<1x384x756xf32>) -> tensor<12x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %50 = "zhigh.Add"(%48, %49) : (tensor<1x384x756xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<1x384x756xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x384x756xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %51 = "zhigh.Unstick"(%50) : (tensor<1x384x756xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x384x756xf32> + %55 = "onnx.Reshape"(%51, %8) {allowzero = 0 : si64, onnx_node_name = "Reshape_85"} : (tensor<1x384x756xf32>, tensor<4xi64>) -> tensor<1x384x12x63xf32> + %56 = "onnx.Transpose"(%55) {onnx_node_name = "Transpose_86", perm = [0, 2, 1, 3]} : (tensor<1x384x12x63xf32>) -> tensor<1x12x384x63xf32> + %64 = "onnx.Reshape"(%56, %2) {allowzero = 0 : si64} : (tensor<1x12x384x63xf32>, tensor<3xi64>) -> tensor<12x384x63xf32> + %65 = "zhigh.Stick"(%64) {layout = "3DS"} : (tensor<12x384x63xf32>) -> tensor<12x384x63xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %66 = "zhigh.MatMul"(%63, %65, %0) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<12x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<12x384x63xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<12x384x63xf16, #zhigh.layout<{dataLayout = "3DS"}>> + %73 = "zhigh.Unstick"(%66) : (tensor<12x384x63xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<12x384x63xf32> + onnx.Return %73 : tensor<12x384x63xf32> + +// mlir2FileCheck.py +// CHECK-LABEL: func.func @test_Roberta_pattern2_bs1_notmod64 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x384x756xf32>, [[PARAM_1_:%.+]]: tensor<1x384x756xf32>, [[PARAM_2_:%.+]]: tensor<1x384x756xf32>) -> tensor<12x384x63xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "onnx.NoValue"() {value} : () -> none +// CHECK-DAG: [[VAR_1_:%.+]] = onnx.Constant dense<[-1, 384, 64]> : tensor<3xi64> +// CHECK-DAG: [[VAR_2_:%.+]] = onnx.Constant dense<[1, 384, 12, 64]> : tensor<4xi64> +// CHECK-DAG: [[VAR_3_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "3DS"} : (tensor<1x384x756xf32>) -> tensor<1x384x756xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_4_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "3DS"} : (tensor<1x384x756xf32>) -> tensor<1x384x756xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK-DAG: [[VAR_5_:%.+]] = "zhigh.Stick"([[PARAM_2_]]) {layout = "3DS"} : (tensor<1x384x756xf32>) -> tensor<12x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_6_:%.+]] = "zhigh.Add"([[VAR_3_]], [[VAR_4_]]) : (tensor<1x384x756xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<1x384x756xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x384x756xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_7_:%.+]] = "zhigh.Unstick"([[VAR_6_]]) : (tensor<1x384x756xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<1x384x756xf32> +// CHECK: [[VAR_8_:%.+]] = "onnx.Reshape"([[VAR_7_]], [[VAR_2_]]) {allowzero = 0 : si64, onnx_node_name = "Reshape_85"} : (tensor<1x384x756xf32>, tensor<4xi64>) -> tensor<1x384x12x63xf32> +// CHECK: [[VAR_9_:%.+]] = "onnx.Transpose"([[VAR_8_]]) {onnx_node_name = "Transpose_86", perm = [0, 2, 1, 3]} : (tensor<1x384x12x63xf32>) -> tensor<1x12x384x63xf32> +// CHECK: [[VAR_10_:%.+]] = "onnx.Reshape"([[VAR_9_]], [[VAR_1_]]) {allowzero = 0 : si64} : (tensor<1x12x384x63xf32>, tensor<3xi64>) -> tensor<12x384x63xf32> +// CHECK: [[VAR_11_:%.+]] = "zhigh.Stick"([[VAR_10_]]) {layout = "3DS"} : (tensor<12x384x63xf32>) -> tensor<12x384x63xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_12_:%.+]] = "zhigh.MatMul"([[VAR_5_]], [[VAR_11_]], [[VAR_0_]]) {transposeA = 0 : si64, transposeB = 0 : si64} : (tensor<12x384x384xf16, #zhigh.layout<{dataLayout = "3DS"}>>, tensor<12x384x63xf16, #zhigh.layout<{dataLayout = "3DS"}>>, none) -> tensor<12x384x63xf16, #zhigh.layout<{dataLayout = "3DS"}>> +// CHECK: [[VAR_13_:%.+]] = "zhigh.Unstick"([[VAR_12_]]) : (tensor<12x384x63xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<12x384x63xf32> +// CHECK: onnx.Return [[VAR_13_]] : tensor<12x384x63xf32> +// CHECK: } +} + diff --git a/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir b/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir index 07cffa92b5..2c7e558c9c 100644 --- a/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir +++ b/test/mlir/accelerators/nnpa/transform/zlow-rewrite.mlir @@ -713,6 +713,75 @@ func.func @should_not_rewrite_unstick_transpose_stick_3(%arg0: memref<5x10xf16, // ----- +// Test reshape, had to add a use of the reshape to make it work. Test only the reshape. +// Fail here because not memref normalized. +#map = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +func.func @handle_zlow_reshape_fail(%arg0: memref<8x384x768xf16, #map>, %arg1: memref<96x64x384xf16, #map>) -> memref<96x384x384xf16, #map>{ + // Constants. + %c64_i64 = arith.constant 64 : i64 + %c96_i64 = arith.constant 96 : i64 + %c384_i64 = arith.constant 384 : i64 + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // Reshape of input. + %alloc = memref.alloc() {alignment = 4096 : i64} : memref<96x384x64xf16, #map> + "zlow.reshape"(%arg0, %alloc) {layout = "3DS"} : (memref<8x384x768xf16, #map>, memref<96x384x64xf16, #map>) -> () + // Use of reshape. + %alloc_0 = memref.alloc() {alignment = 4096 : i64} : memref<96x384x384xf16, #map> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<4xi64> + krnl.store %c96_i64, %alloc_1[%c0] : memref<4xi64> + krnl.store %c384_i64, %alloc_1[%c1] : memref<4xi64> + krnl.store %c64_i64, %alloc_1[%c2] : memref<4xi64> + krnl.store %c384_i64, %alloc_1[%c3] : memref<4xi64> + %0 = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_3", shape = [96, 6, 1, 1, 32, 64], value = dense_resource : tensor<2359296xi8>} : () -> memref<96x6x1x1x32x64xf16> + "zlow.matmul"(%alloc, %arg1, %0, %alloc_1, %alloc_0) {is_bcast1 = 0 : si64, is_bcast23 = 0 : si64, is_stacked = -1 : si64, transposeA = 0 : si64, transposeB = 0 : si64} : (memref<96x384x64xf16, #map>, memref<96x64x384xf16, #map>, memref<96x6x1x1x32x64xf16>, memref<4xi64>, memref<96x384x384xf16, #map>) -> () + return %alloc_0 : memref<96x384x384xf16, #map> + +// mlir2FileCheck.py +// CHECK-DAG: [[MAP_0_:#.+]] = affine_map<(d0, d1, d2) -> (d0, d2 floordiv 64, 0, d1 floordiv 32, d1 mod 32, d2 mod 64)> +// CHECK-LABEL: func.func @handle_zlow_reshape_fail +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<8x384x768xf16, #map>, [[PARAM_1_:%.+]]: memref<96x64x384xf16, #map>) -> memref<96x384x384xf16, #map> { + +// CHECK: [[RES_:%.+]] = memref.alloc() {{.*}}: memref<96x384x64xf16, #map> +// CHECK: "zlow.reshape"([[PARAM_0_]], [[RES_]]) {layout = "3DS"} : (memref<8x384x768xf16, #map>, memref<96x384x64xf16, #map>) -> () + +// CHECK: } +} + +// ----- + +// Succeed now because memref normalized. + +func.func @handle_zlow_reshape_success(%arg0: memref<8x12x1x12x32x64xf16>, %arg1: memref<96x6x1x2x32x64xf16>) -> memref<96x6x1x12x32x64xf16> { + %c64_i64 = arith.constant 64 : i64 + %c96_i64 = arith.constant 96 : i64 + %c384_i64 = arith.constant 384 : i64 + %c3 = arith.constant 3 : index + %c2 = arith.constant 2 : index + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %alloc = memref.alloc() {alignment = 4096 : i64} : memref<96x1x1x12x32x64xf16> + "zlow.reshape"(%arg0, %alloc) {layout = "3DS"} : (memref<8x12x1x12x32x64xf16>, memref<96x1x1x12x32x64xf16>) -> () + %alloc_0 = memref.alloc() {alignment = 4096 : i64} : memref<96x6x1x12x32x64xf16> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<4xi64> + krnl.store %c96_i64, %alloc_1[%c0] : memref<4xi64> + krnl.store %c384_i64, %alloc_1[%c1] : memref<4xi64> + krnl.store %c64_i64, %alloc_1[%c2] : memref<4xi64> + krnl.store %c384_i64, %alloc_1[%c3] : memref<4xi64> + %0 = "krnl.global"() {alignment = 4096 : i64, name = "constant_stickify_3", shape = [96, 6, 1, 1, 32, 64], value = dense_resource : tensor<2359296xi8>} : () -> memref<96x6x1x1x32x64xf16> + "zlow.matmul"(%alloc, %arg1, %0, %alloc_1, %alloc_0) {is_bcast1 = 0 : si64, is_bcast23 = 0 : si64, is_stacked = -1 : si64, transposeA = 0 : si64, transposeB = 0 : si64} : (memref<96x1x1x12x32x64xf16>, memref<96x6x1x2x32x64xf16>, memref<96x6x1x1x32x64xf16>, memref<4xi64>, memref<96x6x1x12x32x64xf16>) -> () + return %alloc_0 : memref<96x6x1x12x32x64xf16> +// mlir2FileCheck.py +// CHECK-LABEL: func.func @handle_zlow_reshape_success +// CHECK-SAME: ([[PARAM_0_:%.+]]: memref<8x12x1x12x32x64xf16>, [[PARAM_1_:%.+]]: memref<96x6x1x2x32x64xf16>) -> memref<96x6x1x12x32x64xf16> { + +// CHECK-DAG: [[VAR_reinterpret_cast_:%.+]] = memref.reinterpret_cast [[PARAM_0_]] to offset: [0], sizes: [96, 1, 1, 12, 32, 64], strides: [24576, 24576, 24576, 2048, 64, 1] : memref<8x12x1x12x32x64xf16> to memref<96x1x1x12x32x64xf16> + +// CHECK: } +} + // Do not rewrite because there is a AffineStoreOp without AffineLoadOp in pattern: unstick -> pad -> stick // TODO: support this pattern.