Skip to content

Commit

Permalink
Only handle non-iteration dims
Browse files Browse the repository at this point in the history
  • Loading branch information
IanWood1 committed Dec 28, 2024
1 parent 612804e commit 98cf839
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -487,12 +487,14 @@ struct FoldAttentionWithProducerReshapeByExpansion final
};

/// Remove the unit dims from `iree_linalg_ext.scatter` 's `update` operand.
/// The `update` tensor is scanned from left to right, starting from the second
/// element. The number of unit dimensions are counted until reaching a non unit
/// dim.
struct FoldScatterUnitDims final : public OpRewritePattern<ScatterOp> {
FoldScatterUnitDims(MLIRContext *context, linalg::ControlDropUnitDims options,
PatternBenefit benefit = 1)
/// The dims in `update` between the batch dims and the continuous slice
/// represent the indexed dimensions. Remove the leading unit dims from the
/// indexed dims.
struct FoldScatterNonIterationUnitDims final
: public OpRewritePattern<ScatterOp> {
FoldScatterNonIterationUnitDims(MLIRContext *context,
linalg::ControlDropUnitDims options,
PatternBenefit benefit = 1)
: OpRewritePattern<ScatterOp>(context, benefit),
options(std::move(options)) {}

Expand All @@ -507,16 +509,18 @@ struct FoldScatterUnitDims final : public OpRewritePattern<ScatterOp> {
llvm::SmallVector<unsigned> canDrop = options.controlFn(scatterOp);
const ArrayRef<int64_t> updateShape = scatterOp.getUpdateType().getShape();

// Find the first `numDimsToDrop` unit dimensions in the update tensor,
// these are the ones that can be dropped.
// Find the number of leading unit dimensions
int64_t rankOfContiguousSlice =
scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth();
ArrayRef<int64_t> indexedDims =
scatterOp.getUpdateSliceShape().drop_back(rankOfContiguousSlice);
int64_t numDimsToDrop =
llvm::find_if(scatterOp.getUpdateSliceShape(),
[](int64_t val) { return val != 1; }) -
updateShape.begin() - 1;
llvm::find_if(indexedDims, [](int64_t val) { return val != 1; }) -
scatterOp.getUpdateSliceShape().begin() - 1;

int64_t batchRank = scatterOp.getBatchRank();
llvm::erase_if(canDrop, [&](unsigned dimPos) {
return dimPos < batchRank || dimPos >= batchRank + numDimsToDrop;
return dimPos < batchRank || dimPos > batchRank + numDimsToDrop;
});
if (canDrop.empty()) {
return failure();
Expand Down Expand Up @@ -777,7 +781,7 @@ SmallVector<unsigned> defaultControlDropUnitDims(Operation *op) {

void populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) {
patterns.add<FoldScatterUnitDims>(patterns.getContext(), options);
patterns.add<FoldScatterNonIterationUnitDims>(patterns.getContext(), options);
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,17 @@ util.func public @scatter1(%arg0: tensor<?x1x1x16x4x128xf16>, %arg1: tensor<?x2x
// CHECK-SAME: to tensor<?x16x4x128xf16>
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%[[COLLAPSE]]

// -----

// TODO: remove other unit dims.
util.func public @scatter_noop(%arg0: tensor<1x?x1x1x4x128xf16>, %arg1: tensor<1x?x1x2xi32>, %arg2: tensor<?x2x1x4x128xf16>) -> tensor<?x2x1x4x128xf16> {
%0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor<1x?x1x1x4x128xf16>, tensor<1x?x1x2xi32>) outs(%arg2 : tensor<?x2x1x4x128xf16>) {
^bb0(%arg3: f16, %arg4: f16):
iree_linalg_ext.yield %arg3 : f16
} -> tensor<?x2x1x4x128xf16>
util.return %0 : tensor<?x2x1x4x128xf16>
}
// CHECK-LABEL: func public @scatter_noop
// CHECK-NOT: tensor.collapse_shape
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter

0 comments on commit 98cf839

Please sign in to comment.