diff --git a/test/TritonGPU/loop-pipeline.mlir b/test/TritonGPU/loop-pipeline.mlir index 90b310425d31..2b9eb6ac0803 100644 --- a/test/TritonGPU/loop-pipeline.mlir +++ b/test/TritonGPU/loop-pipeline.mlir @@ -84,8 +84,13 @@ // AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %{{.*}}#4 // AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#5 // AMD: %[[MULF_29:.*]] = arith.mulf %[[LOCAL_LOAD_28]], %{{.*}} -// AMD: %[[DOT_30:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[MULF_29]], %{{.*}}#2 -// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_26]], %[[DOT_30]], %{{.*}}#2 +// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_26]] +// AMD: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[MULF_29]], %{{.*}}#2 +// AMD: scf.yield %[[DOT_32]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#2 +// AMD: } +// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_26]], %[[IF_30]], %{{.*}}#2 // AMD: triton_gpu.local_dealloc %{{.*}} // AMD: triton_gpu.local_dealloc %{{.*}} @@ -200,8 +205,10 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index, // AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]] // AMD: scf.yield // AMD-COUNT-2: triton_gpu.local_load +// AMD: %[[IF1:.*]] = scf.if // AMD: %[[DOT1:.*]] = tt.dot -// AMD: %[[SEL1:.*]] = arith.select %{{.*}}, %[[DOT1]], %[[FOR]]#2 +// AMD: scf.yield %[[DOT1]] +// AMD: %[[SEL1:.*]] = arith.select %{{.*}}, %[[IF1]], %[[FOR]]#2 // AMD-COUNT-2: triton_gpu.local_dealloc // AMD: scf.yield %[[SEL1]] @@ -407,19 +414,29 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index, // AMD: %[[CMPI_29:.*]] = arith.cmpi sge, %[[ADDI_28]], %{{.*}} // AMD: %[[LOCAL_LOAD_30:.*]] = triton_gpu.local_load %{{.*}}#4 // AMD: %[[LOCAL_LOAD_31:.*]] = triton_gpu.local_load %{{.*}}#5 -// AMD: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_30]], %[[LOCAL_LOAD_31]], %{{.*}}#0 +// AMD: %[[IF_32:.*]] = scf.if %[[CMPI_27]] +// AMD: %[[DOT_43:.*]] = tt.dot %[[LOCAL_LOAD_30]], %[[LOCAL_LOAD_31]], %{{.*}}#0 +// AMD: scf.yield %[[DOT_43]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#0 +// AMD: } // AMD: %[[ADDI_33:.*]] = arith.addi %{{.*}}#3, %{{.*}} // AMD: %[[CMPI_34:.*]] = arith.cmpi slt, %[[ADDI_33]], %{{.*}} // AMD: %[[SELECT_35:.*]] = arith.select %[[CMPI_34]], %[[ADDI_33]], %{{.*}} -// AMD: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_35]], %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_35]], %{{.*}}, %{{.*}}] // AMD: triton_gpu.local_store %{{.*}}#6, %[[MEMDESC_SUBVIEW_36]] -// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_35]], %{{.*}}, %{{.*}}] +// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_35]], %{{.*}}, %{{.*}}] // AMD: triton_gpu.local_store %{{.*}}#7, %[[MEMDESC_SUBVIEW_37]] -// AMD: %[[SELECT_38:.*]] = arith.select %[[CMPI_27]], %[[DOT_32]], %{{.*}}#0 +// AMD: %[[SELECT_38:.*]] = arith.select %[[CMPI_27]], %[[IF_32]], %{{.*}}#0 // AMD: %[[LOCAL_LOAD_39:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_36]] // AMD: %[[LOCAL_LOAD_40:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_37]] -// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_39]], %[[LOCAL_LOAD_40]], %[[SELECT_38]] -// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_29]], %[[DOT_41]], %[[SELECT_38]] +// AMD: %[[IF_41:.*]] = scf.if %[[CMPI_29]] +// AMD: %[[DOT_43:.*]] = tt.dot %[[LOCAL_LOAD_39]], %[[LOCAL_LOAD_40]], %[[SELECT_38]] +// AMD: scf.yield %[[DOT_43]] +// AMD: } else { +// AMD: scf.yield %[[SELECT_38]] +// AMD: } +// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_29]], %[[IF_41]], %[[SELECT_38]] // AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]] // AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]] @@ -981,7 +998,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // AMD: %[[CMPI_24:.*]] = arith.cmpi sge, %[[ADDI_23]], %{{.*}} // AMD: %[[LOCAL_LOAD_25:.*]] = triton_gpu.local_load %{{.*}}#4 // AMD: %[[LOCAL_LOAD_26:.*]] = triton_gpu.local_load %{{.*}}#5 -// AMD: %[[DOT_27:.*]] = tt.dot %[[LOCAL_LOAD_25]], %[[LOCAL_LOAD_26]], %{{.*}}#0 +// AMD: %[[IF_27:.*]] = scf.if %[[CMPI_22]] +// AMD: %[[DOT_47:.*]] = tt.dot %[[LOCAL_LOAD_25]], %[[LOCAL_LOAD_26]], %{{.*}}#0 +// AMD: scf.yield %[[DOT_47]] +// AMD: } else { +// AMD: scf.yield %{{.*}}#0 +// AMD: } // AMD: %[[ADDPTR_28:.*]] = tt.addptr %{{.*}}#1, %{{.*}} // AMD: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_24]] // AMD: %[[LOAD_30:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_29]] @@ -998,11 +1020,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // AMD: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_40]] // AMD: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}] // AMD: triton_gpu.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_41]] -// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_22]], %[[DOT_27]], %{{.*}}#0 +// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_22]], %[[IF_27]], %{{.*}}#0 // AMD: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_40]] // AMD: %[[LOCAL_LOAD_44:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_41]] -// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_43]], %[[LOCAL_LOAD_44]], %[[SELECT_42]] -// AMD: %[[SELECT_46:.*]] = arith.select %[[CMPI_24]], %[[DOT_45]], %[[SELECT_42]] +// AMD: %[[IF_45:.*]] = scf.if %[[CMPI_24]] +// AMD: %[[DOT_47:.*]] = tt.dot %[[LOCAL_LOAD_43]], %[[LOCAL_LOAD_44]], %[[SELECT_42]] +// AMD: scf.yield %[[DOT_47]] +// AMD: } else { +// AMD: scf.yield %[[SELECT_42]] +// AMD: } +// AMD: %[[SELECT_46:.*]] = arith.select %[[CMPI_24]], %[[IF_45]], %[[SELECT_42]] // AMD: triton_gpu.local_dealloc %{{.*}} // AMD: triton_gpu.local_dealloc %{{.*}} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp index 565361451d93..5a076a7bd24d 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp @@ -543,6 +543,25 @@ createStreamOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, return allocs; } +static Operation *streamPredication(RewriterBase &rewriter, Operation *op, + Value pred) { + // The epilogue peeling generates a select for the stage output. This causes + // too much register pressure with the loop result and the epilogue-dot in + // regs for the select. Conditionally executing the dot will allow the backend + // to optimize the select away as redundant. + if (auto dotOp = dyn_cast(op)) { + auto loc = dotOp->getLoc(); + auto ifOp = rewriter.create(loc, dotOp.getResult().getType(), + pred, /*withElseRegion=*/true); + auto thenB = ifOp.getThenBodyBuilder(); + auto yield = thenB.create(loc, dotOp.getResult()); + dotOp->moveBefore(yield); + ifOp.getElseBodyBuilder().create(loc, dotOp.getC()); + return ifOp; + } + return tt::predicateOp(rewriter, op, pred); +} + static bool preprocessLoopAndBuildSchedule(scf::ForOp &forOp, int numStages, tt::PipeliningOption &options) { // Schedule the loads and root ops (dot ops) in the loop. This will give us @@ -599,7 +618,7 @@ static bool preprocessLoopAndBuildSchedule(scf::ForOp &forOp, int numStages, s = std::move(schedule); }; options.peelEpilogue = true; - options.predicateFn = tt::predicateOp; + options.predicateFn = streamPredication; options.supportDynamicLoops = true; OpBuilder builder(forOp);