diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 3331ca4cb8643f..7808ea7da5fc7e 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -305,6 +305,7 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [ ); let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -498,7 +499,7 @@ def Tosa_AddOp : Tosa_ElemWiseBinaryOp<"add", [Commutative]> { Tosa_Tensor:$output ); - let hasFolder = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1124,7 +1125,7 @@ def Tosa_SelectOp : Tosa_Op<"select", [ Tosa_Tensor:$output ); let hasCanonicalizeMethod = 1; - let hasFolder = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1208,7 +1209,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [ I1Tensor:$output ); - let hasFolder = 1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// @@ -1591,8 +1592,7 @@ def Tosa_TileOp: Tosa_Op<"tile", [ // Operator: transpose //===----------------------------------------------------------------------===// def Tosa_TransposeOp : Tosa_Op<"transpose", [ - DeclareOpInterfaceMethods, + InferTensorType, Pure]> { let summary = "Transpose operator"; @@ -1611,6 +1611,9 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [ let extraClassDeclaration = [{ LogicalResult getConstantPerms(llvm::SmallVector &perms); + /// Returns true when two result types are compatible for this op; + /// Method used by InferTypeOpInterface. + static bool isCompatibleReturnTypes(TypeRange l, TypeRange r); }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index 48dc95b3bed496..040eac24c85d0e 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -140,7 +140,102 @@ template static LogicalResult verifyConvOp(T op) { return success(); } +template static LogicalResult verifyPoolOp(T op) { + auto inputETy = llvm::cast(op.getInput().getType()).getElementType(); + auto resultETy = llvm::cast(op.getType()).getElementType(); + if (auto quantType = + llvm::dyn_cast(inputETy)) + inputETy = quantType.getStorageType(); + + if (auto quantType = + llvm::dyn_cast(resultETy)) + resultETy = quantType.getStorageType(); + + // [kernel_y, kernel_x] <-> [0,1] + auto kernel = op.getKernel(); + // [stride_y, stride_x] + auto stride = op.getStride(); + // [pad_top, pad_bottom, pad_left, pad_right] + auto pad = op.getPad(); + // ERROR_IF(kernel_y < 1 || kernel_x < 1); // kernel size must be >= 1 + if (kernel[0] < 1 || kernel[1] < 1) { + return op.emitOpError("kernel should be greater than one."); + } + // ERROR_IF(stride_y < 1 || stride_x < 1); + if (stride[0] < 0 || stride[1] < 0) { + return op.emitOpError("stride should be greater than one."); + } + // ERROR_IF(pad_top < 0 || pad_bottom < 0 || pad_left < 0 || pad_right < 0); + if (pad[0] < 0 || pad[1] < 0 || pad[2] < 0 || pad[3] < 0) { + return op.emitOpError("pad should be positive."); + } + // Padding must be less than kernel size to avoid + // a divide-by-zero. + /* + ERROR_IF(pad_right >= kernel_x || pad_left >= kernel_x); + ERROR_IF(pad_top >= kernel_y || pad_bottom >= kernel_y); + */ + + if (pad[3] >= kernel[1] || pad[2] >= kernel[1] || pad[0] >= kernel[0] || + pad[1] >= kernel[0]) { + return op.emitOpError("pad must be less than kernel size."); + } + + //[N,IH,IW,C] + auto inputShapeType = llvm::cast(op.getInput().getType()); + //[N,OH,OW,C] + auto outputShapeType = llvm::cast(op.getOutput().getType()); + if (inputShapeType.hasStaticShape() && outputShapeType.hasStaticShape()) { + auto inputShape = inputShapeType.getShape(); + auto outputShape = outputShapeType.getShape(); + auto inputHeight = inputShape[1]; + auto inputWidth = inputShape[2]; + auto outputHeight = outputShape[1]; + auto outputWidth = outputShape[2]; + // IH + pad_top + pad_bottom - kernel_y + auto height = inputHeight + pad[0] + pad[1] - kernel[0]; + // IW + pad_left + pad_right - kernel_x + auto width = inputWidth + pad[2] + pad[3] - kernel[1]; + // idiv_check(IH + pad_top + pad_bottom - kernel_y, stride_y) + if (height % stride[0] != 0) { + return op.emitOpError("vertical stride is not in correct multiple."); + } + // idiv_check(IW + pad_left + pad_right - kernel_x, stride_x) + if (width % stride[1] != 0) { + return op.emitOpError("horizontal stride is not in correct multiple."); + } + /* + ERROR_IF(OH != idiv_check(IH + pad_top + pad_bottom - kernel_y, stride_y) + + 1); + */ + + if ((outputHeight != (height / stride[0]) + 1)) { + return op.emitOpError("output height is not correct, should be ") + << (height / stride[0]) + 1 << "."; + } + /* + ERROR_IF(OW != idiv_check(IW + pad_left + pad_right - kernel_x, stride_x) + + 1); + */ + if (outputWidth != (width / stride[1]) + 1) { + return op.emitOpError("output width is not correct, should be ") + << (width / stride[1]) + 1 << "."; + } + } + if (inputETy.isF32() && resultETy.isF32()) + return success(); + if (inputETy.isInteger(8) && resultETy.isInteger(8)) + return success(); + if (inputETy.isInteger(16) && resultETy.isInteger(16)) + return success(); + if (inputETy.isInteger(32) && resultETy.isInteger(32)) + return success(); + + return op.emitOpError("input/output element types are incompatible."); +} + +LogicalResult tosa::MaxPool2dOp::verify() { return verifyPoolOp(*this); } LogicalResult tosa::AvgPool2dOp::verify() { auto inputETy = llvm::cast(getInput().getType()).getElementType(); auto resultETy = llvm::cast(getType()).getElementType(); @@ -157,21 +252,18 @@ LogicalResult tosa::AvgPool2dOp::verify() { if (llvm::isa(inputETy) && !accType.isInteger(32)) return emitOpError("accumulator type for integer tensor is not i32"); - if ((inputETy.isBF16() || inputETy.isF16()) && - !(accType.isF16() || accType.isF32())) - return emitOpError("accumulator type for f16/bf16 tensor is not f16/f32"); + auto result = verifyPoolOp(*this); + if (result.succeeded()) { + if ((inputETy.isF16()) && !(accType.isF16() || accType.isF32())) + return emitOpError("accumulator type for f16 tensor is not f16/f32"); - if (inputETy.isF32() && !accType.isF32()) - return emitOpError("accumulator type for f32 tensor is not f32"); + if ((inputETy.isBF16()) && !(accType.isF32())) + return emitOpError("accumulator type for bf16 tensor is not f32"); - if (inputETy.isF32() && resultETy.isF32()) - return success(); - if (inputETy.isInteger(8) && resultETy.isInteger(8)) - return success(); - if (inputETy.isInteger(16) && resultETy.isInteger(16)) - return success(); - - return emitOpError("input/output element types are incompatible."); + if (inputETy.isF32() && !accType.isF32()) + return emitOpError("accumulator type for f32 tensor is not f32"); + } + return result; } //===----------------------------------------------------------------------===// @@ -712,6 +804,33 @@ bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]); } +bool tosa::TransposeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) { + + if (l.size() != r.size() || l.size() != 1) + return false; + + auto left = getElementTypeOrSelf(l[0]); + auto right = getElementTypeOrSelf(r[0]); + + if (auto quantType = llvm::dyn_cast(left)) + left = quantType.getStorageType(); + + if (auto quantType = llvm::dyn_cast(left)) + left = quantType.getStorageType(); + + if (auto quantType = llvm::dyn_cast(right)){ + right = quantType.getStorageType(); + } + + if (auto quantType = llvm::dyn_cast(right)){ + right = quantType.getStorageType(); + } + + if (left != right) + return false; + return succeeded(verifyCompatibleShape(l[0], r[0])); +} + LogicalResult tosa::ReshapeOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, ValueShapeRange operands, DictionaryAttr attributes, @@ -860,6 +979,16 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( SmallVectorImpl &inferredReturnShapes) { ShapeAdaptor inputShape = operands.getShape(0); ShapeAdaptor permsShape = operands.getShape(1); + auto inputType = getElementTypeOrSelf(operands[0]); + + if (auto quantType = + llvm::dyn_cast(inputType)) + inputType = quantType.getStorageType(); + + if (auto quantType = + llvm::dyn_cast(inputType)) + inputType = quantType.getStorageType(); + // If input rank and permutation length is unknown, the output rank is // unknown. @@ -880,13 +1009,13 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( SmallVector outputShape; if (!inputShape.hasRank()) { outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamic); - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } // Rank-0 means no permutations matter. if (inputShape.getRank() == 0) { - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } @@ -903,7 +1032,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( // permutation. if (allTheSame) { outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0)); - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } @@ -917,7 +1046,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents( } } - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index e285a9de1d66d3..4a6d7576b38a0b 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -151,3 +151,148 @@ func.func @test_const_attribute_type_mismatch() -> tensor<100x100xf32> { %0 = "tosa.const"() {value = dense<0.000000e+00> : tensor<1x1xf32>} : () -> tensor<100x100xf32> return %0 : tensor<100x100xf32> } + +// ----- + +func.func @test_avg_pool2d_negative_kernel(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op kernel should be greater than one.}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + return %0 : tensor<1x7x7x9xi8> +} + +// ----- + +func.func @test_avg_pool2d_negative_stride(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op stride should be greater than one.}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + return %0 : tensor<1x7x7x9xi8> +} + +// ----- + +func.func @test_avg_pool2d_negative_pad(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op pad should be positive}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + return %0 : tensor<1x7x7x9xi8> +} + +// ----- + +func.func @test_avg_pool2d_kernel_lessthan_pad(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op pad must be less than kernel size}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + return %0 : tensor<1x7x7x9xi8> +} + +// ----- + +func.func @test_avg_pool2d_vert_stride_incorrect_multiple(%arg0: tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op vertical stride is not in correct multiple.}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x7x7x9xi8>) -> tensor<1x7x7x9xi8> + return %0 : tensor<1x7x7x9xi8> +} + +// ----- + +func.func @test_avg_pool2d_hor_stride_incorrect_multiple(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op horizontal stride is not in correct multiple.}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> + return %0 : tensor<1x7x4x9xi8> +} + +// ----- + +func.func @test_max_pool2d_hor_stride_incorrect_multiple(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> { + // expected-error@+1 {{'tosa.max_pool2d' op horizontal stride is not in correct multiple.}} + %0 = "tosa.max_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x7x4x9xi8> + return %0 : tensor<1x7x4x9xi8> +} + +// ----- + +func.func @test_avg_pool2d_output_height_incorrect(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x7x8x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op output height is not correct, should be 3.}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x7x8x9xi8> + return %0 : tensor<1x7x8x9xi8> +} + +// ----- + +func.func @test_avg_pool2d_output_width_incorrect(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x3x8x9xi8> { + // expected-error@+1 {{'tosa.avg_pool2d' op output width is not correct, should be 3.}} + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x3x8x9xi8> + return %0 : tensor<1x3x8x9xi8> +} + +// ----- + +func.func @test_max_pool2d_output_width_incorrect(%arg0: tensor<1x6x6x9xi8>) -> tensor<1x3x8x9xi8> { + // expected-error@+1 {{'tosa.max_pool2d' op output width is not correct, should be 3.}} + %0 = "tosa.max_pool2d"(%arg0) {acc_type = i32, kernel = array, pad = array, stride = array} : (tensor<1x6x6x9xi8>) -> tensor<1x3x8x9xi8> + return %0 : tensor<1x3x8x9xi8> +} + +// ----- + +func.func @test_const_incorrect_output(%arg0 : index) -> tensor<4xi32> { + // expected-error@+1{{inferred shape of elements literal ([4]) does not match type ([3])}} + %0 = "tosa.const"() {value = dense<[3, 0, 1, 2]> : tensor<3xi32>} : () -> tensor<4xi32> + return %0 : tensor<4xi32> +} + +// ----- + +func.func @test_transpose_incorrect_result_shape(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x20xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // expected-error@+2{{'tosa.transpose' op failed to infer returned types}} + // expected-error@+1{{'tosa.transpose' op inferred type(s) 'tensor<3x13x21xf32>' are incompatible with return type(s) of operation 'tensor<3x13x20xf32>'}} + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x20xf32> + return %1 : tensor<3x13x20xf32> +} + +// ----- + +func.func @test_transpose_incorrect_result_rank(%arg0: tensor<13x21x3xf32>) -> tensor<3x13xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // expected-error@+2{{'tosa.transpose' op failed to infer returned types}} + // expected-error@+1{{'tosa.transpose' op inferred type(s) 'tensor<3x13x21xf32>' are incompatible with return type(s) of operation 'tensor<3x13xf32>'}} + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13xf32> + return %1 : tensor<3x13xf32> +} + +// ----- + +func.func @test_transpose_incorrect_result_type(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xi8> { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // expected-error@+2 {{failed to infer returned types}} + // expected-error@+1{{'tosa.transpose' op inferred type(s) 'tensor<3x13x21xf32>' are incompatible with return type(s) of operation 'tensor<3x13x21xi8>'}} + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21xi8> + return %1 : tensor<3x13x21xi8> +} + +// ----- + +func.func @test_transpose_high_rank_perm(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21x4xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + // expected-error@+1 {{failed to infer returned types}} + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<4xi32>) -> tensor<3x13x21x4xf32> + return %1 : tensor<3x13x21x4xf32> +} + +// ----- + +func.func @test_transpose_low_rank_perm(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21x4xf32> { + %0 = "tosa.const"() {value = dense<[0, 1]> : tensor<2xi32>} : () -> tensor<2xi32> + // expected-error@+1 {{failed to infer returned types}} + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<2xi32>) -> tensor<3x13x21x4xf32> + return %1 : tensor<3x13x21x4xf32> +} + +// ----- +func.func @test_transpose_result_high_rank(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21x4xf32> { + %0 = "tosa.const"() {value = dense<[2, 0, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + // expected-error@+2 {{failed to infer returned types}} + // expected-error@+1 {{'tosa.transpose' op inferred type(s) 'tensor<3x13x21xf32>' are incompatible with return type(s) of operation 'tensor<3x13x21x4xf32>'}} + %1 = "tosa.transpose"(%arg0, %0) : (tensor<13x21x3xf32>, tensor<3xi32>) -> tensor<3x13x21x4xf32> + return %1 : tensor<3x13x21x4xf32> +} diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index bf913363039d79..7f32b7db76258a 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -703,11 +703,11 @@ func.func @test_pool_dynamic_input(%arg0: tensor) { // CHECK-LABEL: @test_pool_padded func.func @test_pool_padded(%arg0: tensor<3x5x6x7xf32>) { - // CHECK: -> tensor<3x5x11x7xf32> - %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor + // CHECK: -> tensor<3x5x8x7xf32> + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor - // CHECK: -> tensor<3x5x11x7xf32> - %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor + // CHECK: -> tensor<3x5x8x7xf32> + %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x5x6x7xf32>) -> tensor return } @@ -733,11 +733,11 @@ func.func @conv2d_dynamic_bias(%input: tensor<2x8x9x3xf32>, %weights: tensor<5x3 // CHECK-LABEL: @test_pool_stride func.func @test_pool_stride(%arg0: tensor<3x11x12x7xf32>) { - // CHECK: -> tensor<3x4x4x7xf32> - %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor + // CHECK: -> tensor<3x5x4x7xf32> + %0 = "tosa.avg_pool2d"(%arg0) {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor - // CHECK: -> tensor<3x4x4x7xf32> - %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor + // CHECK: -> tensor<3x5x4x7xf32> + %1 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<3x11x12x7xf32>) -> tensor return }