Skip to content

Commit

Permalink
Fix infinite loop in getSliceToPartition
Browse files Browse the repository at this point in the history
  • Loading branch information
ardaunal committed Dec 19, 2024
1 parent 8706035 commit 91d37e3
Showing 1 changed file with 14 additions and 23 deletions.
37 changes: 14 additions & 23 deletions lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void fixTaskId(triton::FuncOp &funcOp) {
LDBG("resulting");
defOp->dump();
});
}
}
}
if (operand.hasOneUse() &&
!oneVecCoversTheOther(asyncTaskIds, defTaskIds)) {
Expand Down Expand Up @@ -112,19 +112,16 @@ bool needToSlice(Value v, int dim, int size) {
return shape.size() > dim && shape[dim] > size;
}

bool getBackwardSliceToPartition(Value root, unsigned dim, int sliceSize,
void getBackwardSliceToPartition(Value root, unsigned dim, int sliceSize,
SetVector<Operation *> &backwardSlice) {
auto newOpInserted = false;
SmallVector<Value> queue = {root};
while (!queue.empty()) {
auto v = queue.back();
queue.pop_back();
if (!needToSlice(v, dim, sliceSize))
continue;
if (auto op = v.getDefiningOp()) {
auto inserted = backwardSlice.insert(op);
newOpInserted |= inserted;
if (inserted) {
if (backwardSlice.insert(op)) {
if (op->hasTrait<OpTrait::Elementwise>() ||
isa<arith::ConstantOp, arith::ExtSIOp, arith::ExtUIOp,
arith::ExtFOp, BroadcastOp, ExpandDimsOp, MakeRangeOp, SplatOp,
Expand Down Expand Up @@ -153,12 +150,10 @@ bool getBackwardSliceToPartition(Value root, unsigned dim, int sliceSize,
}
}
}
return newOpInserted;
};

bool getForwardSliceToPartition(Value root, unsigned dim, int sliceSize,
void getForwardSliceToPartition(Value root, unsigned dim, int sliceSize,
SetVector<Operation *> &forwardSlice) {
auto newOpInserted = false;
SmallVector<Value> queue = {root};
llvm::SmallDenseSet<Value> seen;
while (!queue.empty()) {
Expand All @@ -178,42 +173,38 @@ bool getForwardSliceToPartition(Value root, unsigned dim, int sliceSize,
if (seen.count(operand.get())) {
queue.push_back(forOp->getResult(operand.getOperandNumber()));
forwardSlice.insert(forOp);
newOpInserted = true;
}
}
}
}
}
}
return newOpInserted;
};

// Compute a closure of all ops originated from or being dependent on by the
// root op.
void getSliceToPartition(Value root, unsigned dim, int sliceSize,
SetVector<Operation *> &slice) {
auto newOpInserted = false;
while (!newOpInserted) {
newOpInserted |= getBackwardSliceToPartition(root, dim, sliceSize, slice);
size_t prevSize = slice.size();
do {
prevSize = slice.size();
getBackwardSliceToPartition(root, dim, sliceSize, slice);
SetVector<Operation *> forwardSlice;
newOpInserted |=
getForwardSliceToPartition(root, dim, sliceSize, forwardSlice);
getForwardSliceToPartition(root, dim, sliceSize, forwardSlice);
slice.insert(forwardSlice.begin(), forwardSlice.end());
for (auto op : forwardSlice) {
if (op->hasTrait<OpTrait::Elementwise>() ||
isa<tt::StoreOp, ExperimentalDescriptorStoreOp>(op)) {
for (OpOperand &operand : op->getOpOperands()) {
newOpInserted |=
getBackwardSliceToPartition(operand.get(), dim, sliceSize, slice);
getBackwardSliceToPartition(operand.get(), dim, sliceSize, slice);
}
} else if (auto dotOp = dyn_cast<nvidia_gpu::WarpGroupDotOp>(op)) {
newOpInserted |= getBackwardSliceToPartition(
dim == 0 ? dotOp.getA() : dotOp.getB(), dim, sliceSize, slice);
newOpInserted |=
getBackwardSliceToPartition(dotOp.getC(), dim, sliceSize, slice);
getBackwardSliceToPartition(dim == 0 ? dotOp.getA() : dotOp.getB(), dim,
sliceSize, slice);
getBackwardSliceToPartition(dotOp.getC(), dim, sliceSize, slice);
}
}
}
} while (prevSize != slice.size());
}

struct DataPartitionScheme {
Expand Down

0 comments on commit 91d37e3

Please sign in to comment.