Skip to content

Commit

Permalink
[AMD] Predicate tt.dot to override select in pipeline epilogue (#4694)
Browse files Browse the repository at this point in the history
This change puts `tt.dot` in a predicated `scf.if` based on pipeline
peeled iteration.
The result is that the final stage is conditional and the peeling
`select` is optimized away in the backend.
  • Loading branch information
sjw36 committed Sep 10, 2024
1 parent 944f634 commit 845d75a
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 14 deletions.
53 changes: 40 additions & 13 deletions test/TritonGPU/loop-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,13 @@
// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %{{.*}}#4
// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#5
// AMD: %[[MULF_29:.*]] = arith.mulf %[[LOCAL_LOAD_28]], %{{.*}}
// AMD: %[[DOT_30:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[MULF_29]], %{{.*}}#2
// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_26]], %[[DOT_30]], %{{.*}}#2
// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_26]]
// AMD: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[MULF_29]], %{{.*}}#2
// AMD: scf.yield %[[DOT_32]]
// AMD: } else {
// AMD: scf.yield %{{.*}}#2
// AMD: }
// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_26]], %[[IF_30]], %{{.*}}#2
// AMD: triton_gpu.local_dealloc %{{.*}}
// AMD: triton_gpu.local_dealloc %{{.*}}

Expand Down Expand Up @@ -200,8 +205,10 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]]
// AMD: scf.yield
// AMD-COUNT-2: triton_gpu.local_load
// AMD: %[[IF1:.*]] = scf.if
// AMD: %[[DOT1:.*]] = tt.dot
// AMD: %[[SEL1:.*]] = arith.select %{{.*}}, %[[DOT1]], %[[FOR]]#2
// AMD: scf.yield %[[DOT1]]
// AMD: %[[SEL1:.*]] = arith.select %{{.*}}, %[[IF1]], %[[FOR]]#2
// AMD-COUNT-2: triton_gpu.local_dealloc
// AMD: scf.yield %[[SEL1]]

Expand Down Expand Up @@ -407,19 +414,29 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
// AMD: %[[CMPI_29:.*]] = arith.cmpi sge, %[[ADDI_28]], %{{.*}}
// AMD: %[[LOCAL_LOAD_30:.*]] = triton_gpu.local_load %{{.*}}#4
// AMD: %[[LOCAL_LOAD_31:.*]] = triton_gpu.local_load %{{.*}}#5
// AMD: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_30]], %[[LOCAL_LOAD_31]], %{{.*}}#0
// AMD: %[[IF_32:.*]] = scf.if %[[CMPI_27]]
// AMD: %[[DOT_43:.*]] = tt.dot %[[LOCAL_LOAD_30]], %[[LOCAL_LOAD_31]], %{{.*}}#0
// AMD: scf.yield %[[DOT_43]]
// AMD: } else {
// AMD: scf.yield %{{.*}}#0
// AMD: }
// AMD: %[[ADDI_33:.*]] = arith.addi %{{.*}}#3, %{{.*}}
// AMD: %[[CMPI_34:.*]] = arith.cmpi slt, %[[ADDI_33]], %{{.*}}
// AMD: %[[SELECT_35:.*]] = arith.select %[[CMPI_34]], %[[ADDI_33]], %{{.*}}
// AMD: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_35]], %{{.*}}, %{{.*}}]
// AMD: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_35]], %{{.*}}, %{{.*}}]
// AMD: triton_gpu.local_store %{{.*}}#6, %[[MEMDESC_SUBVIEW_36]]
// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_35]], %{{.*}}, %{{.*}}]
// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_35]], %{{.*}}, %{{.*}}]
// AMD: triton_gpu.local_store %{{.*}}#7, %[[MEMDESC_SUBVIEW_37]]
// AMD: %[[SELECT_38:.*]] = arith.select %[[CMPI_27]], %[[DOT_32]], %{{.*}}#0
// AMD: %[[SELECT_38:.*]] = arith.select %[[CMPI_27]], %[[IF_32]], %{{.*}}#0
// AMD: %[[LOCAL_LOAD_39:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_36]]
// AMD: %[[LOCAL_LOAD_40:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_37]]
// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_39]], %[[LOCAL_LOAD_40]], %[[SELECT_38]]
// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_29]], %[[DOT_41]], %[[SELECT_38]]
// AMD: %[[IF_41:.*]] = scf.if %[[CMPI_29]]
// AMD: %[[DOT_43:.*]] = tt.dot %[[LOCAL_LOAD_39]], %[[LOCAL_LOAD_40]], %[[SELECT_38]]
// AMD: scf.yield %[[DOT_43]]
// AMD: } else {
// AMD: scf.yield %[[SELECT_38]]
// AMD: }
// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_29]], %[[IF_41]], %[[SELECT_38]]
// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]]
// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]]

