Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LinalgExt] Drop the unit dims on scatter ops 2/3 #19450

Merged
merged 3 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,73 @@ struct FoldAttentionWithProducerReshapeByExpansion final
linalg::ControlFusionFn controlFoldingReshapes;
};

/// Remove the unit dims from `iree_linalg_ext.scatter` 's `update` operand.
/// The dims in `update` between the batch dims and the continuous slice
/// represent the indexed dimensions. Remove the leading unit dims from the
/// indexed dims.
struct FoldScatterNonIterationUnitDims final
: public OpRewritePattern<ScatterOp> {
FoldScatterNonIterationUnitDims(MLIRContext *context,
linalg::ControlDropUnitDims options,
PatternBenefit benefit = 1)
: OpRewritePattern<ScatterOp>(context, benefit),
options(std::move(options)) {}

LogicalResult matchAndRewrite(ScatterOp scatterOp,
PatternRewriter &rewriter) const override {
if (options.rankReductionStrategy !=
linalg::ControlDropUnitDims::RankReductionStrategy::
ReassociativeReshape) {
return rewriter.notifyMatchFailure(
scatterOp, "Only reassociative reshape strategy supported");
}
llvm::SmallVector<unsigned> canDrop = options.controlFn(scatterOp);
const ArrayRef<int64_t> updateShape = scatterOp.getUpdateType().getShape();

// Find the number of leading unit dimensions
int64_t rankOfContiguousSlice =
scatterOp.getOriginalType().getRank() - scatterOp.getIndexDepth();
ArrayRef<int64_t> indexedDims =
scatterOp.getUpdateSliceShape().drop_back(rankOfContiguousSlice);
int64_t numDimsToDrop =
llvm::find_if(indexedDims, [](int64_t val) { return val != 1; }) -
scatterOp.getUpdateSliceShape().begin() - 1;

int64_t batchRank = scatterOp.getBatchRank();
llvm::erase_if(canDrop, [&](unsigned dimPos) {
return dimPos < batchRank || dimPos > batchRank + numDimsToDrop;
});
if (canDrop.empty()) {
return failure();
}

SmallVector<int64_t> droppedUpdateShape;
droppedUpdateShape.reserve(updateShape.size() - canDrop.size());
for (auto [idx, dimLen] : llvm::enumerate(updateShape)) {
if (!llvm::is_contained(canDrop, idx)) {
droppedUpdateShape.push_back(dimLen);
}
}

auto reassoc =
getReassociationIndicesForCollapse(updateShape, droppedUpdateShape);
assert(reassoc.has_value() && "expected reassociation to be valid");
auto collapseOp = rewriter.create<tensor::CollapseShapeOp>(
scatterOp.getLoc(),
RankedTensorType::get(droppedUpdateShape,
scatterOp.getUpdateType().getElementType()),
scatterOp.getUpdates(), reassoc.value());

rewriter.modifyOpInPlace(scatterOp, [&]() {
scatterOp.setOperand(ScatterOp::kUpdatesOpNum, collapseOp.getResult());
});
return success();
}

private:
linalg::ControlDropUnitDims options;
IanWood1 marked this conversation as resolved.
Show resolved Hide resolved
};

} // namespace

/// Return the `reassociation` indices to use to collapse the operand when the
Expand Down Expand Up @@ -708,4 +775,14 @@ void populateFoldReshapeOpsByExpansionPatterns(
patterns.getContext(), controlFoldingReshapes);
}

SmallVector<unsigned> defaultControlDropUnitDims(Operation *op) {
auto fusionOp = cast<LinalgFusionOpInterface>(op);
return llvm::to_vector(llvm::seq<unsigned>(0, fusionOp.getNumLoops()));
}

void populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options) {
patterns.add<FoldScatterNonIterationUnitDims>(patterns.getContext(), options);
}

} // namespace mlir::iree_compiler::IREE::LinalgExt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@ void populateBubbleTransposeFromLinalgExtOps(
RewritePatternSet &patterns,
const linalg::ControlFusionFn &controlFusionFn);

/// Default function to drop unit dims for for linalgext ops.
SmallVector<unsigned> defaultControlDropUnitDims(Operation *op);

/// Drop unit extent dims from linalg ext ops
void populateFoldUnitExtentDimsPatterns(
RewritePatternSet &patterns, const linalg::ControlDropUnitDims &options);

