Skip to content

Commit

Permalink
[DataPartition] Support paritioning if results.
Browse files Browse the repository at this point in the history
  • Loading branch information
htyu committed Dec 20, 2024
1 parent 67f51cc commit 8c8e02d
Showing 1 changed file with 56 additions and 8 deletions.
64 changes: 56 additions & 8 deletions lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,12 +168,11 @@ void getForwardSliceToPartition(Value root, unsigned dim, int sliceSize,
if (op->getNumResults() > 0)
seen.insert(op->getResult(0));
if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (seen.count(operand.get())) {
queue.push_back(forOp->getResult(operand.getOperandNumber()));
forwardSlice.insert(forOp);
}
auto parentOp = yieldOp->getParentOp();
for (OpOperand &operand : yieldOp->getOpOperands()) {
if (seen.count(operand.get())) {
queue.push_back(parentOp->getResult(operand.getOperandNumber()));
forwardSlice.insert(parentOp);
}
}
}
Expand Down Expand Up @@ -543,7 +542,56 @@ Operation *sliceOp(Operation *op, int offset,
mappings.map(regionArg, newRegionArg);
reverseMappings.map(newRegionArg, regionArg);
}

} else if (auto ifOp = dyn_cast<scf::IfOp>(op)) {
// Slice the yield op and update if results
auto thenYieldOp = ifOp.thenYield();
auto elseYieldOp = ifOp.elseYield();
auto newThenYieldOp = sliceOp(thenYieldOp, offset, builder, mappings,
reverseMappings, partitionScheme);
sliceOp(elseYieldOp, offset, builder, mappings, reverseMappings,
partitionScheme);
assert(newThenYieldOp->getNumOperands() > ifOp->getNumResults() &&
"no need to slice if op");
// Clone ifOp with updated results but re-use the original regions.
builder.setInsertionPoint(op);
SmallVector<Type, 4> newResultTypes;
for (auto thenResult : thenYieldOp.getResults()) {
newResultTypes.push_back(thenResult.getType());
}
auto newIfOp = builder.create<scf::IfOp>(ifOp.getLoc(), newResultTypes,
ifOp.getCondition());
// Move the original regions to the cloned operation.
newIfOp.getThenRegion().takeBody(ifOp.getThenRegion());
newIfOp.getElseRegion().takeBody(ifOp.getElseRegion());
newOp = newIfOp;
newIfOp->setAttrs(ifOp->getAttrs());
partitionScheme.ops.insert(newIfOp);
ifOp->setAttr("to_be_removed", builder.getUnitAttr());

// Replace ifOp with newIfOp
for (unsigned i = 0; i < ifOp.getNumResults(); ++i)
ifOp.getResult(i).replaceAllUsesWith(newIfOp.getResult(i));

// Map if results based on the mapping for yield
for (auto &v : thenYieldOp->getOpOperands()) {
auto newV = mappings.lookupOrNull(v.get());
if (newV) {
int operandIndex = v.getOperandNumber();
// find the corresponding operand index of newV in newYieldOp
int newOperandIndex = -1;
for (int i = 0; i < newThenYieldOp->getNumOperands(); ++i) {
if (newThenYieldOp->getOperand(i) == newV) {
newOperandIndex = i;
break;
}
}
assert(newOperandIndex >= 0 && "newV not found in newYieldOp");
auto newResult = newIfOp.getResult(operandIndex);
auto newSlicedResult = newIfOp.getResult(newOperandIndex);
mappings.map(newResult, newSlicedResult);
reverseMappings.map(newSlicedResult, newResult);
}
}
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
int num = yieldOp.getNumOperands();
for (int i = 0; i < num; i++) {
Expand All @@ -565,7 +613,7 @@ Operation *sliceOp(Operation *op, int offset,
newOp->walk(
[&](Operation *childOp) { setAsyncTaskIds(childOp, sliceTaskIds); });
} else {
llvm_unreachable("unsupported value type");
llvm_unreachable("unsupported op type");
}

LLVM_DEBUG({
Expand Down

0 comments on commit 8c8e02d

Please sign in to comment.