-
Notifications
You must be signed in to change notification settings - Fork 12.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Shape] Support >2 args in shape.broadcast
folder
#126808
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir Author: Mateusz Sokół (mtsokol) ChangesHi! As the title says, this PR adds support for >2 arguments in Full diff: https://github.com/llvm/llvm-project/pull/126808.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 65efc88e9c403..daa33ea865a5c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -649,24 +649,32 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
return getShapes().front();
}
- // TODO: Support folding with more than 2 input shapes
- if (getShapes().size() > 2)
+ if (!adaptor.getShapes().front())
return nullptr;
- if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
- return nullptr;
- auto lhsShape = llvm::to_vector<6>(
- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
- .getValues<int64_t>());
- auto rhsShape = llvm::to_vector<6>(
- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
+ auto firstShape = llvm::to_vector<6>(
+ llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
.getValues<int64_t>());
+
SmallVector<int64_t, 6> resultShape;
+ resultShape.clear();
+ std::copy(firstShape.begin(), firstShape.end(), std::back_inserter(resultShape));
- // If the shapes are not compatible, we can't fold it.
- // TODO: Fold to an "error".
- if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
- return nullptr;
+ for (auto next : adaptor.getShapes().drop_front()) {
+ if (!next)
+ return nullptr;
+ auto nextShape = llvm::to_vector<6>(
+ llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
+
+ SmallVector<int64_t, 6> tmpShape;
+ // If the shapes are not compatible, we can't fold it.
+ // TODO: Fold to an "error".
+ if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
+ return nullptr;
+
+ resultShape.clear();
+ std::copy(tmpShape.begin(), tmpShape.end(), std::back_inserter(resultShape));
+ }
Builder builder(getContext());
return builder.getIndexTensorAttr(resultShape);
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index a7aa25eae2644..6e62a33037eb8 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -84,7 +84,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
// One or both dimensions is unknown. Follow TensorFlow behavior:
// - If either dimension is greater than 1, we assume that the program is
- // correct, and the other dimension will be broadcast to match it.
+ // correct, and the other dimension will be broadcasted to match it.
// - If either dimension is 1, the other dimension is the output.
if (*i1 > 1) {
*iR = *i1;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index cf439c9c1b854..9ed4837a2fe7e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -86,6 +86,19 @@ func.func @broadcast() -> !shape.shape {
// -----
+// Variadic case including extent tensors.
+// CHECK-LABEL: @broadcast_variadic
+func.func @broadcast_variadic() -> !shape.shape {
+ // CHECK: shape.const_shape [7, 2, 10] : !shape.shape
+ %0 = shape.const_shape [2, 1] : tensor<2xindex>
+ %1 = shape.const_shape [7, 2, 1] : tensor<3xindex>
+ %2 = shape.const_shape [1, 10] : tensor<2xindex>
+ %3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> !shape.shape
+ return %3 : !shape.shape
+}
+
+// -----
+
// Rhs is a scalar.
// CHECK-LABEL: func @f
func.func @f(%arg0 : !shape.shape) -> !shape.shape {
|
@llvm/pr-subscribers-mlir-shape Author: Mateusz Sokół (mtsokol) ChangesHi! As the title says, this PR adds support for >2 arguments in Full diff: https://github.com/llvm/llvm-project/pull/126808.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp
index 65efc88e9c403..daa33ea865a5c 100644
--- a/mlir/lib/Dialect/Shape/IR/Shape.cpp
+++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp
@@ -649,24 +649,32 @@ OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
return getShapes().front();
}
- // TODO: Support folding with more than 2 input shapes
- if (getShapes().size() > 2)
+ if (!adaptor.getShapes().front())
return nullptr;
- if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
- return nullptr;
- auto lhsShape = llvm::to_vector<6>(
- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
- .getValues<int64_t>());
- auto rhsShape = llvm::to_vector<6>(
- llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
+ auto firstShape = llvm::to_vector<6>(
+ llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
.getValues<int64_t>());
+
SmallVector<int64_t, 6> resultShape;
+ resultShape.clear();
+ std::copy(firstShape.begin(), firstShape.end(), std::back_inserter(resultShape));
- // If the shapes are not compatible, we can't fold it.
- // TODO: Fold to an "error".
- if (!OpTrait::util::getBroadcastedShape(lhsShape, rhsShape, resultShape))
- return nullptr;
+ for (auto next : adaptor.getShapes().drop_front()) {
+ if (!next)
+ return nullptr;
+ auto nextShape = llvm::to_vector<6>(
+ llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
+
+ SmallVector<int64_t, 6> tmpShape;
+ // If the shapes are not compatible, we can't fold it.
+ // TODO: Fold to an "error".
+ if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
+ return nullptr;
+
+ resultShape.clear();
+ std::copy(tmpShape.begin(), tmpShape.end(), std::back_inserter(resultShape));
+ }
Builder builder(getContext());
return builder.getIndexTensorAttr(resultShape);
diff --git a/mlir/lib/Dialect/Traits.cpp b/mlir/lib/Dialect/Traits.cpp
index a7aa25eae2644..6e62a33037eb8 100644
--- a/mlir/lib/Dialect/Traits.cpp
+++ b/mlir/lib/Dialect/Traits.cpp
@@ -84,7 +84,7 @@ bool OpTrait::util::getBroadcastedShape(ArrayRef<int64_t> shape1,
if (ShapedType::isDynamic(*i1) || ShapedType::isDynamic(*i2)) {
// One or both dimensions is unknown. Follow TensorFlow behavior:
// - If either dimension is greater than 1, we assume that the program is
- // correct, and the other dimension will be broadcast to match it.
+ // correct, and the other dimension will be broadcasted to match it.
// - If either dimension is 1, the other dimension is the output.
if (*i1 > 1) {
*iR = *i1;
diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir
index cf439c9c1b854..9ed4837a2fe7e 100644
--- a/mlir/test/Dialect/Shape/canonicalize.mlir
+++ b/mlir/test/Dialect/Shape/canonicalize.mlir
@@ -86,6 +86,19 @@ func.func @broadcast() -> !shape.shape {
// -----
+// Variadic case including extent tensors.
+// CHECK-LABEL: @broadcast_variadic
+func.func @broadcast_variadic() -> !shape.shape {
+ // CHECK: shape.const_shape [7, 2, 10] : !shape.shape
+ %0 = shape.const_shape [2, 1] : tensor<2xindex>
+ %1 = shape.const_shape [7, 2, 1] : tensor<3xindex>
+ %2 = shape.const_shape [1, 10] : tensor<2xindex>
+ %3 = shape.broadcast %0, %1, %2 : tensor<2xindex>, tensor<3xindex>, tensor<2xindex> -> !shape.shape
+ return %3 : !shape.shape
+}
+
+// -----
+
// Rhs is a scalar.
// CHECK-LABEL: func @f
func.func @f(%arg0 : !shape.shape) -> !shape.shape {
|
for (auto next : adaptor.getShapes().drop_front()) { | ||
if (!next) | ||
return nullptr; | ||
auto nextShape = llvm::to_vector<6>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the getBroadcastedShape
implementation shape vector size is hardcoded to 6
, so I did it similarly here. Does it make sense? Looks like an arbitrary value from the outside.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, semi. If I recall it was either the default elsewhere in an ML framework where this was used or the max rank along set of ML models. But it is a bit arbitrary. Elsewhere folks also use the default of SmallVector. (The latter is probably a little bit more arbitrary, but neither is very fine tuned).
mlir/lib/Dialect/Shape/IR/Shape.cpp
Outdated
resultShape.clear(); | ||
std::copy(tmpShape.begin(), tmpShape.end(), std::back_inserter(resultShape)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed getBroadcastedShape
implementation and I'm not sure if it's the best way to handle Vectors/Shapes here, so a penny for your thoughts!
✅ With the latest revision this PR passed the C/C++ code formatter. |
c4a815d
to
c3cf613
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good thanks
for (auto next : adaptor.getShapes().drop_front()) { | ||
if (!next) | ||
return nullptr; | ||
auto nextShape = llvm::to_vector<6>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, semi. If I recall it was either the default elsewhere in an ML framework where this was used or the max rank along set of ML models. But it is a bit arbitrary. Elsewhere folks also use the default of SmallVector. (The latter is probably a little bit more arbitrary, but neither is very fine tuned).
mlir/lib/Dialect/Shape/IR/Shape.cpp
Outdated
SmallVector<int64_t, 6> resultShape; | ||
resultShape.clear(); | ||
std::copy(firstShape.begin(), firstShape.end(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is firstShape needed vs directly initializing resultShape?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it isn't needed here - updated!
c3cf613
to
b259199
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG, modulo checking formatting.
return nullptr; | ||
|
||
resultShape.clear(); | ||
std::copy(tmpShape.begin(), tmpShape.end(), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this what clang-format produced?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jpienaar Yes, that's correct - it was produced by a clang-format
. Here's another place where std::copy
is formatted the same way:
llvm-project/clang/include/clang/Lex/MacroInfo.h
Lines 536 to 537 in 74ca579
std::copy(Overrides.begin(), Overrides.end(), | |
reinterpret_cast<ModuleMacro **>(this + 1)); |
Hi!
As the title says, this PR adds support for >2 arguments in
shape.broadcast
folder by sequentially callinggetBroadcastedShape
.