/// Helper struct to hold the results of collapsing an operation.
struct CollapseResult {
SmallVector<Value> results;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
//===----------------------------------------------------------------------===//

#include "iree/compiler/Dialect/Flow/Transforms/RegionOpUtils.h"
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
#include "iree/compiler/Dialect/LinalgExt/Transforms/Transforms.h"
#include "iree/compiler/Dialect/Stream/IR/StreamTypes.h"
#include "iree/compiler/Dialect/Util/Analysis/Explorer.h"
#include "iree/compiler/DispatchCreation/Passes.h"
Expand Down Expand Up @@ -151,9 +153,14 @@ void FoldUnitExtentDimsPass::runOnOperation() {
if (!IREE::Flow::isNonNullAndOutsideDispatch(op)) {
return SmallVector<unsigned>{};
}
if (isa<IREE::LinalgExt::LinalgExtOp>(op)) {
return IREE::LinalgExt::defaultControlDropUnitDims(op);
}
return defaultFn(op);
};
linalg::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns, options);
IREE::LinalgExt::populateFoldUnitExtentDimsPatterns(foldUnitDimsPatterns,
options);
linalg::populateMoveInitOperandsToInputPattern(foldUnitDimsPatterns);
if (failed(
applyPatternsGreedily(moduleOp, std::move(foldUnitDimsPatterns)))) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,47 @@ module @fold_stream_parameter {
// CHECK: util.global private mutable @[[GLOBAL:.+]] = #stream.parameter.named<"module"::"global"> : tensor<10xf32>
// CHECK: util.func public @fold_stream_parameter
// CHECK: %[[LOAD:.+]] = util.global.load @[[GLOBAL]] : tensor<10xf32>

// -----

util.func public @scatter0(%arg0: tensor<?x1x2x16x4x128xf16>, %arg1: tensor<?x1xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
%0 = iree_linalg_ext.scatter dimension_map = [0] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x2x16x4x128xf16>, tensor<?x1xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
^bb0(%arg3: f16, %arg4: f16):
iree_linalg_ext.yield %arg3 : f16
} -> tensor<?x2x16x4x128xf16>
util.return %0 : tensor<?x2x16x4x128xf16>
}
// CHECK-LABEL: func public @scatter0
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
// CHECK-SAME: to tensor<?x2x16x4x128xf16>
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%[[COLLAPSE]]

// -----

util.func public @scatter1(%arg0: tensor<?x1x1x16x4x128xf16>, %arg1: tensor<?x2xi32>, %arg2: tensor<?x2x16x4x128xf16>) -> tensor<?x2x16x4x128xf16> {
%0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor<?x1x1x16x4x128xf16>, tensor<?x2xi32>) outs(%arg2 : tensor<?x2x16x4x128xf16>) {
^bb0(%arg3: f16, %arg4: f16):
iree_linalg_ext.yield %arg3 : f16
} -> tensor<?x2x16x4x128xf16>
util.return %0 : tensor<?x2x16x4x128xf16>
}
// CHECK-LABEL: func public @scatter1
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape
// CHECK-SAME: to tensor<?x16x4x128xf16>
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
// CHECK-SAME: ins(%[[COLLAPSE]]

// -----

// TODO: remove other unit dims.
util.func public @scatter_noop(%arg0: tensor<1x?x1x1x4x128xf16>, %arg1: tensor<1x?x1x2xi32>, %arg2: tensor<?x2x1x4x128xf16>) -> tensor<?x2x1x4x128xf16> {
%0 = iree_linalg_ext.scatter dimension_map = [0, 1] unique_indices(true) ins(%arg0, %arg1 : tensor<1x?x1x1x4x128xf16>, tensor<1x?x1x2xi32>) outs(%arg2 : tensor<?x2x1x4x128xf16>) {
^bb0(%arg3: f16, %arg4: f16):
iree_linalg_ext.yield %arg3 : f16
} -> tensor<?x2x1x4x128xf16>
util.return %0 : tensor<?x2x1x4x128xf16>
}
// CHECK-LABEL: func public @scatter_noop
// CHECK-NOT: tensor.collapse_shape
// CHECK: %[[SCATTER:.+]] = iree_linalg_ext.scatter
Loading