Skip to content

Commit

Permalink
[tilingInterface] Update the tile sizes to i64 attr type (iree-org#17761
Browse files Browse the repository at this point in the history
)

Update the tile sizes to contain i64Attrs instead of arith.constant.
Somehow it's giving dynamic shapes in tensor.extract_slice since the
arith.constant op isn't folded or seen as a constant.

To fix Issue: iree-org#17441
  • Loading branch information
pashu123 authored Jun 28, 2024
1 parent dcba7c5 commit e38cc7f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,19 @@ getDistributeLBAndStep(OpBuilder &b, Location loc, OpFoldResult lb,
return {distributeLB, distributeStep};
}

// Helper function to change arith.constant to i64 attribute.
static void changeArithCstToI64Attr(OpBuilder &b,
MutableArrayRef<OpFoldResult> constants) {
for (OpFoldResult &val : constants) {
if (auto dyn_cast = llvm::dyn_cast_if_present<Value>(val)) {
APInt intVal;
if (matchPattern(dyn_cast, m_ConstantInt(&intVal))) {
val = b.getI64IntegerAttr(intVal.getSExtValue());
}
}
}
}

//===----------------------------------------------------------------------===//
// TileDispatchUsingSCFForOp implementation.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -166,6 +179,13 @@ static SmallVector<scf::ForOp> generateTileLoopNest(
loops.push_back(loop);
builder.setInsertionPoint(loop.getBody()->getTerminator());
}

// Update the sizes if it contains arith.index with i64 attrs.
// TODO: tensor.extract_slice is unable to determine the
// result type if arith.constant is present. This is a workaround
// to ensure that the result type is determined.
changeArithCstToI64Attr(builder, sizes);

return loops;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ hal.executable private @static_3d_sort {
// CHECK: %[[DIM_X:.+]] = gpu.block_dim x
// CHECK: scf.for %[[IV_X:.+]] = %[[TID_X]] to %{{.+}} step %[[DIM_X]]
// CHECK: %[[DEST:.+]] = memref.subview %[[WG_OUTPUT]][0, 0, %[[IV_X]]]
// CHECK: %[[CAST:.+]] = memref.cast %[[DEST]]
// CHECK: iree_linalg_ext.sort
// CHECK-SAME: dimension(1)
// CHECK-SAME: outs(%[[CAST]]
// CHECK-SAME: outs(%[[DEST]]

0 comments on commit e38cc7f

Please sign in to comment.