Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding verifiers for some tosa operators #90

Draft
wants to merge 12 commits into
base: feature/fused-ops
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def Tosa_MaxPool2dOp : Tosa_Op<"max_pool2d", [
);

let hasCanonicalizer = 1;
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -498,7 +499,7 @@ def Tosa_AddOp : Tosa_ElemWiseBinaryOp<"add", [Commutative]> {
Tosa_Tensor:$output
);

let hasFolder = 1;
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1124,7 +1125,7 @@ def Tosa_SelectOp : Tosa_Op<"select", [
Tosa_Tensor:$output
);
let hasCanonicalizeMethod = 1;
let hasFolder = 1;
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1208,7 +1209,7 @@ def Tosa_GreaterEqualOp : Tosa_Op<"greater_equal", [
I1Tensor:$output
);

let hasFolder = 1;
let hasFolder = 1;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1591,8 +1592,7 @@ def Tosa_TileOp: Tosa_Op<"tile", [
// Operator: transpose
//===----------------------------------------------------------------------===//
def Tosa_TransposeOp : Tosa_Op<"transpose", [
DeclareOpInterfaceMethods<InferShapedTypeOpInterface,
["inferReturnTypeComponents"]>,
InferTensorType,
Pure]> {
let summary = "Transpose operator";

Expand All @@ -1611,6 +1611,9 @@ def Tosa_TransposeOp : Tosa_Op<"transpose", [

let extraClassDeclaration = [{
LogicalResult getConstantPerms(llvm::SmallVector<int64_t> &perms);
/// Returns true when two result types are compatible for this op;
/// Method used by InferTypeOpInterface.
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];

let hasCanonicalizer = 1;
Expand Down
163 changes: 146 additions & 17 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,102 @@ template <typename T> static LogicalResult verifyConvOp(T op) {

return success();
}
template <typename T> static LogicalResult verifyPoolOp(T op) {
auto inputETy = llvm::cast<ShapedType>(op.getInput().getType()).getElementType();
auto resultETy = llvm::cast<ShapedType>(op.getType()).getElementType();

if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputETy))
inputETy = quantType.getStorageType();

if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultETy))
resultETy = quantType.getStorageType();

// [kernel_y, kernel_x] <-> [0,1]
auto kernel = op.getKernel();
// [stride_y, stride_x]
auto stride = op.getStride();
// [pad_top, pad_bottom, pad_left, pad_right]
auto pad = op.getPad();
// ERROR_IF(kernel_y < 1 || kernel_x < 1); // kernel size must be >= 1
if (kernel[0] < 1 || kernel[1] < 1) {
return op.emitOpError("kernel should be greater than one.");
}
// ERROR_IF(stride_y < 1 || stride_x < 1);
if (stride[0] < 0 || stride[1] < 0) {
return op.emitOpError("stride should be greater than one.");
}
// ERROR_IF(pad_top < 0 || pad_bottom < 0 || pad_left < 0 || pad_right < 0);
if (pad[0] < 0 || pad[1] < 0 || pad[2] < 0 || pad[3] < 0) {
return op.emitOpError("pad should be positive.");
}
// Padding must be less than kernel size to avoid
// a divide-by-zero.
/*
ERROR_IF(pad_right >= kernel_x || pad_left >= kernel_x);
ERROR_IF(pad_top >= kernel_y || pad_bottom >= kernel_y);
*/

if (pad[3] >= kernel[1] || pad[2] >= kernel[1] || pad[0] >= kernel[0] ||
pad[1] >= kernel[0]) {
return op.emitOpError("pad must be less than kernel size.");
}

//[N,IH,IW,C]
auto inputShapeType = llvm::cast<ShapedType>(op.getInput().getType());
//[N,OH,OW,C]
auto outputShapeType = llvm::cast<ShapedType>(op.getOutput().getType());
if (inputShapeType.hasStaticShape() && outputShapeType.hasStaticShape()) {
auto inputShape = inputShapeType.getShape();
auto outputShape = outputShapeType.getShape();
auto inputHeight = inputShape[1];
auto inputWidth = inputShape[2];
auto outputHeight = outputShape[1];
auto outputWidth = outputShape[2];
// IH + pad_top + pad_bottom - kernel_y
auto height = inputHeight + pad[0] + pad[1] - kernel[0];
// IW + pad_left + pad_right - kernel_x
auto width = inputWidth + pad[2] + pad[3] - kernel[1];
// idiv_check(IH + pad_top + pad_bottom - kernel_y, stride_y)
if (height % stride[0] != 0) {
return op.emitOpError("vertical stride is not in correct multiple.");
}
// idiv_check(IW + pad_left + pad_right - kernel_x, stride_x)
if (width % stride[1] != 0) {
return op.emitOpError("horizontal stride is not in correct multiple.");
}
/*
ERROR_IF(OH != idiv_check(IH + pad_top + pad_bottom - kernel_y, stride_y)
+ 1);
*/

if ((outputHeight != (height / stride[0]) + 1)) {
return op.emitOpError("output height is not correct, should be ")
<< (height / stride[0]) + 1 << ".";
}
/*
ERROR_IF(OW != idiv_check(IW + pad_left + pad_right - kernel_x, stride_x) +
1);
*/
if (outputWidth != (width / stride[1]) + 1) {
return op.emitOpError("output width is not correct, should be ")
<< (width / stride[1]) + 1 << ".";
}
}
if (inputETy.isF32() && resultETy.isF32())
return success();
if (inputETy.isInteger(8) && resultETy.isInteger(8))
return success();
if (inputETy.isInteger(16) && resultETy.isInteger(16))
return success();
if (inputETy.isInteger(32) && resultETy.isInteger(32))
return success();

return op.emitOpError("input/output element types are incompatible.");
}

