-
Notifications
You must be signed in to change notification settings - Fork 12.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ODS] Allow inferring operand types from multiple variables #127517
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -387,7 +387,8 @@ void Operator::populateTypeInferenceInfo( | |
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { | ||
// Check for a non-variable length operand to use as the type anchor. | ||
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) { | ||
NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg); | ||
NamedTypeConstraint *operand = | ||
llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg); | ||
return operand && !operand->isVariableLength(); | ||
}); | ||
if (operandI == arguments.end()) | ||
|
@@ -396,7 +397,7 @@ void Operator::populateTypeInferenceInfo( | |
// All result types are inferred from the operand type. | ||
int operandIdx = operandI - arguments.begin(); | ||
for (int i = 0; i < getNumResults(); ++i) | ||
resultTypeMapping.emplace_back(operandIdx, "$_self"); | ||
resultTypeMapping.emplace_back(operandIdx, "$arg0"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So independent of the operand this is arg0? |
||
|
||
allResultsHaveKnownTypes = true; | ||
traits.push_back(Trait::create(inferTrait->getDefInit())); | ||
|
@@ -424,12 +425,12 @@ void Operator::populateTypeInferenceInfo( | |
for (auto [idx, infer] : llvm::enumerate(inference)) { | ||
if (getResult(idx).constraint.getBuilderCall()) { | ||
infer.sources.emplace_back(InferredResultType::mapResultIndex(idx), | ||
"$_self"); | ||
"$arg0"); | ||
infer.inferred = true; | ||
} | ||
} | ||
|
||
// Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the | ||
// Use `AllTypesMatch` and `InferTypesFrom` operation traits to build the | ||
// result type inference graph. | ||
for (const Trait &trait : traits) { | ||
const Record &def = trait.getDef(); | ||
|
@@ -445,10 +446,11 @@ void Operator::populateTypeInferenceInfo( | |
if (&traitDef->getDef() == inferTrait) | ||
return; | ||
|
||
// The `TypesMatchWith` trait represents a 1 -> 1 type inference edge with a | ||
// The `InferTypesFrom` trait represents a 1 -> 1 type inference edge with a | ||
// type transformer. | ||
if (def.isSubClassOf("TypesMatchWith")) { | ||
int target = argumentsAndResultsIndex.lookup(def.getValueAsString("rhs")); | ||
if (def.isSubClassOf("InferTypesFrom")) { | ||
int target = | ||
argumentsAndResultsIndex.lookup(def.getValueAsString("target")); | ||
// Ignore operand type inference. | ||
if (InferredResultType::isArgIndex(target)) | ||
continue; | ||
|
@@ -457,8 +459,10 @@ void Operator::populateTypeInferenceInfo( | |
// If the type of the result has already been inferred, do nothing. | ||
if (infer.inferred) | ||
continue; | ||
int sourceIndex = | ||
argumentsAndResultsIndex.lookup(def.getValueAsString("lhs")); | ||
std::vector<StringRef> args = def.getValueAsListOfStrings("args"); | ||
assert(args.size() == 1 && | ||
"multiple arguments for result inference not yet supported."); | ||
Comment on lines
+463
to
+464
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure how we've handled this in the past, but I feel There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. +1 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ODS error handling isn't great, but There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tablegen error reporting is quite awful, but the best that can be done here is likely |
||
int sourceIndex = argumentsAndResultsIndex.lookup(args[0]); | ||
infer.sources.emplace_back(sourceIndex, | ||
def.getValueAsString("transformer").str()); | ||
// Locally propagate inferredness. | ||
|
@@ -493,7 +497,7 @@ void Operator::populateTypeInferenceInfo( | |
for (int resultIndex : resultIndices) { | ||
ResultTypeInference &infer = inference[resultIndex]; | ||
if (!infer.inferred) { | ||
infer.sources.assign(1, {*fullyInferredIndex, "$_self"}); | ||
infer.sources.assign(1, {*fullyInferredIndex, "$arg0"}); | ||
infer.inferred = true; | ||
} | ||
} | ||
|
@@ -504,7 +508,7 @@ void Operator::populateTypeInferenceInfo( | |
if (resultIndex == otherResultIndex) | ||
continue; | ||
inference[resultIndex].sources.emplace_back( | ||
InferredResultType::unmapResultIndex(otherResultIndex), "$_self"); | ||
InferredResultType::unmapResultIndex(otherResultIndex), "$arg0"); | ||
} | ||
} | ||
} | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -303,22 +303,33 @@ struct OperationFormat { | |||||||||||||||
std::optional<int> getBuilderIdx() const { return builderIdx; } | ||||||||||||||||
void setBuilderIdx(int idx) { builderIdx = idx; } | ||||||||||||||||
|
||||||||||||||||
int getNumArgs() const { return resolver.size(); } | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could the doc comments here be adjusted? Ditto below |
||||||||||||||||
|
||||||||||||||||
/// Get the variable this type is resolved to, or nullptr. | ||||||||||||||||
const NamedTypeConstraint *getVariable() const { | ||||||||||||||||
return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver); | ||||||||||||||||
const NamedTypeConstraint *getVariable(int i) const { | ||||||||||||||||
return resolver.empty() | ||||||||||||||||
? nullptr | ||||||||||||||||
: llvm::dyn_cast_if_present<const NamedTypeConstraint *>( | ||||||||||||||||
resolver[i]); | ||||||||||||||||
Comment on lines
+310
to
+313
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
1 less line of code There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same below |
||||||||||||||||
} | ||||||||||||||||
/// Get the attribute this type is resolved to, or nullptr. | ||||||||||||||||
const NamedAttribute *getAttribute() const { | ||||||||||||||||
return llvm::dyn_cast_if_present<const NamedAttribute *>(resolver); | ||||||||||||||||
const NamedAttribute *getAttribute(int i) const { | ||||||||||||||||
return resolver.empty() | ||||||||||||||||
? nullptr | ||||||||||||||||
: llvm::dyn_cast_if_present<const NamedAttribute *>( | ||||||||||||||||
resolver[i]); | ||||||||||||||||
} | ||||||||||||||||
/// Get the transformer for the type of the variable, or std::nullopt. | ||||||||||||||||
std::optional<StringRef> getVarTransformer() const { | ||||||||||||||||
return variableTransformer; | ||||||||||||||||
} | ||||||||||||||||
void setResolver(ConstArgument arg, std::optional<StringRef> transformer) { | ||||||||||||||||
void setResolver(const SmallVector<ConstArgument, 1> &arg, | ||||||||||||||||
std::optional<StringRef> transformer) { | ||||||||||||||||
resolver = arg; | ||||||||||||||||
Comment on lines
+326
to
328
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
ulta nit |
||||||||||||||||
variableTransformer = transformer; | ||||||||||||||||
assert(getVariable() || getAttribute()); | ||||||||||||||||
assert(llvm::all_of(llvm::seq<int>(arg.size()), [&](int i) { | ||||||||||||||||
return getVariable(i) || getAttribute(i); | ||||||||||||||||
})); | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
private: | ||||||||||||||||
|
@@ -327,7 +338,7 @@ struct OperationFormat { | |||||||||||||||
std::optional<int> builderIdx; | ||||||||||||||||
/// If the type is resolved based upon another operand or result, this is | ||||||||||||||||
/// the variable or the attribute that this type is resolved to. | ||||||||||||||||
ConstArgument resolver; | ||||||||||||||||
SmallVector<ConstArgument, 1> resolver; | ||||||||||||||||
/// If the type is resolved based upon another operand or result, this is | ||||||||||||||||
/// a transformer to apply to the variable when resolving. | ||||||||||||||||
std::optional<StringRef> variableTransformer; | ||||||||||||||||
|
@@ -1685,23 +1696,25 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) { | |||||||||||||||
std::optional<StringRef> transformer = resolver.getVarTransformer(); | ||||||||||||||||
if (!transformer) | ||||||||||||||||
continue; | ||||||||||||||||
// Ensure that we don't verify the same variables twice. | ||||||||||||||||
const NamedTypeConstraint *variable = resolver.getVariable(); | ||||||||||||||||
if (!variable || !verifiedVariables.insert(variable).second) | ||||||||||||||||
continue; | ||||||||||||||||
for (int i = 0, e = resolver.getNumArgs(); i < e; ++i) { | ||||||||||||||||
// Ensure that we don't verify the same variables twice. | ||||||||||||||||
const NamedTypeConstraint *variable = resolver.getVariable(i); | ||||||||||||||||
if (!variable || !verifiedVariables.insert(variable).second) | ||||||||||||||||
continue; | ||||||||||||||||
|
||||||||||||||||
auto constraint = variable->constraint; | ||||||||||||||||
body << " for (::mlir::Type type : " << variable->name << "Types) {\n" | ||||||||||||||||
<< " (void)type;\n" | ||||||||||||||||
<< " if (!(" | ||||||||||||||||
<< tgfmt(constraint.getConditionTemplate(), | ||||||||||||||||
&verifierFCtx.withSelf("type")) | ||||||||||||||||
<< ")) {\n" | ||||||||||||||||
<< formatv(" return parser.emitError(parser.getNameLoc()) << " | ||||||||||||||||
"\"'{0}' must be {1}, but got \" << type;\n", | ||||||||||||||||
variable->name, constraint.getSummary()) | ||||||||||||||||
<< " }\n" | ||||||||||||||||
<< " }\n"; | ||||||||||||||||
auto constraint = variable->constraint; | ||||||||||||||||
body << " for (::mlir::Type type : " << variable->name << "Types) {\n" | ||||||||||||||||
<< " (void)type;\n" | ||||||||||||||||
<< " if (!(" | ||||||||||||||||
<< tgfmt(constraint.getConditionTemplate(), | ||||||||||||||||
&verifierFCtx.withSelf("type")) | ||||||||||||||||
<< ")) {\n" | ||||||||||||||||
<< formatv(" return parser.emitError(parser.getNameLoc()) << " | ||||||||||||||||
"\"'{0}' must be {1}, but got \" << type;\n", | ||||||||||||||||
variable->name, constraint.getSummary()) | ||||||||||||||||
<< " }\n" | ||||||||||||||||
<< " }\n"; | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
// Initialize the set of buildable types. | ||||||||||||||||
|
@@ -1717,26 +1730,30 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) { | |||||||||||||||
auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) { | ||||||||||||||||
if (std::optional<int> val = resolver.getBuilderIdx()) { | ||||||||||||||||
body << "odsBuildableType" << *val; | ||||||||||||||||
} else if (const NamedTypeConstraint *var = resolver.getVariable()) { | ||||||||||||||||
if (std::optional<StringRef> tform = resolver.getVarTransformer()) { | ||||||||||||||||
FmtContext fmtContext; | ||||||||||||||||
fmtContext.addSubst("_ctxt", "parser.getContext()"); | ||||||||||||||||
if (var->isVariadic()) | ||||||||||||||||
fmtContext.withSelf(var->name + "Types"); | ||||||||||||||||
else | ||||||||||||||||
fmtContext.withSelf(var->name + "Types[0]"); | ||||||||||||||||
body << tgfmt(*tform, &fmtContext); | ||||||||||||||||
} else { | ||||||||||||||||
body << var->name << "Types"; | ||||||||||||||||
if (!var->isVariadic()) | ||||||||||||||||
body << "[0]"; | ||||||||||||||||
} else if (std::optional<StringRef> tform = resolver.getVarTransformer()) { | ||||||||||||||||
FmtContext fmtContext; | ||||||||||||||||
fmtContext.addSubst("_ctxt", "parser.getContext()"); | ||||||||||||||||
for (int i = 0, e = resolver.getNumArgs(); i < e; ++i) { | ||||||||||||||||
std::string substName = "arg" + std::to_string(i); | ||||||||||||||||
if (const NamedTypeConstraint *var = resolver.getVariable(i)) { | ||||||||||||||||
if (var->isVariadic()) | ||||||||||||||||
fmtContext.addSubst(substName, var->name + "Types"); | ||||||||||||||||
else | ||||||||||||||||
fmtContext.addSubst(substName, var->name + "Types[0]"); | ||||||||||||||||
} else if (const NamedAttribute *attr = resolver.getAttribute(i)) { | ||||||||||||||||
fmtContext.addSubst(substName, attr->name + "Attr.getType()"); | ||||||||||||||||
} else { | ||||||||||||||||
assert(false && "resolver arguements should be a type constraint or " | ||||||||||||||||
"an attribute"); | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
} else if (const NamedAttribute *attr = resolver.getAttribute()) { | ||||||||||||||||
if (std::optional<StringRef> tform = resolver.getVarTransformer()) | ||||||||||||||||
body << tgfmt(*tform, | ||||||||||||||||
&FmtContext().withSelf(attr->name + "Attr.getType()")); | ||||||||||||||||
else | ||||||||||||||||
body << attr->name << "Attr.getType()"; | ||||||||||||||||
body << tgfmt(*tform, &fmtContext); | ||||||||||||||||
} else if (const NamedTypeConstraint *var = resolver.getVariable(0)) { | ||||||||||||||||
body << var->name << "Types"; | ||||||||||||||||
if (!var->isVariadic()) | ||||||||||||||||
body << "[0]"; | ||||||||||||||||
} else if (const NamedAttribute *attr = resolver.getAttribute(0)) { | ||||||||||||||||
body << attr->name << "Attr.getType()"; | ||||||||||||||||
} else { | ||||||||||||||||
body << curVar << "Types"; | ||||||||||||||||
} | ||||||||||||||||
|
@@ -2717,7 +2734,7 @@ class OpFormatParser : public FormatParser { | |||||||||||||||
/// type as well as an optional transformer to apply to that type in order to | ||||||||||||||||
/// properly resolve the type of a variable. | ||||||||||||||||
struct TypeResolutionInstance { | ||||||||||||||||
ConstArgument resolver; | ||||||||||||||||
SmallVector<ConstArgument, 1> resolver; | ||||||||||||||||
std::optional<StringRef> transformer; | ||||||||||||||||
}; | ||||||||||||||||
|
||||||||||||||||
|
@@ -2827,7 +2844,7 @@ LogicalResult OpFormatParser::verify(SMLoc loc, | |||||||||||||||
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false); | ||||||||||||||||
} else if (def.getName() == "SameOperandsAndResultType") { | ||||||||||||||||
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); | ||||||||||||||||
} else if (def.isSubClassOf("TypesMatchWith")) { | ||||||||||||||||
} else if (def.isSubClassOf("InferTypesFrom")) { | ||||||||||||||||
handleTypesMatchConstraint(variableTyResolver, def); | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we rename this function |
||||||||||||||||
} else if (!op.allResultTypesKnown()) { | ||||||||||||||||
// This doesn't check the name directly to handle | ||||||||||||||||
|
@@ -3228,9 +3245,9 @@ void OpFormatParser::handleAllTypesMatchConstraint( | |||||||||||||||
|
||||||||||||||||
// Mark this value as the type resolver for the other variables. | ||||||||||||||||
for (unsigned j = 0; j != i; ++j) | ||||||||||||||||
variableTyResolver[values[j]] = {arg, std::nullopt}; | ||||||||||||||||
variableTyResolver[values[j]] = {{arg}, std::nullopt}; | ||||||||||||||||
for (unsigned j = i + 1; j != e; ++j) | ||||||||||||||||
variableTyResolver[values[j]] = {arg, std::nullopt}; | ||||||||||||||||
variableTyResolver[values[j]] = {{arg}, std::nullopt}; | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
|
@@ -3251,21 +3268,26 @@ void OpFormatParser::handleSameTypesConstraint( | |||||||||||||||
// Set the resolvers for each operand and result. | ||||||||||||||||
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i) | ||||||||||||||||
if (!seenOperandTypes.test(i)) | ||||||||||||||||
variableTyResolver[op.getOperand(i).name] = {resolver, std::nullopt}; | ||||||||||||||||
variableTyResolver[op.getOperand(i).name] = {{resolver}, std::nullopt}; | ||||||||||||||||
if (includeResults) { | ||||||||||||||||
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i) | ||||||||||||||||
if (!seenResultTypes.test(i)) | ||||||||||||||||
variableTyResolver[op.getResultName(i)] = {resolver, std::nullopt}; | ||||||||||||||||
variableTyResolver[op.getResultName(i)] = {{resolver}, std::nullopt}; | ||||||||||||||||
} | ||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
void OpFormatParser::handleTypesMatchConstraint( | ||||||||||||||||
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) { | ||||||||||||||||
StringRef lhsName = def.getValueAsString("lhs"); | ||||||||||||||||
StringRef rhsName = def.getValueAsString("rhs"); | ||||||||||||||||
std::vector<StringRef> args = def.getValueAsListOfStrings("args"); | ||||||||||||||||
StringRef target = def.getValueAsString("target"); | ||||||||||||||||
StringRef transformer = def.getValueAsString("transformer"); | ||||||||||||||||
if (ConstArgument arg = findSeenArg(lhsName)) | ||||||||||||||||
variableTyResolver[rhsName] = {arg, transformer}; | ||||||||||||||||
|
||||||||||||||||
SmallVector<ConstArgument, 1> resolutionArgs; | ||||||||||||||||
llvm::for_each(args, [&](StringRef arg) { | ||||||||||||||||
if (ConstArgument seenArg = findSeenArg(arg)) | ||||||||||||||||
resolutionArgs.push_back(seenArg); | ||||||||||||||||
}); | ||||||||||||||||
variableTyResolver[target] = {resolutionArgs, transformer}; | ||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||
} | ||||||||||||||||
|
||||||||||||||||
ConstArgument OpFormatParser::findSeenArg(StringRef name) { | ||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unpack
seems to be a bit confusing to me, I would keep it astransform(lhs)
or maybetransform(types(lhs))
.Could we also document the requirements of the
transform
string? E.g. that the types of arguments inlhs
are expected to be referenced as$argN