Expand Down Expand Up @@ -981,7 +998,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// AMD: %[[CMPI_24:.*]] = arith.cmpi sge, %[[ADDI_23]], %{{.*}}
// AMD: %[[LOCAL_LOAD_25:.*]] = triton_gpu.local_load %{{.*}}#4
// AMD: %[[LOCAL_LOAD_26:.*]] = triton_gpu.local_load %{{.*}}#5
// AMD: %[[DOT_27:.*]] = tt.dot %[[LOCAL_LOAD_25]], %[[LOCAL_LOAD_26]], %{{.*}}#0
// AMD: %[[IF_27:.*]] = scf.if %[[CMPI_22]]
// AMD: %[[DOT_47:.*]] = tt.dot %[[LOCAL_LOAD_25]], %[[LOCAL_LOAD_26]], %{{.*}}#0
// AMD: scf.yield %[[DOT_47]]
// AMD: } else {
// AMD: scf.yield %{{.*}}#0
// AMD: }
// AMD: %[[ADDPTR_28:.*]] = tt.addptr %{{.*}}#1, %{{.*}}
// AMD: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_24]]
// AMD: %[[LOAD_30:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_29]]
Expand All @@ -998,11 +1020,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// AMD: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_40]]
// AMD: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
// AMD: triton_gpu.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_41]]
// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_22]], %[[DOT_27]], %{{.*}}#0
// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_22]], %[[IF_27]], %{{.*}}#0
// AMD: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_40]]
// AMD: %[[LOCAL_LOAD_44:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_41]]
// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_43]], %[[LOCAL_LOAD_44]], %[[SELECT_42]]
// AMD: %[[SELECT_46:.*]] = arith.select %[[CMPI_24]], %[[DOT_45]], %[[SELECT_42]]
// AMD: %[[IF_45:.*]] = scf.if %[[CMPI_24]]
// AMD: %[[DOT_47:.*]] = tt.dot %[[LOCAL_LOAD_43]], %[[LOCAL_LOAD_44]], %[[SELECT_42]]
// AMD: scf.yield %[[DOT_47]]
// AMD: } else {
// AMD: scf.yield %[[SELECT_42]]
// AMD: }
// AMD: %[[SELECT_46:.*]] = arith.select %[[CMPI_24]], %[[IF_45]], %[[SELECT_42]]
// AMD: triton_gpu.local_dealloc %{{.*}}
// AMD: triton_gpu.local_dealloc %{{.*}}

Expand Down
21 changes: 20 additions & 1 deletion third_party/amd/lib/TritonAMDGPUTransforms/StreamPipelineV2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,25 @@ createStreamOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule,
return allocs;
}

static Operation *streamPredication(RewriterBase &rewriter, Operation *op,
Value pred) {
// The epilogue peeling generates a select for the stage output. This causes
// too much register pressure with the loop result and the epilogue-dot in
// regs for the select. Conditionally executing the dot will allow the backend
// to optimize the select away as redundant.
if (auto dotOp = dyn_cast<tt::DotOp>(op)) {
auto loc = dotOp->getLoc();
auto ifOp = rewriter.create<scf::IfOp>(loc, dotOp.getResult().getType(),
pred, /*withElseRegion=*/true);
auto thenB = ifOp.getThenBodyBuilder();
auto yield = thenB.create<scf::YieldOp>(loc, dotOp.getResult());
dotOp->moveBefore(yield);
ifOp.getElseBodyBuilder().create<scf::YieldOp>(loc, dotOp.getC());
return ifOp;
}
return tt::predicateOp(rewriter, op, pred);
}

static bool preprocessLoopAndBuildSchedule(scf::ForOp &forOp, int numStages,
tt::PipeliningOption &options) {
// Schedule the loads and root ops (dot ops) in the loop. This will give us
Expand Down Expand Up @@ -599,7 +618,7 @@ static bool preprocessLoopAndBuildSchedule(scf::ForOp &forOp, int numStages,
s = std::move(schedule);
};
options.peelEpilogue = true;
options.predicateFn = tt::predicateOp;
options.predicateFn = streamPredication;
options.supportDynamicLoops = true;

OpBuilder builder(forOp);
Expand Down

0 comments on commit 845d75a

Please sign in to comment.