Skip to content

Commit

Permalink
[NNPA] Support bias none in ONNX.Gemm (#1466)
Browse files Browse the repository at this point in the history
* Support for bias none in Gemm

Signed-off-by: Haruki Imai <[email protected]>
  • Loading branch information
imaihal authored May 27, 2022
1 parent 862bca6 commit 0996500
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 12 deletions.
26 changes: 17 additions & 9 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXLegalityCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ bool isValidElementType(Value val) {
/// detect whether the shapes are exactly the same or not. Hence, return false.
/// Also, check the ranks of two tensors, they must be in range of (0, 4].
bool haveSameStaticShape(Value value1, Value value2) {
auto valueType1 = value1.getType().cast<ShapedType>();
auto valueType2 = value2.getType().cast<ShapedType>();
ShapedType valueType1 = value1.getType().cast<ShapedType>();
ShapedType valueType2 = value2.getType().cast<ShapedType>();
if (!valueType1.hasRank() || !valueType2.hasRank())
return false;
// Different rank, return false.
if (valueType1.getRank() != valueType2.getRank())
return false;
Expand Down Expand Up @@ -360,48 +362,54 @@ template <>
bool isSuitableForZDNN<ONNXSoftmaxOp>(ONNXSoftmaxOp op) {
if (!isValidElementType(op.input()))
return false;
return ((op.axis() == 1 || op.axis() == -1) &&
(op.input().getType().cast<ShapedType>().getRank() == 2));
ShapedType inputType = op.getType().cast<ShapedType>();
return (op.axis() == 1 || op.axis() == -1) && inputType.hasRank() &&
(inputType.getRank() == 2);
}

/// Check legality for ONNXRelu.
template <>
bool isSuitableForZDNN<ONNXReluOp>(ONNXReluOp op) {
if (!isValidElementType(op.X()))
return false;
return (op.X().getType().cast<ShapedType>().getRank() <= 4);
ShapedType xType = op.X().getType().cast<ShapedType>();
return xType.hasRank() && (xType.getRank() <= 4);
}

/// Check legality for ONNXTanh.
template <>
bool isSuitableForZDNN<ONNXTanhOp>(ONNXTanhOp op) {
if (!isValidElementType(op.input()))
return false;
return (op.input().getType().cast<ShapedType>().getRank() <= 4);
ShapedType inputType = op.getType().cast<ShapedType>();
return inputType.hasRank() && (inputType.getRank() <= 4);
}

/// Check legality for ONNXSigmoid.
template <>
bool isSuitableForZDNN<ONNXSigmoidOp>(ONNXSigmoidOp op) {
if (!isValidElementType(op.X()))
return false;
return (op.X().getType().cast<ShapedType>().getRank() <= 4);
ShapedType xType = op.X().getType().cast<ShapedType>();
return xType.hasRank() && (xType.getRank() <= 4);
}

/// Check legality for ONNXLog.
template <>
bool isSuitableForZDNN<ONNXLogOp>(ONNXLogOp op) {
if (!isValidElementType(op.input()))
return false;
return (op.input().getType().cast<ShapedType>().getRank() <= 4);
ShapedType inputType = op.input().getType().cast<ShapedType>();
return inputType.hasRank() && (inputType.getRank() <= 4);
}

/// Check legality for ONNXExp.
template <>
bool isSuitableForZDNN<ONNXExpOp>(ONNXExpOp op) {
if (!isValidElementType(op.input()))
return false;
return (op.input().getType().cast<ShapedType>().getRank() <= 4);
ShapedType inputType = op.input().getType().cast<ShapedType>();
return inputType.hasRank() && (inputType.getRank() <= 4);
}

/// Check legality for ONNXMatMul.
Expand Down
16 changes: 13 additions & 3 deletions src/Accelerators/NNPA/Conversion/ONNXToZHigh/ONNXToZHigh.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,16 @@ def IsNoneType : Constraint<CPred<"(($_self).getType().isa<NoneType>())">>;
def IsNotNoneType : Constraint<CPred<"(!($_self).getType().isa<NoneType>())">>;

class HasRankOf<int rank> : Constraint<
CPred<"$0.getType().isa<ShapedType>() && $0.getType().cast<ShapedType>().getRank() == " # rank>
CPred<"$0.getType().isa<ShapedType>() && "
"$0.getType().cast<ShapedType>().hasRank() && "
"$0.getType().cast<ShapedType>().getRank() == " # rank>
>;

def IsBiasNoneOr1D : Constraint<
CPred<"$_self.getType().isa<NoneType>() || "
" ($_self.getType().isa<ShapedType>() && "
" $_self.getType().cast<ShapedType>().hasRank() && "
" $_self.getType().cast<ShapedType>().getRank() == 1)">
>;

class VariadicSizeIs<int N> : Constraint<
Expand Down Expand Up @@ -536,14 +545,15 @@ def normalizeONNXGemmTransBPattern : Pat<
(addBenefit 1)
>;

def replaceONNXGemmBias1DPattern : Pat<

def replaceONNXGemmBiasNoneOr1DPattern : Pat<
(ONNXGemmOp $a, $b, $c, $_, $_, $_, $_),
(ZHighUnstickOp
(ZHighMatMulOp
(ZHighStickOp $a, (_2DLayoutAttr)),
(ZHighStickOp $b, (_2DLayoutAttr)),
(ZHighStickOp $c, (_1DLayoutAttr)))),
[(HasRankOf<1> $c)],
[(IsBiasNoneOr1D:$c)],
(addBenefit 0)
>;

Expand Down
18 changes: 18 additions & 0 deletions test/mlir/accelerators/nnpa/conversion/onnx-to-zhigh/gemm.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,23 @@
// RUN: onnx-mlir-opt --maccel=NNPA --shape-inference --convert-onnx-to-zhigh --canonicalize %s -split-input-file | FileCheck %s

func @test_gemm_bias_none(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32>) -> tensor<*xf32> {
%bias = "onnx.NoValue"() {value} : () -> none
%0 ="onnx.Gemm"(%arg0, %arg1, %bias) {alpha = 1.0 : f32, beta = 1.0 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<10x5xf32>, tensor<5x10xf32>, none) -> tensor<*xf32>
"func.return"(%0) : (tensor<*xf32>) -> ()

// CHECK-LABEL: func @test_gemm_bias_none
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x5xf32>, [[PARAM_1_:%.+]]: tensor<5x10xf32>) -> tensor<10x10xf32> {
// CHECK-DAG: [[VAR_0_:%.+]] = "zhigh.Stick"([[PARAM_0_]]) {layout = "2D"} : (tensor<10x5xf32>) -> tensor<10x5xf32, #zhigh.encoding<{dataLayout = "2D"}>>
// CHECK-DAG: [[VAR_1_:%.+]] = "zhigh.Stick"([[PARAM_1_]]) {layout = "2D"} : (tensor<5x10xf32>) -> tensor<5x10xf32, #zhigh.encoding<{dataLayout = "2D"}>>
// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.NoValue"() {value} : () -> none
// CHECK: [[VAR_3_:%.+]] = "zhigh.MatMul"([[VAR_0_]], [[VAR_1_]], [[VAR_2_]]) : (tensor<10x5xf32, #zhigh.encoding<{dataLayout = "2D"}>>, tensor<5x10xf32, #zhigh.encoding<{dataLayout = "2D"}>>, none) -> tensor<*xf32>
// CHECK: [[VAR_4_:%.+]] = "zhigh.Unstick"([[VAR_3_]]) : (tensor<*xf32>) -> tensor<10x10xf32>
// CHECK: return [[VAR_4_]] : tensor<10x10xf32>
// CHECK: }
}

// -----

func @test_gemm_bias_1d(%arg0 : tensor<10x5xf32>, %arg1 : tensor<5x10xf32>, %arg2: tensor<10xf32>) -> tensor<*xf32> {
%0 ="onnx.Gemm"(%arg0, %arg1, %arg2) {alpha = 1.0 : f32, beta = 1.0 : f32, transA = 0 : si64, transB = 0 : si64} : (tensor<10x5xf32>, tensor<5x10xf32>, tensor<10xf32>) -> tensor<*xf32>
"func.return"(%0) : (tensor<*xf32>) -> ()
Expand Down

0 comments on commit 0996500

Please sign in to comment.