Skip to content

Commit

Permalink
Optimization for Roberta unstick->reshape->transpose->reshape->stick (#…
Browse files Browse the repository at this point in the history
…3056)

Signed-off-by: Alexandre Eichenberger <[email protected]>
  • Loading branch information
AlexandreEichenberger authored Feb 1, 2025
1 parent 988271d commit 8530104
Show file tree
Hide file tree
Showing 22 changed files with 832 additions and 76 deletions.
48 changes: 34 additions & 14 deletions docs/Dialects/zhigh.md
Original file line number Diff line number Diff line change
Expand Up @@ -782,13 +782,6 @@ Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterfac

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>op_type</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
Expand All @@ -814,13 +807,6 @@ Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterfac

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>op_type</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
Expand Down Expand Up @@ -857,6 +843,40 @@ Effects: `MemoryEffects::Effect{}`
| :----: | ----------- |
| `Out` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH

### `zhigh.Reshape` (::onnx_mlir::zhigh::ZHighReshapeOp)

_ZHigh Reshape operation for Z Tensors_

ZHigh operation to perform a converts a Z Tensor from one type to an equivalent type
with a provided shape. The data is never copied or modified. When no layout is specified,
the output preserve the same layout as the source input.

Traits: `AlwaysSpeculatableImplTrait`

Interfaces: `ConditionallySpeculatable`, `NoMemoryEffect (MemoryEffectOpInterface)`, `ShapeHelperOpInterface`, `ShapeInferenceOpInterface`

Effects: `MemoryEffects::Effect{}`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>layout</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `source` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH
| `shape` | tensor of 64-bit signless integer values

#### Results:

| Result | Description |
| :----: | ----------- |
| `result` | unranked tensor of 16-bit float values or 1D tensor of 16-bit float values with layout _1D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2D or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3D or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4D or unranked tensor of 16-bit float values or 2D tensor of 16-bit float values with layout _2DS or unranked tensor of 16-bit float values or 3D tensor of 16-bit float values with layout _3DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout _4DS or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NCHW or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout NHWC or unranked tensor of 16-bit float values or 4D tensor of 16-bit float values with layout HWCK or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout FICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout ZRH or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BFICO or unranked tensor of 16-bit float values or 2D/3D tensor of 16-bit float values with layout BZRH

### `zhigh.Sigmoid` (::onnx_mlir::zhigh::ZHighSigmoidOp)

_ZHigh Sigmoid operation_
Expand Down
25 changes: 23 additions & 2 deletions docs/Dialects/zlow.md
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,6 @@ Traits: `MemRefsNormalizable`
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>layout</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>op_type</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:
Expand All @@ -795,7 +794,6 @@ Traits: `MemRefsNormalizable`
<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>layout</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
<tr><td><code>op_type</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:
Expand Down Expand Up @@ -832,6 +830,29 @@ Interfaces: `MemoryEffectOpInterface`
| `shape` | memref of 64-bit signless integer values
| `Out` | memref of dlfloat16 type values

### `zlow.reshape` (::onnx_mlir::zlow::ZLowReshapeOp)

_ZLow Reshape operation_

ZLow operation to perform a reshape (no data movement).

Traits: `MemRefsNormalizable`

#### Attributes:

<table>
<tr><th>Attribute</th><th>MLIR Type</th><th>Description</th></tr>
<tr><td><code>layout</code></td><td>::mlir::StringAttr</td><td>string attribute</td></tr>
</table>

#### Operands:

| Operand | Description |
| :-----: | ----------- |
| `X` | memref of dlfloat16 type values
| `shape` | memref of 64-bit signless integer values
| `Out` | memref of dlfloat16 type values

### `zlow.sigmoid` (::onnx_mlir::zlow::ZLowSigmoidOp)

_ZLow sigmoid operation_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,6 @@ def replaceONNXBatchNormalizationInferenceModePattern : Pattern<
//
//===----------------------------------------------------------------------===//


// Create an ONNX Shape Op with type
def CreateShapeOp: NativeCodeCall<
"$_builder.create<mlir::ONNXShapeOp>($_loc, $0, $1, IntegerAttr(), 0)"
>;

// Get a type for a tensor that stores the shape of another tensor.
def GetShapeTypeOf: NativeCodeCall<
"RankedTensorType::get({mlir::cast<ShapedType>($0.getType()).getRank()}, $_builder.getIntegerType(64))"
>;

// Check unidirectional broadcasting from the first to second tensor.
def IsUniBroadcastingFromFirstToSecond: Constraint<
CPred<"isUniBroadcatableFirstToSecond($0, $1)">,
Expand Down
45 changes: 43 additions & 2 deletions src/Accelerators/NNPA/Conversion/ZHighToZLow/ZHighToZLow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1093,6 +1093,47 @@ struct ZHighToZLowUnaryOpLowering : public ConversionPattern {
}
};

