Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TableGen] Add a !listflatten operator to TableGen #109346

Merged
merged 1 commit into from
Sep 24, 2024

Conversation

jurahul
Copy link
Contributor

@jurahul jurahul commented Sep 19, 2024

Add a !listflatten operator that will transform an input list of type list<list<X>> to list<X> by concatenating elements of the constituent lists of the input argument.

@jurahul
Copy link
Contributor Author

jurahul commented Sep 19, 2024

My goal is to use this to flatten Intrinsic TypeSig field using this operator and eliminate the nested loop in IntrinsicEmitter. Also have some other use in mind, so hoping this will be useful in general.

@jurahul jurahul marked this pull request as ready for review September 20, 2024 00:47
@llvmbot
Copy link
Collaborator

llvmbot commented Sep 20, 2024

@llvm/pr-subscribers-tablegen

Author: Rahul Joshi (jurahul)

Changes

Add a !listflatten operator that will transform an input list of type list&lt;list&lt;X&gt;&gt; to list&lt;X&gt; by concatenating elements of the constituent lists of the input argument.


Full diff: https://github.com/llvm/llvm-project/pull/109346.diff

8 Files Affected:

  • (modified) llvm/docs/TableGen/ProgRef.rst (+11-6)
  • (modified) llvm/include/llvm/TableGen/Record.h (+2-1)
  • (modified) llvm/lib/TableGen/Record.cpp (+25)
  • (modified) llvm/lib/TableGen/TGLexer.cpp (+1)
  • (modified) llvm/lib/TableGen/TGLexer.h (+1)
  • (modified) llvm/lib/TableGen/TGParser.cpp (+26-5)
  • (added) llvm/test/TableGen/listflatten-error.td (+6)
  • (added) llvm/test/TableGen/listflatten.td (+29)
diff --git a/llvm/docs/TableGen/ProgRef.rst b/llvm/docs/TableGen/ProgRef.rst
index dcea3b721dae27..69cfaeb5f8442e 100644
--- a/llvm/docs/TableGen/ProgRef.rst
+++ b/llvm/docs/TableGen/ProgRef.rst
@@ -223,12 +223,12 @@ TableGen provides "bang operators" that have a wide variety of uses:
                : !div         !empty       !eq          !exists      !filter
                : !find        !foldl       !foreach     !ge          !getdagarg
                : !getdagname  !getdagop    !gt          !head        !if
-               : !interleave  !isa         !le          !listconcat  !listremove
-               : !listsplat   !logtwo      !lt          !mul         !ne
-               : !not         !or          !range       !repr        !setdagarg
-               : !setdagname  !setdagop    !shl         !size        !sra
-               : !srl         !strconcat   !sub         !subst       !substr
-               : !tail        !tolower     !toupper     !xor
+               : !interleave  !isa         !le          !listconcat  !listflatten
+               : !listremove  !listsplat   !logtwo      !lt          !mul
+               : !ne          !not         !or          !range       !repr
+               : !setdagarg   !setdagname  !setdagop    !shl         !size
+               : !sra         !srl         !strconcat   !sub         !subst
+               : !substr      !tail        !tolower     !toupper     !xor
 
 The ``!cond`` operator has a slightly different
 syntax compared to other bang operators, so it is defined separately:
@@ -1832,6 +1832,11 @@ and non-0 as true.
     This operator concatenates the list arguments *list1*, *list2*, etc., and
     produces the resulting list. The lists must have the same element type.
 
+``!listflatten(``\ *list*\ ``)``
+    This operator flattens a list of lists *list* and produces a list with all
+    elements of the constituent lists concatenated. If *list* is of type
+    ``list<list<X>>`` the resulting list is of type ``list<X>``.
+
 ``!listremove(``\ *list1*\ ``,`` *list2*\ ``)``
     This operator returns a copy of *list1* removing all elements that also occur in
     *list2*. The lists must have the same element type.
diff --git a/llvm/include/llvm/TableGen/Record.h b/llvm/include/llvm/TableGen/Record.h
index 5348c1177f63ed..4cd73c3f675527 100644
--- a/llvm/include/llvm/TableGen/Record.h
+++ b/llvm/include/llvm/TableGen/Record.h
@@ -847,7 +847,8 @@ class UnOpInit : public OpInit, public FoldingSetNode {
     EMPTY,
     GETDAGOP,
     LOG2,
-    REPR
+    REPR,
+    LISTFLATTEN,
   };
 
 private:
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index ff2da3badb3628..1f403e19339a2a 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -987,6 +987,28 @@ Init *UnOpInit::Fold(Record *CurRec, bool IsFinal) const {
       }
     }
     break;
