Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Codegen][Tuner] attr verifier for tuning specs #19486

Merged
merged 9 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/Verifier.h"

#define DEBUG_TYPE "iree-codegen-link-tuning-specs"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
Expand Down Expand Up @@ -53,35 +54,14 @@ static SmallVector<NamedSequenceOp> findTuningSpecs(ModuleOp module) {
});
}

// Returns true iff the entrypoint has the following signature:
// ```
// transform.named_sequence @name(%arg0: !transform.any_op) ->
// (!transform.any_op)
// ```
static LogicalResult validateTuningSpec(NamedSequenceOp op) {
ArrayRef<Type> resTypes = op.getFunctionType().getResults();
if (resTypes.size() != 1 || !isa<transform::AnyOpType>(resTypes[0])) {
return op.emitWarning()
<< "Tuning spec entry point expected to return any_op";
}

ArrayRef<Type> argTypes = op.getArgumentTypes();
if (argTypes.size() != 1 || !isa<transform::AnyOpType>(argTypes[0])) {
return op.emitWarning() << "Tuning spec entry point expected to have a "
"single any_op argument";
}

return success();
}

static bool consumesInputOp(NamedSequenceOp op) {
if (op.getArgAttr(0, kArgConsumedAttrName)) {
return true;
}
return false;
}

static NamedSequenceOp
static FailureOr<NamedSequenceOp>
emitLinkedTuningSpec(ModuleOp module, ArrayRef<NamedSequenceOp> specsToLink) {
OpBuilder builder(module->getContext());
builder.setInsertionPointToEnd(module.getBody());
Expand Down Expand Up @@ -144,6 +124,12 @@ emitLinkedTuningSpec(ModuleOp module, ArrayRef<NamedSequenceOp> specsToLink) {
}

builder.create<transform::YieldOp>(loc, operand);

if (failed(mlir::verify(module))) {
module.emitError("Linked tuning spec failed to verify");
return failure();
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
}

return newSpec;
}

Expand All @@ -169,13 +155,6 @@ FailureOr<NamedSequenceOp> linkTuningSpecs(ModuleOp module) {
llvm::append_range(tuningSpecs, findTuningSpecs(nested));
}

for (NamedSequenceOp spec : tuningSpecs) {
LDBG("Found tuning spec: " << spec.getSymName());
if (failed(validateTuningSpec(spec))) {
return failure();
}
}