// Reshape operation. Code similar to unary lowering, except that we use the
// operation's specialized shape here.
struct ZHighToZLowReshapeOpLowering : public ConversionPattern {
ZHighToZLowReshapeOpLowering(TypeConverter &typeConverter, MLIRContext *ctx)
: ConversionPattern(ZHighReshapeOp::getOperationName(), 1, ctx) {}

LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
Location loc = op->getLoc();
Value input = operands[0];

// Helper builders.
MultiDialectBuilder<IndexExprBuilderForKrnl> create(rewriter, loc);

// Convert ZTensor type to MemRefType.
ZMemRefType zMemRefType =
convertZTensorToMemRefType(*op->result_type_begin());

// Shape helper.
ZHighReshapeOpShapeHelper shapeHelper(op, operands, &create.krnlIE);
shapeHelper.computeShapeAndAssertOnFailure();
SmallVector<IndexExpr, 4> &dims = shapeHelper.getOutputDims();

// Allocate a buffer for the result MemRef. Follow this pattern to be
// similar to all the other zlow patterns. Will remove the alloc when
// lowering zlow.reshape to memref.reinterpret_cast once memrefs are
// normalized. See code in ReshapeToReinterpretCastPattern.
Value alloc = insertAllocForZMemRef(zMemRefType, dims, op, rewriter);

// Note, we do not need to save the shape of the original operation, as this
// reshape is "no-op" that logically reorganize the shape of the operation
// into 2 equivalent shapes under their given layout.

// Emit a ZLow operation.
rewriter.create<ZLowReshapeOp>(
loc, input, /* shape,*/ alloc, zMemRefType.layout);
rewriter.replaceOp(op, alloc);
return success();
}
};

//===----------------------------------------------------------------------===//
// Lower ZHigh ReduceMax/ReduceMin to ZLow ReduceMax/ReduceMin
//===----------------------------------------------------------------------===//
Expand All @@ -1117,8 +1158,6 @@ struct ZHighToZLowReduceOpLowering : public ConversionPattern {
: ConversionPattern(OP_TYPE::getOperationName(), 1, ctx) {}
LogicalResult matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
MLIRContext *context = rewriter.getContext();
OP_TYPE reduceOp = mlir::cast<OP_TYPE>(op);
Location loc = op->getLoc();
Value data = operands[0];

Expand Down Expand Up @@ -2285,6 +2324,8 @@ void populateZHighToZLowConversionPattern(mlir::RewritePatternSet &patterns,
patterns.insert<ZHighToZLowUnaryOpLowering<ZHighTanhOp>>(typeConverter, ctx);
patterns.insert<ZHighToZLowUnaryOpLowering<ZHighSigmoidOp>>(
typeConverter, ctx);
// Reshape operations.
patterns.insert<ZHighToZLowReshapeOpLowering>(typeConverter, ctx);
// Neural network operations.
patterns.insert<ZHighToZLowReduceOpLowering<ZHighReduceMaxOp>>(
typeConverter, ctx);
Expand Down
1 change: 1 addition & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ add_onnx_mlir_library(OMZHighOps
ZHighOps/QuantizedMatMul/QuantizedMatMul.cpp
ZHighOps/QuantizedStick/QuantizedStick.cpp
ZHighOps/Reduction/Reduction.cpp
ZHighOps/Reshape/Reshape.cpp
ZHighOps/Softmax/Softmax.cpp
ZHighOps/Stick/Stick.cpp
ZHighOps/StickForGRU/StickForGRU.cpp
Expand Down
24 changes: 24 additions & 0 deletions src/Accelerators/NNPA/Dialect/ZHigh/ZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -1195,4 +1195,28 @@ def ZHighFixGRUYhOp:ZHigh_Op<"FixGRUYh", [Pure,
}];
}

def ZHighReshapeOp:ZHigh_Op<"Reshape", [Pure,
DeclareOpInterfaceMethods<ShapeInferenceOpInterface>,
DeclareOpInterfaceMethods<ShapeHelperOpInterface>]> {
let summary = "ZHigh Reshape operation for Z Tensors";
let description = [{
ZHigh operation to perform a converts a Z Tensor from one type to an equivalent type
with a provided shape. The data is never copied or modified. When no layout is specified,
the output preserve the same layout as the source input.
}];
let arguments = (ins AnyTypeOf<[AnyZTensor]>:$source, // Input Z Tensor to be reshaped.
TensorOf<[I64]>:$shape, // Shape of output Z Tensor.
OptionalAttr<StrAttr>:$layout); // Layout of output Z Tensor, default same as input.
let results = (outs AnyTypeOf<[AnyZTensor]>:$result);

let extraClassDefinition = [{
onnx_mlir::ONNXOpShapeHelper * ZHighReshapeOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef<mlir::Value> oper,
onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) {
onnx_mlir::ONNXOpShapeHelper *sh = new ZHighReshapeOpShapeHelper(op, oper, ieb, scope);
assert(sh && "failed to allocate shape helper");
return sh;
}
}];
}

#endif // ZHIGH_OPS
Loading

0 comments on commit 8530104

Please sign in to comment.