LogicalResult tosa::MaxPool2dOp::verify() { return verifyPoolOp(*this); }
LogicalResult tosa::AvgPool2dOp::verify() {
auto inputETy = llvm::cast<ShapedType>(getInput().getType()).getElementType();
auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
Expand All @@ -157,21 +252,18 @@ LogicalResult tosa::AvgPool2dOp::verify() {
if (llvm::isa<IntegerType>(inputETy) && !accType.isInteger(32))
return emitOpError("accumulator type for integer tensor is not i32");

if ((inputETy.isBF16() || inputETy.isF16()) &&
!(accType.isF16() || accType.isF32()))
return emitOpError("accumulator type for f16/bf16 tensor is not f16/f32");
auto result = verifyPoolOp(*this);
if (result.succeeded()) {
if ((inputETy.isF16()) && !(accType.isF16() || accType.isF32()))
return emitOpError("accumulator type for f16 tensor is not f16/f32");

if (inputETy.isF32() && !accType.isF32())
return emitOpError("accumulator type for f32 tensor is not f32");
if ((inputETy.isBF16()) && !(accType.isF32()))
return emitOpError("accumulator type for bf16 tensor is not f32");

if (inputETy.isF32() && resultETy.isF32())
return success();
if (inputETy.isInteger(8) && resultETy.isInteger(8))
return success();
if (inputETy.isInteger(16) && resultETy.isInteger(16))
return success();

return emitOpError("input/output element types are incompatible.");
if (inputETy.isF32() && !accType.isF32())
return emitOpError("accumulator type for f32 tensor is not f32");
}
return result;
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -712,6 +804,33 @@ bool tosa::ReshapeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
return getElementTypeOrSelf(l[0]) == getElementTypeOrSelf(r[0]);
}

bool tosa::TransposeOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {

if (l.size() != r.size() || l.size() != 1)
return false;

auto left = getElementTypeOrSelf(l[0]);
auto right = getElementTypeOrSelf(r[0]);

if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(left))
left = quantType.getStorageType();

if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedPerAxisType>(left))
left = quantType.getStorageType();

if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedType>(right)){
right = quantType.getStorageType();
}

if (auto quantType = llvm::dyn_cast<mlir::quant::UniformQuantizedPerAxisType>(right)){
right = quantType.getStorageType();
}

if (left != right)
return false;
return succeeded(verifyCompatibleShape(l[0], r[0]));
}

LogicalResult tosa::ReshapeOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
Expand Down Expand Up @@ -860,6 +979,16 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes) {
ShapeAdaptor inputShape = operands.getShape(0);
ShapeAdaptor permsShape = operands.getShape(1);
auto inputType = getElementTypeOrSelf(operands[0]);

if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputType))
inputType = quantType.getStorageType();

if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedPerAxisType>(inputType))
inputType = quantType.getStorageType();


// If input rank and permutation length is unknown, the output rank is
// unknown.
Expand All @@ -880,13 +1009,13 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
SmallVector<int64_t> outputShape;
if (!inputShape.hasRank()) {
outputShape.resize(permsShape.getDimSize(0), ShapedType::kDynamic);
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}

// Rank-0 means no permutations matter.
if (inputShape.getRank() == 0) {
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}

Expand All @@ -903,7 +1032,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
// permutation.
if (allTheSame) {
outputShape.resize(inputShape.getRank(), inputShape.getDimSize(0));
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}

Expand All @@ -917,7 +1046,7 @@ LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
}
}

inferredReturnShapes.push_back(ShapedTypeComponents(outputShape));
inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType));
return success();
}

Expand Down
Loading