Skip to content

Commit e0b92bb

Browse files
committed
Drop failure case for stablehlo.dynamic_broadcast_in_dim
The failure to broadcast dynamically makes the assumption the input dynamic shape could be expanded by being `1`. This should be handled by an earlier trasform to materialize a known broadcast if we intend to support both cases.
1 parent 4bf77d2 commit e0b92bb

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

stablehlo/conversions/linalg/tests/miscellaneous.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,24 @@ func.func @constant() -> tensor<i32> {
206206

207207
// -----
208208

209+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d1)>
210+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0, d1)>
211+
212+
// CHECK-LABEL: @dynamic_broadcast
213+
// CHECK-SAME: %[[VAL_0:[a-zA-Z0-9_]*]]
214+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
215+
// CHECK: %[[EXTRACT:.+]] = tensor.extract %arg1[%[[C1]]] : tensor<2xi32>
216+
// CHECK: %[[CAST:.+]] = arith.index_cast %[[EXTRACT]] : i32 to index
217+
// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[CAST]]) : tensor<1x?xf32>
218+
// CHECK: %[[GENERIC:.+]] = linalg.generic {indexing_maps = [#[[MAP0]], #[[MAP1]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_0]] : tensor<?xf32>) outs(%[[EMPTY]] : tensor<1x?xf32>)
219+
func.func public @dynamic_broadcast(%arg0: tensor<?xf32>, %arg1 : tensor<2xi32>) -> (tensor<1x?xf32>) {
220+
%c = stablehlo.constant dense<1> : tensor<1xi32>
221+
%4 = stablehlo.dynamic_broadcast_in_dim %arg0, %arg1, dims = [1] : (tensor<?xf32>, tensor<2xi32>) -> tensor<1x?xf32>
222+
return %4 : tensor<1x?xf32>
223+
}
224+
225+
// -----
226+
209227
// CHECK-LABEL: func @elided_constant
210228
// CHECK: %[[CONSTANT:.*]] = arith.constant dense_resource<__elided__> : tensor<1024xf32>
211229
func.func @elided_constant() -> tensor<1024xf32> {

stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,8 +678,7 @@ struct HloDynamicBroadcastInDimConverter final
678678
// Use static type info.
679679
auto bcastDims = op.getBroadcastDimensions();
680680
for (auto [idx, dim] : llvm::enumerate(operandType.getShape())) {
681-
if (ShapedType::isDynamic(dim)) continue;
682-
681+
// We can assume if the input is dynamic it is not expanding.
683682
bool isExpanding = dim == 1;
684683
dimExprs[idx] = isExpanding ? rewriter.getAffineConstantExpr(0)
685684
: rewriter.getAffineDimExpr(bcastDims[idx]);

0 commit comments

Comments
 (0)