Skip to content

Commit

Permalink
[DispatchCreation] Changes to dispatch region in preparation for hori…
Browse files Browse the repository at this point in the history
…zontal fusion changes. (#19876)

Current dispatch region formation handles consumer fusion by making the
consumer the root of the DAG moved into dispatches. For cases where we
have more than one consumer that dont have a direct dependency, this
approach does not work. This changes dispatch region formation to keep
the root operation as is, and move in consumers into the dispatch
iteratively. This required a few additional changes

1) Move the method `moveOperandDefs` into a utility function.
2) Changes to how the dynamic dims of results of `flow.dispatch.region`
created are resolved.

---------

Signed-off-by: MaheshRavishankar <[email protected]>
Signed-off-by: Ian Wood <[email protected]>
Co-authored-by: Ian Wood <[email protected]>
  • Loading branch information
MaheshRavishankar and IanWood1 authored Feb 13, 2025
1 parent eb58f82 commit 78ec7f2
Show file tree
Hide file tree
Showing 9 changed files with 425 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void TensorDimTrackingRewriter::notifyOperationErased(Operation *op) {
void TensorDimTrackingRewriter::notifyOperationInserted(Operation *op,
InsertPoint previous) {
IRRewriter::Listener::notifyOperationInserted(op, previous);
if (isa<tensor::DimOp>(op))
auto dimOp = dyn_cast<tensor::DimOp>(op);
if (dimOp && isa<OpResult>(dimOp.getSource()))
dimOps.insert(op);
}

Expand All @@ -60,16 +61,21 @@ LogicalResult simplifyDimOps(RewriterBase &rewriter,
std::optional<int64_t> idx = dimOp.getConstantIndex();
if (!idx.has_value())
continue;

if (isa<BlockArgument>(dimOp.getSource())) {
continue;
}

// Only DimOps with ranked tensors are supported.
auto tensorType =
llvm::dyn_cast<RankedTensorType>(dimOp.getSource().getType());
if (!tensorType)
continue;

OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(dimOp);
if (!tensorType.isDynamicDim(*idx)) {
// Rewrite static dimension with constant.
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(dimOp);
int64_t size = tensorType.getShape()[*idx];
rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(dimOp, size);
continue;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,18 +266,8 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value,
// Value is an OpResult.
Operation *op = value.getDefiningOp();
OpResult opResult = llvm::cast<OpResult>(value);
b.setInsertionPoint(op);

// Case 3: Value is tied. Reify the dimensions of the tied operand.
auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op);
if (tiedOp) {
Value tiedOperand = tiedOp.getTiedResultOperand(value);
if (tiedOperand && tiedOperand.getType() == value.getType())
return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims,
createTensorDimOps);
}

// Case 4: Query ShapeAwareOpInterface.
// Case 3: Query ShapeAwareOpInterface.
auto shapeAwareOp = dyn_cast<IREE::Util::ShapeAwareOpInterface>(op);
if (shapeAwareOp) {
ValueRange dims =
Expand All @@ -286,6 +276,15 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value,
return success();
}

// Case 4: Value is tied. Reify the dimensions of the tied operand.
auto tiedOp = dyn_cast<IREE::Util::TiedOpInterface>(op);
if (tiedOp) {
Value tiedOperand = tiedOp.getTiedResultOperand(value);
if (tiedOperand && tiedOperand.getType() == value.getType())
return reifyDynamicResultDimsImpl(b, tiedOperand, dynamicDims,
/*createTensorDimOps=*/true);
}

// Case 5: Query ReifyRankedShapedTypeOpInterface.
auto reifyShapeOp = dyn_cast<ReifyRankedShapedTypeOpInterface>(op);
if (reifyShapeOp) {
Expand All @@ -308,8 +307,14 @@ reifyDynamicResultDimsImpl(OpBuilder &b, Value value,
}

/// Reify the dynamic dimensions of the given value.
/// Deprecated. Use `getOptimizedDynamicResultDims` instead.
LogicalResult reifyDynamicResultDims(OpBuilder &b, Value value,
SmallVectorImpl<Value> &dynamicDims) {

OpBuilder::InsertionGuard g(b);
if (auto op = value.getDefiningOp()) {
b.setInsertionPoint(op);
}
return reifyDynamicResultDimsImpl(b, value, dynamicDims,
/*createTensorDimOps=*/true);
}
Expand Down Expand Up @@ -473,7 +478,7 @@ movePrecedingOpsIntoDispatchRegion(RewriterBase &rewriter,
rewriter.setInsertionPoint(target);
SmallVector<Value> &dims =
dispatchOpNewResultsDynamicDims.emplace_back();
if (failed(reifyDynamicResultDims(rewriter, result, dims))) {
if (failed(getOptimizedDynamicResultDims(rewriter, result, dims))) {
return target->emitOpError(
"failed to reify dynamic dims of result to be yielded from "
"dispatch region");
Expand Down Expand Up @@ -554,9 +559,10 @@ moveFollowingOpIntoDispatchRegion(RewriterBase &rewriter, Operation *target,
for (auto [index, result] : llvm::enumerate(target->getResults())) {
replacedValues.push_back(result);
yieldedResults.push_back(clonedTarget->getResult(index));
rewriter.setInsertionPoint(target);
OpBuilder::InsertionGuard g1(rewriter);
rewriter.setInsertionPoint(regionOp);
SmallVector<Value> &dims = dispatchOpNewResultsDynamicDims.emplace_back();
if (failed(reifyDynamicResultDims(rewriter, result, dims))) {
if (failed(getOptimizedDynamicResultDims(rewriter, result, dims))) {
return target->emitOpError(
"failed to reify dynamic dims of result to be yielded from "
"dispatch region");
Expand Down
7 changes: 6 additions & 1 deletion compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,12 @@ isContractionOpSequence(Value yielded) {
/// Recognize an operation that is horizontally fused contraction.
/// TODO: The logic below is quite convoluted. Might be better
/// off having a dedicated operation for this.
bool isaHorizontallyFusedContraction(linalg::LinalgOp linalgOp) {
bool isaHorizontallyFusedContraction(Operation *op) {
auto linalgOp = dyn_cast_or_null<linalg::GenericOp>(op);
if (!linalgOp) {
return false;
}

if (linalgOp->getNumResults() == 1) {
return false;
}
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Dialect/LinalgExt/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ bool isGatherlikeOp(Operation *op);
/// Check if a given operation is a horizontally fused contraction operation.
/// The expectation is that the LHS is common, and all the operands are
/// different RHS.
bool isaHorizontallyFusedContraction(linalg::LinalgOp genericOp);
bool isaHorizontallyFusedContraction(Operation *op);

} // namespace mlir::iree_compiler::IREE::LinalgExt
#endif // IREE_COMPILER_DIALECT_LINALGEXT_UTILS_UTILS_H_
Loading

0 comments on commit 78ec7f2

Please sign in to comment.