size_t numConsumedSpecs = llvm::count_if(tuningSpecs, consumesInputOp);
if (numConsumedSpecs > 0 && numConsumedSpecs != tuningSpecs.size()) {
LDBG("Only " << numConsumedSpecs << " tuning specs out of "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/Location.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Support/FileUtilities.h"

#define DEBUG_TYPE "iree-codegen-materialize-tuning-specs"
Expand Down Expand Up @@ -138,8 +139,19 @@ getDefaultTuningSpec(ModuleOp module,

// Load the library through the codegen dialect so that we cache the parsed
// module.
return dialect.getOrParseTransformLibraryModule(defaultTuningSpecName,
*defaultTuningSpecSource);
FailureOr<ModuleOp> defaultTransformLibrary =
dialect.getOrParseTransformLibraryModule(defaultTuningSpecName,
*defaultTuningSpecSource);

#ifndef NDEBUG
if (succeeded(defaultTransformLibrary) &&
failed(mlir::verify(*defaultTransformLibrary)))
return (*defaultTransformLibrary).emitError()
<< "Default tuning spec " << defaultTuningSpecName
<< " failed to verify";
#endif

return *defaultTransformLibrary;
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
}

static FailureOr<DenseElementsAttr>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ iree_lit_test_suite(
"vector_layout_analysis.mlir",
"vectorize_memref_copy.mlir",
"vectorize_tensor_pad.mlir",
"verify_tuning_specs.mlir",
"verify_workgroup_distribution.mlir",
"vmvx_materialize_encoding.mlir",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ iree_lit_test_suite(
"vector_layout_analysis.mlir"
"vectorize_memref_copy.mlir"
"vectorize_tensor_pad.mlir"
"verify_tuning_specs.mlir"
"verify_workgroup_distribution.mlir"
"vmvx_materialize_encoding.mlir"
TOOLS
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// RUN: iree-opt --verify-diagnostics --split-input-file %s

module @foo_module attributes { transform.with_named_sequence } {
func.func @baz(%arg0: i32) -> () {
return
}
transform.named_sequence @bar(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
attributes { iree_codegen.something } {
transform.yield %arg0 : !transform.any_op
}
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
// expected-error @+1{{'iree_codegen.tuning_spec_entrypoint' attribute must be a UnitAttr}}
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> !transform.any_op
attributes { iree_codegen.tuning_spec_entrypoint = "foo" } {
transform.yield %arg0 : !transform.any_op
}
}
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved

// -----

module @foo_module attributes { transform.with_named_sequence } {
// expected-error @+1{{Tuning spec entry point expected to have a single any_op argument}}
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}, %arg1: !transform.any_op {transform.readonly}) -> !transform.any_op
attributes { iree_codegen.tuning_spec_entrypoint } {
transform.yield %arg0 : !transform.any_op
}
}

// -----

module @foo_module attributes { transform.with_named_sequence } {
// expected-error @+1{{Tuning spec entry point expected to have a single any_op argument}}
transform.named_sequence @foo(%arg0: i32) -> !transform.any_op
attributes { iree_codegen.tuning_spec_entrypoint } {}
}

// -----

module @foo_module attributes { transform.with_named_sequence } {
// expected-error @+1{{Tuning spec entry point expected to return any_op}}
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly}) -> i32
attributes { iree_codegen.tuning_spec_entrypoint } {
%0 = arith.constant 0 : i32
transform.yield %0 : i32
}
}

// -----

module @foo_module attributes { transform.with_named_sequence } {
// expected-error @+1{{Tuning spec entry point expected to return any_op}}
transform.named_sequence @foo(%arg0: !transform.any_op {transform.readonly})
attributes { iree_codegen.tuning_spec_entrypoint } {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.cpp.inc"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
#include "iree/compiler/Codegen/Dialect/Codegen/IR/UKernelOps.h"
#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/IR/DialectImplementation.h"

namespace mlir::iree_compiler::IREE::Codegen {
Expand Down Expand Up @@ -45,4 +46,47 @@ void IREECodegenDialect::initialize() {
>();
}

LogicalResult
IREECodegenDialect::verifyOperationAttribute(Operation *op,
NamedAttribute attribute) {
StringRef symbol = attribute.getName().strref();
Attribute attr = attribute.getValue();

// This function verifies the validity of a specific operation attribute.
// - If the attribute's name matches `kTuningSpecEntrypointAttrName`
// ("iree_codegen.tuning_spec_entrypoint"):
// 1. The attribute value must be a UnitAttr.
// 2. If the operation is a transform::NamedSequenceOp:
// - The operation's function signature must satisfy the following:
// a. It must have exactly one result type, and the result must be of
// type `transform::AnyOpType`.
// b. It must have exactly one argument type, and the argument must be
// of type `transform::AnyOpType`.

if (symbol != kTuningSpecEntrypointAttrName)
return success();

// Verify that the attribute is a UnitAttr.
kuhar marked this conversation as resolved.
Show resolved Hide resolved
if (!isa<UnitAttr>(attr)) {
return op->emitError("'") << symbol << "' attribute must be a UnitAttr";
}

if (auto namedSeqOp = dyn_cast<transform::NamedSequenceOp>(op)) {
ArrayRef<Type> resTypes = namedSeqOp.getFunctionType().getResults();
if (resTypes.size() != 1 || !isa<transform::AnyOpType>(resTypes[0])) {
return namedSeqOp.emitError()
<< "Tuning spec entry point expected to return any_op";
}

ArrayRef<Type> argTypes = namedSeqOp.getArgumentTypes();
if (argTypes.size() != 1 || !isa<transform::AnyOpType>(argTypes[0])) {
return namedSeqOp.emitError()
<< "Tuning spec entry point expected to have a "
"single any_op argument";
}
}

return success();
}

} // namespace mlir::iree_compiler::IREE::Codegen
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def IREECodegen_Dialect : Dialect {
std::mutex libraryMutex;
}];
let useDefaultAttributePrinterParser = 1;
let hasOperationAttrVerify = 1;
}

def AnyRankedTensorOrMemRefType : AnyTypeOf<[AnyRankedTensor, AnyMemRef]>;
Expand Down
Loading