diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp index fa68a6073b02..50b8e7b7ead1 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateAMDMatmul.cpp @@ -76,9 +76,8 @@ static bool isChainDot(tt::DotOp &dotOp) { bwdOpt.filter = filter; mlir::SetVector bwdSlices; // search backward of the operand 0 of the dot - auto oper0 = dotOp.getOperand(0).getDefiningOp(); - mlir::getBackwardSlice(dyn_cast(oper0), &bwdSlices, bwdOpt); - int i = 0; + mlir::Operation* oper0 = dotOp.getOperand(0).getDefiningOp(); + mlir::getBackwardSlice(oper0, &bwdSlices, bwdOpt); for (Operation *op : bwdSlices) { if (isa(op) && (op != dotOp)) { return true; @@ -88,7 +87,6 @@ static bool isChainDot(tt::DotOp &dotOp) { return false; } - SmallVector warpsPerTile(tt::DotOp dotOp, const ArrayRef shape, int numWarps, SmallVector shapePerWarp) {