Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][ODS] Allow inferring operand types from multiple variables #127517

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 19 additions & 9 deletions mlir/include/mlir/IR/OpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -556,20 +556,30 @@ class AllShapesMatch<list<string> names> :
class AllTypesMatch<list<string> names> :
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;

// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
// A type constraint that denotes `transform(unpack(lhs.getTypes())) == rhs.getType()`.
// An optional comparator function may be provided that changes the above form
// into: `comparator(transform(lhs.getType()), rhs.getType())`.
class TypesMatchWith<string summary, string lhsArg, string rhsArg,
string transform, string comparator = "std::equal_to<>()">
// into: `comparator(transform(unpack(lhs.getTypes())), rhs.getType())`.
class InferTypesFrom<string summary, list<string> lhsArg, string rhsArg,
string transform,
string comparator = "std::equal_to<>()">
Comment on lines +559 to +564
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unpack seems to be a bit confusing to me, I would keep it as transform(lhs) or maybe transform(types(lhs)).
Could we also document the requirements of the transform string? E.g. that the types of arguments in lhs are expected to be referenced as $argN

: PredOpTrait<summary, CPred<
comparator # "(" #
!subst("$_self", "$" # lhsArg # ".getType()", transform) #
", $" # rhsArg # ".getType())">> {
string lhs = lhsArg;
string rhs = rhsArg;
comparator # "(" #
!foldl(transform, !range(lhsArg), acc, i, !subst("$arg" # i, "$" # lhsArg[i] # ".getType()", acc)) #
", $" # rhsArg # ".getType()" # ")">> {
list<string> args = lhsArg;
string target = rhsArg;
string transformer = transform;
}

// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
// An optional comparator function may be provided that changes the above form
// into: `comparator(transform(lhs.getType()), rhs.getType())`.
class TypesMatchWith<string summary, string lhsArg, string rhsArg,
string transform, string comparator = "std::equal_to<>()">
: InferTypesFrom<summary, [lhsArg], rhsArg,
!subst("$_self", "$arg0", transform),
comparator>;

// The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional
// and not present returns success.
class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg,
Expand Down
26 changes: 15 additions & 11 deletions mlir/lib/TableGen/Operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ void Operator::populateTypeInferenceInfo(
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
// Check for a non-variable length operand to use as the type anchor.
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
NamedTypeConstraint *operand =
llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
return operand && !operand->isVariableLength();
});
if (operandI == arguments.end())
Expand All @@ -396,7 +397,7 @@ void Operator::populateTypeInferenceInfo(
// All result types are inferred from the operand type.
int operandIdx = operandI - arguments.begin();
for (int i = 0; i < getNumResults(); ++i)
resultTypeMapping.emplace_back(operandIdx, "$_self");
resultTypeMapping.emplace_back(operandIdx, "$arg0");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So independent of the operand this is arg0?


allResultsHaveKnownTypes = true;
traits.push_back(Trait::create(inferTrait->getDefInit()));
Expand Down Expand Up @@ -424,12 +425,12 @@ void Operator::populateTypeInferenceInfo(
for (auto [idx, infer] : llvm::enumerate(inference)) {
if (getResult(idx).constraint.getBuilderCall()) {
infer.sources.emplace_back(InferredResultType::mapResultIndex(idx),
"$_self");
"$arg0");
infer.inferred = true;
}
}

// Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the
// Use `AllTypesMatch` and `InferTypesFrom` operation traits to build the
// result type inference graph.
for (const Trait &trait : traits) {
const Record &def = trait.getDef();
Expand All @@ -445,10 +446,11 @@ void Operator::populateTypeInferenceInfo(
if (&traitDef->getDef() == inferTrait)
return;

// The `TypesMatchWith` trait represents a 1 -> 1 type inference edge with a
// The `InferTypesFrom` trait represents a 1 -> 1 type inference edge with a
// type transformer.
if (def.isSubClassOf("TypesMatchWith")) {
int target = argumentsAndResultsIndex.lookup(def.getValueAsString("rhs"));
if (def.isSubClassOf("InferTypesFrom")) {
int target =
argumentsAndResultsIndex.lookup(def.getValueAsString("target"));
// Ignore operand type inference.
if (InferredResultType::isArgIndex(target))
continue;
Expand All @@ -457,8 +459,10 @@ void Operator::populateTypeInferenceInfo(
// If the type of the result has already been inferred, do nothing.
if (infer.inferred)
continue;
int sourceIndex =
argumentsAndResultsIndex.lookup(def.getValueAsString("lhs"));
std::vector<StringRef> args = def.getValueAsListOfStrings("args");
assert(args.size() == 1 &&
"multiple arguments for result inference not yet supported.");
Comment on lines +463 to +464
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how we've handled this in the past, but I feel llvm::report_fatal_error might be more appropriate here given this is a backend limitation that should also be reported in release builds

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ODS error handling isn't great, but report_fatal_error would at least mean the user is more likely to actually see the error message

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tablegen error reporting is quite awful, but the best that can be done here is likely PrintFatalError. Look at other uses in this file for that, there are some TableGen utilities that do this in a uniform way.

int sourceIndex = argumentsAndResultsIndex.lookup(args[0]);
infer.sources.emplace_back(sourceIndex,
def.getValueAsString("transformer").str());
// Locally propagate inferredness.
Expand Down Expand Up @@ -493,7 +497,7 @@ void Operator::populateTypeInferenceInfo(
for (int resultIndex : resultIndices) {
ResultTypeInference &infer = inference[resultIndex];
if (!infer.inferred) {
infer.sources.assign(1, {*fullyInferredIndex, "$_self"});
infer.sources.assign(1, {*fullyInferredIndex, "$arg0"});
infer.inferred = true;
}
}
Expand All @@ -504,7 +508,7 @@ void Operator::populateTypeInferenceInfo(
if (resultIndex == otherResultIndex)
continue;
inference[resultIndex].sources.emplace_back(
InferredResultType::unmapResultIndex(otherResultIndex), "$_self");
InferredResultType::unmapResultIndex(otherResultIndex), "$arg0");
}
}
}
Expand Down
13 changes: 13 additions & 0 deletions mlir/test/lib/Dialect/Test/TestOpsSyntax.td
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,19 @@ def FormatTypesMatchContextOp : TEST_Op<"format_types_match_context", [
let assemblyFormat = "attr-dict $value `:` type($value)";
}

//===----------------------------------------------------------------------===//
// InferTypesFrom type inference

def FormatTypesMatchMultipleVarOp : TEST_Op<"format_types_match_multiple_var", [
InferTypesFrom<"result type is a tuple of types of value1 and value2",
["value1", "result"], "value2",
"TupleType::get($_ctxt, {$arg0, $arg1})">]> {
let arguments = (ins AnyType:$value1,
AnyType:$value2);
let results = (outs AnyType:$result);
let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1) `->` type($result)";
}

//===----------------------------------------------------------------------===//
// InferTypeOpInterface type inference in assembly format

Expand Down
9 changes: 9 additions & 0 deletions mlir/test/mlir-tblgen/op-format.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
%i64 = "foo.op"() : () -> (i64)
// CHECK: %[[I32:.*]] =
%i32 = "foo.op"() : () -> (i32)
// CHECK: %[[I64_I32_TUP:.*]]
%i64_i32_tuple = "foo.op"() : () -> (tuple<i64, i32>)
// CHECK: %[[MEMREF:.*]] =
%memref = "foo.op"() : () -> (memref<1xf64>)

Expand Down Expand Up @@ -481,6 +483,13 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
// CHECK: test.format_types_match_context %[[I64]] : i64
%ignored_res6 = test.format_types_match_context %i64 : i64

//===----------------------------------------------------------------------===//
// InferTypesFrom type inference
//===----------------------------------------------------------------------===//

// CHECK: test.format_types_match_multiple_var
%ignored_res6a = test.format_types_match_multiple_var %i64, %i64_i32_tuple : i64 -> i32

//===----------------------------------------------------------------------===//
// InferTypeOpInterface type inference
//===----------------------------------------------------------------------===//
Expand Down
4 changes: 2 additions & 2 deletions mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3707,7 +3707,6 @@ void OpEmitter::genTypeInterfaceMethods() {
typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
"].getType()")
.str();

// If this is an attribute, index into the attribute dictionary.
} else {
auto *attr =
Expand Down Expand Up @@ -3743,7 +3742,8 @@ void OpEmitter::genTypeInterfaceMethods() {
continue;
}
body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
<< tgfmt(infer.getTransformer(), &fctx.withSelf(typeStr)) << ";\n";
<< tgfmt(infer.getTransformer(), &fctx.addSubst("arg0", typeStr))
<< ";\n";
constructedIndices[i] = inferredTypeIdx - 1;
}
}
Expand Down
126 changes: 74 additions & 52 deletions mlir/tools/mlir-tblgen/OpFormatGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,22 +303,33 @@ struct OperationFormat {
std::optional<int> getBuilderIdx() const { return builderIdx; }
void setBuilderIdx(int idx) { builderIdx = idx; }

int getNumArgs() const { return resolver.size(); }
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could the doc comments here be adjusted? Ditto below


/// Get the variable this type is resolved to, or nullptr.
const NamedTypeConstraint *getVariable() const {
return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver);
const NamedTypeConstraint *getVariable(int i) const {
return resolver.empty()
? nullptr
: llvm::dyn_cast_if_present<const NamedTypeConstraint *>(
resolver[i]);
Comment on lines +310 to +313
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return resolver.empty()
? nullptr
: llvm::dyn_cast_if_present<const NamedTypeConstraint *>(
resolver[i]);
if (resolver.empty())
return nullptr
return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver[i]);

1 less line of code

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same below

}
/// Get the attribute this type is resolved to, or nullptr.
const NamedAttribute *getAttribute() const {
return llvm::dyn_cast_if_present<const NamedAttribute *>(resolver);
const NamedAttribute *getAttribute(int i) const {
return resolver.empty()
? nullptr
: llvm::dyn_cast_if_present<const NamedAttribute *>(
resolver[i]);
}
/// Get the transformer for the type of the variable, or std::nullopt.
std::optional<StringRef> getVarTransformer() const {
return variableTransformer;
}
void setResolver(ConstArgument arg, std::optional<StringRef> transformer) {
void setResolver(const SmallVector<ConstArgument, 1> &arg,
std::optional<StringRef> transformer) {
resolver = arg;
Comment on lines +326 to 328
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
void setResolver(const SmallVector<ConstArgument, 1> &arg,
std::optional<StringRef> transformer) {
resolver = arg;
void setResolver(SmallVector<ConstArgument, 1> arg,
std::optional<StringRef> transformer) {
resolver = std::move(arg);

ulta nit

variableTransformer = transformer;
assert(getVariable() || getAttribute());
assert(llvm::all_of(llvm::seq<int>(arg.size()), [&](int i) {
return getVariable(i) || getAttribute(i);
}));
}

private:
Expand All @@ -327,7 +338,7 @@ struct OperationFormat {
std::optional<int> builderIdx;
/// If the type is resolved based upon another operand or result, this is
/// the variable or the attribute that this type is resolved to.
ConstArgument resolver;
SmallVector<ConstArgument, 1> resolver;
/// If the type is resolved based upon another operand or result, this is
/// a transformer to apply to the variable when resolving.
std::optional<StringRef> variableTransformer;
Expand Down Expand Up @@ -1685,23 +1696,25 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
std::optional<StringRef> transformer = resolver.getVarTransformer();
if (!transformer)
continue;
// Ensure that we don't verify the same variables twice.
const NamedTypeConstraint *variable = resolver.getVariable();
if (!variable || !verifiedVariables.insert(variable).second)
continue;
for (int i = 0, e = resolver.getNumArgs(); i < e; ++i) {
// Ensure that we don't verify the same variables twice.
const NamedTypeConstraint *variable = resolver.getVariable(i);
if (!variable || !verifiedVariables.insert(variable).second)
continue;

auto constraint = variable->constraint;
body << " for (::mlir::Type type : " << variable->name << "Types) {\n"
<< " (void)type;\n"
<< " if (!("
<< tgfmt(constraint.getConditionTemplate(),
&verifierFCtx.withSelf("type"))
<< ")) {\n"
<< formatv(" return parser.emitError(parser.getNameLoc()) << "
"\"'{0}' must be {1}, but got \" << type;\n",
variable->name, constraint.getSummary())
<< " }\n"
<< " }\n";
auto constraint = variable->constraint;
body << " for (::mlir::Type type : " << variable->name << "Types) {\n"
<< " (void)type;\n"
<< " if (!("
<< tgfmt(constraint.getConditionTemplate(),
&verifierFCtx.withSelf("type"))
<< ")) {\n"
<< formatv(" return parser.emitError(parser.getNameLoc()) << "
"\"'{0}' must be {1}, but got \" << type;\n",
variable->name, constraint.getSummary())
<< " }\n"
<< " }\n";
}
}

// Initialize the set of buildable types.
Expand All @@ -1717,26 +1730,30 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
if (std::optional<int> val = resolver.getBuilderIdx()) {
body << "odsBuildableType" << *val;
} else if (const NamedTypeConstraint *var = resolver.getVariable()) {
if (std::optional<StringRef> tform = resolver.getVarTransformer()) {
FmtContext fmtContext;
fmtContext.addSubst("_ctxt", "parser.getContext()");
if (var->isVariadic())
fmtContext.withSelf(var->name + "Types");
else
fmtContext.withSelf(var->name + "Types[0]");
body << tgfmt(*tform, &fmtContext);
} else {
body << var->name << "Types";
if (!var->isVariadic())
body << "[0]";
} else if (std::optional<StringRef> tform = resolver.getVarTransformer()) {
FmtContext fmtContext;
fmtContext.addSubst("_ctxt", "parser.getContext()");
for (int i = 0, e = resolver.getNumArgs(); i < e; ++i) {
std::string substName = "arg" + std::to_string(i);
if (const NamedTypeConstraint *var = resolver.getVariable(i)) {
if (var->isVariadic())
fmtContext.addSubst(substName, var->name + "Types");
else
fmtContext.addSubst(substName, var->name + "Types[0]");
} else if (const NamedAttribute *attr = resolver.getAttribute(i)) {
fmtContext.addSubst(substName, attr->name + "Attr.getType()");
} else {
assert(false && "resolver arguements should be a type constraint or "
"an attribute");
}
}
} else if (const NamedAttribute *attr = resolver.getAttribute()) {
if (std::optional<StringRef> tform = resolver.getVarTransformer())
body << tgfmt(*tform,
&FmtContext().withSelf(attr->name + "Attr.getType()"));
else
body << attr->name << "Attr.getType()";
body << tgfmt(*tform, &fmtContext);
} else if (const NamedTypeConstraint *var = resolver.getVariable(0)) {
body << var->name << "Types";
if (!var->isVariadic())
body << "[0]";
} else if (const NamedAttribute *attr = resolver.getAttribute(0)) {
body << attr->name << "Attr.getType()";
} else {
body << curVar << "Types";
}
Expand Down Expand Up @@ -2717,7 +2734,7 @@ class OpFormatParser : public FormatParser {
/// type as well as an optional transformer to apply to that type in order to
/// properly resolve the type of a variable.
struct TypeResolutionInstance {
ConstArgument resolver;
SmallVector<ConstArgument, 1> resolver;
std::optional<StringRef> transformer;
};

Expand Down Expand Up @@ -2827,7 +2844,7 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
} else if (def.getName() == "SameOperandsAndResultType") {
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
} else if (def.isSubClassOf("TypesMatchWith")) {
} else if (def.isSubClassOf("InferTypesFrom")) {
handleTypesMatchConstraint(variableTyResolver, def);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this function

} else if (!op.allResultTypesKnown()) {
// This doesn't check the name directly to handle
Expand Down Expand Up @@ -3228,9 +3245,9 @@ void OpFormatParser::handleAllTypesMatchConstraint(

// Mark this value as the type resolver for the other variables.
for (unsigned j = 0; j != i; ++j)
variableTyResolver[values[j]] = {arg, std::nullopt};
variableTyResolver[values[j]] = {{arg}, std::nullopt};
for (unsigned j = i + 1; j != e; ++j)
variableTyResolver[values[j]] = {arg, std::nullopt};
variableTyResolver[values[j]] = {{arg}, std::nullopt};
}
}

Expand All @@ -3251,21 +3268,26 @@ void OpFormatParser::handleSameTypesConstraint(
// Set the resolvers for each operand and result.
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
if (!seenOperandTypes.test(i))
variableTyResolver[op.getOperand(i).name] = {resolver, std::nullopt};
variableTyResolver[op.getOperand(i).name] = {{resolver}, std::nullopt};
if (includeResults) {
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
if (!seenResultTypes.test(i))
variableTyResolver[op.getResultName(i)] = {resolver, std::nullopt};
variableTyResolver[op.getResultName(i)] = {{resolver}, std::nullopt};
}
}

void OpFormatParser::handleTypesMatchConstraint(
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
StringRef lhsName = def.getValueAsString("lhs");
StringRef rhsName = def.getValueAsString("rhs");
std::vector<StringRef> args = def.getValueAsListOfStrings("args");
StringRef target = def.getValueAsString("target");
StringRef transformer = def.getValueAsString("transformer");
if (ConstArgument arg = findSeenArg(lhsName))
variableTyResolver[rhsName] = {arg, transformer};

SmallVector<ConstArgument, 1> resolutionArgs;
llvm::for_each(args, [&](StringRef arg) {
if (ConstArgument seenArg = findSeenArg(arg))
resolutionArgs.push_back(seenArg);
});
variableTyResolver[target] = {resolutionArgs, transformer};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
variableTyResolver[target] = {resolutionArgs, transformer};
variableTyResolver[target] = {std::move(resolutionArgs), transformer};

}

ConstArgument OpFormatParser::findSeenArg(StringRef name) {
Expand Down