Skip to content

Commit

Permalink
Review comment - nit
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhishek-Varma committed Nov 29, 2024
1 parent 5257be1 commit 327fc7d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,14 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
do {
Value::user_range users = computeOp->getResult(0).getUsers();
if (!llvm::hasSingleElement(users)) {
computeOp->emitOpError("Expected only one user of the compute op");
LLVM_DEBUG(llvm::dbgs()
<< "Expected only one user of the compute op\n");
return signalPassFailure();
}

Operation *terminatorStoreOp = *(users.begin());
Operation *candidateSliceOp = *(users.begin());
if (!(isa<tensor::InsertSliceOp, tensor::ParallelInsertSliceOp>(
terminatorStoreOp))) {
candidateSliceOp))) {
computeOp = computeOp->getParentOfType<LoopLikeOpInterface>();
LLVM_DEBUG(llvm::dbgs()
<< "Going to reattempt fusion because didn't find "
Expand All @@ -94,9 +95,9 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
continue;
}
std::optional<scf::SCFFuseConsumerOfSliceResult> fusedConsumer =
scf::tileAndFuseConsumerOfSlice(rewriter, terminatorStoreOp);
scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp);
if (!fusedConsumer) {
terminatorStoreOp->emitOpError(
candidateSliceOp->emitOpError(
"Failed to fuse any consumer op into the producer");
return signalPassFailure();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,19 +182,6 @@ module {

// -----

func.func @no_user_of_producer(%arg0: tensor<64xf32>) {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c8 = arith.constant 8 : index
scf.for %arg1 = %c0 to %c8 step %c1 {
// expected-error @+1 {{Expected only one user of the compute op}}
%1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%arg0, %arg0 : tensor<64xf32>, tensor<64xf32>) outs(%arg0 : tensor<64xf32>) -> tensor<64xf32>
}
return
}

// -----

func.func @no_consumer_fusion(%arg0: tensor<64xf32>) -> tensor<64xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
Expand Down Expand Up @@ -388,16 +375,6 @@ module {

// -----

func.func @no_user_of_producer(%arg0: tensor<64xf32>) {
scf.forall (%arg1, %arg2) in (1,2) {
// expected-error @+1 {{Expected only one user of the compute op}}
%1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%arg0, %arg0 : tensor<64xf32>, tensor<64xf32>) outs(%arg0 : tensor<64xf32>) -> tensor<64xf32>
}
return
}

// -----

func.func @no_consumer_fusion(%arg0: tensor<64xf32>) -> tensor<64xf32> {
%0 = scf.forall (%arg1, %arg2) in (1,2) shared_outs(%out = %arg0) -> tensor<64xf32> {
%1 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>} ins(%arg0, %arg0 : tensor<64xf32>, tensor<64xf32>) outs(%arg0 : tensor<64xf32>) -> tensor<64xf32>
Expand Down

0 comments on commit 327fc7d

Please sign in to comment.