diff --git a/docs/Dialects/zhigh.md b/docs/Dialects/zhigh.md
index dd87eeecf5..8aa2197c08 100644
--- a/docs/Dialects/zhigh.md
+++ b/docs/Dialects/zhigh.md
@@ -782,13 +782,6 @@ Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterfac
Effects: `MemoryEffects::Effect{}`
-#### Attributes:
-
-
-Attribute | MLIR Type | Description |
-op_type | ::mlir::StringAttr | string attribute |
-
-
#### Operands:
| Operand | Description |
@@ -814,13 +807,6 @@ Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterfac
Effects: `MemoryEffects::Effect{}`
-#### Attributes:
-
-
-Attribute | MLIR Type | Description |
-op_type | ::mlir::StringAttr | string attribute |
-
-
#### Operands:
| Operand | Description |
@@ -857,6 +843,40 @@ Effects: `MemoryEffects::Effect{}`
| :----: | ----------- |
| `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+### `zhigh.Reshape` (::onnx_mlir::zhigh::ZHighReshapeOp)
+
+_ZHigh Reshape operation for Z Tensors_
+
+ZHigh operation to perform a converts a Z Tensor from one type to an equivalent type
+with a provided shape. The data is never copied or modified. When no layout is specified,
+the output preserve the same layout as the source input.
+
+Traits: `AlwaysSpeculatableImplTrait`
+
+Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`
+
+Effects: `MemoryEffects::Effect{}`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+layout | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `source` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+| `shape` | tensor of 64-bit signless integer values
+
+#### Results:
+
+| Result | Description |
+| :----: | ----------- |
+| `result` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
+
### `zhigh.Sigmoid` (::onnx_mlir::zhigh::ZHighSigmoidOp)
_ZHigh Sigmoid operation_
diff --git a/docs/Dialects/zlow.md b/docs/Dialects/zlow.md
index 7be1c6457b..e97ece5313 100644
--- a/docs/Dialects/zlow.md
+++ b/docs/Dialects/zlow.md
@@ -770,7 +770,6 @@ Traits: `MemRefsNormalizable`
Attribute | MLIR Type | Description |
layout | ::mlir::StringAttr | string attribute |
-op_type | ::mlir::StringAttr | string attribute |
#### Operands:
@@ -795,7 +794,6 @@ Traits: `MemRefsNormalizable`
Attribute | MLIR Type | Description |
layout | ::mlir::StringAttr | string attribute |
-op_type | ::mlir::StringAttr | string attribute |
#### Operands:
@@ -832,6 +830,29 @@ Interfaces: `MemoryEffectOpInterface`
| `shape` | memref of 64-bit signless integer values
| `Out` | memref of dlfloat16 type values
+### `zlow.reshape` (::onnx_mlir::zlow::ZLowReshapeOp)
+
+_ZLow Reshape operation_
+
+ZLow operation to perform a reshape (no data movement).
+
+Traits: `MemRefsNormalizable`
+
+#### Attributes:
+
+
+Attribute | MLIR Type | Description |
+layout | ::mlir::StringAttr | string attribute |
+
+
+#### Operands:
+
+| Operand | Description |
+| :-----: | ----------- |
+| `X` | memref of dlfloat16 type values
+| `shape` | memref of 64-bit signless integer values
+| `Out` | memref of dlfloat16 type values
+
### `zlow.sigmoid` (::onnx_mlir::zlow::ZLowSigmoidOp)
_ZLow sigmoid operation_
diff --git a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td
index 2173000382..05a373f9b5 100644
--- a/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td
+++ b/src/Accelerators/NNPA/Conversion/ONNXToZHigh/RewriteONNXForZHigh.td
@@ -95,17 +95,6 @@ def replaceONNXBatchNormalizationInferenceModePattern : Pattern<
//
//===----------------------------------------------------------------------===//
-
-// Create an ONNX Shape Op with type
-def CreateShapeOp: NativeCodeCall<
- "$_builder.create($_loc, $0, $1, IntegerAttr(), 0)"
->;
-
-// Get a type for a tensor that stores the shape of another tensor.
-def GetShapeTypeOf: NativeCodeCall<
- "RankedTensorType::get({mlir::cast($0.getType()).getRank()}, $_builder.getIntegerType(64))"
->;
-
// Check unidirectional broadcasting from the first to second tensor.
def IsUniBroadcastingFromFirstToSecond: Constraint<
CPred<"isUniBroadcatableFirstToSecond($0, $1)">,
diff --git a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
index 2cdc850e02..1ed8f6e85b 100644
--- a/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
+++ b/src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
@@ -1093,6 +1093,47 @@ struct ZHighToZLowUnaryOpLowering : public ConversionPattern {
}
};
+// Reshape operation. Code similar to unary lowering, except that we use the
+// operation's specialized shape here.
+struct ZHighToZLowReshapeOpLowering : public ConversionPattern {
+ ZHighToZLowReshapeOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
+ : ConversionPattern(ZHighReshapeOp::getOperationName(), 1, ctx) {}
+
+ LogicalResult matchAndRewrite(Operation *op, ArrayRef operands,
+ ConversionPatternRewriter &rewriter) const final {
+ Location loc = op->getLoc();
+ Value input = operands[0];
+
+ // Helper builders.
+ MultiDialectBuilder create(rewriter, loc);
+
+ // Convert ZTensor type to MemRefType.
+ ZMemRefType zMemRefType =
+ convertZTensorToMemRefType(*op->result_type_begin());
+
+ // Shape helper.
+ ZHighReshapeOpShapeHelper shapeHelper(op, operands, &create.krnlIE);
+ shapeHelper.computeShapeAndAssertOnFailure();
+ SmallVector &dims = shapeHelper.getOutputDims();
+
+ // Allocate a buffer for the result MemRef. Follow this pattern to be
+ // similar to all the other zlow patterns. Will remove the alloc when
+ // lowering zlow.reshape to memref.reinterpret_cast once memrefs are
+ // normalized. See code in ReshapeToReinterpretCastPattern.
+ Value alloc = insertAllocForZMemRef(zMemRefType, dims, op, rewriter);
+
+ // Note, we do not need to save the shape of the original operation, as this
+ // reshape is "no-op" that logically reorganize the shape of the operation
+ // into 2 equivalent shapes under their given layout.
+
+ // Emit a ZLow operation.
+ rewriter.create(
+ loc, input, /* shape,*/ alloc, zMemRefType.layout);
+ rewriter.replaceOp(op, alloc);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Lower ZHigh ReduceMax/ReduceMin to ZLow ReduceMax/ReduceMin
//===----------------------------------------------------------------------===//
@@ -1117,8 +1158,6 @@ struct ZHighToZLowReduceOpLowering : public ConversionPattern {
: ConversionPattern(OP_TYPE::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef operands,
ConversionPatternRewriter &rewriter) const final {
- MLIRContext *context = rewriter.getContext();
- OP_TYPE reduceOp = mlir::cast(op);
Location loc = op->getLoc();
Value data = operands[0];
@@ -2285,6 +2324,8 @@ void populateZHighToZLowConversionPattern(mlir::RewritePatternSet &patterns,
patterns.insert>(typeConverter, ctx);
patterns.insert>(
typeConverter, ctx);
+ // Reshape operations.
+ patterns.insert(typeConverter, ctx);
// Neural network operations.
patterns.insert>(
typeConverter, ctx);
diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt b/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt
index 2c3e6ca953..7d8fa855de 100644
--- a/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt
+++ b/src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt
@@ -26,6 +26,7 @@ add_onnx_mlir_library(OMZHighOps
ZHighOps/QuantizedMatMul/QuantizedMatMul.cpp
ZHighOps/QuantizedStick/QuantizedStick.cpp
ZHighOps/Reduction/Reduction.cpp
+ ZHighOps/Reshape/Reshape.cpp
ZHighOps/Softmax/Softmax.cpp
ZHighOps/Stick/Stick.cpp
ZHighOps/StickForGRU/StickForGRU.cpp
diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td b/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
index 7bbcd02c87..cbc73dad18 100644
--- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
+++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
@@ -1195,4 +1195,28 @@ def ZHighFixGRUYhOp:ZHigh_Op<"FixGRUYh", [Pure,
}];
}
+def ZHighReshapeOp:ZHigh_Op<"Reshape", [Pure,
+ DeclareOpInterfaceMethods,
+ DeclareOpInterfaceMethods]> {
+ let summary = "ZHigh Reshape operation for Z Tensors";
+ let description = [{
+ ZHigh operation to perform a converts a Z Tensor from one type to an equivalent type
+ with a provided shape. The data is never copied or modified. When no layout is specified,
+ the output preserve the same layout as the source input.
+ }];
+ let arguments = (ins AnyTypeOf<[AnyZTensor]>:$source, // Input Z Tensor to be reshaped.
+ TensorOf<[I64]>:$shape, // Shape of output Z Tensor.
+ OptionalAttr:$layout); // Layout of output Z Tensor, default same as input.
+ let results = (outs AnyTypeOf<[AnyZTensor]>:$result);
+
+ let extraClassDefinition = [{
+ onnx_mlir::ONNXOpShapeHelper * ZHighReshapeOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper,
+ onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
+ onnx_mlir::ONNXOpShapeHelper *sh = new ZHighReshapeOpShapeHelper(op, oper, ieb, scope);
+ assert(sh && "failed to allocate shape helper");
+ return sh;
+ }
+ }];
+}
+
#endif // ZHIGH_OPS
diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp
index c120bc6b44..926d08913c 100644
--- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp
+++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.cpp
@@ -218,7 +218,7 @@ ZTensorEncodingAttr::QuantizedType getZTensorQuantizedType(Type type) {
// Utility functions.
Value getMinusBcastConst(
- mlir::OpBuilder &builder, Location loc, FloatAttr floatAttr, Value X) {
+ OpBuilder &builder, Location loc, FloatAttr floatAttr, Value X) {
ShapedType xType = mlir::cast(X.getType());
assert(xType.hasStaticShape() && "expected static shape");
float val = floatAttr.getValueAsDouble() * -1.0;
@@ -283,14 +283,14 @@ bool isTiling2DTo4D(Value val) {
if (!(inputShape.size() == 2 && outputShape.size() == 4))
return false;
- // Tiling over each input dimension.
+ // Tiling over each input dimension. Assume here that the dims are static.
return ((inputShape[0] == outputShape[0] * outputShape[1]) &&
(inputShape[1] == outputShape[2] * outputShape[3]));
}
/// Check if ONNXReshapeOp is reshaping 3D to 4D by tiling the first input
/// dimension.
-bool isTiling3DTo4D(Value val) {
+bool isLeftmostTiling3DTo4D(Value val) {
auto reshapeOp = mlir::dyn_cast(val.getDefiningOp());
if (!reshapeOp)
return false;
@@ -312,12 +312,50 @@ bool isTiling3DTo4D(Value val) {
if (!(inputShape.size() == 3 && outputShape.size() == 4))
return false;
- // Tiling over each input dimension.
+ // Tiling over each input dimension. Assume here that the dims are static.
return ((inputShape[0] == outputShape[0] * outputShape[1]) &&
(inputShape[1] == outputShape[2]) &&
(inputShape[2] == outputShape[3]));
}
+/// Check if ONNXReshapeOp is reshaping 3D to 4D by tiling the last input
+/// dimension. If tilingSize>0, then check that it is tiling by that amount (or
+/// a multiple thereof).
+bool isRightmostTiling3DTo4D(Value val, int64_t tilingSize) {
+ auto reshapeOp = mlir::dyn_cast(val.getDefiningOp());
+ if (!reshapeOp)
+ return false;
+
+ Value input = reshapeOp.getData();
+ Value output = reshapeOp.getReshaped();
+ Type inputType = input.getType();
+ Type outputType = output.getType();
+
+ if (!isRankedShapedType(inputType))
+ return false;
+ if (!isRankedShapedType(outputType))
+ return false;
+
+ ArrayRef inputShape = getShape(inputType);
+ ArrayRef outputShape = getShape(outputType);
+
+ // Not reshape from 3D to 4D.
+ if (!(inputShape.size() == 3 && outputShape.size() == 4))
+ return false;
+
+ // Check that the tiling size is given, then the last dim of the output is
+ // statically determined and is a multiples of tiling size.
+ if (tilingSize > 0)
+ if (ShapedType::isDynamic(outputShape[3]) ||
+ (outputShape[3] % tilingSize != 0))
+ return false;
+
+ // Tiling over each input dimension. Assume here that the dims are static.
+ return ((inputShape[0] == outputShape[0]) &&
+ (inputShape[1] == outputShape[1]) &&
+ (inputShape[2] == outputShape[2] * outputShape[3]));
+}
+
/// Check if a 4D tensor is collapsed into 2D by merging the each two
/// dimensions.
bool isCollapsing4DTo2D(Value val) {
@@ -342,14 +380,15 @@ bool isCollapsing4DTo2D(Value val) {
if (!(inputShape.size() == 4 && outputShape.size() == 2))
return false;
- // Collapsing by merging the first two dimensions.
+ // Collapsing by merging the first two dimensions. Assume here that the dims
+ // are static.
return ((inputShape[0] * inputShape[1] == outputShape[0]) &&
(inputShape[2] * inputShape[3] == outputShape[1]));
}
/// Check if a 4D tensor is collapsed into 3D by merging the first two
-/// dimensions.
-bool isCollapsing4DTo3D(Value val) {
+/// (leftmost) dimensions.
+bool isLeftmostCollapsing4DTo3D(Value val) {
auto reshapeOp = mlir::dyn_cast(val.getDefiningOp());
if (!reshapeOp)
return false;
@@ -371,12 +410,44 @@ bool isCollapsing4DTo3D(Value val) {
if (!(inputShape.size() == 4 && outputShape.size() == 3))
return false;
- // Collapsing by merging the first two dimensions.
+ // Collapsing by merging the first two dimensions. Assume here that the dims
+ // are static.
return ((inputShape[0] * inputShape[1] == outputShape[0]) &&
(inputShape[2] == outputShape[1]) &&
(inputShape[3] == outputShape[2]));
}
+/// Check if a 4D tensor is collapsed into 3D by merging the last two
+/// (rightmost) dimensions.
+bool isRightmostCollapsing4DTo3D(Value val) {
+ auto reshapeOp = mlir::dyn_cast(val.getDefiningOp());
+ if (!reshapeOp)
+ return false;
+
+ Value input = reshapeOp.getData();
+ Value output = reshapeOp.getReshaped();
+ Type inputType = input.getType();
+ Type outputType = output.getType();
+
+ if (!isRankedShapedType(inputType))
+ return false;
+ if (!isRankedShapedType(outputType))
+ return false;
+
+ ArrayRef inputShape = getShape(inputType);
+ ArrayRef outputShape = getShape(outputType);
+
+ // Not reshape from 4D to 3D.
+ if (!(inputShape.size() == 4 && outputShape.size() == 3))
+ return false;
+
+ // Collapsing by merging the first two dimensions. Assume here that the dims
+ // are static.
+ return ((inputShape[0] == outputShape[0]) &&
+ (inputShape[1] == outputShape[1]) &&
+ (inputShape[2] * inputShape[3] == outputShape[2]));
+}
+
AffineMapAttr getTiling2DTo4DMap(OpBuilder &b, Value val) {
assert(isTiling2DTo4D(val) &&
"ONNXReshapeOp is not suitable for getting a tiling affine map");
@@ -402,8 +473,8 @@ AffineMapAttr getTiling2DTo4DMap(OpBuilder &b, Value val) {
return AffineMapAttr::get(map);
}
-AffineMapAttr getTiling3DTo4DMap(OpBuilder &b, Value val) {
- assert(isTiling3DTo4D(val) &&
+AffineMapAttr getLeftmostTiling3DTo4DMap(OpBuilder &b, Value val) {
+ assert(isLeftmostTiling3DTo4D(val) &&
"ONNXReshapeOp is not suitable for getting a tiling affine map");
auto reshapeOp = mlir::dyn_cast(val.getDefiningOp());
@@ -450,8 +521,8 @@ AffineMapAttr getCollapsing4DTo2DMap(OpBuilder &b, Value val) {
return AffineMapAttr::get(map);
}
-AffineMapAttr getCollapsing4DTo3DMap(OpBuilder &b, Value val) {
- assert(isCollapsing4DTo3D(val) &&
+AffineMapAttr getLeftmostCollapsing4DTo3DMap(OpBuilder &b, Value val) {
+ assert(isLeftmostCollapsing4DTo3D(val) &&
"ONNXReshapeOp is not suitable for getting a collapsing affine map");
auto reshapeOp = mlir::dyn_cast(val.getDefiningOp());
@@ -484,6 +555,44 @@ AffineMapAttr getTransposeMap(OpBuilder &b, ArrayAttr permAttr) {
return AffineMapAttr::get(map);
}
+/// Check the values of a transpose map to be equal to the permVals.
+bool isTransposePermutationEqualTo(
+ ArrayAttr permAttr, mlir::ArrayRef permVals) {
+ // Check same rank.
+ int64_t permAttrSize = ArrayAttrSize(permAttr);
+ int64_t permValSize = permVals.size();
+ if (permAttrSize != permValSize)
+ return false;
+ // Check same values; abort on failure.
+ for (int64_t i = 0; i < permAttrSize; ++i) {
+ int64_t v = ArrayAttrIntVal(permAttr, i);
+ if (permVals[i] != v)
+ return false;
+ }
+ // Identical, success.
+ return true;
+}
+
+bool isShapeDimMultipleOf(Value val, int64_t index, int64_t multipleVal) {
+ // Type must be shaped and ranked.
+ Type type = val.getType();
+ if (!isRankedShapedType(type))
+ return false;
+ // Index must be within bounds of the shape rank; negative is from back.
+ ArrayRef shape = getShape(type);
+ int64_t size = shape.size();
+ if (index < 0)
+ index += size;
+ if (index < 0 || index >= size)
+ return false;
+ // At this time, only reason about static shapes.
+ int64_t dim = shape[index];
+ if (ShapedType::isDynamic(dim))
+ return false;
+ // All good now, check if dim is a multiple of "multipleVal."
+ return dim % multipleVal == 0;
+}
+
IntegerAttr getAxisNHWC(IntegerAttr axisNCHWAttr) {
int64_t axisNCHW = axisNCHWAttr.getSInt();
int64_t axisNHWC;
diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp
index cc346ef17d..c37d2af544 100644
--- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp
+++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.hpp
@@ -74,22 +74,33 @@ mlir::Value getConstantOfType(
bool oneIsOfLayout(
mlir::Type t1, mlir::Type t2, ZTensorEncodingAttr::DataLayout layout);
-/// Check if ONNXReshapeOp is reshaping 2D/3D to 4D by tiling each input
+/// Check if ONNXReshapeOp is reshaping 2D/3D to 4D by tiling an input
/// dimension.
bool isTiling2DTo4D(mlir::Value val);
mlir::AffineMapAttr getTiling2DTo4DMap(mlir::OpBuilder &b, mlir::Value val);
-bool isTiling3DTo4D(mlir::Value val);
-mlir::AffineMapAttr getTiling3DTo4DMap(mlir::OpBuilder &b, mlir::Value val);
-/// Check if ONNXReshapeOp is collapsing 4D into 3D/2D by merging the first two
-/// dimensions.
-bool isCollapsing4DTo3D(mlir::Value val);
-mlir::AffineMapAttr getCollapsing4DTo3DMap(mlir::OpBuilder &b, mlir::Value val);
+bool isLeftmostTiling3DTo4D(mlir::Value val);
+bool isRightmostTiling3DTo4D(mlir::Value val, int64_t tilingSize);
+mlir::AffineMapAttr getLeftmostTiling3DTo4DMap(
+ mlir::OpBuilder &b, mlir::Value val);
+/// Check if ONNXReshapeOp is collapsing 4D into 3D by merging the first two
+/// (leftmost) dimensions.
+bool isLeftmostCollapsing4DTo3D(mlir::Value val);
+/// Check if ONNXReshapeOp is collapsing 4D into 3D by merging the last two
+/// (rightmost) dimensions.
+bool isRightmostCollapsing4DTo3D(mlir::Value val);
+mlir::AffineMapAttr getLeftmostCollapsing4DTo3DMap(
+ mlir::OpBuilder &b, mlir::Value val);
bool isCollapsing4DTo2D(mlir::Value val);
mlir::AffineMapAttr getCollapsing4DTo2DMap(mlir::OpBuilder &b, mlir::Value val);
/// Get an affine map for the permutation array.
mlir::AffineMapAttr getTransposeMap(
mlir::OpBuilder &b, mlir::ArrayAttr permAttr);
-
+/// Check the values of a transpose map to be equal to the permVals.
+bool isTransposePermutationEqualTo(
+ mlir::ArrayAttr permAttr, mlir::ArrayRef permVals);
+/// Return true when shape(Value)[index] % multipleVal == 0.
+/// Negative indices, count from the back (-1 is last element).
+bool isShapeDimMultipleOf(mlir::Value val, int64_t index, int64_t multipleVal);
/// Get an axis for NHWC layout given an axis for NCHW layout.
mlir::IntegerAttr getAxisNHWC(mlir::IntegerAttr axisNCHWAttr);
diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td
index d3321a6464..f046029b13 100644
--- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td
+++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/OpHelper.td
@@ -39,6 +39,16 @@ def NotSameLayout: Constraint<
def IsNoneType : Constraint(($_self).getType())">>;
+// Create an ONNX Shape Op with type
+def CreateShapeOp: NativeCodeCall<
+ "$_builder.create($_loc, $0, $1, IntegerAttr(), 0)"
+>;
+
+// Get a type for a tensor that stores the shape of another tensor.
+def GetShapeTypeOf: NativeCodeCall<
+ "RankedTensorType::get({mlir::cast($0.getType()).getRank()}, $_builder.getIntegerType(64))"
+>;
+
def GetLayout : NativeCodeCall<
"::onnx_mlir::zhigh::convertZTensorDataLayoutToStringAttr($_builder, "
"::onnx_mlir::zhigh::getZTensorLayout($0.getType()))"
@@ -144,10 +154,16 @@ def IsTiling2DTo4D : Constraint<
"Is tiling by ONNReshapeOp"
>;
-/// Check if ONNXReshapeOp is reshaping 3D to 4D by tiling each input dimension.
-def IsTiling3DTo4D : Constraint<
- CPred<"::onnx_mlir::zhigh::isTiling3DTo4D($0)">,
- "Is tiling by ONNReshapeOp"
+/// Check if ONNXReshapeOp is reshaping 3D to 4D by tiling leftmost input dimension.
+def IsLeftmostTiling3DTo4D : Constraint<
+ CPred<"::onnx_mlir::zhigh::isLeftmostTiling3DTo4D($0)">,
+ "Is leftmost tiling by ONNReshapeOp"
+>;
+
+/// Check if ONNXReshapeOp is reshaping 3D to 4D by tiling rightmost input dimension.
+def IsRightmostTiling3DTo4DBy64 : Constraint<
+ CPred<"::onnx_mlir::zhigh::isRightmostTiling3DTo4D($0, 64)">,
+ "Is rightmost tiling of size 64 by ONNReshapeOp"
>;
/// Check if ONNXReshapeOp is reshaping 4D to 2D by collapsing the first two input dimensions.
@@ -157,9 +173,15 @@ def IsCollapsing4DTo2D : Constraint<
>;
/// Check if ONNXReshapeOp is reshaping 4D to 3D by collapsing the first two input dimensions.
-def IsCollapsing4DTo3D : Constraint<
- CPred<"::onnx_mlir::zhigh::isCollapsing4DTo3D($0)">,
- "Is collapsing by ONNXReshapeOp"
+def IsLeftmostCollapsing4DTo3D : Constraint<
+ CPred<"::onnx_mlir::zhigh::isLeftmostCollapsing4DTo3D($0)">,
+ "Is leftmost collapsing by ONNXReshapeOp"
+>;
+
+/// Check if ONNXReshapeOp is reshaping 4D to 3D by collapsing the last two input dimensions.
+def IsRightmostCollapsing4DTo3D : Constraint<
+ CPred<"::onnx_mlir::zhigh::isRightmostCollapsing4DTo3D($0)">,
+ "Is rightmost collapsing by ONNXReshapeOp"
>;
def GetResultType : NativeCodeCall<
@@ -170,22 +192,37 @@ def GetTiling2DTo4DMap : NativeCodeCall<
"::onnx_mlir::zhigh::getTiling2DTo4DMap($_builder, $0)"
>;
-def GetTiling3DTo4DMap : NativeCodeCall<
- "::onnx_mlir::zhigh::getTiling3DTo4DMap($_builder, $0)"
+def GetLeftmostTiling3DTo4DMap : NativeCodeCall<
+ "::onnx_mlir::zhigh::getLeftmostTiling3DTo4DMap($_builder, $0)"
>;
def GetCollapsing4DTo2DMap: NativeCodeCall<
"::onnx_mlir::zhigh::getCollapsing4DTo2DMap($_builder, $0)"
>;
-def GetCollapsing4DTo3DMap: NativeCodeCall<
- "::onnx_mlir::zhigh::getCollapsing4DTo3DMap($_builder, $0)"
+def GetLeftmostCollapsing4DTo3DMap: NativeCodeCall<
+ "::onnx_mlir::zhigh::getLeftmostCollapsing4DTo3DMap($_builder, $0)"
>;
def GetTransposeMap : NativeCodeCall<
"::onnx_mlir::zhigh::getTransposeMap($_builder, $0)"
>;
+def Is4DTransposePermutationEqualTo0213 : Constraint<
+ CPred<"::onnx_mlir::zhigh::isTransposePermutationEqualTo($0, {0, 2, 1, 3})">,
+ "Is 4D Transpose with pattern (0, 2, 1, 3)"
+>;
+
+class IsShapeDimMultipleOf32 : Constraint<
+ CPred<"::onnx_mlir::zhigh::isShapeDimMultipleOf($0, " # index # ", 32)">,
+ "The operand shape at given index is a multiple of 32"
+>;
+
+class IsShapeDimMultipleOf64 : Constraint<
+ CPred<"::onnx_mlir::zhigh::isShapeDimMultipleOf($0, " # index # ", 64)">,
+ "The operand shape at given index is a multiple of 64"
+>;
+
def IsIdentityAffineMap : Constraint<
CPred<"$_self.isIdentity()">,
"Is identity AffineMap"
diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Reshape/Reshape.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Reshape/Reshape.cpp
new file mode 100644
index 0000000000..e1affe1a18
--- /dev/null
+++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Reshape/Reshape.cpp
@@ -0,0 +1,70 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ */
+
+//===------------------ Reshape.cpp - ZHigh Operations
+//---------------------===//
+//
+// Copyright 2025 The IBM Research Authors.
+//
+// =============================================================================
+//
+//
+//===----------------------------------------------------------------------===//
+
+#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp"
+
+using namespace mlir;
+using namespace onnx_mlir;
+
+namespace onnx_mlir {
+namespace zhigh {
+
+//===----------------------------------------------------------------------===//
+// ShapeHelper
+//===----------------------------------------------------------------------===//
+
+LogicalResult ZHighReshapeOpShapeHelper::computeShape() {
+ ZHighReshapeOpAdaptor operandAdaptor(operands);
+
+ // Shape has the dimensions of the output.
+ DimsExpr outputDims;
+ createIE->getIntFromArrayAsSymbols(operandAdaptor.getShape(), outputDims);
+ setOutputDims(outputDims);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Shape inference
+//===----------------------------------------------------------------------===//
+
+// Builder
+
+LogicalResult ZHighReshapeOp::inferShapes(
+ std::function doShapeInference) {
+ Value source = getSource();
+ if (!hasRankedType(source))
+ return success();
+ // Output type has the same type as the input/source type.
+ RankedTensorType sourceType =
+ mlir::dyn_cast(source.getType());
+ Type elementType = sourceType.getElementType();
+ // Get encoding
+ StringAttr layout = getLayoutAttr();
+ ZTensorEncodingAttr::DataLayout dataLayout;
+ Attribute encoding;
+ if (layout) {
+ // Operation has an optional output layout, use it.
+ dataLayout = convertStringAttrToZTensorDataLayout(layout);
+ encoding = ZTensorEncodingAttr::get(this->getContext(), dataLayout);
+ } else {
+ // Operation does not have an optional output layout, reuse it from input.
+ encoding = sourceType.getEncoding();
+ }
+
+ ZHighReshapeOpShapeHelper shapeHelper(getOperation());
+ return shapeHelper.computeShapeAndUpdateType(elementType, encoding);
+}
+
+} // namespace zhigh
+} // namespace onnx_mlir
diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp
index f9427116af..84282a9b1b 100644
--- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp
+++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp
@@ -56,6 +56,7 @@ DECLARE_SHAPE_HELPER_ZHIGH(ZHighStickifiedConstantOfShapeOpShapeHelper)
DECLARE_SHAPE_HELPER_ZHIGH(ZHighStickOpShapeHelper)
DECLARE_SHAPE_HELPER_ZHIGH(ZHighQuantizedStickOpShapeHelper)
DECLARE_SHAPE_HELPER_ZHIGH(ZHighUnstickOpShapeHelper)
+DECLARE_SHAPE_HELPER_ZHIGH(ZHighReshapeOpShapeHelper)
#undef DECLARE_SHAPE_HELPER_ZHIGH
//===----------------------------------------------------------------------===//
diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp
index 179cfe139b..780411cadb 100644
--- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp
+++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/Stick.cpp
@@ -12,6 +12,7 @@
//===----------------------------------------------------------------------===//
#include "src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/ShapeHelper.hpp"
+#include "src/Compiler/CompilerOptions.hpp"
using namespace mlir;
using namespace onnx_mlir;
@@ -142,6 +143,8 @@ void ZHighStickOp::getCanonicalizationPatterns(
results.insert(context);
results.insert(context);
results.insert(context);
+ results.insert(context);
+ results.insert(context);
}
} // namespace zhigh
diff --git a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/ZHighStick.td b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/ZHighStick.td
index e612f788d6..69fda996e0 100644
--- a/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/ZHighStick.td
+++ b/src/Accelerators/NNPA/Dialect/ZHigh/ZHighOps/Stick/ZHighStick.td
@@ -105,7 +105,7 @@ def ReplaceONNXSoftplusPattern: Pattern<
(ZHighStickOp:$stickout (ONNXSoftplusOp:$out (ZHighUnstickOp $X)), $layout, $_),
[
// Get stickified constant of minus one with input shape.
- // Donot saturate since input orignally from NNPA.
+ // Donot saturate since input originally from NNPA.
(ZHighStickOp:$minusOne (GetConstantOfType<"-1.0"> $out), $layout, (NoneIntegerAttr)),
// Get minus X with input shape.
(ZHighMulOp:$minusX $X, $minusOne, (returnType $X)),
@@ -122,7 +122,7 @@ def ReplaceONNXSoftplusPattern: Pattern<
[(IsStaticShapeTensor $X), (SameLayout $X, $stickout)]
>;
-// Calulation of `1/sqrt(X)` or reciprocal square root is often found in
+// Calculation of `1/sqrt(X)` or reciprocal square root is often found in
// deep learning models, but zDNN does not support it. Thus, we rewrite it into
// zDNN-supported operations.
//
@@ -149,7 +149,94 @@ def ReplaceONNXReciprocalSqrtPattern: Pat<
[(IsStaticShapeTensor $X), (SameLayout $X, $stick)]
>;
-// The folowing pattern was found in bertsquad and GPT models.
+// The following pattern was found in Roberta models.
+// ```
+// %66 = "zhigh.Unstick"(%65) : (tensor<12x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<12x384x64xf32>
+// %67 = "onnx.Reshape"(%66, %2) {allowzero = 0 : si64} : (tensor<12x384x64xf32>, tensor<4xi64>) -> tensor<1x12x384x64xf32>
+// %68 = "onnx.Transpose"(%67) {onnx_node_name = "Transpose_94", perm = [0, 2, 1, 3]} : (tensor<1x12x384x64xf32>) -> tensor<1x384x12x64xf32>
+// %69 = "onnx.Reshape"(%68, %9) {allowzero = 0 : si64, onnx_node_name = "Reshape_104"} : (tensor<1x384x12x64xf32>, tensor<3xi64>) -> tensor<1x384x768xf32>
+// %70 = "zhigh.Stick"(%69) {layout = "3DS"} : (tensor<1x384x768xf32>) -> tensor<1x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>
+// ```
+// When the input tensor %65 (here tensor<12x384x64xf16) with dims E3xE2xE1 and 3DS layout satisfies the following properties
+// 1) E2 % 32 == 0
+// 2) E1 % 64 == 0
+// Then we can guarantee that input value %65 has the same value
+// in the same memory location (*) than output %70.
+// (*) does not say the tensor are identical, just that they store
+// the same value at the same byte offset in the memory.
+//
+// High level intuition for this is that the reshape / transpose / reshape
+// perform some memory layout operations, which results in no change in the
+// 3DS representation as that representation also perform layout changes.
+//
+// Limitation: current pattern assume static sizes.
+
+def ReshapeTransposeReshapeRoberta3DSWPattern1 : Pat<
+ // Input: X -> unstick -> reshape1 -> transpose -> reshape 2 -> stick.
+ (ZHighStickOp:$stick
+ (ONNXReshapeOp:$reshape2
+ (ONNXTransposeOp:$transpose
+ (ONNXReshapeOp:$reshape1
+ (ZHighUnstickOp:$unstick $X),
+ $shape1, $_),
+ $perm),
+ $shape2, $_),
+ $layout3DS, $saturation),
+ // Output: initial X value unchanged, but transformed with the new compatible shape.
+ (ZHighReshapeOp $X, (CreateShapeOp (GetShapeTypeOf $stick), $stick), (GetLayout $stick)),
+ // Conditions.
+ [(TensorHas3DSLayout $X), (Is3DSLayout $layout3DS), // Input/output are 3DS.
+ (IsStaticShapeTensor $X), (IsStaticShapeTensor $unstick), // Static shapes only.
+ (IsStaticShapeTensor $reshape1), (IsStaticShapeTensor $transpose),
+ (IsStaticShapeTensor $reshape2),(IsStaticShapeTensor $stick),
+ (IsShapeDimMultipleOf32<1> $X), // Second dim of input is a multiple of 32.
+ (IsShapeDimMultipleOf64<2> $X), // Third dim of input is a multiple of 64.
+ (Is4DTransposePermutationEqualTo0213 $perm), // Permute middle 2 dims.
+ (IsLeftmostTiling3DTo4D $reshape1), // 1st reshape is tiling in the leftmost dimension
+ (IsRightmostCollapsing4DTo3D $reshape2), // 2nd reshape is collapsing the last two dimensions.
+ ]
+>;
+
+// Second pattern found in roberta
+//
+// %33 = "zhigh.Unstick"(%32) : (tensor<8x384x768xf16, #zhigh.layout<{dataLayout = "3DS"}>>) -> tensor<8x384x768xf32> // unstick
+// %44 = "onnx.Reshape"(%33, %7) {allowzero = 0 : si64, onnx_node_name = "Reshape_64"} : (tensor<8x384x768xf32>, tensor<4xi64>) -> tensor<8x384x12x64xf32> // last goes from 768 to 12, 64
+// %45 = "onnx.Transpose"(%44) {onnx_node_name = "Transpose_65", perm = [0, 2, 1, 3]} : (tensor<8x384x12x64xf32>) -> tensor<8x12x384x64xf32> // same permute
+// %50 = "onnx.Reshape"(%45, %2) {allowzero = 0 : si64} : (tensor<8x12x384x64xf32>, tensor<3xi64>) -> tensor<96x384x64xf32> // collapse first 2 dims
+// %52 = "zhigh.Stick"(%50) {layout = "3DS"} : (tensor<96x384x64xf32>) -> tensor<96x384x64xf16, #zhigh.layout<{dataLayout = "3DS"}>>
+//
+// Namely unstick -> reshape (by tiling rightmost one to E1/64, 64) -> permute (2nd & 3rd) -> reshape (collapse first 2) -> stick
+// Shapes goes from 8 x 384 x (12*64 = 768) to (8*12 = 96) x 384 x 64
+//
+// pattern works only when input E2 % 32 and E1 % 64 == 0
+
+def ReshapeTransposeReshapeRoberta3DSWPattern2 : Pat<
+ // Input: X -> unstick -> reshape1 -> transpose -> reshape 2 -> stick.
+ (ZHighStickOp:$stick
+ (ONNXReshapeOp:$reshape2
+ (ONNXTransposeOp:$transpose
+ (ONNXReshapeOp:$reshape1
+ (ZHighUnstickOp:$unstick $X),
+ $shape1, $_),
+ $perm),
+ $shape2, $_),
+ $layout3DS, $saturation),
+ // Output: initial X value unchanged, but transformed with the compatible shape.
+ (ZHighReshapeOp $X, (CreateShapeOp (GetShapeTypeOf $stick), $stick), (GetLayout $stick)),
+ // Conditions.
+ [(TensorHas3DSLayout $X), (Is3DSLayout $layout3DS), // Input/output are 3DS.
+ (IsStaticShapeTensor $X), (IsStaticShapeTensor $unstick), // Static shapes only.
+ (IsStaticShapeTensor $reshape1), (IsStaticShapeTensor $transpose),
+ (IsStaticShapeTensor $reshape2),(IsStaticShapeTensor $stick),
+ (IsShapeDimMultipleOf32<1> $X), // Second dim of input is a multiple of 32.
+ (IsShapeDimMultipleOf64<2> $X), // Third dim of input is a multiple of 64.
+ (Is4DTransposePermutationEqualTo0213 $perm), // Permute middle 2 dims.
+ (IsRightmostTiling3DTo4DBy64 $reshape1), // 1st reshape is tiling by 64 the rightmost dimension
+ (IsLeftmostCollapsing4DTo3D $reshape2), // 2nd reshape is collapsing the first two dimensions.
+ ]
+>;
+
+// The following pattern was found in bertsquad and GPT models.
// ```
// %0 = "zhigh.Unstick"(%X) {layout = "2D"} : tensor> -> tensor
// %1 = "onnx.Reshape"(%0) : tensor -> tensor
@@ -186,7 +273,7 @@ def ReshapeTransposeReshape2DTo3DSPattern : Pat<
(returnType (GetResultType $reshape1))),
(GetTransposeMap $perm),
(returnType (GetResultType $transpose))),
- (GetCollapsing4DTo3DMap $reshape2),
+ (GetLeftmostCollapsing4DTo3DMap $reshape2),
(returnType (GetResultType $reshape2))),
$layout3DS, $saturation),
[(TensorHas2DLayout $X), (Is3DSLayout $layout3DS),
@@ -194,7 +281,7 @@ def ReshapeTransposeReshape2DTo3DSPattern : Pat<
(IsStaticShapeTensor $reshape1), (IsStaticShapeTensor $transpose),
(IsStaticShapeTensor $reshape2),(IsStaticShapeTensor $stick),
(IsTiling2DTo4D $reshape1), // 1st reshape is tiling over each input dimension
- (IsCollapsing4DTo3D $reshape2), // 2nd reshape is collapsing the first two dimensions.
+ (IsLeftmostCollapsing4DTo3D $reshape2), // 2nd reshape is collapsing the first two dimensions.
]
>;
@@ -214,7 +301,7 @@ def ReshapeTransposeReshape3DSTo2DPattern : Pat<
(ONNXShapeTransformOp // transpose
(ONNXShapeTransformOp // reshape
(ZHighUnstickOp $X),
- (GetTiling3DTo4DMap $reshape1),
+ (GetLeftmostTiling3DTo4DMap $reshape1),
(returnType (GetResultType $reshape1))),
(GetTransposeMap $perm),
(returnType (GetResultType $transpose))),
@@ -225,7 +312,7 @@ def ReshapeTransposeReshape3DSTo2DPattern : Pat<
(IsStaticShapeTensor $X), (IsStaticShapeTensor $unstick),
(IsStaticShapeTensor $reshape1), (IsStaticShapeTensor $transpose),
(IsStaticShapeTensor $reshape2),(IsStaticShapeTensor $stick),
- (IsTiling3DTo4D $reshape1), // 1st reshape is tiling over each input dimension
+ (IsLeftmostTiling3DTo4D $reshape1), // 1st reshape is tiling over each input dimension
(IsCollapsing4DTo2D $reshape2), // 2nd reshape is collapsing the first two dimensions.
]
>;
diff --git a/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td b/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td
index a66cb8273f..1ab8ae958d 100644
--- a/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td
+++ b/src/Accelerators/NNPA/Dialect/ZLow/ZLow.td
@@ -687,4 +687,15 @@ def ZLowConvertF32ToDLF16VectorOp:ZLow_Op<"vec_f32_to_dlf16", [Pure]> {
];
}
+def ZLowReshapeOp:ZLow_Op<"reshape", [MemRefsNormalizable]> {
+ let summary = "ZLow Reshape operation";
+ let description = [{
+ ZLow operation to perform a reshape (no data movement).
+ }];
+ // Note that no shape is needed for this operation.
+ let arguments = (ins ZMemRef:$X,
+ ZMemRef:$Out,
+ StrAttr:$layout);
+}
+
#endif // ZLOW_OPS
diff --git a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp
index f97373764a..2328f70265 100644
--- a/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp
+++ b/src/Accelerators/NNPA/Transform/ZLow/ZLowRewrite.cpp
@@ -25,6 +25,7 @@
#include "src/Accelerators/NNPA/Pass/NNPAPasses.hpp"
#include "src/Accelerators/NNPA/Support/LayoutHelper.hpp"
#include "src/Dialect/Mlir/DialectBuilder.hpp"
+#include "src/Dialect/Mlir/IndexExpr.hpp"
#include