diff --git a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp index 2e76e9fe5c6b..a8611dc4e9f7 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TileDispatchUsingInterface.cpp @@ -65,6 +65,19 @@ getDistributeLBAndStep(OpBuilder &b, Location loc, OpFoldResult lb, return {distributeLB, distributeStep}; } +// Helper function to change arith.constant to i64 attribute. +static void changeArithCstToI64Attr(OpBuilder &b, + MutableArrayRef constants) { + for (OpFoldResult &val : constants) { + if (auto dyn_cast = llvm::dyn_cast_if_present(val)) { + APInt intVal; + if (matchPattern(dyn_cast, m_ConstantInt(&intVal))) { + val = b.getI64IntegerAttr(intVal.getSExtValue()); + } + } + } +} + //===----------------------------------------------------------------------===// // TileDispatchUsingSCFForOp implementation. //===----------------------------------------------------------------------===// @@ -166,6 +179,13 @@ static SmallVector generateTileLoopNest( loops.push_back(loop); builder.setInsertionPoint(loop.getBody()->getTerminator()); } + + // Update the sizes if it contains arith.index with i64 attrs. + // TODO: tensor.extract_slice is unable to determine the + // result type if arith.constant is present. This is a workaround + // to ensure that the result type is determined. + changeArithCstToI64Attr(builder, sizes); + return loops; } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir index ee4c519aa245..fe2ddfd38e59 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/tile_and_distribute_sort.mlir @@ -40,7 +40,6 @@ hal.executable private @static_3d_sort { // CHECK: %[[DIM_X:.+]] = gpu.block_dim x // CHECK: scf.for %[[IV_X:.+]] = %[[TID_X]] to %{{.+}} step %[[DIM_X]] // CHECK: %[[DEST:.+]] = memref.subview %[[WG_OUTPUT]][0, 0, %[[IV_X]]] -// CHECK: %[[CAST:.+]] = memref.cast %[[DEST]] // CHECK: iree_linalg_ext.sort // CHECK-SAME: dimension(1) -// CHECK-SAME: outs(%[[CAST]] +// CHECK-SAME: outs(%[[DEST]]