Skip to content

Commit

Permalink
Add verifiers for types of Input ops global load/store (#8908)
Browse files Browse the repository at this point in the history
Follow behavior (approximately) in Util dialect for these ops and verify
that the value loaded/stored matches load/store's types.
  • Loading branch information
jpienaar authored Apr 16, 2022
1 parent a6fbb76 commit 4b9fd17
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def IREEInput_GlobalAddressOp : IREEInput_PureOp<"global.address"> {
}];
}

def IREEInput_GlobalLoadOp : IREEInput_Op<"global.load"> {
def IREEInput_GlobalLoadOp : IREEInput_Op<"global.load",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = [{loads a value from a global variable}];
let description = [{
Returns a copy of the global value.
Expand Down Expand Up @@ -161,9 +162,11 @@ def IREEInput_GlobalLoadIndirectOp : IREEInput_Op<"global.load.indirect"> {
let assemblyFormat = [{
$global attr-dict `:` type($global) `->` type($result)
}];
let hasVerifier = 1;
}

def IREEInput_GlobalStoreOp : IREEInput_Op<"global.store"> {
def IREEInput_GlobalStoreOp : IREEInput_Op<"global.store",
[DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
let summary = [{stores a value into a global variable}];
let description = [{
Stores a copy of the value into a global.
Expand Down Expand Up @@ -193,6 +196,7 @@ def IREEInput_GlobalStoreIndirectOp : IREEInput_Op<"global.store.indirect"> {
let assemblyFormat = [{
$value `,` $global attr-dict `:` type($value) `->` type($global)
}];
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,74 @@ void GlobalOp::build(OpBuilder &builder, OperationState &result, StringRef name,
build(builder, result, name, isMutable, type, llvm::None, attrs);
}

// Returns true if the given |accessType| is compatible with the |globalType|.
// For example, this will return true if the global type is a tensor<?xf32>
// and the access is tensor<4xf32>.
static bool isGlobalTypeCompatible(Type globalType, Type accessType) {
// If one is a shaped type, then they both must be and have compatible
// shapes.
if (globalType.isa<ShapedType>() && accessType.isa<ShapedType>()) {
return succeeded(mlir::verifyCompatibleShape(globalType, accessType)) &&
globalType.cast<ShapedType>().getElementType() ==
accessType.cast<ShapedType>().getElementType();
}

// Permissively allow any other types to be marked compatible as long as
// neither are shaped type.
return !globalType.isa<ShapedType>() && !accessType.isa<ShapedType>();
}

LogicalResult
GlobalLoadOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto globalOp =
symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, globalAttr());
if (!globalOp) {
return emitOpError() << "undefined global: " << global();
}
auto loadType = getResult().getType();
if (!isGlobalTypeCompatible(globalOp.type(), loadType)) {
return emitOpError() << "global type mismatch; global " << global()
<< " is " << globalOp.type() << " but load is "
<< loadType;
}
return success();
}

LogicalResult GlobalLoadIndirectOp::verify() {
auto globalType = global().getType().cast<PtrType>().getTargetType();
auto loadType = getResult().getType();
if (!isGlobalTypeCompatible(globalType, loadType)) {
return emitOpError() << "global type mismatch; global pointer is "
<< globalType << " but load is " << loadType;
}
return success();
}

LogicalResult
GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto globalOp =
symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, globalAttr());
if (!globalOp) {
return emitOpError() << "undefined global: " << global();
}
auto storeType = value().getType();
if (!isGlobalTypeCompatible(globalOp.type(), storeType)) {
return emitOpError() << "global type mismatch; global " << global()
<< " is " << globalOp.type() << " but store is "
<< storeType;
}
return success();
}

LogicalResult GlobalStoreIndirectOp::verify() {
auto globalType = global().getType().cast<PtrType>().getTargetType();
auto storeType = value().getType();
if (!isGlobalTypeCompatible(globalType, storeType)) {
return emitOpError() << "global type mismatch; global pointer is "
<< globalType << " but store is " << storeType;
}
return success();
}

#define GET_OP_CLASSES
#include "iree-dialects/Dialect/Input/InputOps.cpp.inc"

0 comments on commit 4b9fd17

Please sign in to comment.