-
Notifications
You must be signed in to change notification settings - Fork 12k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TableGen] Add a !listflatten operator to TableGen #109346
Conversation
130b8ba
to
d7bb17d
Compare
My goal is to use this to flatten Intrinsic |
d7bb17d
to
7cc20d7
Compare
@llvm/pr-subscribers-tablegen Author: Rahul Joshi (jurahul) ChangesAdd a !listflatten operator that will transform an input list of type Full diff: https://github.com/llvm/llvm-project/pull/109346.diff 8 Files Affected:
diff --git a/llvm/docs/TableGen/ProgRef.rst b/llvm/docs/TableGen/ProgRef.rst
index dcea3b721dae27..69cfaeb5f8442e 100644
--- a/llvm/docs/TableGen/ProgRef.rst
+++ b/llvm/docs/TableGen/ProgRef.rst
@@ -223,12 +223,12 @@ TableGen provides "bang operators" that have a wide variety of uses:
: !div !empty !eq !exists !filter
: !find !foldl !foreach !ge !getdagarg
: !getdagname !getdagop !gt !head !if
- : !interleave !isa !le !listconcat !listremove
- : !listsplat !logtwo !lt !mul !ne
- : !not !or !range !repr !setdagarg
- : !setdagname !setdagop !shl !size !sra
- : !srl !strconcat !sub !subst !substr
- : !tail !tolower !toupper !xor
+ : !interleave !isa !le !listconcat !listflatten
+ : !listremove !listsplat !logtwo !lt !mul
+ : !ne !not !or !range !repr
+ : !setdagarg !setdagname !setdagop !shl !size
+ : !sra !srl !strconcat !sub !subst
+ : !substr !tail !tolower !toupper !xor
The ``!cond`` operator has a slightly different
syntax compared to other bang operators, so it is defined separately:
@@ -1832,6 +1832,11 @@ and non-0 as true.
This operator concatenates the list arguments *list1*, *list2*, etc., and
produces the resulting list. The lists must have the same element type.
+``!listflatten(``\ *list*\ ``)``
+ This operator flattens a list of lists *list* and produces a list with all
+ elements of the constituent lists concatenated. If *list* is of type
+ ``list<list<X>>`` the resulting list is of type ``list<X>``.
+
``!listremove(``\ *list1*\ ``,`` *list2*\ ``)``
This operator returns a copy of *list1* removing all elements that also occur in
*list2*. The lists must have the same element type.
diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h
index 5348c1177f63ed..4cd73c3f675527 100644
--- a/llvm/include/llvm/TableGen/Record.h
+++ b/llvm/include/llvm/TableGen/Record.h
@@ -847,7 +847,8 @@ class UnOpInit : public OpInit, public FoldingSetNode {
EMPTY,
GETDAGOP,
LOG2,
- REPR
+ REPR,
+ LISTFLATTEN,
};
private:
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index ff2da3badb3628..1f403e19339a2a 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -987,6 +987,28 @@ Init *UnOpInit::Fold(Record *CurRec, bool IsFinal) const {
}
}
break;
+
+ case LISTFLATTEN:
+ ListInit *LHSList = dyn_cast<ListInit>(LHS);
+ if (!LHSList)
+ break;
+ ListRecTy *InnerListTy = cast<ListRecTy>(LHSList->getElementType());
+ if (!InnerListTy)
+ break;
+ std::vector<Init *> Flattened;
+ bool Failed = false;
+ // Concatenate elements of all the inner lists.
+ for (Init *InnerInit : LHSList->getValues()) {
+ ListInit *InnerList = dyn_cast<ListInit>(InnerInit);
+ if (!InnerList) {
+ Failed = true;
+ break;
+ }
+ for (Init *InnerElem : InnerList->getValues())
+ Flattened.push_back(InnerElem);
+ }
+ if (!Failed)
+ return ListInit::get(Flattened, InnerListTy->getElementType());
}
return const_cast<UnOpInit *>(this);
}
@@ -1011,6 +1033,9 @@ std::string UnOpInit::getAsString() const {
case EMPTY: Result = "!empty"; break;
case GETDAGOP: Result = "!getdagop"; break;
case LOG2 : Result = "!logtwo"; break;
+ case LISTFLATTEN:
+ Result = "!listflatten";
+ break;
case REPR:
Result = "!repr";
break;
diff --git a/llvm/lib/TableGen/TGLexer.cpp b/llvm/lib/TableGen/TGLexer.cpp
index 62a884e01a5306..8fe7f69ecf8e59 100644
--- a/llvm/lib/TableGen/TGLexer.cpp
+++ b/llvm/lib/TableGen/TGLexer.cpp
@@ -628,6 +628,7 @@ tgtok::TokKind TGLexer::LexExclaim() {
.Case("foreach", tgtok::XForEach)
.Case("filter", tgtok::XFilter)
.Case("listconcat", tgtok::XListConcat)
+ .Case("listflatten", tgtok::XListFlatten)
.Case("listsplat", tgtok::XListSplat)
.Case("listremove", tgtok::XListRemove)
.Case("range", tgtok::XRange)
diff --git a/llvm/lib/TableGen/TGLexer.h b/llvm/lib/TableGen/TGLexer.h
index 9adc03ccc72b85..4fa4d84d0535d3 100644
--- a/llvm/lib/TableGen/TGLexer.h
+++ b/llvm/lib/TableGen/TGLexer.h
@@ -122,6 +122,7 @@ enum TokKind {
XSRL,
XSHL,
XListConcat,
+ XListFlatten,
XListSplat,
XStrConcat,
XInterleave,
diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp
index 1a60c2a567a297..20de3cc4dad9e9 100644
--- a/llvm/lib/TableGen/TGParser.cpp
+++ b/llvm/lib/TableGen/TGParser.cpp
@@ -1190,6 +1190,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
case tgtok::XNOT:
case tgtok::XToLower:
case tgtok::XToUpper:
+ case tgtok::XListFlatten:
case tgtok::XLOG2:
case tgtok::XHead:
case tgtok::XTail:
@@ -1235,6 +1236,11 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
Code = UnOpInit::NOT;
Type = IntRecTy::get(Records);
break;
+ case tgtok::XListFlatten:
+ Lex.Lex(); // eat the operation.
+ Code = UnOpInit::LISTFLATTEN;
+ Type = IntRecTy::get(Records); // Bogus type used here.
+ break;
case tgtok::XLOG2:
Lex.Lex(); // eat the operation
Code = UnOpInit::LOG2;
@@ -1309,7 +1315,8 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
}
}
- if (Code == UnOpInit::HEAD || Code == UnOpInit::TAIL) {
+ if (Code == UnOpInit::HEAD || Code == UnOpInit::TAIL ||
+ Code == UnOpInit::LISTFLATTEN) {
ListInit *LHSl = dyn_cast<ListInit>(LHS);
TypedInit *LHSt = dyn_cast<TypedInit>(LHS);
if (!LHSl && !LHSt) {
@@ -1328,6 +1335,8 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
TokError("empty list argument in unary operator");
return nullptr;
}
+ bool UseElementType =
+ Code == UnOpInit::HEAD || Code == UnOpInit::LISTFLATTEN;
if (LHSl) {
Init *Item = LHSl->getElement(0);
TypedInit *Itemt = dyn_cast<TypedInit>(Item);
@@ -1335,12 +1344,24 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
TokError("untyped list element in unary operator");
return nullptr;
}
- Type = (Code == UnOpInit::HEAD) ? Itemt->getType()
- : ListRecTy::get(Itemt->getType());
+ Type = UseElementType ? Itemt->getType()
+ : ListRecTy::get(Itemt->getType());
} else {
assert(LHSt && "expected list type argument in unary operator");
ListRecTy *LType = dyn_cast<ListRecTy>(LHSt->getType());
- Type = (Code == UnOpInit::HEAD) ? LType->getElementType() : LType;
+ Type = UseElementType ? LType->getElementType() : LType;
+ }
+
+ // for listflatten, we expect a list of lists.
+ if (Code == UnOpInit::LISTFLATTEN) {
+ ListRecTy *InnerListTy = dyn_cast<ListRecTy>(Type);
+ if (!InnerListTy) {
+ TokError("expected argument of type list of list in !listflatten "
+ "operator");
+ return nullptr;
+ }
+ // listflatten will convert list<list<X>> to list<X>.
+ Type = ListRecTy::get(InnerListTy->getElementType());
}
}
@@ -1378,7 +1399,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
case tgtok::XExists: {
// Value ::= !exists '<' Type '>' '(' Value ')'
- Lex.Lex(); // eat the operation
+ Lex.Lex(); // eat the operation.
RecTy *Type = ParseOperatorType();
if (!Type)
diff --git a/llvm/test/TableGen/listflatten-error.td b/llvm/test/TableGen/listflatten-error.td
new file mode 100644
index 00000000000000..e18528a08e6bf6
--- /dev/null
+++ b/llvm/test/TableGen/listflatten-error.td
@@ -0,0 +1,6 @@
+// RUN: not llvm-tblgen %s 2>&1 | FileCheck %s -DFILE=%s
+
+// CHECK: [[FILE]]:[[@LINE+2]]:33: error: expected argument of type list of list in !listflatten operator
+class Flatten<list<int> A> {
+ list<int> F = !listflatten(A);
+}
diff --git a/llvm/test/TableGen/listflatten.td b/llvm/test/TableGen/listflatten.td
new file mode 100644
index 00000000000000..20119b24cce0a3
--- /dev/null
+++ b/llvm/test/TableGen/listflatten.td
@@ -0,0 +1,29 @@
+
+// RUN: llvm-tblgen %s | FileCheck %s
+
+class Flatten<list<int> A, list<int> B> {
+ list<int> Flat1 = !listflatten([A, B, [6], [7, 8]]);
+
+ list<list<int>> X = [A, B];
+ list<int> Flat2 = !listflatten(!listconcat(X, [[7]]));
+
+ // Generate a nested list of integers.
+ list<int> Y0 = [1, 2, 3, 4];
+ list<list<int>> Y1 = !foreach(elem, Y0, [elem]);
+ list<list<list<int>>> Y2 = !foreach(elem, Y1, [elem]);
+ list<list<list<list<int>>>> Y3 = !foreach(elem, Y2, [elem]);
+
+ // Flatten it completely.
+ list<int> Flat3=!listflatten(!listflatten(!listflatten(Y3)));
+
+ // Flatten it partially.
+ list<list<list<int>>> Flat4 = !listflatten(Y3);
+ list<list<int>> Flat5 = !listflatten(!listflatten(Y3));
+}
+
+// CHECK: list<int> Flat1 = [1, 2, 3, 4, 5, 6, 7, 8];
+// CHECK: list<int> Flat2 = [1, 2, 3, 4, 5, 7];
+// CHECK: list<int> Flat3 = [1, 2, 3, 4];
+// CHECK{LITERAL}: list<list<list<int>>> Flat4 = [[[1]], [[2]], [[3]], [[4]]];
+def F : Flatten<[1,2], [3,4,5]>;
+
|
7cc20d7
to
9c1866d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could probably implement this with existing primitives, but I don't think we need to really minimize the number of operators
I realized after I implemented this that existing fold primitives can also implement this (and I see atleast one example in an AMDGPU td file that does this). But this is much easier to use. |
9c1866d
to
44997da
Compare
Add a !listflatten operator that will transform an input list of type `list<list<X>>` to `list<X>` by concatenating elements of the constituent lists of the input argument.
44997da
to
56f2511
Compare
Add a !listflatten operator that will transform an input list of type `list<list<X>>` to `list<X>` by concatenating elements of the constituent lists of the input argument.
Add a !listflatten operator that will transform an input list of type `list<list<X>>` to `list<X>` by concatenating elements of the constituent lists of the input argument.
Add a !listflatten operator that will transform an input list of type
list<list<X>>
tolist<X>
by concatenating elements of the constituent lists of the input argument.