-
Notifications
You must be signed in to change notification settings - Fork 12.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][ODS] Allow inferring operand types from multiple variables #127517
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-ods @llvm/pr-subscribers-mlir Author: Kunwar Grover (Groverkss) ChangesThis patch adds support for inferring operand types from multiple operand types. The patch introduces a new Inferring result types could also be added with this change, but is more complex as we need to generate a different builder call as well (which needs more intrusive changes into the operation definition builder) Full diff: https://github.com/llvm/llvm-project/pull/127517.diff 6 Files Affected:
diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td
index 51b60972203e7..f992481d4aa31 100644
--- a/mlir/include/mlir/IR/OpBase.td
+++ b/mlir/include/mlir/IR/OpBase.td
@@ -556,20 +556,30 @@ class AllShapesMatch<list<string> names> :
class AllTypesMatch<list<string> names> :
AllMatchSameOperatorTrait<names, "$_self.getType()", "type">;
-// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
+// A type constraint that denotes `transform(unpack(lhs.getTypes())) == rhs.getType()`.
// An optional comparator function may be provided that changes the above form
-// into: `comparator(transform(lhs.getType()), rhs.getType())`.
-class TypesMatchWith<string summary, string lhsArg, string rhsArg,
- string transform, string comparator = "std::equal_to<>()">
+// into: `comparator(transform(unpack(lhs.getTypes())), rhs.getType())`.
+class InferTypesFrom<string summary, list<string> lhsArg, string rhsArg,
+ string transform,
+ string comparator = "std::equal_to<>()">
: PredOpTrait<summary, CPred<
- comparator # "(" #
- !subst("$_self", "$" # lhsArg # ".getType()", transform) #
- ", $" # rhsArg # ".getType())">> {
- string lhs = lhsArg;
- string rhs = rhsArg;
+ comparator # "(" #
+ !foldl(transform, !range(lhsArg), acc, i, !subst("$arg" # i, "$" # lhsArg[i] # ".getType()", acc)) #
+ ", $" # rhsArg # ".getType()" # ")">> {
+ list<string> args = lhsArg;
+ string target = rhsArg;
string transformer = transform;
}
+// A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`.
+// An optional comparator function may be provided that changes the above form
+// into: `comparator(transform(lhs.getType()), rhs.getType())`.
+class TypesMatchWith<string summary, string lhsArg, string rhsArg,
+ string transform, string comparator = "std::equal_to<>()">
+ : InferTypesFrom<summary, [lhsArg], rhsArg,
+ !subst("$_self", "$arg0", transform),
+ comparator>;
+
// The same as TypesMatchWith but if either `lhsArg` or `rhsArg` are optional
// and not present returns success.
class OptionalTypesMatchWith<string summary, string lhsArg, string rhsArg,
diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp
index 20a43ef15d09e..ec5561b89ea74 100644
--- a/mlir/lib/TableGen/Operator.cpp
+++ b/mlir/lib/TableGen/Operator.cpp
@@ -387,7 +387,8 @@ void Operator::populateTypeInferenceInfo(
if (getTrait("::mlir::OpTrait::SameOperandsAndResultType")) {
// Check for a non-variable length operand to use as the type anchor.
auto *operandI = llvm::find_if(arguments, [](const Argument &arg) {
- NamedTypeConstraint *operand = llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
+ NamedTypeConstraint *operand =
+ llvm::dyn_cast_if_present<NamedTypeConstraint *>(arg);
return operand && !operand->isVariableLength();
});
if (operandI == arguments.end())
@@ -396,7 +397,7 @@ void Operator::populateTypeInferenceInfo(
// All result types are inferred from the operand type.
int operandIdx = operandI - arguments.begin();
for (int i = 0; i < getNumResults(); ++i)
- resultTypeMapping.emplace_back(operandIdx, "$_self");
+ resultTypeMapping.emplace_back(operandIdx, "$arg0");
allResultsHaveKnownTypes = true;
traits.push_back(Trait::create(inferTrait->getDefInit()));
@@ -424,12 +425,12 @@ void Operator::populateTypeInferenceInfo(
for (auto [idx, infer] : llvm::enumerate(inference)) {
if (getResult(idx).constraint.getBuilderCall()) {
infer.sources.emplace_back(InferredResultType::mapResultIndex(idx),
- "$_self");
+ "$arg0");
infer.inferred = true;
}
}
- // Use `AllTypesMatch` and `TypesMatchWith` operation traits to build the
+ // Use `AllTypesMatch` and `InferTypesFrom` operation traits to build the
// result type inference graph.
for (const Trait &trait : traits) {
const Record &def = trait.getDef();
@@ -445,10 +446,11 @@ void Operator::populateTypeInferenceInfo(
if (&traitDef->getDef() == inferTrait)
return;
- // The `TypesMatchWith` trait represents a 1 -> 1 type inference edge with a
+ // The `InferTypesFrom` trait represents a 1 -> 1 type inference edge with a
// type transformer.
- if (def.isSubClassOf("TypesMatchWith")) {
- int target = argumentsAndResultsIndex.lookup(def.getValueAsString("rhs"));
+ if (def.isSubClassOf("InferTypesFrom")) {
+ int target =
+ argumentsAndResultsIndex.lookup(def.getValueAsString("target"));
// Ignore operand type inference.
if (InferredResultType::isArgIndex(target))
continue;
@@ -457,8 +459,10 @@ void Operator::populateTypeInferenceInfo(
// If the type of the result has already been inferred, do nothing.
if (infer.inferred)
continue;
- int sourceIndex =
- argumentsAndResultsIndex.lookup(def.getValueAsString("lhs"));
+ std::vector<StringRef> args = def.getValueAsListOfStrings("args");
+ assert(args.size() == 1 &&
+ "multiple arguments for result inference not yet supported.");
+ int sourceIndex = argumentsAndResultsIndex.lookup(args[0]);
infer.sources.emplace_back(sourceIndex,
def.getValueAsString("transformer").str());
// Locally propagate inferredness.
@@ -493,7 +497,7 @@ void Operator::populateTypeInferenceInfo(
for (int resultIndex : resultIndices) {
ResultTypeInference &infer = inference[resultIndex];
if (!infer.inferred) {
- infer.sources.assign(1, {*fullyInferredIndex, "$_self"});
+ infer.sources.assign(1, {*fullyInferredIndex, "$arg0"});
infer.inferred = true;
}
}
@@ -504,7 +508,7 @@ void Operator::populateTypeInferenceInfo(
if (resultIndex == otherResultIndex)
continue;
inference[resultIndex].sources.emplace_back(
- InferredResultType::unmapResultIndex(otherResultIndex), "$_self");
+ InferredResultType::unmapResultIndex(otherResultIndex), "$arg0");
}
}
}
diff --git a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
index 2848cb994231b..33e4b9a623636 100644
--- a/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
+++ b/mlir/test/lib/Dialect/Test/TestOpsSyntax.td
@@ -653,6 +653,19 @@ def FormatTypesMatchContextOp : TEST_Op<"format_types_match_context", [
let assemblyFormat = "attr-dict $value `:` type($value)";
}
+//===----------------------------------------------------------------------===//
+// InferTypesFrom type inference
+
+def FormatTypesMatchMultipleVarOp : TEST_Op<"format_types_match_multiple_var", [
+ InferTypesFrom<"result type is a tuple of types of value1 and value2",
+ ["value1", "result"], "value2",
+ "TupleType::get($_ctxt, {$arg0, $arg1})">]> {
+ let arguments = (ins AnyType:$value1,
+ AnyType:$value2);
+ let results = (outs AnyType:$result);
+ let assemblyFormat = "attr-dict $value1 `,` $value2 `:` type($value1) `->` type($result)";
+}
+
//===----------------------------------------------------------------------===//
// InferTypeOpInterface type inference in assembly format
diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir
index 08b0c52413a75..445826afb5ed5 100644
--- a/mlir/test/mlir-tblgen/op-format.mlir
+++ b/mlir/test/mlir-tblgen/op-format.mlir
@@ -4,6 +4,8 @@
%i64 = "foo.op"() : () -> (i64)
// CHECK: %[[I32:.*]] =
%i32 = "foo.op"() : () -> (i32)
+// CHECK: %[[I64_I32_TUP:.*]]
+%i64_i32_tuple = "foo.op"() : () -> (tuple<i64, i32>)
// CHECK: %[[MEMREF:.*]] =
%memref = "foo.op"() : () -> (memref<1xf64>)
@@ -481,6 +483,13 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64
// CHECK: test.format_types_match_context %[[I64]] : i64
%ignored_res6 = test.format_types_match_context %i64 : i64
+//===----------------------------------------------------------------------===//
+// InferTypesFrom type inference
+//===----------------------------------------------------------------------===//
+
+// CHECK: test.format_types_match_multiple_var
+%ignored_res6a = test.format_types_match_multiple_var %i64, %i64_i32_tuple : i64 -> i32
+
//===----------------------------------------------------------------------===//
// InferTypeOpInterface type inference
//===----------------------------------------------------------------------===//
diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
index 629e863dac5e3..ab50686e67ba6 100644
--- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp
@@ -3707,7 +3707,6 @@ void OpEmitter::genTypeInterfaceMethods() {
typeStr = ("operands[" + Twine(arg.operandOrAttributeIndex()) +
"].getType()")
.str();
-
// If this is an attribute, index into the attribute dictionary.
} else {
auto *attr =
@@ -3743,7 +3742,8 @@ void OpEmitter::genTypeInterfaceMethods() {
continue;
}
body << " ::mlir::Type odsInferredType" << inferredTypeIdx++ << " = "
- << tgfmt(infer.getTransformer(), &fctx.withSelf(typeStr)) << ";\n";
+ << tgfmt(infer.getTransformer(), &fctx.addSubst("arg0", typeStr))
+ << ";\n";
constructedIndices[i] = inferredTypeIdx - 1;
}
}
diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
index fe724e86d6707..81edf36f9d4d7 100644
--- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp
@@ -303,22 +303,33 @@ struct OperationFormat {
std::optional<int> getBuilderIdx() const { return builderIdx; }
void setBuilderIdx(int idx) { builderIdx = idx; }
+ int getNumArgs() const { return resolver.size(); }
+
/// Get the variable this type is resolved to, or nullptr.
- const NamedTypeConstraint *getVariable() const {
- return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver);
+ const NamedTypeConstraint *getVariable(int i) const {
+ return resolver.empty()
+ ? nullptr
+ : llvm::dyn_cast_if_present<const NamedTypeConstraint *>(
+ resolver[i]);
}
/// Get the attribute this type is resolved to, or nullptr.
- const NamedAttribute *getAttribute() const {
- return llvm::dyn_cast_if_present<const NamedAttribute *>(resolver);
+ const NamedAttribute *getAttribute(int i) const {
+ return resolver.empty()
+ ? nullptr
+ : llvm::dyn_cast_if_present<const NamedAttribute *>(
+ resolver[i]);
}
/// Get the transformer for the type of the variable, or std::nullopt.
std::optional<StringRef> getVarTransformer() const {
return variableTransformer;
}
- void setResolver(ConstArgument arg, std::optional<StringRef> transformer) {
+ void setResolver(const SmallVector<ConstArgument, 1> &arg,
+ std::optional<StringRef> transformer) {
resolver = arg;
variableTransformer = transformer;
- assert(getVariable() || getAttribute());
+ assert(llvm::all_of(llvm::seq<int>(arg.size()), [&](int i) {
+ return getVariable(i) || getAttribute(i);
+ }));
}
private:
@@ -327,7 +338,7 @@ struct OperationFormat {
std::optional<int> builderIdx;
/// If the type is resolved based upon another operand or result, this is
/// the variable or the attribute that this type is resolved to.
- ConstArgument resolver;
+ SmallVector<ConstArgument, 1> resolver;
/// If the type is resolved based upon another operand or result, this is
/// a transformer to apply to the variable when resolving.
std::optional<StringRef> variableTransformer;
@@ -1685,23 +1696,25 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
std::optional<StringRef> transformer = resolver.getVarTransformer();
if (!transformer)
continue;
- // Ensure that we don't verify the same variables twice.
- const NamedTypeConstraint *variable = resolver.getVariable();
- if (!variable || !verifiedVariables.insert(variable).second)
- continue;
+ for (int i = 0, e = resolver.getNumArgs(); i < e; ++i) {
+ // Ensure that we don't verify the same variables twice.
+ const NamedTypeConstraint *variable = resolver.getVariable(i);
+ if (!variable || !verifiedVariables.insert(variable).second)
+ continue;
- auto constraint = variable->constraint;
- body << " for (::mlir::Type type : " << variable->name << "Types) {\n"
- << " (void)type;\n"
- << " if (!("
- << tgfmt(constraint.getConditionTemplate(),
- &verifierFCtx.withSelf("type"))
- << ")) {\n"
- << formatv(" return parser.emitError(parser.getNameLoc()) << "
- "\"'{0}' must be {1}, but got \" << type;\n",
- variable->name, constraint.getSummary())
- << " }\n"
- << " }\n";
+ auto constraint = variable->constraint;
+ body << " for (::mlir::Type type : " << variable->name << "Types) {\n"
+ << " (void)type;\n"
+ << " if (!("
+ << tgfmt(constraint.getConditionTemplate(),
+ &verifierFCtx.withSelf("type"))
+ << ")) {\n"
+ << formatv(" return parser.emitError(parser.getNameLoc()) << "
+ "\"'{0}' must be {1}, but got \" << type;\n",
+ variable->name, constraint.getSummary())
+ << " }\n"
+ << " }\n";
+ }
}
// Initialize the set of buildable types.
@@ -1717,26 +1730,30 @@ void OperationFormat::genParserTypeResolution(Operator &op, MethodBody &body) {
auto emitTypeResolver = [&](TypeResolution &resolver, StringRef curVar) {
if (std::optional<int> val = resolver.getBuilderIdx()) {
body << "odsBuildableType" << *val;
- } else if (const NamedTypeConstraint *var = resolver.getVariable()) {
- if (std::optional<StringRef> tform = resolver.getVarTransformer()) {
- FmtContext fmtContext;
- fmtContext.addSubst("_ctxt", "parser.getContext()");
- if (var->isVariadic())
- fmtContext.withSelf(var->name + "Types");
- else
- fmtContext.withSelf(var->name + "Types[0]");
- body << tgfmt(*tform, &fmtContext);
- } else {
- body << var->name << "Types";
- if (!var->isVariadic())
- body << "[0]";
+ } else if (std::optional<StringRef> tform = resolver.getVarTransformer()) {
+ FmtContext fmtContext;
+ fmtContext.addSubst("_ctxt", "parser.getContext()");
+ for (int i = 0, e = resolver.getNumArgs(); i < e; ++i) {
+ std::string substName = "arg" + std::to_string(i);
+ if (const NamedTypeConstraint *var = resolver.getVariable(i)) {
+ if (var->isVariadic())
+ fmtContext.addSubst(substName, var->name + "Types");
+ else
+ fmtContext.addSubst(substName, var->name + "Types[0]");
+ } else if (const NamedAttribute *attr = resolver.getAttribute(i)) {
+ fmtContext.addSubst(substName, attr->name + "Attr.getType()");
+ } else {
+ assert(false && "resolver arguements should be a type constraint or "
+ "an attribute");
+ }
}
- } else if (const NamedAttribute *attr = resolver.getAttribute()) {
- if (std::optional<StringRef> tform = resolver.getVarTransformer())
- body << tgfmt(*tform,
- &FmtContext().withSelf(attr->name + "Attr.getType()"));
- else
- body << attr->name << "Attr.getType()";
+ body << tgfmt(*tform, &fmtContext);
+ } else if (const NamedTypeConstraint *var = resolver.getVariable(0)) {
+ body << var->name << "Types";
+ if (!var->isVariadic())
+ body << "[0]";
+ } else if (const NamedAttribute *attr = resolver.getAttribute(0)) {
+ body << attr->name << "Attr.getType()";
} else {
body << curVar << "Types";
}
@@ -2717,7 +2734,7 @@ class OpFormatParser : public FormatParser {
/// type as well as an optional transformer to apply to that type in order to
/// properly resolve the type of a variable.
struct TypeResolutionInstance {
- ConstArgument resolver;
+ SmallVector<ConstArgument, 1> resolver;
std::optional<StringRef> transformer;
};
@@ -2827,7 +2844,7 @@ LogicalResult OpFormatParser::verify(SMLoc loc,
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false);
} else if (def.getName() == "SameOperandsAndResultType") {
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true);
- } else if (def.isSubClassOf("TypesMatchWith")) {
+ } else if (def.isSubClassOf("InferTypesFrom")) {
handleTypesMatchConstraint(variableTyResolver, def);
} else if (!op.allResultTypesKnown()) {
// This doesn't check the name directly to handle
@@ -3228,9 +3245,9 @@ void OpFormatParser::handleAllTypesMatchConstraint(
// Mark this value as the type resolver for the other variables.
for (unsigned j = 0; j != i; ++j)
- variableTyResolver[values[j]] = {arg, std::nullopt};
+ variableTyResolver[values[j]] = {{arg}, std::nullopt};
for (unsigned j = i + 1; j != e; ++j)
- variableTyResolver[values[j]] = {arg, std::nullopt};
+ variableTyResolver[values[j]] = {{arg}, std::nullopt};
}
}
@@ -3251,21 +3268,26 @@ void OpFormatParser::handleSameTypesConstraint(
// Set the resolvers for each operand and result.
for (unsigned i = 0, e = op.getNumOperands(); i != e; ++i)
if (!seenOperandTypes.test(i))
- variableTyResolver[op.getOperand(i).name] = {resolver, std::nullopt};
+ variableTyResolver[op.getOperand(i).name] = {{resolver}, std::nullopt};
if (includeResults) {
for (unsigned i = 0, e = op.getNumResults(); i != e; ++i)
if (!seenResultTypes.test(i))
- variableTyResolver[op.getResultName(i)] = {resolver, std::nullopt};
+ variableTyResolver[op.getResultName(i)] = {{resolver}, std::nullopt};
}
}
void OpFormatParser::handleTypesMatchConstraint(
StringMap<TypeResolutionInstance> &variableTyResolver, const Record &def) {
- StringRef lhsName = def.getValueAsString("lhs");
- StringRef rhsName = def.getValueAsString("rhs");
+ std::vector<StringRef> args = def.getValueAsListOfStrings("args");
+ StringRef target = def.getValueAsString("target");
StringRef transformer = def.getValueAsString("transformer");
- if (ConstArgument arg = findSeenArg(lhsName))
- variableTyResolver[rhsName] = {arg, transformer};
+
+ SmallVector<ConstArgument, 1> resolutionArgs;
+ llvm::for_each(args, [&](StringRef arg) {
+ if (ConstArgument seenArg = findSeenArg(arg))
+ resolutionArgs.push_back(seenArg);
+ });
+ variableTyResolver[target] = {resolutionArgs, transformer};
}
ConstArgument OpFormatParser::findSeenArg(StringRef name) {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, only minor comments :))
I think it makes sense for someone else more familiar with the type inference to look at it as well however
// A type constraint that denotes `transform(unpack(lhs.getTypes())) == rhs.getType()`. | ||
// An optional comparator function may be provided that changes the above form | ||
// into: `comparator(transform(lhs.getType()), rhs.getType())`. | ||
class TypesMatchWith<string summary, string lhsArg, string rhsArg, | ||
string transform, string comparator = "std::equal_to<>()"> | ||
// into: `comparator(transform(unpack(lhs.getTypes())), rhs.getType())`. | ||
class InferTypesFrom<string summary, list<string> lhsArg, string rhsArg, | ||
string transform, | ||
string comparator = "std::equal_to<>()"> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unpack
seems to be a bit confusing to me, I would keep it as transform(lhs)
or maybe transform(types(lhs))
.
Could we also document the requirements of the transform
string? E.g. that the types of arguments in lhs
are expected to be referenced as $argN
assert(args.size() == 1 && | ||
"multiple arguments for result inference not yet supported."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how we've handled this in the past, but I feel llvm::report_fatal_error
might be more appropriate here given this is a backend limitation that should also be reported in release builds
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ODS error handling isn't great, but report_fatal_error
would at least mean the user is more likely to actually see the error message
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tablegen error reporting is quite awful, but the best that can be done here is likely PrintFatalError
. Look at other uses in this file for that, there are some TableGen utilities that do this in a uniform way.
void setResolver(const SmallVector<ConstArgument, 1> &arg, | ||
std::optional<StringRef> transformer) { | ||
resolver = arg; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
void setResolver(const SmallVector<ConstArgument, 1> &arg, | |
std::optional<StringRef> transformer) { | |
resolver = arg; | |
void setResolver(SmallVector<ConstArgument, 1> arg, | |
std::optional<StringRef> transformer) { | |
resolver = std::move(arg); |
ulta nit
@@ -303,22 +303,33 @@ struct OperationFormat { | |||
std::optional<int> getBuilderIdx() const { return builderIdx; } | |||
void setBuilderIdx(int idx) { builderIdx = idx; } | |||
|
|||
int getNumArgs() const { return resolver.size(); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could the doc comments here be adjusted? Ditto below
@@ -2827,7 +2844,7 @@ LogicalResult OpFormatParser::verify(SMLoc loc, | |||
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/false); | |||
} else if (def.getName() == "SameOperandsAndResultType") { | |||
handleSameTypesConstraint(variableTyResolver, /*includeResults=*/true); | |||
} else if (def.isSubClassOf("TypesMatchWith")) { | |||
} else if (def.isSubClassOf("InferTypesFrom")) { | |||
handleTypesMatchConstraint(variableTyResolver, def); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we rename this function
if (ConstArgument seenArg = findSeenArg(arg)) | ||
resolutionArgs.push_back(seenArg); | ||
}); | ||
variableTyResolver[target] = {resolutionArgs, transformer}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
variableTyResolver[target] = {resolutionArgs, transformer}; | |
variableTyResolver[target] = {std::move(resolutionArgs), transformer}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with some minor comments, and Markus's as well
return resolver.empty() | ||
? nullptr | ||
: llvm::dyn_cast_if_present<const NamedTypeConstraint *>( | ||
resolver[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return resolver.empty() | |
? nullptr | |
: llvm::dyn_cast_if_present<const NamedTypeConstraint *>( | |
resolver[i]); | |
if (resolver.empty()) | |
return nullptr | |
return llvm::dyn_cast_if_present<const NamedTypeConstraint *>(resolver[i]); |
1 less line of code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same below
@@ -396,7 +397,7 @@ void Operator::populateTypeInferenceInfo( | |||
// All result types are inferred from the operand type. | |||
int operandIdx = operandI - arguments.begin(); | |||
for (int i = 0; i < getNumResults(); ++i) | |||
resultTypeMapping.emplace_back(operandIdx, "$_self"); | |||
resultTypeMapping.emplace_back(operandIdx, "$arg0"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So independent of the operand this is arg0?
This patch adds support for inferring operand types from multiple operand types.
The patch introduces a new
InferTypesFrom
class and makes the backend rely on it, since it's a more general class. The olderTypesMatchWith
class is backwards compatible with this change and works the same as before.Inferring result types could also be added with this change, but is more complex as we need to generate a different builder call as well (which needs more intrusive changes into the operation definition builder)