Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Shape] Support >2 args in shape.broadcast folder #126808

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

mtsokol
Copy link
Contributor

@mtsokol mtsokol commented Feb 11, 2025

Hi!

As the title says, this PR adds support for >2 arguments in shape.broadcast folder by sequentially calling getBroadcastedShape.

@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2025

@llvm/pr-subscribers-mlir

Author: Mateusz Sokół (mtsokol)

Changes

Hi!

As the title says, this PR adds support for >2 arguments in shape.broadcast folder by sequentially calling getBroadcastedShape.


Full diff: https://github.com/llvm/llvm-project/pull/126808.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+21-13)
  • (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
  • (modified) mlir/test/Dialect/Shape/canonicalize.mlir (+13)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 65efc88e9c403..daa33ea865a5c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -649,24 +649,32 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
     return getShapes().front();
   }
 
-  // TODO: Support folding with more than 2 input shapes
-  if (getShapes().size() > 2)
+  if (!adaptor.getShapes().front())
     return nullptr;
 
-  if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
-    return nullptr;
-  auto lhsShape = llvm::to_vector<6>(
-      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
-          .getValues<int64_t>());
-  auto rhsShape = llvm::to_vector<6>(
-      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
+  auto firstShape = llvm::to_vector<6>(
+      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
           .getValues<int64_t>());
+
   SmallVector<int64_t, 6> resultShape;
+  resultShape.clear();
+  std::copy(firstShape.begin(), firstShape.end(), std::back_inserter(resultShape));
 
-  // If the shapes are not compatible, we can't fold it.
-  // TODO: Fold to an "error".
-  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
-    return nullptr;
+  for (auto next : adaptor.getShapes().drop_front()) {
+    if (!next)
+      return nullptr;
+    auto nextShape = llvm::to_vector<6>(
+        llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
+
+    SmallVector<int64_t, 6> tmpShape;
+    // If the shapes are not compatible, we can't fold it.
+    // TODO: Fold to an "error".
+    if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
+      return nullptr;
+
+    resultShape.clear();
+    std::copy(tmpShape.begin(), tmpShape.end(), std::back_inserter(resultShape));
+  }
 
   Builder builder(getContext());
   return builder.getIndexTensorAttr(resultShape);
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index a7aa25eae2644..6e62a33037eb8 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -84,7 +84,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
     if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
       // One or both dimensions is unknown. Follow TensorFlow behavior:
       // - If either dimension is greater than 1, we assume that the program is
-      //   correct, and the other dimension will be broadcast to match it.
+      //   correct, and the other dimension will be broadcasted to match it.
       // - If either dimension is 1, the other dimension is the output.
       if (*i1 > 1) {
         *iR = *i1;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index cf439c9c1b854..9ed4837a2fe7e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -86,6 +86,19 @@ func.func @broadcast() -> !shape.shape {
 
 // -----
 
+// Variadic case including extent tensors.
+// CHECK-LABEL: @broadcast_variadic
+func.func @broadcast_variadic() -> !shape.shape {
+  // CHECK: shape.const_shape [7, 2, 10] : !shape.shape
+  %0 = shape.const_shape [2, 1] : tensor<2xindex>
+  %1 = shape.const_shape [7, 2, 1] : tensor<3xindex>
+  %2 = shape.const_shape [1, 10] : tensor<2xindex>
+  %3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> !shape.shape
+  return %3 : !shape.shape
+}
+
+// -----
+
 // Rhs is a scalar.
 // CHECK-LABEL: func @f
 func.func @f(%arg0 : !shape.shape) -> !shape.shape {

@llvmbot
Copy link
Member

llvmbot commented Feb 11, 2025

@llvm/pr-subscribers-mlir-shape

Author: Mateusz Sokół (mtsokol)

Changes

Hi!

As the title says, this PR adds support for >2 arguments in shape.broadcast folder by sequentially calling getBroadcastedShape.


Full diff: https://github.com/llvm/llvm-project/pull/126808.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/Shape/IR/Shape.cpp (+21-13)
  • (modified) mlir/lib/Dialect/Traits.cpp (+1-1)
  • (modified) mlir/test/Dialect/Shape/canonicalize.mlir (+13)
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 65efc88e9c403..daa33ea865a5c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -649,24 +649,32 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
     return getShapes().front();
   }
 
-  // TODO: Support folding with more than 2 input shapes
-  if (getShapes().size() > 2)
+  if (!adaptor.getShapes().front())
     return nullptr;
 
-  if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
-    return nullptr;
-  auto lhsShape = llvm::to_vector<6>(
-      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
-          .getValues<int64_t>());
-  auto rhsShape = llvm::to_vector<6>(
-      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
+  auto firstShape = llvm::to_vector<6>(
+      llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
           .getValues<int64_t>());
+
   SmallVector<int64_t, 6> resultShape;
+  resultShape.clear();
+  std::copy(firstShape.begin(), firstShape.end(), std::back_inserter(resultShape));
 
-  // If the shapes are not compatible, we can't fold it.
-  // TODO: Fold to an "error".
-  if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
-    return nullptr;
+  for (auto next : adaptor.getShapes().drop_front()) {
+    if (!next)
+      return nullptr;
+    auto nextShape = llvm::to_vector<6>(
+        llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
+
+    SmallVector<int64_t, 6> tmpShape;
+    // If the shapes are not compatible, we can't fold it.
+    // TODO: Fold to an "error".
+    if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
+      return nullptr;
+
+    resultShape.clear();
+    std::copy(tmpShape.begin(), tmpShape.end(), std::back_inserter(resultShape));
+  }
 
   Builder builder(getContext());
   return builder.getIndexTensorAttr(resultShape);
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index a7aa25eae2644..6e62a33037eb8 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -84,7 +84,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
     if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
       // One or both dimensions is unknown. Follow TensorFlow behavior:
       // - If either dimension is greater than 1, we assume that the program is
-      //   correct, and the other dimension will be broadcast to match it.
+      //   correct, and the other dimension will be broadcasted to match it.
       // - If either dimension is 1, the other dimension is the output.
       if (*i1 > 1) {
         *iR = *i1;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index cf439c9c1b854..9ed4837a2fe7e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -86,6 +86,19 @@ func.func @broadcast() -> !shape.shape {
 
 // -----
 
+// Variadic case including extent tensors.
+// CHECK-LABEL: @broadcast_variadic
+func.func @broadcast_variadic() -> !shape.shape {
+  // CHECK: shape.const_shape [7, 2, 10] : !shape.shape
+  %0 = shape.const_shape [2, 1] : tensor<2xindex>
+  %1 = shape.const_shape [7, 2, 1] : tensor<3xindex>
+  %2 = shape.const_shape [1, 10] : tensor<2xindex>
+  %3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> !shape.shape
+  return %3 : !shape.shape
+}
+
+// -----
+
 // Rhs is a scalar.
 // CHECK-LABEL: func @f
 func.func @f(%arg0 : !shape.shape) -> !shape.shape {

for (auto next : adaptor.getShapes().drop_front()) {
if (!next)
return nullptr;
auto nextShape = llvm::to_vector<6>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the getBroadcastedShape implementation shape vector size is hardcoded to 6, so I did it similarly here. Does it make sense? Looks like an arbitrary value from the outside.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, semi. If I recall it was either the default elsewhere in an ML framework where this was used or the max rank along set of ML models. But it is a bit arbitrary. Elsewhere folks also use the default of SmallVector. (The latter is probably a little bit more arbitrary, but neither is very fine tuned).

Comment on lines 675 to 676
resultShape.clear();
std::copy(tmpShape.begin(), tmpShape.end(), std::back_inserter(resultShape));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed getBroadcastedShape implementation and I'm not sure if it's the best way to handle Vectors/Shapes here, so a penny for your thoughts!

Copy link

github-actions bot commented Feb 11, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@mtsokol mtsokol force-pushed the shape-broadcast-fold-vararg branch from c4a815d to c3cf613 Compare February 11, 2025 22:15
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good thanks

for (auto next : adaptor.getShapes().drop_front()) {
if (!next)
return nullptr;
auto nextShape = llvm::to_vector<6>(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, semi. If I recall it was either the default elsewhere in an ML framework where this was used or the max rank along set of ML models. But it is a bit arbitrary. Elsewhere folks also use the default of SmallVector. (The latter is probably a little bit more arbitrary, but neither is very fine tuned).

SmallVector<int64_t, 6> resultShape;
resultShape.clear();
std::copy(firstShape.begin(), firstShape.end(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is firstShape needed vs directly initializing resultShape?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it isn't needed here - updated!

@mtsokol mtsokol force-pushed the shape-broadcast-fold-vararg branch from c3cf613 to b259199 Compare February 19, 2025 16:57
@mtsokol mtsokol requested a review from jpienaar February 20, 2025 09:34
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, modulo checking formatting.

return nullptr;

resultShape.clear();
std::copy(tmpShape.begin(), tmpShape.end(),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this what clang-format produced?

Copy link
Contributor Author

@mtsokol mtsokol Mar 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jpienaar Yes, that's correct - it was produced by a clang-format. Here's another place where std::copy is formatted the same way:

std::copy(Overrides.begin(), Overrides.end(),
reinterpret_cast<ModuleMacro **>(this + 1));

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants