diff --git a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td index 624baf571f20..e4876455d2a7 100644 --- a/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td +++ b/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/Input/InputOps.td @@ -127,7 +127,8 @@ def IREEInput_GlobalAddressOp : IREEInput_PureOp<"global.address"> { }]; } -def IREEInput_GlobalLoadOp : IREEInput_Op<"global.load"> { +def IREEInput_GlobalLoadOp : IREEInput_Op<"global.load", + [DeclareOpInterfaceMethods]> { let summary = [{loads a value from a global variable}]; let description = [{ Returns a copy of the global value. @@ -161,9 +162,11 @@ def IREEInput_GlobalLoadIndirectOp : IREEInput_Op<"global.load.indirect"> { let assemblyFormat = [{ $global attr-dict `:` type($global) `->` type($result) }]; + let hasVerifier = 1; } -def IREEInput_GlobalStoreOp : IREEInput_Op<"global.store"> { +def IREEInput_GlobalStoreOp : IREEInput_Op<"global.store", + [DeclareOpInterfaceMethods]> { let summary = [{stores a value into a global variable}]; let description = [{ Stores a copy of the value into a global. @@ -193,6 +196,7 @@ def IREEInput_GlobalStoreIndirectOp : IREEInput_Op<"global.store.indirect"> { let assemblyFormat = [{ $value `,` $global attr-dict `:` type($value) `->` type($global) }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputOps.cpp b/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputOps.cpp index 403fa0173f3a..f96fc07f7317 100644 --- a/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputOps.cpp +++ b/llvm-external-projects/iree-dialects/lib/Dialect/Input/InputOps.cpp @@ -117,5 +117,74 @@ void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name, build(builder, result, name, isMutable, type, llvm::None, attrs); } +// Returns true if the given |accessType| is compatible with the |globalType|. +// For example, this will return true if the global type is a tensor +// and the access is tensor<4xf32>. +static bool isGlobalTypeCompatible(Type globalType, Type accessType) { + // If one is a shaped type, then they both must be and have compatible + // shapes. + if (globalType.isa() && accessType.isa()) { + return succeeded(mlir::verifyCompatibleShape(globalType, accessType)) && + globalType.cast().getElementType() == + accessType.cast().getElementType(); + } + + // Permissively allow any other types to be marked compatible as long as + // neither are shaped type. + return !globalType.isa() && !accessType.isa(); +} + +LogicalResult +GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto globalOp = + symbolTable.lookupNearestSymbolFrom(*this, globalAttr()); + if (!globalOp) { + return emitOpError() << "undefined global: " << global(); + } + auto loadType = getResult().getType(); + if (!isGlobalTypeCompatible(globalOp.type(), loadType)) { + return emitOpError() << "global type mismatch; global " << global() + << " is " << globalOp.type() << " but load is " + << loadType; + } + return success(); +} + +LogicalResult GlobalLoadIndirectOp::verify() { + auto globalType = global().getType().cast().getTargetType(); + auto loadType = getResult().getType(); + if (!isGlobalTypeCompatible(globalType, loadType)) { + return emitOpError() << "global type mismatch; global pointer is " + << globalType << " but load is " << loadType; + } + return success(); +} + +LogicalResult +GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + auto globalOp = + symbolTable.lookupNearestSymbolFrom(*this, globalAttr()); + if (!globalOp) { + return emitOpError() << "undefined global: " << global(); + } + auto storeType = value().getType(); + if (!isGlobalTypeCompatible(globalOp.type(), storeType)) { + return emitOpError() << "global type mismatch; global " << global() + << " is " << globalOp.type() << " but store is " + << storeType; + } + return success(); +} + +LogicalResult GlobalStoreIndirectOp::verify() { + auto globalType = global().getType().cast().getTargetType(); + auto storeType = value().getType(); + if (!isGlobalTypeCompatible(globalType, storeType)) { + return emitOpError() << "global type mismatch; global pointer is " + << globalType << " but store is " << storeType; + } + return success(); +} + #define GET_OP_CLASSES #include "iree-dialects/Dialect/Input/InputOps.cpp.inc"