+
+  case LISTFLATTEN:
+    ListInit *LHSList = dyn_cast<ListInit>(LHS);
+    if (!LHSList)
+      break;
+    ListRecTy *InnerListTy = cast<ListRecTy>(LHSList->getElementType());
+    if (!InnerListTy)
+      break;
+    std::vector<Init *> Flattened;
+    bool Failed = false;
+    // Concatenate elements of all the inner lists.
+    for (Init *InnerInit : LHSList->getValues()) {
+      ListInit *InnerList = dyn_cast<ListInit>(InnerInit);
+      if (!InnerList) {
+        Failed = true;
+        break;
+      }
+      for (Init *InnerElem : InnerList->getValues())
+        Flattened.push_back(InnerElem);
+    }
+    if (!Failed)
+      return ListInit::get(Flattened, InnerListTy->getElementType());
   }
   return const_cast<UnOpInit *>(this);
 }
@@ -1011,6 +1033,9 @@ std::string UnOpInit::getAsString() const {
   case EMPTY: Result = "!empty"; break;
   case GETDAGOP: Result = "!getdagop"; break;
   case LOG2 : Result = "!logtwo"; break;
+  case LISTFLATTEN:
+    Result = "!listflatten";
+    break;
   case REPR:
     Result = "!repr";
     break;
diff --git a/llvm/lib/TableGen/TGLexer.cpp b/llvm/lib/TableGen/TGLexer.cpp
index 62a884e01a5306..8fe7f69ecf8e59 100644
--- a/llvm/lib/TableGen/TGLexer.cpp
+++ b/llvm/lib/TableGen/TGLexer.cpp
@@ -628,6 +628,7 @@ tgtok::TokKind TGLexer::LexExclaim() {
           .Case("foreach", tgtok::XForEach)
           .Case("filter", tgtok::XFilter)
           .Case("listconcat", tgtok::XListConcat)
+          .Case("listflatten", tgtok::XListFlatten)
           .Case("listsplat", tgtok::XListSplat)
           .Case("listremove", tgtok::XListRemove)
           .Case("range", tgtok::XRange)
diff --git a/llvm/lib/TableGen/TGLexer.h b/llvm/lib/TableGen/TGLexer.h
index 9adc03ccc72b85..4fa4d84d0535d3 100644
--- a/llvm/lib/TableGen/TGLexer.h
+++ b/llvm/lib/TableGen/TGLexer.h
@@ -122,6 +122,7 @@ enum TokKind {
   XSRL,
   XSHL,
   XListConcat,
+  XListFlatten,
   XListSplat,
   XStrConcat,
   XInterleave,
diff --git a/llvm/lib/TableGen/TGParser.cpp b/llvm/lib/TableGen/TGParser.cpp
index 1a60c2a567a297..20de3cc4dad9e9 100644
--- a/llvm/lib/TableGen/TGParser.cpp
+++ b/llvm/lib/TableGen/TGParser.cpp
@@ -1190,6 +1190,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
   case tgtok::XNOT:
   case tgtok::XToLower:
   case tgtok::XToUpper:
+  case tgtok::XListFlatten:
   case tgtok::XLOG2:
   case tgtok::XHead:
   case tgtok::XTail:
@@ -1235,6 +1236,11 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
       Code = UnOpInit::NOT;
       Type = IntRecTy::get(Records);
       break;
+    case tgtok::XListFlatten:
+      Lex.Lex(); // eat the operation.
+      Code = UnOpInit::LISTFLATTEN;
+      Type = IntRecTy::get(Records); // Bogus type used here.
+      break;
     case tgtok::XLOG2:
       Lex.Lex();  // eat the operation
       Code = UnOpInit::LOG2;
@@ -1309,7 +1315,8 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
       }
     }
 
-    if (Code == UnOpInit::HEAD || Code == UnOpInit::TAIL) {
+    if (Code == UnOpInit::HEAD || Code == UnOpInit::TAIL ||
+        Code == UnOpInit::LISTFLATTEN) {
       ListInit *LHSl = dyn_cast<ListInit>(LHS);
       TypedInit *LHSt = dyn_cast<TypedInit>(LHS);
       if (!LHSl && !LHSt) {
@@ -1328,6 +1335,8 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
         TokError("empty list argument in unary operator");
         return nullptr;
       }
+      bool UseElementType =
+          Code == UnOpInit::HEAD || Code == UnOpInit::LISTFLATTEN;
       if (LHSl) {
         Init *Item = LHSl->getElement(0);
         TypedInit *Itemt = dyn_cast<TypedInit>(Item);
@@ -1335,12 +1344,24 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
           TokError("untyped list element in unary operator");
           return nullptr;
         }
-        Type = (Code == UnOpInit::HEAD) ? Itemt->getType()
-                                        : ListRecTy::get(Itemt->getType());
+        Type = UseElementType ? Itemt->getType()
+                              : ListRecTy::get(Itemt->getType());
       } else {
         assert(LHSt && "expected list type argument in unary operator");
         ListRecTy *LType = dyn_cast<ListRecTy>(LHSt->getType());
-        Type = (Code == UnOpInit::HEAD) ? LType->getElementType() : LType;
+        Type = UseElementType ? LType->getElementType() : LType;
+      }
+
+      // for listflatten, we expect a list of lists.
+      if (Code == UnOpInit::LISTFLATTEN) {
+        ListRecTy *InnerListTy = dyn_cast<ListRecTy>(Type);
+        if (!InnerListTy) {
+          TokError("expected argument of type list of list in !listflatten "
+                   "operator");
+          return nullptr;
+        }
+        // listflatten will convert list<list<X>> to list<X>.
+        Type = ListRecTy::get(InnerListTy->getElementType());
       }
     }
 
@@ -1378,7 +1399,7 @@ Init *TGParser::ParseOperation(Record *CurRec, RecTy *ItemType) {
 
   case tgtok::XExists: {
     // Value ::= !exists '<' Type '>' '(' Value ')'
-    Lex.Lex(); // eat the operation
+    Lex.Lex(); // eat the operation.
 
     RecTy *Type = ParseOperatorType();
     if (!Type)
diff --git a/llvm/test/TableGen/listflatten-error.td b/llvm/test/TableGen/listflatten-error.td
new file mode 100644
index 00000000000000..e18528a08e6bf6
--- /dev/null
+++ b/llvm/test/TableGen/listflatten-error.td
@@ -0,0 +1,6 @@
+// RUN: not llvm-tblgen %s 2>&1 | FileCheck %s  -DFILE=%s
+
+// CHECK: [[FILE]]:[[@LINE+2]]:33: error: expected argument of type list of list in !listflatten operator
+class Flatten<list<int> A> {
+    list<int> F = !listflatten(A);
+}
diff --git a/llvm/test/TableGen/listflatten.td b/llvm/test/TableGen/listflatten.td
new file mode 100644
index 00000000000000..20119b24cce0a3
--- /dev/null
+++ b/llvm/test/TableGen/listflatten.td
@@ -0,0 +1,29 @@
+
+// RUN: llvm-tblgen %s | FileCheck %s
+
+class Flatten<list<int> A, list<int> B> {
+    list<int> Flat1 = !listflatten([A, B, [6], [7, 8]]);
+
+    list<list<int>> X = [A, B];
+    list<int> Flat2 = !listflatten(!listconcat(X, [[7]]));
+
+    // Generate a nested list of integers.
+    list<int> Y0 = [1, 2, 3, 4];
+    list<list<int>> Y1 = !foreach(elem, Y0, [elem]);
+    list<list<list<int>>> Y2 = !foreach(elem, Y1, [elem]);
+    list<list<list<list<int>>>> Y3 = !foreach(elem, Y2, [elem]);
+
+    // Flatten it completely.
+    list<int> Flat3=!listflatten(!listflatten(!listflatten(Y3)));
+
+    // Flatten it partially.
+    list<list<list<int>>> Flat4 = !listflatten(Y3);
+    list<list<int>> Flat5 = !listflatten(!listflatten(Y3));
+}
+
+// CHECK: list<int> Flat1 = [1, 2, 3, 4, 5, 6, 7, 8];
+// CHECK: list<int> Flat2 = [1, 2, 3, 4, 5, 7];
+// CHECK: list<int> Flat3 = [1, 2, 3, 4];
+// CHECK{LITERAL}: list<list<list<int>>> Flat4 = [[[1]], [[2]], [[3]], [[4]]];
+def F : Flatten<[1,2], [3,4,5]>;
+

Copy link
Contributor

@arsenm arsenm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could probably implement this with existing primitives, but I don't think we need to really minimize the number of operators

llvm/lib/TableGen/Record.cpp Outdated Show resolved Hide resolved
@jurahul
Copy link
Contributor Author

jurahul commented Sep 20, 2024

I realized after I implemented this that existing fold primitives can also implement this (and I see atleast one example in an AMDGPU td file that does this). But this is much easier to use.

Add a !listflatten operator that will transform an input list of
type `list<list<X>>` to `list<X>` by concatenating elements of the
constituent lists of the input argument.
@jurahul jurahul merged commit 66c8dce into llvm:main Sep 24, 2024
9 checks passed
@jurahul jurahul deleted the tg_listflatten_operator branch September 24, 2024 13:01
augusto2112 pushed a commit to augusto2112/llvm-project that referenced this pull request Sep 26, 2024
Add a !listflatten operator that will transform an input list of type
`list<list<X>>` to `list<X>` by concatenating elements of the
constituent lists of the input argument.
xgupta pushed a commit to xgupta/llvm-project that referenced this pull request Oct 4, 2024
Add a !listflatten operator that will transform an input list of type
`list<list<X>>` to `list<X>` by concatenating elements of the
constituent lists of the input argument.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants