Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMD] Predicate tt.dot to override select in pipeline epilogue #4694

Merged
merged 3 commits into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 40 additions & 13 deletions test/TritonGPU/loop-pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,13 @@
// AMD: %[[LOCAL_LOAD_27:.*]] = triton_gpu.local_load %{{.*}}#4
// AMD: %[[LOCAL_LOAD_28:.*]] = triton_gpu.local_load %{{.*}}#5
// AMD: %[[MULF_29:.*]] = arith.mulf %[[LOCAL_LOAD_28]], %{{.*}}
// AMD: %[[DOT_30:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[MULF_29]], %{{.*}}#2
// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_26]], %[[DOT_30]], %{{.*}}#2
// AMD: %[[IF_30:.*]] = scf.if %[[CMPI_26]]
// AMD: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_27]], %[[MULF_29]], %{{.*}}#2
// AMD: scf.yield %[[DOT_32]]
// AMD: } else {
// AMD: scf.yield %{{.*}}#2
// AMD: }
// AMD: %[[SELECT_31:.*]] = arith.select %[[CMPI_26]], %[[IF_30]], %{{.*}}#2
// AMD: triton_gpu.local_dealloc %{{.*}}
// AMD: triton_gpu.local_dealloc %{{.*}}

Expand Down Expand Up @@ -200,8 +205,10 @@ tt.func @matmul_loop(%lb : index, %ub : index, %step : index,
// AMD: triton_gpu.local_store %{{.+}}, %[[SUBVIEW1]]
// AMD: scf.yield
// AMD-COUNT-2: triton_gpu.local_load
// AMD: %[[IF1:.*]] = scf.if
// AMD: %[[DOT1:.*]] = tt.dot
// AMD: %[[SEL1:.*]] = arith.select %{{.*}}, %[[DOT1]], %[[FOR]]#2
// AMD: scf.yield %[[DOT1]]
// AMD: %[[SEL1:.*]] = arith.select %{{.*}}, %[[IF1]], %[[FOR]]#2
// AMD-COUNT-2: triton_gpu.local_dealloc
// AMD: scf.yield %[[SEL1]]

Expand Down Expand Up @@ -407,19 +414,29 @@ tt.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
// AMD: %[[CMPI_29:.*]] = arith.cmpi sge, %[[ADDI_28]], %{{.*}}
// AMD: %[[LOCAL_LOAD_30:.*]] = triton_gpu.local_load %{{.*}}#4
// AMD: %[[LOCAL_LOAD_31:.*]] = triton_gpu.local_load %{{.*}}#5
// AMD: %[[DOT_32:.*]] = tt.dot %[[LOCAL_LOAD_30]], %[[LOCAL_LOAD_31]], %{{.*}}#0
// AMD: %[[IF_32:.*]] = scf.if %[[CMPI_27]]
// AMD: %[[DOT_43:.*]] = tt.dot %[[LOCAL_LOAD_30]], %[[LOCAL_LOAD_31]], %{{.*}}#0
// AMD: scf.yield %[[DOT_43]]
// AMD: } else {
// AMD: scf.yield %{{.*}}#0
// AMD: }
// AMD: %[[ADDI_33:.*]] = arith.addi %{{.*}}#3, %{{.*}}
// AMD: %[[CMPI_34:.*]] = arith.cmpi slt, %[[ADDI_33]], %{{.*}}
// AMD: %[[SELECT_35:.*]] = arith.select %[[CMPI_34]], %[[ADDI_33]], %{{.*}}
// AMD: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_0]][%[[SELECT_35]], %{{.*}}, %{{.*}}]
// AMD: %[[MEMDESC_SUBVIEW_36:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_35]], %{{.*}}, %{{.*}}]
// AMD: triton_gpu.local_store %{{.*}}#6, %[[MEMDESC_SUBVIEW_36]]
// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = triton_gpu.memdesc_subview %[[LOCAL_ALLOC_1]][%[[SELECT_35]], %{{.*}}, %{{.*}}]
// AMD: %[[MEMDESC_SUBVIEW_37:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_35]], %{{.*}}, %{{.*}}]
// AMD: triton_gpu.local_store %{{.*}}#7, %[[MEMDESC_SUBVIEW_37]]
// AMD: %[[SELECT_38:.*]] = arith.select %[[CMPI_27]], %[[DOT_32]], %{{.*}}#0
// AMD: %[[SELECT_38:.*]] = arith.select %[[CMPI_27]], %[[IF_32]], %{{.*}}#0
// AMD: %[[LOCAL_LOAD_39:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_36]]
// AMD: %[[LOCAL_LOAD_40:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_37]]
// AMD: %[[DOT_41:.*]] = tt.dot %[[LOCAL_LOAD_39]], %[[LOCAL_LOAD_40]], %[[SELECT_38]]
// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_29]], %[[DOT_41]], %[[SELECT_38]]
// AMD: %[[IF_41:.*]] = scf.if %[[CMPI_29]]
// AMD: %[[DOT_43:.*]] = tt.dot %[[LOCAL_LOAD_39]], %[[LOCAL_LOAD_40]], %[[SELECT_38]]
// AMD: scf.yield %[[DOT_43]]
// AMD: } else {
// AMD: scf.yield %[[SELECT_38]]
// AMD: }
// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_29]], %[[IF_41]], %[[SELECT_38]]
// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_0]]
// AMD: triton_gpu.local_dealloc %[[LOCAL_ALLOC_1]]

Expand Down Expand Up @@ -981,7 +998,12 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// AMD: %[[CMPI_24:.*]] = arith.cmpi sge, %[[ADDI_23]], %{{.*}}
// AMD: %[[LOCAL_LOAD_25:.*]] = triton_gpu.local_load %{{.*}}#4
// AMD: %[[LOCAL_LOAD_26:.*]] = triton_gpu.local_load %{{.*}}#5
// AMD: %[[DOT_27:.*]] = tt.dot %[[LOCAL_LOAD_25]], %[[LOCAL_LOAD_26]], %{{.*}}#0
// AMD: %[[IF_27:.*]] = scf.if %[[CMPI_22]]
// AMD: %[[DOT_47:.*]] = tt.dot %[[LOCAL_LOAD_25]], %[[LOCAL_LOAD_26]], %{{.*}}#0
// AMD: scf.yield %[[DOT_47]]
// AMD: } else {
// AMD: scf.yield %{{.*}}#0
// AMD: }
// AMD: %[[ADDPTR_28:.*]] = tt.addptr %{{.*}}#1, %{{.*}}
// AMD: %[[SPLAT_29:.*]] = tt.splat %[[CMPI_24]]
// AMD: %[[LOAD_30:.*]] = tt.load %[[ADDPTR_28]], %[[SPLAT_29]]
Expand All @@ -998,11 +1020,16 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// AMD: triton_gpu.local_store %[[LOAD_30]], %[[MEMDESC_SUBVIEW_40]]
// AMD: %[[MEMDESC_SUBVIEW_41:.*]] = triton_gpu.memdesc_subview %{{.*}}[%[[SELECT_39]], %{{.*}}, %{{.*}}]
// AMD: triton_gpu.local_store %[[LOAD_36]], %[[MEMDESC_SUBVIEW_41]]
// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_22]], %[[DOT_27]], %{{.*}}#0
// AMD: %[[SELECT_42:.*]] = arith.select %[[CMPI_22]], %[[IF_27]], %{{.*}}#0
// AMD: %[[LOCAL_LOAD_43:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_40]]
// AMD: %[[LOCAL_LOAD_44:.*]] = triton_gpu.local_load %[[MEMDESC_SUBVIEW_41]]
// AMD: %[[DOT_45:.*]] = tt.dot %[[LOCAL_LOAD_43]], %[[LOCAL_LOAD_44]], %[[SELECT_42]]
// AMD: %[[SELECT_46:.*]] = arith.select %[[CMPI_24]], %[[DOT_45]], %[[SELECT_42]]
// AMD: %[[IF_45:.*]] = scf.if %[[CMPI_24]]
// AMD: %[[DOT_47:.*]] = tt.dot %[[LOCAL_LOAD_43]], %[[LOCAL_LOAD_44]], %[[SELECT_42]]
// AMD: scf.yield %[[DOT_47]]
// AMD: } else {
// AMD: scf.yield %[[SELECT_42]]
// AMD: }
// AMD: %[[SELECT_46:.*]] = arith.select %[[CMPI_24]], %[[IF_45]], %[[SELECT_42]]
// AMD: triton_gpu.local_dealloc %{{.*}}
// AMD: triton_gpu.local_dealloc %{{.*}}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,25 @@ createStreamOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule,
return allocs;
}

static Operation *streamPredication(RewriterBase &rewriter, Operation *op,
Value pred) {
// The epilogue peeling generates a select for the stage output. This causes
// too much register pressure with the loop result and the epilogue-dot in
// regs for the select. Conditionally executing the dot will allow the backend
// to optimize the select away as redundant.
if (auto dotOp = dyn_cast<tt::DotOp>(op)) {
auto loc = dotOp->getLoc();
auto ifOp = rewriter.create<scf::IfOp>(loc, dotOp.getResult().getType(),
pred, /*withElseRegion=*/true);
auto thenB = ifOp.getThenBodyBuilder();
auto yield = thenB.create<scf::YieldOp>(loc, dotOp.getResult());
dotOp->moveBefore(yield);
ifOp.getElseBodyBuilder().create<scf::YieldOp>(loc, dotOp.getC());
return ifOp;
}
return tt::predicateOp(rewriter, op, pred);
}

static bool preprocessLoopAndBuildSchedule(scf::ForOp &forOp, int numStages,
tt::PipeliningOption &options) {
// Schedule the loads and root ops (dot ops) in the loop. This will give us
Expand Down Expand Up @@ -599,7 +618,7 @@ static bool preprocessLoopAndBuildSchedule(scf::ForOp &forOp, int numStages,
s = std::move(schedule);
};
options.peelEpilogue = true;
options.predicateFn = tt::predicateOp;
options.predicateFn = streamPredication;
options.supportDynamicLoops = true;

OpBuilder builder(forOp);
Expand Down
Loading