diff --git a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp index 383a750eed01..27000cc280b7 100644 --- a/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp +++ b/lib/Conversion/SCFToCalyx/SCFToCalyx.cpp @@ -1986,6 +1986,26 @@ class BuildParGroups : public calyx::FuncOpPartialLoweringPattern { rewriter.create(newParOp.getLoc()); rewriter.replaceOp(scfParOp, newParOp); + + auto containsIfOp = [](scf::ParallelOp parOp) -> bool { + bool hasIfOp = false; + parOp.walk([&](scf::IfOp ifOp) { + hasIfOp = true; + return WalkResult::interrupt(); + }); + return hasIfOp; + }; + if (containsIfOp(newParOp)) { + auto *context = newParOp.getContext(); + RewritePatternSet patterns(newParOp.getContext()); + scf::IfOp::getCanonicalizationPatterns(patterns, context); + if (failed( + applyPatternsGreedily(newParOp->getParentOfType(), + std::move(patterns)))) { + return failure(); + } + } + return success(); } }; @@ -2081,7 +2101,6 @@ class BuildControl : public calyx::FuncOpPartialLoweringPattern { if (walkResult.wasInterrupted()) return failure(); - } else if (auto *forSchedPtr = std::get_if(&group); forSchedPtr) { auto forOp = forSchedPtr->forOp; @@ -2780,6 +2799,8 @@ void SCFToCalyxPass::runOnOperation() { /// This pass inlines scf.ExecuteRegionOp's by adding control-flow. addGreedyPattern(loweringPatterns); + /// Partial evaluate the scf.ParallelOp and apply the scf.IfOp + /// canonicalization optionally. addOncePattern(loweringPatterns, patternState, funcMap, *loweringState); diff --git a/test/Conversion/SCFToCalyx/convert_simple.mlir b/test/Conversion/SCFToCalyx/convert_simple.mlir index c3edd0a975da..d3ab09f536b4 100644 --- a/test/Conversion/SCFToCalyx/convert_simple.mlir +++ b/test/Conversion/SCFToCalyx/convert_simple.mlir @@ -581,3 +581,52 @@ module { return } } + +// Test lower scf.parallel when there is a nested scf.if that can be +// canonicalized. See: https://github.com/llvm/circt/issues/8086 + +// ----- + +// CHECK: calyx.control { +// CHECK: calyx.seq { +// CHECK: calyx.par { +// CHECK: calyx.seq { +// CHECK: calyx.enable @bb0_0 +// CHECK: calyx.enable @bb0_1 +// CHECK: } +// CHECK: calyx.seq { +// CHECK: calyx.enable @bb0_2 +// CHECK: calyx.enable @bb0_3 +// CHECK: } +// CHECK: calyx.seq { +// CHECK: calyx.enable @bb0_4 +// CHECK: calyx.enable @bb0_5 +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } +// CHECK: } + +module { + func.func @main(%arg0 : memref<6xi32>, %arg1 : memref<6xi32>) { + %c2 = arith.constant 2 : index + %c1 = arith.constant 1 : index + %c3 = arith.constant 3 : index + %c0 = arith.constant 0 : index + scf.parallel (%arg2, %arg3) = (%c0, %c0) to (%c3, %c2) step (%c1, %c1) { + %4 = arith.shli %arg3, %c2 : index + %5 = arith.addi %4, %arg2 : index + %6 = memref.load %arg0[%5] : memref<6xi32> + %7 = arith.shli %arg2, %c1 : index + %8 = arith.addi %7, %arg3 : index + %9 = arith.remui %8, %c2 : index + %10 = arith.cmpi eq, %9, %c0 : index + scf.if %10 { + memref.store %6, %arg1[%8] : memref<6xi32> + scf.yield + } + scf.reduce + } + return + } +}