From 56df264a366b70cd26fa84dd60256a22f27ba733 Mon Sep 17 00:00:00 2001 From: Bert Maher Date: Thu, 14 Nov 2024 12:52:21 -0800 Subject: [PATCH] [warpspec] Add experimental support for warp specialization with user annotations This commit is a squash generated by: ``` git diff --stat b62b221a...9755e293a -- . ':(exclude)python/gemmbench' ':(exclude)python/hstuBench' ':(exclude)third_party/proton' ``` Additional modifications: - Update README.md with Warp Specialization Support - Propagate mma layout to following elementwise operations. --- README.md | 211 ++- .../PatternTritonGPUOpToLLVM.h | 4 + .../Conversion/TritonGPUToLLVM/Utility.h | 42 + .../Dialect/TritonGPU/Transforms/Passes.td | 105 ++ .../TritonGPU/Transforms/PipeliningUtility.h | 9 + .../Dialect/TritonGPU/Transforms/Schedule.h | 6 +- .../TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td | 119 ++ .../TritonNvidiaGPU/Transforms/Utility.h | 129 ++ include/triton/Tools/Sys/GetEnv.hpp | 5 + lib/Analysis/Allocation.cpp | 39 +- lib/Analysis/Membar.cpp | 11 + .../TritonGPUToLLVM/AllocateSharedMemory.cpp | 1 + lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 + .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 8 +- .../TritonGPUToLLVM/RegReallocOpToLLVM.cpp | 47 + .../Transforms/RewriteTensorPointer.cpp | 7 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 12 +- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 10 +- .../TritonGPU/Transforms/CMakeLists.txt | 6 + .../TritonGPU/Transforms/LoopScheduling.cpp | 622 +++++++ lib/Dialect/TritonGPU/Transforms/PingPong.cpp | 186 +++ .../Pipeliner/MatmulLoopPipeline.cpp | 229 ++- .../Transforms/Pipeliner/PipelineExpander.cpp | 58 +- .../Pipeliner/PipeliningUtility.cpp | 78 + .../Transforms/Pipeliner/Schedule.cpp | 14 +- .../Pipeliner/TMAStoresPipeline.cpp | 16 +- .../Transforms/RemoveLayoutConversions.cpp | 16 + .../TritonGPU/Transforms/TaskIdPropagate.cpp | 407 +++++ .../TritonGPU/Transforms/WSCodePartition.cpp | 1424 +++++++++++++++++ .../TritonGPU/Transforms/WSDataPartition.cpp | 680 ++++++++ .../TritonGPU/Transforms/WSLowering.cpp | 349 ++++ lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp | 13 + .../TritonNvidiaGPU/Transforms/CMakeLists.txt | 1 + .../Transforms/TMALowering.cpp | 20 +- .../TritonNvidiaGPU/Transforms/Utility.cpp | 162 ++ python/src/ir.cc | 32 +- python/src/passes.cc | 11 + python/triton/compiler/code_generator.py | 18 + python/triton/language/__init__.py | 1 + python/triton/language/core.py | 21 +- python/triton/runtime/autotuner.py | 27 +- .../tutorials/10-warp-specialized-matmul.py | 319 ++++ python/tutorials/mm.py | 201 +++ test/TritonGPU/combine.mlir | 42 + test/TritonGPU/comp-pipeline.mlir | 102 ++ .../WarpSpecialization/async_propagate.mlir | 63 + .../WarpSpecialization/ws_code_partition.mlir | 306 ++++ .../WarpSpecialization/ws_data_partition.mlir | 136 ++ .../WarpSpecialization/ws_lowering.mlir | 237 +++ .../TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 2 +- third_party/nvidia/backend/compiler.py | 12 + .../include/Dialect/NVGPU/IR/NVGPUOps.td | 42 + .../lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp | 109 +- .../TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp | 108 +- .../ConvertLayoutOpToLLVM.cpp | 4 +- .../LoadStoreOpToLLVM.cpp | 44 +- .../TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp | 14 + .../TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp | 9 + 58 files changed, 6819 insertions(+), 88 deletions(-) create mode 100644 include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h create mode 100644 lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.cpp create mode 100644 lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp create mode 100644 lib/Dialect/TritonGPU/Transforms/PingPong.cpp create mode 100644 lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp create mode 100644 lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp create mode 100644 lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp create mode 100644 lib/Dialect/TritonGPU/Transforms/WSLowering.cpp create mode 100644 lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp create mode 100644 python/tutorials/10-warp-specialized-matmul.py create mode 100644 python/tutorials/mm.py create mode 100644 test/TritonGPU/comp-pipeline.mlir create mode 100644 test/TritonNvidiaGPU/WarpSpecialization/async_propagate.mlir create mode 100644 test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir create mode 100644 test/TritonNvidiaGPU/WarpSpecialization/ws_data_partition.mlir create mode 100644 test/TritonNvidiaGPU/WarpSpecialization/ws_lowering.mlir diff --git a/README.md b/README.md index 4685ae30f..555c4bd1a 100644 --- a/README.md +++ b/README.md @@ -1,16 +1,211 @@ -
- Triton logo -
-The Triton Conference is happening again on September 17th, 2024 in Fremont (CA)! +# Warp Specialization Support + + +Warp specialization enhances kernel performance by utilizing an asynchronous execution model, where different parts of the kernel are handled by separate hardware units. The data communication between these units, via shared memory on the H100, operates with high efficiency. With this in mind, we’ve developed a Triton DSL extension that allows users to partition their kernel into asynchronous tasks (which map to warp groups on NVIDIA GPU), which naturally execute concurrently, leveraging the hardware’s multitasking warp scheduler. The following sections provide a breakdown of the compiler features developed to enable warp specialization. + + +## Asynchronous Tasks + +Warp specialization is built on top of the concept of partitioning the user’s program into asynchronous tasks (referred to as "async tasks" or “tasks” in the following sections). Each async task will be executed by a standalone warp group on the supported hardware, to achieve instruction level parallelism. Optimally and automatically partitioning async tasks is quite a challenge for the compiler. As a result, the Triton DSL has been extended to allow users to perform manual partitioning. + +The language extension is built around the Python context manager, designed to be simple and intuitive. Such extension is platform-agnostic, i.e., on platforms where warp specialization is not supported, the user annotation will be ignored, with no impact on correctness and performance. + +For instance, a warp-specialized GEMM implementation might look like this: + +```python +@triton.jit +def matmul_persistent_ws_kernel( + a_ptr, b_ptr, c_ptr, M, N, K, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_M) + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_m = pid // num_pid_m + pid_n = pid % num_pid_n + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + # Use tl.async_task to specify warp-specialized code + with tl.async_task([0]): + a = tl.load(a_ptrs) + b = tl.load(b_ptrs) + acc += tl.dot(a, b) + a_ptrs += BLOCK_K * stride_ak + b_ptrs += BLOCK_K * stride_bk + c = acc.to(tl.float16) + c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :] + # Use tl.async_task to specify warp-specialized code + with tl.async_task([1]): + tl.store(c_ptrs, c) +``` + +By wrapping a code block within the **tl.async_task** statement, the user specifies that the block will be executed by a certain number of warp groups, as defined by the statement. In the example above, the load operations are assigned to task 0, while the store operations are handled by task 1. Operations that are explicitly specified with a task id are known as anchor operations, and they affect the task assignment for the remaining operations. + + +The non-anchor operations are assigned to a task by the compiler in the following way: + +- Control dependencies exclusive to an anchor operation are included in the same task as the anchor operation. +- Data dependencies exclusive to an anchor operation are included in the same task as the anchor operation, unless they are another anchor operation. +- Control or data dependencies shared between tasks are included in all those tasks. + +For the GEMM example above, the compiler computes a task scheme and annotates it in the IR using MLIR attributes. To illustrate this more clearly, let's use source code annotations. After task propagation: + +```python +@triton.jit +def matmul_persistent_ws_kernel( + a_ptr, b_ptr, c_ptr, M, N, K, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + pid = tl.program_id(axis=0) # async_task 0, 1 + num_pid_m = tl.cdiv(M, BLOCK_M) # async_task 0, 1 + num_pid_n = tl.cdiv(N, BLOCK_N) # async_task 0, 1 + pid_m = pid // num_pid_m # async_task 0, 1 + pid_n = pid % num_pid_n # async_task 0, 1 + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # async_task 0, 1 + offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # async_task 0, 1 + offs_k = tl.arange(0, BLOCK_K) # async_task 0 + a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0 + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) # async_task 0 + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # async_task 1 + for k in range(0, tl.cdiv(K, BLOCK_K)): # async_task 0, 1 + a = tl.load(a_ptrs) # async_task 0 + b = tl.load(b_ptrs) # async_task 0 + acc += tl.dot(a, b) # async_task 1 + a_ptrs += BLOCK_K * stride_ak # async_task 0 + b_ptrs += BLOCK_K * stride_bk # async_task 0 + c = acc.to(tl.float16) # async_task 1 + c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :] # async_task 1 + tl.store(c_ptrs, c) # async_task 1 +``` + +## Data Partitioning + +To further improve performance, the user may choose to split the same workload across two async tasks This way, when one task is blocked on a heavy computation (e.g., the dot operation), the other group can execute other operations in parallel. This can be easily achieved by annotating the store operation with two tasks: + +```python +with tl.async_task([1,2]): + tl.store(c_ptr) +``` + +The compiler determines how to divide the work between the two tasks to maximize performance. On the H100 GPU, the compiler will, by default, attempt to split the input tensor A along the M dimension so that each consumer computes half of the output tensor independently. This approach is known as cooperative partitioning. If this split is not advantageous—for instance, if it results in a smaller-than-native `wgmma` instruction—the compiler will instead attempt to split along the N dimension. + +The transformed code for the above GEMM kernel with a configured tile size [128, 256, 64] will look like below (using source annotations instead of IR for illustration) + +```python +@triton.jit +def matmul_persistent_ws_kernel( + a_ptr, b_ptr, c_ptr, M, N, K, + stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, + BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, +): + pid = tl.program_id(axis=0) # async_task 0, 1, 2 + num_pid_m = tl.cdiv(M, BLOCK_M) # async_task 0, 1, 2 + num_pid_n = tl.cdiv(N, BLOCK_N) # async_task 0, 1, 2 + pid_m = pid // num_pid_m # async_task 0, 1, 2 + pid_n = pid % num_pid_n # async_task 0, 1, 2 + offs_m_1 = pid_m * BLOCK_M + tl.arange(0, BLOCK_M // 2) # async_task 0, 1, 2 + offs_m_2 = pid_m * BLOCK_M + tl.arange(BLOCK_M // 2, BLOCK_M) # async_task 0, 1, 2 + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_N) # async_task 0, 1, 2 + offs_k = tl.arange(0, BLOCK_K) # async_task 0 + a_ptrs_1 = a_ptr + (offs_m_1[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0 + a_ptrs_2 = a_ptr + (offs_m_2[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0 + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) # async_task 0 + acc_1 = tl.zeros((BLOCK_M // 2, BLOCK_N), dtype=tl.float32) # async_task 1 + acc_1 = tl.zeros((BLOCK_M // 2, BLOCK_N), dtype=tl.float32) # async_task 2 + for k in range(0, tl.cdiv(K, BLOCK_K)): # async_task 0, 1, 2 + a_1 = tl.load(a_ptrs_1) # async_task 0 + a_2 = tl.load(a_ptrs_2) # async_task 0 + b = tl.load(b_ptrs) # async_task 0 + acc_1 += tl.dot(a_1, b) # async_task 1 + acc_2 += tl.dot(a_2, b) # async_task 2 + a_ptrs_1 += BLOCK_K * stride_ak # async_task 0 + a_ptrs_2 += BLOCK_K * stride_ak # async_task 0 + b_ptrs += BLOCK_K * stride_bk # async_task 0 + c_1 = acc_1.to(tl.float16) # async_task 1 + c_2 = acc_2.to(tl.float16) # async_task 2 + c_ptrs_1 = c_ptr_1 + stride_cm * offs_m_1[:, None] + stride_cn * offs_n[None, :] # async_task 1 + c_ptrs_2 = c_ptr_2 + stride_cm * offs_m_2[:, None] + stride_cn * offs_n[None, :] # async_task 2 + tl.store(c_ptrs_1, c_1) # async_task 1 + tl.store(c_ptrs_2, c_2) # async_task 2 +``` + +## Code Partitioning + +We assume all operations are already marked with a list of taskIds. We first find all communications required between warp groups. Each communication starts from a load operation with a single taskId, and ends at a direct user of the load which belongs to a different taskId. For `ForOps` containing a communication channel, we add additional arguments: `phase` and `bufferIndex`. + +We introduce a tuning configuration: `num_buffers_warp_spec`. For each communication channel, if it is within a `forOp`, we use an array of buffers in SMEM to save the results, and size of the array is determined by `num_buffers_warp_spec`. We also use an array of barriers for each communication channel that is inside a `ForOp`. At this pass, four new operations are introduced to correctly synchronize between the producer and the consumer: `ProducerAcquireOp`, `ProducerCommitOp`, `ConsumerWaitOp`, and `ConsumerReleaseOp`. Each of the four new ops take a token, a buffer Index. `ProducerAcquire` and `ConsumerWait` take an additional phase operand. + + +For `ForOps` with multiple task Ids, we clone one copy for each taskId, each copy contains the operations with the specific taskId. In the end, we create multiple `IfOps`, one for each possible taskId. We go through the body of the function, clone the op for each attached task Id and put the cloned op in the right `IfOp`. + +To adjust register usage, we introduce two new ops: `RegAllocOp` and `RegDeallocOp`, both taking an integer operand. For each warp group, we decide to insert either `RegAllocOp` or `RegDeallocOp`. The current heuristic is simple: if the task Id is 0, we add `RegDeallocOp`, otherwise we use `RegAllocOp`. The amount of register adjustment can be tuned via `reg_dec_producer` and `reg_inc_consumer`. + +This pass also lowers `loadOp`s to `AsyncTMACopyGlobalToLocalOp` or `AsyncCopyGlobalToLocalOp`, so the communication can be expressed via SMEM. For TMA, the producer will become +`ProducerAcquire` -> `barrier_expect` -> `AsyncTMACopyGlobalToLocalOp`, and the consumer will contain `wait_barrier` -> ops -> `ConsumerRelease`. For non-TMA loads, the producer will become `ProducerAcquire` -> `AsyncCopyGlobalToLocalOp` -> `ProducerCommitOp`, and the consumer will contain `ConsumerWaitOp` -> ops -> `ConsumerRelease`. + + +# Performance + +## Flash Attention + +``` + (Batch, Heads, SeqLen, Dhead) triton_tutorial_flash_v2_ws-tflops triton_tutorial_flash_v2_tma_ws-tflops triton_tutorial_flash_v2-tflops +------------------------------- ------------------------------------ ---------------------------------------- --------------------------------- + (8, 16, 8192, 128) 548.783 561.539 482.664 +``` + +Benchmarking instructions: +``` +git clone //github.com/pytorch-labs/tritonbench.git +cd tritonbench +TORCH_CUDA_ARCH_LIST=9.0a python run.py --op flash_attention --only triton_tutorial_flash_v2_ws,triton_tutorial_flash_v2_tma_ws,triton_tutorial_flash_v2 --num-inputs 1 --seq-len 13 --metrics tflops --batch 8 --n-heads 16 --d-head 128 +``` + +## FP8 Rowwise GEMM with FP8_FAST_ACC=False + + +With warp specialization: + +``` + (M, N, K) _triton-tflops cutlass-tflops cutlass-speedup cutlass-accuracy cublas_12.1-tflops cublas_12.1-speedup cublas_12.1-accuracy +------------------ ---------------- ---------------- ----------------- ------------------ -------------------- --------------------- ---------------------- +(8192, 8192, 8192) 1028.33 380.182 0.369707 1 1033.47 1.00499 1 +``` + +Without warp specialization: + +``` + (M, N, K) _triton-tflops cutlass-tflops cutlass-speedup cutlass-accuracy cublas_12.1-tflops cublas_12.1-speedup cublas_12.1-accuracy +------------------ ---------------- ---------------- ----------------- ------------------ -------------------- --------------------- ---------------------- +(8192, 8192, 8192) 919.595 384.423 0.418035 1 1035.43 1.12596 1 +``` + + +Benchmarking instructions: + +``` +git clone //github.com/pytorch-labs/tritonbench.git +cd tritonbench +git checkout hoy/fp8gemmWS +python install.py --fbgemm +python run.py --op fp8_gemm_rowwise --m 8192 --n 8192 --k 8192 --no_fp8_fast_accum --warp_specialization +python run.py --op fp8_gemm_rowwise --m 8192 --n 8192 --k 8192 --no_fp8_fast_accum +``` + + -If you are interested in attending, please fill up [this form](https://docs.google.com/forms/d/e/1FAIpQLSecHC1lkalcm0h3JDUbspekDX5bmBvMxgVTLaK3e-61bzDDbg/viewform). -| **`Documentation`** | **`Nightly Wheels`** | -|-------------------- | -------------------- | -| [![Documentation](https://github.com/triton-lang/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) | [![Wheels](https://github.com/triton-lang/triton/actions/workflows/wheels.yml/badge.svg?branch=release/2.0.x)](https://github.com/triton-lang/triton/actions/workflows/wheels.yml) | +# General information about Triton # Triton diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index 29af2c5f7..ebb54a8a8 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -102,6 +102,10 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, PatternBenefit benefit); +void populateRegReallocOpToLLVMPatterns(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + PatternBenefit benefit); + } // namespace triton } // namespace mlir diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index b209a02b4..bbbba9459 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -13,6 +13,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" #include "triton/Tools/LinearLayout.h" #include "triton/Tools/StrUtil.h" #include "triton/Tools/Sys/GetEnv.hpp" @@ -144,6 +145,20 @@ using namespace mlir::triton; namespace mlir { namespace triton { +static inline void insertBarrier(PatternRewriter &rewriter, Operation *op) { + auto barrierOp = rewriter.create(op->getLoc()); + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() == 1) { + int asyncTaskId = asyncTaskIds[0]; + int barId = asyncTaskId + nameBarrierIdBegin; + assert(barId < nameBarrierIdEnd); + // TODO: Change hard code style of numThreads. + const int numThreads = 128; + barrierOp->setAttr("bar_id", rewriter.getI64IntegerAttr(barId)); + barrierOp->setAttr("num_threads", rewriter.getI64IntegerAttr(numThreads)); + } +} + // Delinearize supposing order is [0, 1, .. , n] template llvm::SmallVector getMultiDimIndexImpl(T linearIndex, @@ -371,6 +386,20 @@ inline Value getStackPointer(RewriterBase &rewriter, return funcOp.getArgument(funcOp.getNumArguments() - 1); } +static Operation *getWarpGroupId(Operation *op) { + auto funcOp = op->getParentOfType(); + Operation *getWarpId = nullptr; + funcOp.walk([&](Operation *op) -> void { + if (isa(op)) { + assert(getWarpId == nullptr); + getWarpId = op; + } + }); + assert(getWarpId); + getWarpId->dump(); + return getWarpId; +} + inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, Operation *op) { auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); @@ -381,6 +410,19 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, .getValue() .getZExtValue(); Value offVal = i32_val(offset); + if (op->hasAttr("allocation.copy")) { + auto copy = cast(op->getAttr("allocation.copy")).getValue().getZExtValue(); + if (copy != 1) { + Operation *getWarpId = getWarpGroupId(op); + Value warpsPerWG = i32_val(4); + Value wgId = udiv(getWarpId->getResult(0), warpsPerWG); + // (wgId - 1) * allocation.size + offset + auto singleSize = cast(op->getAttr("allocation.size")).getValue().getZExtValue(); + Value sub1 = sub(wgId, i32_val(1)); + Value temp = mul(sub1, i32_val(singleSize)); + offVal = add(temp, offVal); + } + } Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal); return base; } diff --git a/include/triton/Dialect/TritonGPU/Transforms/Passes.td b/include/triton/Dialect/TritonGPU/Transforms/Passes.td index f2b79d222..d375e3801 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Passes.td +++ b/include/triton/Dialect/TritonGPU/Transforms/Passes.td @@ -179,4 +179,109 @@ def TritonGPUOptimizeAccumulatorInit: Pass<"tritongpu-optimize-accumulator-init" "mlir::triton::TritonDialect"]; } +def TritonGPUTaskIdPropagate : Pass<"triton-gpu-taskid-propagate", "mlir::ModuleOp"> { + let summary = "Propagate async_task_id annotations based on dependencies"; + + let description = [{ + This pass propagates the `async_task_id` annotation to the dependencies + of any op that has it set. This has the functional effect of partitioning + the graph into multiple async tasks, based on the initial annotation. + }]; + + let dependentDialects = [ + "mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect" + ]; + + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUWSCodePartition: Pass<"tritongpu-warp-spec-code-partition", "mlir::ModuleOp"> { + let summary = "TritonGPU warp specialization code partition"; + + let description = "This pass generates warp specialized code baed on task id attributes."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numBuffers", "num-buffers", + "int32_t", /*default*/"0", + "number of buffering for producer-consumer">, + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization">, + Option<"regDecProducer", "producer-reg-dec", + "int32_t", /*default*/"40", + "register decrement for producer warp group">, + Option<"regIncConsumer", "consumer-reg-inc", + "int32_t", /*default*/"232", + "register indrement for consumer warp group"> + ]; +} + +def TritonGPUWSDataPartition : Pass<"tritongpu-warp-spec-data-partition", "mlir::ModuleOp"> { + let summary = "Warp specialization data partition"; + + let description = "This pass partitions operations into multiple suboperations which operate on smaller data shapes"; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUWSLowering : Pass<"tritongpu-warp-spec-lowering", "mlir::ModuleOp"> { + let summary = "Warp specialization lowering"; + + let description = "This pass lowers warp specializtion related operations."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization"> + ]; +} + +def TritonGPUPingPongSync: Pass<"tritongpu-ping-pong-sync", "mlir::ModuleOp"> { + let summary = "TritonGPU experiemental ping pong schedule"; + + let description = "This pass generates warp specialized code baed on warp group id attributes."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + let options = [ + Option<"numConsumerGroups", "num-consumer-groups", + "int32_t", /*default*/"0", + "number of consumer warp groups for warp specialization">, + Option<"partitionStyle", "partition-style", + "int32_t", /*default*/"0", + "partition style for multiple consumer warp groups"> + ]; +} + +// #ifdef __FACEBOOK__ +def TritonGPULoopScheduling: Pass<"tritongpu-loop-scheduling", "mlir::ModuleOp"> { + let summary = "Generate loop scheduling for SWP"; + + let description = "This pass sets up stages and clustering for software pipelining."; + + let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect", + "mlir::triton::TritonDialect"]; + let options = [ + Option<"numStages", "num-stages", + "int32_t", /*default*/"3", + "number of pipeline stages"> + ]; +} +// #endif #endif diff --git a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h index 88f062a01..db89d0dec 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h +++ b/include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h @@ -29,6 +29,15 @@ void addOps(scf::ForOp forOp, int stage, /// mutable. void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse, Value val); + +// Begin __FACEBOOK__ CompPipe +/// Create a map from load ops to their indirection level and the +/// final use of the load op (another load op, or a dot op). +/// Indirection level is "0" for the load op directly used by the dot op, +/// "1" for the load op used by the load op used by the dot op, and so on. +llvm::SmallVector> +loadOpsToIndirectionLevelAndUse(scf::ForOp forOp); +// End __FACEBOOK__ CompPipe } // namespace triton } // namespace mlir diff --git a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h index 1dd1fc686..4bd8ff79e 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h @@ -84,8 +84,10 @@ class CoarseSchedule { return true; } - void insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, - bool includeArg); + void + insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, + bool includeArg, + DenseMap *additionalDep = nullptr); void erase(Operation *op) { opToStageAndCluster.erase(op); } diff --git a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td index 243b93436..3ce1d80de 100644 --- a/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td +++ b/include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUOps.td @@ -43,6 +43,38 @@ class TTNG_Op traits = []> : !listconcat(traits, [VerifyTensorLayoutsTrait])> { } +def TTNG_MBarrierArriveOp : TTNG_Op<"mbarrier_arrive", [AttrSizedOperandSegments, + MemoryEffects<[MemWrite]>]> { + let summary = "mbarrier arrive"; + + let description = [{ + This operation defining the arriving action for a mbarrier. + txCount: + An optional attribute that set tx-count. This Op will be lowered into + mbarrier.arrive.expect_tx if the optional attribute exist. + trackAsyncOp: + If true, this op will be lowered into cp.async.mbarrier.arrive.noinc. + pred: + Only perform arrive action when pred is true. + remoteCtaId: + if set, perform an remote arrive action. + + Example: + + triton_nvidia_gpu.mbarrier_arrive %0 {trackAsyncOp = false} : !tt.ptr + + }]; + + let arguments = (ins TT_MemDescType:$mbarrier, + Optional:$pred, + Optional:$remoteCtaId, + I1Attr: $trackAsyncOp, + DefaultValuedAttr: $txCount + ); + + let assemblyFormat = "operands attr-dict `:` type(operands)"; +} + def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> { let arguments = (ins BoolAttr:$bCluster); @@ -57,6 +89,31 @@ def TTNG_FenceAsyncSharedOp : TTNG_Op<"fence_async_shared"> { }]; } +def TTNG_GetCanonicalWarpIdOp : TTNG_Op<"get_canonical_warp_id", [Pure]> { + let description = [{ + Returns the one dimensional warpId when it's used for producing warp uniform values. + }]; + + let results = (outs I32:$result); + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTNG_NamedBarrierArriveOp : TTNG_Op<"bar_arrive", []> { + let summary = "named barrier arrive"; + + let arguments = (ins I32:$bar, I32: $numThreads); + + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def TTNG_NamedBarrierWaitOp : TTNG_Op<"bar_wait", []> { + let summary = "named barrier wait"; + + let arguments = (ins I32:$bar, I32: $numThreads); + + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + def TTNG_ClusterArriveOp : TTNG_Op<"cluster_arrive", []> { let arguments = (ins I1Attr:$relaxed); let assemblyFormat = "attr-dict"; @@ -249,4 +306,66 @@ def TTNG_TMAStoreWait : TTNG_Op<"async_tma_store_wait"> { let assemblyFormat = "attr-dict"; } +def TTNG_GetAsyncTaskIdOp : TTNG_Op<"get_async_task_id", [Pure]> { + let results = (outs I32:$result); + + let builders = [OpBuilder<(ins)>]; + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +// +// Token +// + +def TTNG_CreateTokenOp : TTNG_Op<"create_token"> { + let results = (outs TensorOf<[TTNG_TokenType]>:$result); + + let arguments = (ins I32Attr:$num); + + let builders = [OpBuilder<(ins "uint32_t":$num)>]; + + let assemblyFormat = "attr-dict `:` type($result)"; +} + +def TTNG_ProducerAcquireOp : TTNG_Op<"producer_acquire"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx, I1:$phase); + + let assemblyFormat = "$token `,` $idx `,` $phase attr-dict `:` type(operands)"; +} + +def TTNG_ProducerCommitOp : TTNG_Op<"producer_commit"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +def TTNG_ConsumerWaitOp : TTNG_Op<"consumer_wait"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx, I1: $phase); + + let assemblyFormat = "$token `,` $idx `,` $phase attr-dict `:` type(operands)"; +} + +def TTNG_ConsumerReleaseOp : TTNG_Op<"consumer_release"> { + let arguments = (ins TensorOf<[TTNG_TokenType]>:$token, I32:$idx); + + let assemblyFormat = "$token `,` $idx attr-dict `:` type(operands)"; +} + +def TTNG_RegAllocOp : TTNG_Op<"reg_alloc", []> { + let summary = "register allocation"; + + let arguments = (ins I32Attr: $regCount); + + let assemblyFormat = "$regCount attr-dict"; +} + +def TTNG_RegDeallocOp : TTNG_Op<"reg_dealloc", []> { + let summary = "register deallocation"; + + let arguments = (ins I32Attr: $regCount); + + let assemblyFormat = "$regCount attr-dict"; +} + #endif diff --git a/include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h new file mode 100644 index 000000000..6a200ebe9 --- /dev/null +++ b/include/triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h @@ -0,0 +1,129 @@ + +#ifndef TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_ +#define TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_ + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "llvm/ADT/MapVector.h" + +namespace mlir { + +// 0 is reserved for default sync. +// TODO: comprehensive mechanism to globally manage namedbarrier. +static int const nameBarrierIdBegin = 1; +static int nameBarrierIdEnd = 16; + +/// Helper functions for async task +typedef int AsyncTaskId; +SmallVector getAsyncTaskIds(Operation *op); +bool hasAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId); +void setAsyncTaskIds(Operation *op, ArrayRef asyncTaskIds); +SmallVector getNestedAsyncTaskIds(Operation *op); +void addAsyncTaskIds(Operation *op, ArrayRef asyncTasks); +void removeAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId); +void removeAsyncTaskIds(Operation *op); +SmallVector getMutexBarIds(Operation *op); +SmallVector getMutexNumThreads(Operation *op); + +static Value GetCanonicalWarpId(RewriterBase &rewriter, Location loc) { + return rewriter.create( + loc, rewriter.getI32Type()); +} + +static Value getClusterCTAId(RewriterBase &rewriter, Location loc) { + return rewriter.create(loc, + rewriter.getI32Type()); +} + +class OpBuilderWithAsyncTaskIds : public OpBuilder { +public: + OpBuilderWithAsyncTaskIds(MLIRContext *context) : OpBuilder(context) {} + + explicit OpBuilderWithAsyncTaskIds(Operation *op) + : OpBuilder(op) + { + setAsyncTaskIdsFromOp(op); + } + + void setAsynTaskIdsFromArray(ArrayRef newAsyncTaskIds) { + asyncTaskIds = SmallVector(newAsyncTaskIds.begin(), newAsyncTaskIds.end()); + } + + void setAsyncTaskIdsFromOp(Operation *op) { + setAsynTaskIdsFromArray(getAsyncTaskIds(op)); + } + + void setAsyncTaskIdsFromValueUsers(Value value) { + SetVector asyncTaskIdSet; + for (Operation *user : value.getUsers()) + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(user)) + asyncTaskIdSet.insert(asyncTaskId); + setAsynTaskIdsFromArray(asyncTaskIdSet.getArrayRef()); + } + + template + OpTy createWithAsyncTaskIds(Args &&...args) { + OpTy op = create(std::forward(args)...); + if (!asyncTaskIds.empty()) + setAsyncTaskIds(op, asyncTaskIds); + return op; + } + +private: + SmallVector asyncTaskIds; +}; + +class PatternRewriterWithAsyncTaskIds { +public: + PatternRewriterWithAsyncTaskIds(PatternRewriter &rewriter, Operation *op) + : rewriter(&rewriter) { + setAsyncTaskIdsFromOp(op); + } + + void setAsynTaskIdsFromArray(ArrayRef newAsyncTaskIds) { + asyncTaskIds = SmallVector(newAsyncTaskIds.begin(), newAsyncTaskIds.end()); + } + + void setAsyncTaskIdsFromOp(Operation *op) { + setAsynTaskIdsFromArray(getAsyncTaskIds(op)); + } + + void setAsyncTaskIdsFromValueUsers(Value value) { + SetVector asyncTaskIdSet; + for (Operation *user : value.getUsers()) + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(user)) + asyncTaskIdSet.insert(asyncTaskId); + setAsynTaskIdsFromArray(asyncTaskIdSet.getArrayRef()); + } + + template + OpTy create(Location location, Args &&...args) { + OpTy op = rewriter->create(location, std::forward(args)...); + if (!asyncTaskIds.empty()) + setAsyncTaskIds(op, asyncTaskIds); + return op; + } + + template + OpTy replaceOpWithNewOp(Operation *op, Args &&...args) { + auto newOp = + rewriter->replaceOpWithNewOp(op, std::forward(args)...); + return newOp; + } + +private: + PatternRewriter* rewriter; + SmallVector asyncTaskIds; +}; + +/// Constant task ids +constexpr AsyncTaskId kLoadAsyncTaskId = 0; +constexpr AsyncTaskId kDotAsyncTaskId = 1; + +bool isWSCandidateLoad(Operation *op); +bool isWSSupported(ModuleOp m, int computeCapability); + +} // namespace mlir + +#endif // TRITON_DIALECT_TRITONNVIDIAGPU_TRANSFORMS_UTILITY_H_ diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 43e7df135..ef77cccc6 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -30,6 +30,11 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_LLVM_DEBUG_ONLY", "USE_IR_LOC", "NVPTX_ENABLE_DUMP", + "PEEL_LAST_ITER", + "ENABLE_PINGPONG", + "HACK_ASYNC_DOT", + "SWP_FOR_CONSUMER", + "HARDCODE_TASKID_FA", // clang-format on }; diff --git a/lib/Analysis/Allocation.cpp b/lib/Analysis/Allocation.cpp index b44b75601..eda8628c7 100644 --- a/lib/Analysis/Allocation.cpp +++ b/lib/Analysis/Allocation.cpp @@ -13,6 +13,8 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/SmallVector.h" using ::mlir::triton::gpu::AMDMfmaEncodingAttr; @@ -189,6 +191,15 @@ class AllocationAnalysis { auto shapePerCTA = triton::gpu::getShapePerCTA(allocType); auto bytes = product(shapePerCTA) * allocType.getElementTypeBitWidth() / 8; + if (op->hasAttr("allocation.copy")) { + auto copy = cast(op->getAttr("allocation.copy")) + .getValue() + .getZExtValue(); + op->setAttr( + "allocation.size", + IntegerAttr::get(IntegerType::get(op->getContext(), 32), bytes)); + bytes = bytes * copy; + } auto alignment = alloc.getAlignmentOrDefault(); allocation->addBuffer(result, bytes, @@ -251,6 +262,15 @@ class AllocationAnalysis { isa(srcTy.getElementType()) ? elems * kPtrBitWidth / 8 : elems * std::max(8, srcTy.getElementTypeBitWidth()) / 8; + if (op->hasAttr("allocation.copy")) { + auto copy = cast(op->getAttr("allocation.copy")) + .getValue() + .getZExtValue(); + op->setAttr( + "allocation.size", + IntegerAttr::get(IntegerType::get(op->getContext(), 32), bytes)); + bytes = bytes * copy; + } maybeAddScratchBuffer(op, bytes, scratchAlignment); } else if (isa(op)) { @@ -370,8 +390,18 @@ class AllocationAnalysis { // range. auto *op = opScratchIter.first; auto *buffer = opScratchIter.second; - bufferRange.insert({buffer, Interval(operationId.lookup(op), - operationId.lookup(op) + 1)}); + // Extend live range when asyncTaskId is not empty (i.e when we have + // warp spec). + if (getAsyncTaskIds(op).empty()) { + bufferRange.insert({buffer, Interval(operationId.lookup(op), + operationId.lookup(op) + 1)}); + } else { + // FIXME: This range makes scratch buffers used in warp-specialized + // regions conflict with everything else in the program, which is + // too conservative, but safe. A better approach would make them + // conflict with buffers live in other warp-specialized regions. + bufferRange.insert({buffer, Interval(0, operationId.size())}); + } } }; processScratchMemory(allocation->opScratch); @@ -408,6 +438,11 @@ class AllocationAnalysis { auto maxId = std::numeric_limits::min(); std::for_each(liveOperations.begin(), liveOperations.end(), [&](Operation *liveOp) { + if (!getAsyncTaskIds(liveOp).empty()) { + minId = 0; + maxId = operationId.size(); + return; + } if (operationId[liveOp] < minId) { minId = operationId[liveOp]; } diff --git a/lib/Analysis/Membar.cpp b/lib/Analysis/Membar.cpp index bb106238e..1bad78b5a 100644 --- a/lib/Analysis/Membar.cpp +++ b/lib/Analysis/Membar.cpp @@ -6,6 +6,7 @@ #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" #include namespace mlir { @@ -95,6 +96,16 @@ void MembarAnalysis::visitTerminator(Operation *op, void MembarAnalysis::insertBarrier(Operation *op, OpBuilder *builder) { OpBuilder::InsertionGuard g(*builder); auto barrierOp = builder->create(op->getLoc()); + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() == 1) { + int asyncTaskId = asyncTaskIds[0]; + int barId = asyncTaskId + nameBarrierIdBegin; + assert(barId < nameBarrierIdEnd); + // TODO: Change hard code style of numThreads. + const int numThreads = 128; + barrierOp->setAttr("bar_id", builder->getI64IntegerAttr(barId)); + barrierOp->setAttr("num_threads", builder->getI64IntegerAttr(numThreads)); + } } void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo, diff --git a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp index aae9faf0e..ac2e77061 100644 --- a/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp +++ b/lib/Conversion/TritonGPUToLLVM/AllocateSharedMemory.cpp @@ -40,6 +40,7 @@ struct AllocateSharedMemory } if (offset == -1) return; + if (op->hasAttr("allocation.offset")) return; op->setAttr("allocation.offset", IntegerAttr::get(IntegerType::get(ctx, 32), offset)); }); diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index cca2830b0..ecf3d34c7 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -15,6 +15,7 @@ add_triton_library(TritonGPUToLLVM ConvertLayoutOpToLLVM.cpp ControlFlowOpToLLVM.cpp FuncOpToLLVM.cpp + RegReallocOpToLLVM.cpp SPMDOpToLLVM.cpp DecomposeUnsupportedConversions.cpp PrintOpToLLVM.cpp diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 403cac9de..d9a24ddc6 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -213,7 +213,7 @@ struct ConvertLayoutOpConversion auto multiDimRepId = getMultiDimIndex(repId, numReplicates, outOrd); if (repId != 0) { - barrier(); + insertBarrier(rewriter, op); } auto successful = targetInfo.processReplicaUsingStMatrix( rewriter, loc, smemBase, vals, srcTy, @@ -224,7 +224,7 @@ struct ConvertLayoutOpConversion multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd, vals, smemBase); } - barrier(); + insertBarrier(rewriter, op); processReplica(loc, rewriter, /*stNotRd*/ false, dstTy, outNumCTAsEachRep, multiDimRepId, outVec, paddedRepShape, origRepShape, outOrd, outVals, smemBase); @@ -581,7 +581,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion llvm::MapVector outVals; for (int i = 0; i < iterations; i++) { if (i != 0) - barrier(); + insertBarrier(rewriter, op); auto &inRegs = inRegsForIter[i]; auto &outRegs = outRegsForIter[i]; @@ -605,7 +605,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion } } - barrier(); + insertBarrier(rewriter, op); for (int j = 0; j < outSize / iterations; j += scratchConfig.outVec) { auto outRegSlice = outRegs[j]; diff --git a/lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.cpp new file mode 100644 index 000000000..d7dca0397 --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/RegReallocOpToLLVM.cpp @@ -0,0 +1,47 @@ +#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { +struct RegAllocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::RegAllocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::RegAllocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getRegCount()); + return success(); + } +}; + +struct RegDeallocOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::RegDeallocOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::RegDeallocOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getRegCount()); + return success(); + } +}; +} // namespace + +void mlir::triton::populateRegReallocOpToLLVMPatterns( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + PatternBenefit benefit) { + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + return; +} diff --git a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp index 4d40e0f31..f796775cb 100644 --- a/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp +++ b/lib/Dialect/Triton/Transforms/RewriteTensorPointer.cpp @@ -314,10 +314,14 @@ class RewriteTensorPointerPass loadOp.getLoc(), newPtr, newMask, newOther, loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); op->getResult(0).replaceAllUsesWith(newResult); + if (op->getAttr("async_task_id")) + newResult->setAttr("async_task_id", op->getAttr("async_task_id")); } else if (auto storeOp = dyn_cast(op)) { - builder.create(storeOp.getLoc(), newPtr, + auto newOp = builder.create(storeOp.getLoc(), newPtr, storeOp.getValue(), newMask, storeOp.getCache(), storeOp.getEvict()); + if (op->getAttr("async_task_id")) + newOp->setAttr("async_task_id", op->getAttr("async_task_id")); } // Erase the original operation @@ -413,6 +417,7 @@ class RewriteTensorPointerPass auto newForOp = builder.create(op.getLoc(), op.getLowerBound(), op.getUpperBound(), op.getStep(), newIterOperands); + newForOp->setAttrs(op->getAttrs()); // Create value mapping. Note that for tensor pointers, we use identity // mapping. It may refer to a value in the old loop, but we will rewrite it diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a454fef56..a1c2990fb 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -2717,8 +2717,9 @@ struct CanonicalizeConvertFromAlloc auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); - rewriter.replaceOpWithNewOp( + auto newAlloc = rewriter.replaceOpWithNewOp( op, op->getResult(0).getType(), convert.getSrc()); + newAlloc->setAttrs(op->getAttrs()); return mlir::success(); } }; @@ -2734,8 +2735,9 @@ struct CanonicalizeConvertFromLocalStore auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); - rewriter.replaceOpWithNewOp(op, convert.getSrc(), - op.getDst()); + auto store = rewriter.replaceOpWithNewOp(op, convert.getSrc(), + op.getDst()); + store->setAttrs(op->getAttrs()); return mlir::success(); } }; @@ -2854,8 +2856,10 @@ struct CanonicalizeConvertFromConvert // cvt(cvt(x, type1), type2) -> cvt(x, type2) if (auto cvt = dyn_cast(arg)) { auto srcType = op.getSrc().getType(); - rewriter.replaceOpWithNewOp( + auto origAttrs = op->getAttrs(); + auto newOp = rewriter.replaceOpWithNewOp( op, op->getResultTypes().front(), cvt.getSrc()); + newOp->setAttrs(origAttrs); return success(); } diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index d9bbd51bd..51d482aaa 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -8,6 +8,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Passes.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" #include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/Support/Debug.h" @@ -291,6 +292,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { versionMinor, warpsPerTile, CTALayout, instrShape); } + PatternRewriterWithAsyncTaskIds taskIdRewriter(rewriter, dotOp); auto newRetType = RankedTensorType::get( oldRetType.getShape(), oldRetType.getElementType(), mmaEnc); // convert accumulator @@ -305,7 +307,7 @@ class BlockedToMMA : public mlir::OpRewritePattern { bool allowTranspose = eltType.isF16() || eltType.isBF16(); a = getSharedMemoryMMAOperand(a, rewriter, 0, allowTranspose); b = getSharedMemoryMMAOperand(b, rewriter, 1, allowTranspose); - newDot = rewriter.create( + newDot = taskIdRewriter.create( dotOp.getLoc(), newRetType, a, b, newAcc, nullptr, dotOp.getInputPrecision(), dotOp.getMaxNumImpreciseAcc(), false); } else { @@ -327,9 +329,9 @@ class BlockedToMMA : public mlir::OpRewritePattern { auto newBType = RankedTensorType::get( oldBType.getShape(), oldBType.getElementType(), newBEncoding); b = rewriter.create(b.getLoc(), newBType, b); - newDot = rewriter.create(dotOp.getLoc(), newRetType, a, b, newAcc, - dotOp.getInputPrecision(), - dotOp.getMaxNumImpreciseAcc()); + newDot = taskIdRewriter.create(dotOp.getLoc(), newRetType, a, b, + newAcc, dotOp.getInputPrecision(), + dotOp.getMaxNumImpreciseAcc()); } // convert dot instruction rewriter.replaceOpWithNewOp(dotOp, oldRetType, diff --git a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt index 99e2ac3c9..d3ef9a145 100644 --- a/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonGPU/Transforms/CMakeLists.txt @@ -3,6 +3,7 @@ add_triton_library(TritonGPUTransforms Coalesce.cpp F32DotTC.cpp CombineTensorSelectAndIf.cpp + LoopScheduling.cpp ReduceDataDuplication.cpp OptimizeAccumulatorInit.cpp OptimizeDotOperands.cpp @@ -18,6 +19,11 @@ add_triton_library(TritonGPUTransforms RemoveLayoutConversions.cpp ReorderInstructions.cpp Utility.cpp + TaskIdPropagate.cpp + WSDataPartition.cpp + WSCodePartition.cpp + WSLowering.cpp + PingPong.cpp DEPENDS TritonGPUTransformsIncGen diff --git a/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp b/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp new file mode 100644 index 000000000..4b8eef70c --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/LoopScheduling.cpp @@ -0,0 +1,622 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +#define DEBUG_TYPE "triton-loop-schedule" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace tt = mlir::triton; +namespace ttg = ::mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +// Begin __FACEBOOK__ CompPipe +static void scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + llvm::SmallVector> + loadOpToIndLevelAndUse = + mlir::triton::loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + if (loadOpToIndLevelAndUse.empty()) + return; + + // Calculate the stage distance between applicable loads. + int maxIndirectionLevel = -1; + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) { + maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + } + unsigned stagesBetweenLoads = + ceil(numStages - 2, maxIndirectionLevel + 1); + + tt::CoarseSchedule::Cluster rootUsersCluster = schedule.clusters.newAtFront(); + // Put the root uses of the loads in the last stage. + for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { + // Non-LoadOp(s) are the root uses of all LoadOp(s) and should be + // always present in the opInfo + if (!isa(use)) { + rootUsers.insert(use); + schedule.insert(use, numStages - 1, rootUsersCluster); + } + } + + SmallVector loadsClusters; + for (int i = 0; i < maxIndirectionLevel + 1; i++) { + loadsClusters.push_back(schedule.clusters.newAtBack()); + } + // Assign stages to the loads. + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + schedule.insert(loadOp, stage, loadsClusters[indLevel]); + } +} + +static tt::CoarseSchedule::Cluster +schedulePrologueAndEpilogue(scf::ForOp forOp, tt::CoarseSchedule &schedule, + DenseSet &rootUsers, int numStages) { + // afterPrologue : first cluster curently but we will add a cluster at front + // and a cluster at back + tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + + // Look for the IfOp that is in the backward slice any of the currently + // scheduled ops and put it at the beginning of the loop. + DenseMap ifsToStage; + // Go stage by stage. + for (int stage = 0; stage < numStages; stage++) { + for (auto [op, stage_, cluster] : schedule.getOpsInOrder(forOp)) { + if (stage_ != stage) + continue; + SetVector backwardSlice; + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + getBackwardSlice((Operation *)op, &backwardSlice, opt); + + for (auto op : backwardSlice) { + if (auto ifOp = dyn_cast(op)) { + ifsToStage.insert({ifOp, stage}); + } + } + } + } + tt::CoarseSchedule::Cluster prologueCluster = schedule.clusters.newAtFront(); + for (auto [ifOp, stage] : ifsToStage) { + schedule.insert(ifOp, stage, prologueCluster); + } + // Look for the IfOp that is in the forward slice of the root users and put it + // at the end of the loop. + tt::CoarseSchedule::Cluster epilogueCluster = schedule.clusters.newAtBack(); + for (auto rootUser : rootUsers) { + SetVector forwardSlice; + getForwardSlice(rootUser, &forwardSlice); + + int stage = schedule[rootUser].first; + for (auto op : forwardSlice) { + scf::IfOp ifOp = dyn_cast(op); + if (ifOp == nullptr) { + // check if the op is in the body of an if op that's part of the loop + auto parentOp = op->getParentOp(); + if (parentOp != nullptr && + parentOp->getParentOp() == forOp.getOperation()) { + ifOp = dyn_cast(parentOp); + } + } + if (ifOp) { + schedule.insertIfAbsent(ifOp, stage, + epilogueCluster); // after prefetch extracts + } + } + } + return afterPrologue; +} + +static const char *kLoopScheduleAttrName = "tt.loop_schedule"; +std::string getLoopScheduleOrDefault(scf::ForOp forOp) { + if (!forOp->hasAttr(kLoopScheduleAttrName)) + return "default"; + return (cast(forOp->getAttr(kLoopScheduleAttrName))).str(); +} +// End __FACEBOOK__ CompPipe + +static bool isHeavyComputation(Operation *op) { + // include exp2, mulf, addf 1D. Somehow we don't go through reduction + // when checking dependencies + if (!isa(op) && !isa(op) && + !isa(op)) + return false; + auto tensorTy = dyn_cast(op->getOperand(0).getType()); + if (!tensorTy) + return false; + if (tensorTy.getRank() < 1) + return false; + return true; +} + +// Find all consumer_waits needed for a given dot. Assume we have this sequence +// consumer_wait -> subview -> local_load -> dot +// or +// consumer_wait -> subview -> dot +// with TMA +// wait_barrier -> subview -> trans -> dot +// We assume consumer_wait and subview are right next to each other. We start +// from consumer_wait or wait_barrier, find the subview and check if the subview +// feeds into the dot. +static DenseSet getConsumerWaits(Operation *dot, + scf::ForOp forOp) { + llvm::SmallVector deps; + DenseSet seen; + // Get dependencies of the DotOp, stop when hitting Subview or another Dot + std::function dfs = [&](Operation *op, + Operation *baseOp) { + if (!seen.insert(op).second) + return; + if (op != baseOp && + op->hasTrait()) // do not go through Dots + return; + if (isa(op)) { + deps.push_back(op); + return; + } + if (isa(op) || op->hasTrait()) + deps.push_back(op); + + for (Value operand : op->getOperands()) { + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, baseOp); + } + } + }; + dfs(dot, dot); + DenseSet depSet; + for (auto *op : deps) { + depSet.insert(op); + } + // Go through loop body, check for the sequence. + Operation *currentWait = nullptr; + unsigned seqNum = 0; + DenseSet waits; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (auto wait = dyn_cast(op)) { + currentWait = &op; + seqNum = 1; + continue; + } + if (auto wait = dyn_cast(op)) { + currentWait = &op; + seqNum = 1; + continue; + } + if (currentWait && seqNum == 1) { + if (isa(op)) + continue; + // subview must be next to wait minus some constants + // we should try to associate a barrier with a buffer + if (auto view = dyn_cast(op)) { + seqNum = 2; + if (depSet.count(&op)) + waits.insert(currentWait); + } else { + currentWait = nullptr; + seqNum = 0; + } + continue; + } + } + return waits; +} + +static void +getListOfProducerAcquires(scf::ForOp forOp, + SmallVector &producerAquires) { + auto funcOp = forOp->getParentOfType(); + funcOp.walk([&](scf::ForOp forOp) { + auto taskArr = mlir::getAsyncTaskIds(forOp); + if (taskArr.size() == 1 && taskArr[0] == 0) { + // Producer warp group ForOp. + forOp.walk([&](Operation *op) { + if (isa(op)) + producerAquires.push_back(op); + }); + } + }); +} + +// FIXME: need to know the corresponding wait/release for a given load. +static Operation * +getConsumerReleaseForWait(Operation *wait, scf::ForOp forOp, + SmallVector &producerAquires, + bool firstLoad) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (auto release = dyn_cast(op)) { + if (isa(wait)) { + // TMA case, only match with producerAquires (1st operand). + // For data partitioning, 4 tokens inside the loop. First 2 + // producerAcquires correspond to firstLoad (loadK). Last 2 correspond + // to secondLoad (loadV). + assert(producerAquires.size() == 4); + if (release->getOperand(0) == + producerAquires[firstLoad ? 0 : 2]->getOperand(0)) + return release; + if (release->getOperand(0) == + producerAquires[firstLoad ? 1 : 3]->getOperand(0)) + return release; + continue; + } + bool isMatch = true; + unsigned i = 0; + for (Value operand : wait->getOperands()) { + if (i >= release->getNumOperands()) + break; + if (operand != release->getOperand(i)) { + isMatch = false; + break; + } + ++i; + } + if (isMatch) + return release; + } + } + return nullptr; +} + +#define GEN_PASS_DEF_TRITONGPULOOPSCHEDULING +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPULoopSchedulingPass + : public impl::TritonGPULoopSchedulingBase { +public: + using impl::TritonGPULoopSchedulingBase< + TritonGPULoopSchedulingPass>::TritonGPULoopSchedulingBase; + + // Begin __FACEBOOK__ CompPipe + int getNumStagesOrDefault(scf::ForOp forOp) { + // Use the attribute attached to the loop if it exists otherwise use the + // global control. + if (!forOp->hasAttr(mlir::triton::kNumStagesAttrName)) + return numStages; + return mlir::cast( + forOp->getAttr(mlir::triton::kNumStagesAttrName)) + .getInt(); + } + + tt::CoarseSchedule::Cluster + getDefaultLoopSchedule(scf::ForOp forOp, tt::CoarseSchedule &schedule, + int numStages) { + DenseSet rootUsers; + scheduleLoads(forOp, schedule, rootUsers, numStages); + return schedulePrologueAndEpilogue(forOp, schedule, rootUsers, numStages); + } + + // Check for warp spec consumer group. Assume two dots. + bool + isFlashAttention(scf::ForOp forOp, + llvm::SmallVector> + &loadOpToIndLevelAndUse, + SmallVector &keyOps, + DenseSet &heavyCompOps) { + SmallVector loads; + SmallVector dots; + for (Operation &op : forOp.getBody()->without_terminator()) { + // Check for loop-carried dependencies. + // We have two loadOps, one feeding the first dot, and the other feeding + // the second dot. + if (isa(op)) { + loads.push_back(&op); + } + if (op.hasTrait()) { + dots.push_back(&op); + } + } + // Check for async_task_id. + auto taskArr = mlir::getAsyncTaskIds(forOp); + bool isConsumerWG = taskArr.size() != 1 ? false : taskArr[0] != 0; + if (dots.size() != 2 || (loads.size() != 2 && !isConsumerWG)) + return false; + + Operation *secondDot = dots[1]; + DenseSet seen; + DenseSet tracedDots; + // Make sure there is a dependency path from firstDot to secondDot. + // This means we need to do computation pipelining to break the dependency. + std::function dfs = [&](Operation *op) { + if (!seen.insert(op).second) + return; + for (Value operand : op->getOperands()) { + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + if (defOp->hasTrait()) { + // Stop tracing when hitting a dot. + tracedDots.insert(defOp); + } else { + if (isHeavyComputation(defOp)) + heavyCompOps.insert(defOp); + dfs(defOp); + } + } + } + }; + dfs(secondDot); + if (tracedDots.size() != 1) + return false; + + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) { + if (dist != 0) + return false; + } + + keyOps.push_back(loads.size() == 0 ? nullptr : loads[0]); // FIXME + keyOps.push_back(loads.size() == 0 ? nullptr : loads[1]); + keyOps.push_back(dots[0]); + keyOps.push_back(secondDot); + return true; + } + + tt::CoarseSchedule::Cluster + getFAFirstDotSchedule(scf::ForOp forOp, tt::CoarseSchedule &schedule, + int numStages) { + llvm::SmallVector> + loadOpToIndLevelAndUse = + mlir::triton::loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + // if (loadOpToIndLevelAndUse.empty()) + // return schedule.clusters.begin(); + + // Check to see if the for loop matches the pattern for flash attention. + // If yes, move the first dot to its own stage (numStages - 2), the + // rest of the computation will be in stage (numStages - 1). The two loads + // will be in stage 0 and 1. + SmallVector keyOps; + DenseSet heavyCompOps; + if (!isFlashAttention(forOp, loadOpToIndLevelAndUse, keyOps, + heavyCompOps)) { + LDBG("isFlashAttention returns false"); + return schedule.clusters.begin(); + } + + // firstLoad: keyOps[0] + tt::CoarseSchedule::Cluster rootUsersCluster = + schedule.clusters.newAtFront(); + tt::CoarseSchedule::Cluster loadCluster = schedule.clusters.newAtBack(); + bool isConsumerWG = keyOps[0] == nullptr; + if (!isConsumerWG) { + schedule.insert(keyOps[0], 0, loadCluster); + schedule.insert(keyOps[1], 1, loadCluster); + } else { + // Check producer warp group to get the list of ProducerAcquires (assume + // they are in order matching firstLoad and secondLoad). Then match + // ConsumerReleases with them. With TMA, align consumerRleases with + // consumerWaits, assuming consumerWaits happen in order matching + // firstLoad and secondLoad. + SmallVector producerAquires; + getListOfProducerAcquires(forOp, producerAquires); + // dependency from consumer_wait to subview, then to consumer_release + // Assume this group of ops: consumer_wait, subview, local_load. Find the + // corresponding consumer_release for the consumer_wait by checking the + // operands. The local_load needed by firstDot will be in the same stage + // cluseter as firstDot. + DenseSet ConsumerWaitsForDot1 = + getConsumerWaits(keyOps[2], forOp); + for (auto *op : ConsumerWaitsForDot1) { + schedule.insert(op, isConsumerWG ? 0 : numStages - 2, rootUsersCluster); + Operation *consumerRelease = + getConsumerReleaseForWait(op, forOp, producerAquires, true); + schedule.insert(consumerRelease, isConsumerWG ? 0 : numStages - 2, + rootUsersCluster); + LLVM_DEBUG({ + LDBG("firstDot wait "); + op->dump(); + LDBG("firstDot release "); + consumerRelease->dump(); + }); + } + DenseSet ConsumerWaitsForDot2 = + getConsumerWaits(keyOps[3], forOp); + for (auto *op : ConsumerWaitsForDot2) { + schedule.insert(op, numStages - 1, rootUsersCluster); + Operation *consumerRelease = + getConsumerReleaseForWait(op, forOp, producerAquires, false); + schedule.insert(consumerRelease, numStages - 1, rootUsersCluster); + LLVM_DEBUG({ + LDBG("secondDot wait "); + op->dump(); + LDBG("secondDot release "); + consumerRelease->dump(); + }); + } + } + schedule.insert(keyOps[2], isConsumerWG ? 0 : numStages - 2, + rootUsersCluster); + schedule.insert(keyOps[3], numStages - 1, rootUsersCluster); + return schedule.clusters.begin(); + } + + tt::CoarseSchedule::Cluster + getFASecondDotSchedule(scf::ForOp forOp, tt::CoarseSchedule &schedule, + int numStages) { + llvm::SmallVector> + loadOpToIndLevelAndUse = + mlir::triton::loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + // if (loadOpToIndLevelAndUse.empty()) + // return schedule.clusters.begin(); + + // Check to see if the for loop matches the pattern for flash attention. + // If yes, move the second dot to its own stage (numStages - 1), the + // rest of the computation will be in stage (numStages - 2). The two loads + // will be in stage 0 and 1. + SmallVector keyOps; + DenseSet heavyCompOps; + if (!isFlashAttention(forOp, loadOpToIndLevelAndUse, keyOps, + heavyCompOps)) { + LDBG("isFlashAttention returns false"); + return schedule.clusters.begin(); + } + // Go through loop body + for (Operation &op : forOp.getBody()->without_terminator()) { + if (isHeavyComputation(&op)) + heavyCompOps.insert(&op); + } + // keyOps: load0, load1, dot0, dot1 + // Dot0(i+1) + // Dot1(i) + // Softmax(i+1): includes MUL0(i+1) + // MUL1(i+1) + tt::CoarseSchedule::Cluster rootUsersCluster = + schedule.clusters.newAtFront(); + tt::CoarseSchedule::Cluster nextCluster = schedule.clusters.newAtBack(); + tt::CoarseSchedule::Cluster nextNextCluster = schedule.clusters.newAtBack(); + tt::CoarseSchedule::Cluster loadCluster = schedule.clusters.newAtBack(); + bool isConsumerWG = keyOps[0] == nullptr; + if (!isConsumerWG) { + schedule.insert(keyOps[0], 0, loadCluster); + schedule.insert(keyOps[1], 1, loadCluster); + } else { + SmallVector producerAquires; + getListOfProducerAcquires(forOp, producerAquires); + + DenseSet ConsumerWaitsForDot1 = + getConsumerWaits(keyOps[2], forOp); + for (auto *op : ConsumerWaitsForDot1) { + schedule.insert(op, isConsumerWG ? 0 : numStages - 2, rootUsersCluster); + Operation *consumerRelease = + getConsumerReleaseForWait(op, forOp, producerAquires, true); + assert(consumerRelease); + schedule.insert(consumerRelease, isConsumerWG ? 0 : numStages - 2, + rootUsersCluster); + LLVM_DEBUG({ + LDBG("firstDot wait "); + op->dump(); + LDBG("firstDot release "); + consumerRelease->dump(); + }); + } + DenseSet ConsumerWaitsForDot2 = + getConsumerWaits(keyOps[3], forOp); + for (auto *op : ConsumerWaitsForDot2) { + schedule.insert(op, numStages - 1, nextCluster); + Operation *consumerRelease = + getConsumerReleaseForWait(op, forOp, producerAquires, false); + schedule.insert(consumerRelease, numStages - 1, nextCluster); + LLVM_DEBUG({ + LDBG("secondDot wait "); + op->dump(); + LDBG("secondDot release "); + consumerRelease->dump(); + }); + } + } + schedule.insert(keyOps[2], isConsumerWG ? 0 : numStages - 2, + rootUsersCluster); + schedule.insert(keyOps[3], numStages - 1, nextCluster); + // Softmax(i+1), MUL1(i+1) in nextNextCluster + for (auto *heavyOp : heavyCompOps) + schedule.insert(heavyOp, isConsumerWG ? 0 : numStages - 2, + nextNextCluster); + return schedule.clusters.begin(); + } + // End __FACEBOOK__ CompPipe + + void runOnOperation() override { + // Begin __FACEBOOK__ CompPipe + SmallVector loops; + getOperation()->walk([&](scf::ForOp forOp) { + // Bail out for loops with num_stage <= 1 or loop without loop_schedule + if (getNumStagesOrDefault(forOp) > 1 && + forOp->hasAttr(kLoopScheduleAttrName)) + loops.push_back(forOp); + }); + + if (loops.empty()) + return; + for (scf::ForOp forOp : loops) { + int loopNumStages = getNumStagesOrDefault(forOp); + tt::CoarseSchedule coarseSchedule(loopNumStages); + tt::CoarseSchedule::Cluster afterPrologue; + + std::string loopSchedule = getLoopScheduleOrDefault(forOp); + if (loopSchedule == "default") { + afterPrologue = + getDefaultLoopSchedule(forOp, coarseSchedule, loopNumStages); + } else if (loopSchedule == "FA_firstDot") { + afterPrologue = + getFAFirstDotSchedule(forOp, coarseSchedule, loopNumStages); + } else if (loopSchedule == "FA_secondDot") { + afterPrologue = + getFASecondDotSchedule(forOp, coarseSchedule, loopNumStages); + } else { + assert(false && "unrecognized loop schedule"); + } + // Go through schedule and assign (stage, cluster). + // shift so afterPrologue will be at clusterId 0 + auto ctx = forOp.getContext(); + for (auto [op, stage_, cluster] : coarseSchedule.getOpsInOrder(forOp)) { + op->setAttr("loop.stage", + IntegerAttr::get(IntegerType::get(ctx, 32), stage_)); + op->setAttr("loop.cluster", + IntegerAttr::get(IntegerType::get(ctx, 32), + *cluster - *afterPrologue)); + LLVM_DEBUG({ + LDBG("set stage " << stage_ << " cluster " << (*cluster)); + op->dump(); + }); + } + } + // End __FACEBOOK__ CompPipe + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/PingPong.cpp b/lib/Dialect/TritonGPU/Transforms/PingPong.cpp new file mode 100644 index 000000000..757eba67e --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/PingPong.cpp @@ -0,0 +1,186 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include + +#define DEBUG_TYPE "triton-ping-pong-sync" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace tt = mlir::triton; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +// Returns the taskId if op has a single taskId, otherwise, returns -1. +static int getSingleTaskId(Operation *op) { + if (!op->hasAttr("async_task_id")) + return -1; + auto taskArray = op->getAttrOfType("async_task_id"); + if (taskArray.getValues().size() > 1) + return -1; + return taskArray.getValues()[0]; +} + +// Treat exp2, mulf, addf, reduce as expensive computation when data type is +// a tensor type of 1D or higher. +static bool isExpensiveComp(Operation *op) { + if (!isa(op) && !isa(op) && + !isa(op) && !isa(op)) + return false; + auto tensorTy = dyn_cast(op->getOperand(0).getType()); + return tensorTy && tensorTy.getRank() >= 1; +} + +static Value createGetAsyncTaskId(OpBuilder &builder, Operation *op) { + auto loc = op->getLoc(); + return builder.create(loc); +} + +#define GEN_PASS_DEF_TRITONGPUPINGPONGSYNC +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUPingPongSyncPass + : public impl::TritonGPUPingPongSyncBase { +public: + using impl::TritonGPUPingPongSyncBase< + TritonGPUPingPongSyncPass>::TritonGPUPingPongSyncBase; + + enum class ResourceType { + Gemm, + OtherComp, + }; + + void getNestedFor(scf::IfOp ifOp, SmallVector &loops) { + ifOp->walk([&](scf::ForOp forOp) { loops.push_back(forOp); }); + } + void runOnFuncOp(triton::FuncOp funcOp) { + // Insert sync points in ForOp for consumer warp groups. Enable this pass + // when number of consumer warp groups == 2. + if (numConsumerGroups != 2) + return; + + SmallVector loops; + // Identify ForOps for consumer warp groups. Here we assume taskId 0 is for + // producer. This pass handles the case of a single forOp for two consumer + // warp groups. + getOperation()->walk([&](scf::IfOp ifOp) { + int wgId = getSingleTaskId(ifOp); + // Assume taskId 0 is for producer. + if (wgId == 1 || wgId == 2) { + getNestedFor(ifOp, loops); + } + }); + + if (!mlir::triton::tools::getBoolEnv("ENABLE_PINGPONG")) + return; + if (loops.size() != 1) + return; + + Operation *startOfGemm = nullptr; + Operation *endOfGemm = nullptr; + // FIXME: only handle the first loop. + auto forOp = loops[0]; + OpBuilder builder(forOp); + // A simple heuristic for now: + // Mark the start of a gemm section when hitting a DotLike op. + // Mark the end of a gemm section once hitting a expensive cuda op. + for (auto &op : forOp.getBody()->without_terminator()) { + if (startOfGemm && endOfGemm) + break; + bool isCudaCore = isExpensiveComp(&op); + if (op.hasTrait() && !isCudaCore && + startOfGemm == nullptr) { + startOfGemm = &op; + continue; + } + if (!op.hasTrait() && isCudaCore && startOfGemm) { + endOfGemm = &op; + break; + } + } + if (!startOfGemm || !endOfGemm) + return; + + LLVM_DEBUG({ + LDBG("found start of tensor core ops"); + startOfGemm->dump(); + }); + LLVM_DEBUG({ + LDBG("found end of tensor core ops"); + endOfGemm->dump(); + }); + + // FIXME: hard-code using named barrier 9 and 10 in this pass. + // Prior to the forOp, add "bar.arrive 9, 256" only when task Id is 2. + // At startOfGemm, insert "bar.sync 8+taskId, 256" + // At endOfGemm, insert "bar.arrive 11-taskId, 256" + builder.setInsertionPoint(forOp); + auto forLoc = forOp->getLoc(); + + // FIXME: hard-code total number of threads to be 256 when numConsumerGroups + // is 2. + Value numThreads = builder.create(forLoc, 256, 32); + Value c_9 = builder.create(forLoc, 9, 32); + + // "bar.arrive 9, 256" only when task Id is 2. + Value c_2 = builder.create(forLoc, 2, 32); + Value curTaskId = createGetAsyncTaskId(builder, forOp); + auto pred = builder.create(forLoc, arith::CmpIPredicate::eq, + curTaskId, c_2); + auto ifOp = builder.create(forLoc, pred, /*else=*/false); + builder.setInsertionPoint(ifOp.thenYield()); + builder.create(forLoc, c_9, numThreads); + + // At startOfGemm, insert "bar.sync 8+taskId, 256" + // 8 + taskId: 9 for taskId 1 and 10 for taskId 2. + builder.setInsertionPoint(startOfGemm); + auto loc = startOfGemm->getLoc(); + Value c_8 = builder.create(loc, 8, 32); + Value syncBarrier = builder.create(loc, c_8, curTaskId); + builder.create(loc, syncBarrier, numThreads); + + // At endOfGemm, insert "bar.arrive 11-taskId, 256" + // 11 - taskId: 10 for taskId 1 and 9 for taskId2. + builder.setInsertionPoint(endOfGemm); + auto loc2 = endOfGemm->getLoc(); + Value c_11 = builder.create(loc2, 11, 32); + Value arriveBarrier = builder.create(loc2, c_11, curTaskId); + builder.create(loc2, arriveBarrier, numThreads); + } + + void runOnOperation() override { + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + LLVM_DEBUG({ + LDBG("post pass"); + getOperation()->dump(); + }); + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 5cc537d5f..54cdf22a1 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -16,6 +16,8 @@ #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" @@ -56,7 +58,8 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, tt::CoarseSchedule &schedule, tt::CoarseSchedule::Cluster prefetchCluster, llvm::MapVector &loadToInfo, - int numStages) { + int numStages, + DenseMap &TMAUserToWait) { OpBuilder builder(forOp); Value zero = builder.create(forOp.getLoc(), 0, 32); // Replace the load with insert/extract slice. @@ -113,6 +116,7 @@ static void createAsyncCopy(scf::ForOp &forOp, tt::LoadOp loadOp, Value alloc, loadOffsets[0] = extractIdx; auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); + TMAUserToWait[viewLoad] = wait; // viewLoad will depend on barrierWait if (isMMV3Load) { auto alloc = cast((*loadOp->getUsers().begin())); replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); @@ -157,7 +161,8 @@ static void createTMAAsyncCopy( scf::ForOp &forOp, tt::ExperimentalDescriptorLoadOp loadOp, Value alloc, Value insertIdx, Value extractIdx, Value barrier, Operation *waitOp, Value phase, tt::CoarseSchedule &schedule, - llvm::MapVector &loadToInfo, int numStages) { + llvm::MapVector &loadToInfo, int numStages, + DenseMap &TMAUserToWait) { assert(phase && "Phase value is required for TMA async copy."); OpBuilder builder(forOp); Attribute sharedMemorySpace = @@ -189,6 +194,7 @@ static void createTMAAsyncCopy( loadOffsets[0] = extractIdx; auto viewLoad = builder.create(loc, subviewTy, alloc, loadOffsets); + TMAUserToWait[viewLoad] = waitOp; // viewLoad will depend on barrierWait if (isMMV3Load) { auto alloc = cast((*loadOp->getUsers().begin())); replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); @@ -590,6 +596,156 @@ scheduleLoads(scf::ForOp forOp, tt::CoarseSchedule &schedule, return loadToInfo; } +// Begin __FACEBOOK__ CompPipe +static bool loopHasSchedule(scf::ForOp forOp) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (op.hasAttr("loop.stage") && op.hasAttr("loop.cluster")) { + return true; + } + } + return false; +} + +static tt::CoarseSchedule::Cluster +getLoopSchedule(scf::ForOp forOp, tt::CoarseSchedule &schedule, + /*DenseSet &rootUsers,*/ int numStages, + llvm::MapVector &loadToInfo) { + ModuleOp moduleOp = forOp->getParentOfType(); + tt::ModuleAxisInfoAnalysis axisInfoAnalysis(moduleOp); + tt::CoarseSchedule::Cluster afterPrologue = schedule.clusters.begin(); + auto taskArr = mlir::getAsyncTaskIds(forOp); + // We either have a single task Id with a merged IfOp for all consumers + // or we have one task Id for each IfOp per consumer. + // We should not see a list of task Ids here. + bool isConsumerWG = taskArr.size() != 1 ? false : taskArr[0] != 0; + + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + llvm::SmallVector> + loadOpToIndLevelAndUse = + mlir::triton::loadOpsToIndirectionLevelAndUse(forOp); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + bool dataPartition = mlir::triton::tools::getBoolEnv("SWP_FOR_CONSUMER"); + // When there are no load operations, continue computation pipelining if + // dataPartition is true and isConsumerWG is true. Early exit otherwise. + if (!(isConsumerWG && dataPartition) && loadOpToIndLevelAndUse.empty()) + return afterPrologue; + + // Check which loads are good for pipelining, and assign them + // memory layouts. + llvm::MapVector loadToInfoT = + assignMemoryLayouts(loadOpToIndLevelAndUse, axisInfoAnalysis); + loadToInfo = loadToInfoT; + + if (!(isConsumerWG && dataPartition) && loadToInfo.empty()) + return afterPrologue; + + // reconstrcut schedule from annotations of (stage, cluster) + int maxClusterId = 0, minClusterId = 0; + bool hasSchedule = false; + for (Operation &op : forOp.getBody()->without_terminator()) { + if (op.hasAttr("loop.stage") && op.hasAttr("loop.cluster")) { + auto clusterId = cast(op.getAttr("loop.cluster")) + .getValue() + .getSExtValue(); + LLVM_DEBUG({ + LDBG("saw cluster " << clusterId); + op.dump(); + }); + if (!hasSchedule) { + minClusterId = clusterId; + maxClusterId = clusterId; + hasSchedule = true; + continue; + } + minClusterId = (clusterId < minClusterId) ? clusterId : minClusterId; + maxClusterId = (clusterId > maxClusterId) ? clusterId : maxClusterId; + } + } + assert(hasSchedule); + LDBG("minCluster " << minClusterId << " max " << maxClusterId); + DenseMap clusters; + for (int i = minClusterId; i < maxClusterId + 1; i++) { + clusters.insert({i, schedule.clusters.newAtBack()}); + } + // afterPrologue should be the first cluster after ifOps? + for (Operation &op : forOp.getBody()->without_terminator()) { + if (op.hasAttr("loop.stage") && op.hasAttr("loop.cluster")) { + auto stage = + cast(op.getAttr("loop.stage")).getValue().getZExtValue(); + auto clusterId = cast(op.getAttr("loop.cluster")) + .getValue() + .getSExtValue(); + schedule.insert(&op, stage, clusters[clusterId]); + LLVM_DEBUG({ + LDBG("insert stage " << stage << " cluster " << clusterId << " " + << *clusters[clusterId]); + op.dump(); + }); + } + } + + // Distance from the load to the use. This needs to be re-worked. + if (forOp->hasAttr(tt::kNumStagesAttrName)) { + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + loadToInfo[loadOp].distToUse = + schedule[use].first - schedule[loadOp].first; + } + return clusters[0]; + } + + // If there is a use chain of load -> dot -> dot, we can ignore the second dot + // here. + // Start from loadOp, check uses and stop the recursion when hitting a dot. + DenseSet seen; + llvm::SmallVector> loadOpToDirectUses; + std::function dfsUse = + [&](Operation *op, Operation *use) { + if (!seen.insert(use).second) + return; + if (use->hasTrait()) { + loadOpToDirectUses.push_back(std::make_tuple(op, use)); + return; + } + for (auto &tUse : use->getUses()) { + Operation *useOp = tUse.getOwner(); + if (useOp && useOp->getBlock() == op->getBlock()) { + dfsUse(op, useOp); + } + } + }; + DenseSet loadOps; + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(loadOp) == 0) + continue; + if (!loadOps.insert(loadOp).second) + continue; + seen.clear(); + dfsUse(loadOp, loadOp); + } + for (auto [loadOp, use] : loadOpToDirectUses) { + LLVM_DEBUG({ + LDBG("loadOpToDirectUses " << schedule[use].first << " " + << schedule[loadOp].first); + loadOp->dump(); + use->dump(); + }); + loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; + } + + return clusters[0]; +} +// End __FACEBOOK__ CompPipe + // Schedule the prologue and epilogue `if` ops in the loop, pushing them as // close to the loop boundaries as possible. Return the cluster after the // prologue (or the beginning of the loop if there is no prologue). @@ -652,8 +808,10 @@ schedulePrologueAndEpilogue(scf::ForOp forOp, tt::CoarseSchedule &schedule, // Add dependencies of anchor ops to the coarse schedule. Schedule them to // the same stage and ordering cluster as the anchor op. -static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, - int numStages) { +static void +scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, + int numStages, + DenseMap &TMAUserToWait) { SmallVector> opsInOrder = schedule.getOpsInOrder(forOp); // Schedule dependencies stage by stage. @@ -661,7 +819,7 @@ static void scheduleDependencies(scf::ForOp forOp, tt::CoarseSchedule &schedule, for (auto [op, stage_, cluster] : opsInOrder) { if (stage_ != stage) continue; - schedule.insertDepsOfOp(op, stage, cluster, false); + schedule.insertDepsOfOp(op, stage, cluster, false, &TMAUserToWait); } } } @@ -818,7 +976,7 @@ struct AsyncLoad { }; // Create barriers and wait ops for the async loads. Barriers may be shared by -// multiple loads is the schedule allows it. +// multiple loads if the schedule allows it. static void createTMABarrierAndWait( scf::ForOp &forOp, SmallVector &asyncLoads, Value insertIdx, Value extractIdx, Value phase, int numBuffers, tt::CoarseSchedule &schedule, @@ -905,7 +1063,7 @@ static void createTMABarrierAndWait( Value pred = builder.create(loc, 1, 1); Operation *expect = builder.create( forOp.getLoc(), barrier, sizeInBytes, pred); - auto [stage, cluster] = schedule[asyncLoads[0].loadOp]; + auto [stage, cluster] = schedule[group[0]->loadOp]; schedule.insert(expect, stage, cluster); builder.setInsertionPointAfter(group.back()->loadOp); @@ -926,7 +1084,8 @@ static void createTMABarrierAndWait( static SmallVector createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, llvm::MapVector &loadToInfo, - SmallVector &barriers, int numStages) { + SmallVector &barriers, int numStages, + DenseMap &TMAUserToWait) { // Calculate the number of buffers needed for each load. // TODO pawel: we could do more fine-grained allocation here and // allocate only the number of buffers that specific loads need. @@ -1017,12 +1176,13 @@ createAsyncOps(scf::ForOp &forOp, tt::CoarseSchedule &schedule, for (AsyncLoad &asyncLoad : asyncLoads) { if (auto loadOp = dyn_cast(asyncLoad.loadOp)) { createAsyncCopy(forOp, loadOp, asyncLoad.alloc, insertIdx, extractIdx, - schedule, prefetchCluster, loadToInfo, numStages); + schedule, prefetchCluster, loadToInfo, numStages, + TMAUserToWait); } else { auto descLoad = cast(asyncLoad.loadOp); createTMAAsyncCopy(forOp, descLoad, asyncLoad.alloc, insertIdx, extractIdx, asyncLoad.barrier, asyncLoad.waitOp, phase, - schedule, loadToInfo, numStages); + schedule, loadToInfo, numStages, TMAUserToWait); } } SmallVector newYieldOperands = {insertIdx, extractIdx}; @@ -1060,9 +1220,24 @@ bool mlir::triton::preProcessLoopAndGetSchedule( // a scaffold for the final schedule. DenseSet rootUsers; tt::CoarseSchedule coarseSchedule(numStages); - llvm::MapVector loadToInfo = - scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); - if (loadToInfo.empty()) + // Begin __FACEBOOK__ CompPipe + bool hasSchedule = loopHasSchedule(forOp); + llvm::MapVector loadToInfo; + tt::CoarseSchedule::Cluster afterPrologue; + if (!hasSchedule) { + loadToInfo = scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); + } else { + afterPrologue = getLoopSchedule(forOp, coarseSchedule, + /*rootUsers,*/ numStages, loadToInfo); + } + // vanilla + // llvm::MapVector loadToInfo = + // scheduleLoads(forOp, coarseSchedule, rootUsers, numStages); + // End __FACEBOOK__ CompPipe + auto taskArr = mlir::getAsyncTaskIds(forOp); + bool isConsumerWG = taskArr.size() != 1 ? false : taskArr[0] != 0; + bool dataPartition = mlir::triton::tools::getBoolEnv("SWP_FOR_CONSUMER"); + if (!(isConsumerWG && dataPartition) && loadToInfo.empty()) return false; LLVM_DEBUG({ @@ -1070,24 +1245,33 @@ bool mlir::triton::preProcessLoopAndGetSchedule( coarseSchedule.dump(); }); - tt::CoarseSchedule::Cluster afterPrologue = - schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, numStages); + // Begin __FACEBOOK__ CompPipe + if (!hasSchedule) { + afterPrologue = schedulePrologueAndEpilogue(forOp, coarseSchedule, + rootUsers, numStages); + } + // vanilla + // tt::CoarseSchedule::Cluster afterPrologue = + // schedulePrologueAndEpilogue(forOp, coarseSchedule, rootUsers, + // numStages); + // End __FACEBOOK__ CompPipe LLVM_DEBUG({ LDBG("Coarse schedule with prologue and epilogue:"); coarseSchedule.dump(); }); SmallVector barriers; + DenseMap TMAUserToWait; // Convert the loads into async loads and create the allocs. - SmallVector allocs = - createAsyncOps(forOp, coarseSchedule, loadToInfo, barriers, numStages); + SmallVector allocs = createAsyncOps( + forOp, coarseSchedule, loadToInfo, barriers, numStages, TMAUserToWait); LLVM_DEBUG({ LDBG("Coarse schedule with async loads:"); coarseSchedule.dump(); }); - scheduleDependencies(forOp, coarseSchedule, numStages); + scheduleDependencies(forOp, coarseSchedule, numStages, TMAUserToWait); LLVM_DEBUG({ LDBG("Coarse schedule with dependencies:"); coarseSchedule.dump(); @@ -1116,7 +1300,10 @@ bool mlir::triton::preProcessLoopAndGetSchedule( std::vector> &s) { s = std::move(schedule); }; - options.peelEpilogue = false; + bool hasLoopSchedule = forOp->hasAttr("tt.loop_schedule"); + bool PeelLastIter = + ::triton::tools::getBoolEnv("PEEL_LAST_ITER") && hasLoopSchedule; + options.peelEpilogue = PeelLastIter ? true : false; options.predicateFn = tt::predicateOp; options.supportDynamicLoops = true; options.annotateFn = [](Operation *op, @@ -1467,6 +1654,10 @@ static std::optional dotCanBeProperlyAsync(ttng::WarpGroupDotOp dotOp, return iterArgIdx; } + if (::triton::tools::getBoolEnv("HACK_ASYNC_DOT")) { + return iterArgIdx; + } + // Rule 3b: Are all users of the dot's result from iteration i-1 after the // first `warp_group_dot_wait {pendings=0}` op? If so, the dot can be // properly async, but we have to thread its result from iteration i-1 through diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp index 2f186e3c5..6be8745e5 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipelineExpander.cpp @@ -32,6 +32,7 @@ #include "llvm/Support/MathExtras.h" #include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Tools/Sys/GetEnv.hpp" #define DEBUG_TYPE "triton-loop-pipelining" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") @@ -452,8 +453,10 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop( Type t = ub.getType(); Location loc = forOp.getLoc(); // newUb = ub - maxStage * step + // peel last iteration of ops, newUb = ub - step + bool PeelLastIter = ::triton::tools::getBoolEnv("PEEL_LAST_ITER") && peelEpilogue; Value maxStageValue = rewriter.create( - loc, rewriter.getIntegerAttr(t, maxStage)); + loc, rewriter.getIntegerAttr(t, PeelLastIter ? 1 : maxStage)); Value maxStageByStep = rewriter.create(loc, step, maxStageValue); newUb = rewriter.create(loc, ub, maxStageByStep); @@ -461,6 +464,7 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop( auto newForOp = rewriter.create(forOp.getLoc(), forOp.getLowerBound(), newUb, forOp.getStep(), newLoopArg); + newForOp->setAttrs(forOp->getAttrs()); // When there are no iter args, the loop body terminator will be created. // Since we always create it below, remove the terminator if it was created. if (!newForOp.getBody()->empty()) @@ -485,11 +489,16 @@ LogicalResult LoopPipelinerInternal::createKernel( mapping.map(arg.value(), newForOp.getRegionIterArgs()[arg.index()]); } SmallVector predicates(maxStage + 1, nullptr); - if (!peelEpilogue) { + bool PeelLastIter = ::triton::tools::getBoolEnv("PEEL_LAST_ITER") && peelEpilogue; + if (!peelEpilogue || PeelLastIter) { // Create a predicate for each stage except the last stage. Location loc = newForOp.getLoc(); Type t = ub.getType(); - for (unsigned i = 0; i < maxStage; i++) { + // predicates[i] = indVar < c = indVar < ub - (maxStage - i) * step + // if peeling last iteration only, S2 should always be executed. + // only create predicates for S0 to S1 + int iEnd = PeelLastIter ? maxStage - 1 : maxStage; + for (unsigned i = 0; i < iEnd; i++) { // c = ub - (maxStage - i) * step Value c = rewriter.create( loc, ub, @@ -619,12 +628,29 @@ LogicalResult LoopPipelinerInternal::createKernel( // If there is a live range spanning across more than 2 stages we need to // add extra arg. for (unsigned i = 1; i < numVersionReturned; i++) { + LLVM_DEBUG({ + llvm::dbgs() << "set valueMapping3: version " << version + << " lastUseStage " << it.second.lastUseStage + << " defStage " << it.second.defStage << " "; + it.first.dump(); + }); setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), version++); yieldOperands.push_back( newForOp.getBody()->getArguments()[yieldOperands.size() + 1 + newForOp.getNumInductionVars()]); } + // Map [key, version] to result of newForOp. + if (PeelLastIter && it.second.lastUseStage == maxStage) { + // we only need version maxStage for ops in stage maxStage + version += maxStage - 1; // loop body contains the first epilogue + } + LLVM_DEBUG({ + llvm::dbgs() << "set valueMapping: version " << version + << " lastUseStage " << it.second.lastUseStage << " defStage " + << it.second.defStage << " "; + it.first.dump(); + }); setValueMapping(it.first, newForOp->getResult(yieldOperands.size()), version++); yieldOperands.push_back(mapping.lookupOrDefault(it.first)); @@ -640,10 +666,14 @@ LogicalResult LoopPipelinerInternal::createKernel( for (unsigned int stage = 1; stage <= maxStage; stage++) setValueMapping(forOp.getRegionIterArgs()[retVal.index()], retVal.value(), stage); - } else if (defStage->second > 0) { + } else if (defStage->second > 0 && + (!PeelLastIter || defStage->second > maxStage - 1)) { + // If PeelLastIter is false, no change. If it is true, only enter when + // defStage->second is bigger than 1. setValueMapping(forOp.getRegionIterArgs()[retVal.index()], newForOp->getResult(retVal.index()), - maxStage - defStage->second + 1); + maxStage - defStage->second + 1 + + (PeelLastIter ? maxStage - 1 : 0)); } } rewriter.create(forOp.getLoc(), yieldOperands); @@ -693,17 +723,33 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter, // Emit `maxStage - 1` epilogue part that includes operations from stages // [i; maxStage]. - for (int64_t i = 1; i <= maxStage; i++) { + bool PeelLastIter = ::triton::tools::getBoolEnv("PEEL_LAST_ITER") && peelEpilogue; + for (int64_t i = PeelLastIter ? maxStage : 1; i <= maxStage; i++) { SmallVector> returnMap(returnValues.size()); for (Operation *op : opOrder) { if (stages[op] < i) continue; + LLVM_DEBUG({ + llvm::errs() << "clone "; + op->dump(); + }); unsigned currentVersion = maxStage - stages[op] + i; unsigned nextVersion = currentVersion + 1; Operation *newOp = cloneAndUpdateOperands(rewriter, op, [&](OpOperand *newOperand) { auto it = valueMapping.find(newOperand->get()); if (it != valueMapping.end()) { + LLVM_DEBUG({ + llvm::errs() << "find valueMapping: version " + << (maxStage - stages[op] + i) << " "; + newOperand->get().dump(); + unsigned tmp = 0; + for (auto v : it->second) { + llvm::errs() << "idx " << tmp << ": "; + v.dump(); + ++tmp; + } + }); Value replacement = it->second[currentVersion]; newOperand->set(replacement); } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp index 1a3162f17..eabb5fe7c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/PipeliningUtility.cpp @@ -81,6 +81,37 @@ Operation *mlir::triton::predicateOp(RewriterBase &rewriter, Operation *op, return op; } + if (isa(op)) + return op; + if (auto wait = dyn_cast(op)) { + rewriter.setInsertionPoint(wait); + auto ifOp = + rewriter.create(wait->getLoc(), pred, /*else=*/false); + rewriter.moveOpBefore(wait, ifOp.thenBlock(), ifOp.thenBlock()->begin()); + return ifOp; + } + if (auto wait = dyn_cast(op)) { + rewriter.setInsertionPoint(wait); + auto ifOp = + rewriter.create(wait->getLoc(), pred, /*else=*/false); + rewriter.moveOpBefore(wait, ifOp.thenBlock(), ifOp.thenBlock()->begin()); + return ifOp; + } + if (auto release = dyn_cast(op)) { + rewriter.setInsertionPoint(release); + auto ifOp = + rewriter.create(release->getLoc(), pred, /*else=*/false); + rewriter.moveOpBefore(release, ifOp.thenBlock(), ifOp.thenBlock()->begin()); + return ifOp; + } + if (auto arrive = dyn_cast(op)) { + rewriter.setInsertionPoint(arrive); + auto ifOp = + rewriter.create(arrive->getLoc(), pred, /*else=*/false); + rewriter.moveOpBefore(arrive, ifOp.thenBlock(), ifOp.thenBlock()->begin()); + return ifOp; + } + assert("don't know how to predicate this op" && false); return op; } @@ -159,6 +190,7 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, trans.getOrderAttr()); } assert(newVal); + newVal.getDefiningOp()->setAttrs(user->getAttrs()); replaceUsesAndPropagateType(builder, user, newVal); opsToDelete.push_back(use.getOwner()); } @@ -173,3 +205,49 @@ void mlir::triton::replaceUsesAndPropagateType(OpBuilder &builder, for (Operation *op : opsToDelete) op->erase(); } + +// Begin __FACEBOOK__ CompPipe +llvm::SmallVector> +mlir::triton::loadOpsToIndirectionLevelAndUse(scf::ForOp forOp) { + llvm::SmallVector> + loadOpToIndLevelAndUse; + DenseSet seen; + + std::function dfs = + [&](Operation *op, int distance, Operation *use) { + if (!seen.insert(op).second) + return; + if (isa(op)) { + // TODO: What if there are multiple uses at different distances? + loadOpToIndLevelAndUse.push_back(std::make_tuple(op, distance, use)); + use = op; + distance++; + } + for (Value operand : op->getOperands()) { + Value v = operand; + Operation *defOp = v.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, distance, use); + } + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!op.hasTrait()) + continue; + seen.clear(); + dfs(&op, 0, &op); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (forOp->hasAttr(tt::kNumStagesAttrName)) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, 0, &op); + } + } + + return loadOpToIndLevelAndUse; +} +// End __FACEBOOK__ CompPipe diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp index 1116b70a0..1d10595c3 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp @@ -15,9 +15,15 @@ namespace tt = mlir::triton; namespace ttg = mlir::triton::gpu; namespace ttng = mlir::triton::nvidia_gpu; -void tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, - tt::CoarseSchedule::Cluster cluster, - bool includeArg) { +void tt::CoarseSchedule::insertDepsOfOp( + Operation *op, int stage, tt::CoarseSchedule::Cluster cluster, + bool includeArg, DenseMap *additionalDep) { + // Look in additionalDep. + if (additionalDep && additionalDep->find(op) != additionalDep->end()) { + Operation *wait = (*additionalDep)[op]; + if (insertIfAbsent(wait, stage, cluster)) + insertDepsOfOp(wait, stage, cluster, includeArg, additionalDep); + } for (Value operand : op->getOperands()) { Value v = operand; llvm::SmallDenseSet seen; @@ -36,7 +42,7 @@ void tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, Operation *defOp = v.getDefiningOp(); if (defOp && defOp->getBlock() == op->getBlock()) { if (insertIfAbsent(defOp, stage, cluster)) { - insertDepsOfOp(defOp, stage, cluster, includeArg); + insertDepsOfOp(defOp, stage, cluster, includeArg, additionalDep); } } } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp index 7985d25b9..eee8219ba 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/TMAStoresPipeline.cpp @@ -1,6 +1,8 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" using namespace mlir; namespace tt = mlir::triton; @@ -29,7 +31,7 @@ getTMAStores(scf::ForOp forOp) { static Value createAlloc(scf::ForOp &forOp, tt::ExperimentalDescriptorStoreOp storeOp) { - OpBuilder builder(forOp); + OpBuilderWithAsyncTaskIds builder(forOp); auto ty = cast(storeOp.getSrc().getType()); auto order = ttg::getOrder(ty.getEncoding()); auto ctaLayout = ttg::getCTALayout(ty.getEncoding()); @@ -44,7 +46,7 @@ static Value createAlloc(scf::ForOp &forOp, Type memdescType = tt::MemDescType::get(ty.getShape(), ty.getElementType(), encoding, sharedMemorySpace, /*mutableMemory*/ true); - Value alloc = builder.create(storeOp->getLoc(), + Value alloc = builder.createWithAsyncTaskIds(storeOp->getLoc(), memdescType, Value()); return alloc; } @@ -52,7 +54,7 @@ static Value createAlloc(scf::ForOp &forOp, static void createTMAAsyncCopy(scf::ForOp &forOp, tt::ExperimentalDescriptorStoreOp storeOp, Value alloc) { - OpBuilder builder(storeOp); + OpBuilderWithAsyncTaskIds builder(storeOp); auto loc = storeOp.getLoc(); auto ty = cast(storeOp.getSrc().getType()); auto order = ttg::getOrder(ty.getEncoding()); @@ -60,10 +62,10 @@ static void createTMAAsyncCopy(scf::ForOp &forOp, // Put wait before the local_store make the store truly async. We know // that we are the only user of the CopyLocalToGlobal. - builder.create(loc, 0); - builder.create(loc, storeOp.getSrc(), alloc); - builder.create(loc, false); - builder.create( + builder.createWithAsyncTaskIds(loc, 0); + builder.createWithAsyncTaskIds(loc, storeOp.getSrc(), alloc); + builder.createWithAsyncTaskIds(loc, false); + builder.createWithAsyncTaskIds( loc, storeOp.getDescPtr(), storeOp.getIndices(), alloc); storeOp->erase(); diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index b6d855a05..d7a4ce9d6 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -169,6 +169,7 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { SmallVector queue = {op->getResult(0)}; SetVector forwardSlice; llvm::SmallDenseSet seen; + llvm::SmallDenseSet seenOps; // facebook T170066846 while (!queue.empty()) { Value currentValue = queue.back(); queue.pop_back(); @@ -211,6 +212,17 @@ bool hasConvertToMMATransisitiveUse(Operation *op, Attribute encoding) { } } } + + // facebook begin T170066846 + // If an op is visited more than once, it indicates a loop and that + // mulitple operands of the op share the same mma layout. Allow the + // propagation to avoid unncessary layout conversion within the loop. + if (op->hasTrait()) { + if (!seenOps.insert(op).second == true) + return true; + } + // facebook end T170066846 + bool isMMAV3 = isa(encoding) && cast(encoding).getVersionMajor() == 3; @@ -528,6 +540,8 @@ Value LayoutPropagation::getValueAs(Value value, Attribute encoding) { tensorType.getElementType(), encoding); Value converted = rewriter.create(value.getLoc(), tmpType, rewrittenValue); + if (value.getDefiningOp()) + converted.getDefiningOp()->setAttrs(value.getDefiningOp()->getAttrs()); // TODO: we could cache the conversion. return converted; } @@ -770,6 +784,7 @@ Operation *LayoutPropagation::rewriteOp(Operation *op) { auto newType = RankedTensorType::get(tensorType.getShape(), tensorType.getElementType(), encoding); auto cvt = rewriter.create(op->getLoc(), newType, src); + cvt->setAttrs(op->getAttrs()); map(op->getResult(0), cvt.getResult()); return cvt.getOperation(); } @@ -1171,6 +1186,7 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( tensorType.getShape(), tensorType.getElementType(), *srcEncoding); auto newConvertOp = builder.create( convertOp.getLoc(), newType, extOrBroadcatOp->getOperand(0)); + newConvertOp->setAttrs(convertOp->getAttrs()); Operation *newExtOrBroadcast = builder.clone(*extOrBroadcatOp); newExtOrBroadcast->setOperand(0, newConvertOp.getResult()); auto oldExtOrBroadcastType = diff --git a/lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp b/lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp new file mode 100644 index 000000000..ddd85dee5 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/TaskIdPropagate.cpp @@ -0,0 +1,407 @@ +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +#define DEBUG_TYPE "triton-gpu-taskid-propagate" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = ::mlir::triton; +namespace ttg = ::mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; + +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUTASKIDPROPAGATE +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// Return all Ops that are marked with target task +void getAsyncTaskOps(triton::FuncOp funcOp, DenseSet &asyncTaskOps, + int asyncTaskId) { + funcOp.walk([&](Operation *op) -> void { + if (auto attr = + op->getAttrOfType("async_task_id")) { + for (auto val : attr.getValues()) { + if (val == asyncTaskId) { + asyncTaskOps.insert(op); + break; + } + } + } + }); +} + +void getAllParentOps(DenseSet &parentOps, Operation *targetOp) { + auto op = targetOp; + while (auto parent = op->getParentOp()) { + if (!isa(parent) && !isa(parent)) { + parentOps.insert(parent); + op = parent; + } else { + break; + } + } +} + +void getAllParentOps(triton::FuncOp funcOp, DenseSet &parentOps, + int asyncTaskId) { + DenseSet targetOps; + getAsyncTaskOps(funcOp, targetOps, asyncTaskId); + for (auto op : targetOps) { + getAllParentOps(parentOps, op); + } +} + +void labelByUsers(Operation *op, ArrayRef allAsyncTasks) { + for (Value result : op->getResults()) { + for (Operation *userOp : result.getUsers()) { + if (!userOp->hasAttr("async_task_id")) { + labelByUsers(userOp, allAsyncTasks); + } + addAsyncTaskIds(op, getAsyncTaskIds(userOp)); + } + } + if (!op->hasAttr("async_task_id")) { + addAsyncTaskIds(op, allAsyncTasks); + } +} + +/// Because we set some special filter rules in populateAsyncTaskRegion, +/// there may be unlabeled Ops, e.g. YieldOps, some definingOps of ForOps. +/// or Ops without relations to asyncTaskOps +void populateUnlabledOpsAtLast(triton::FuncOp funcOp, + ArrayRef allAsyncTasks) { + // Label asyncTasks' parentOps + for (int i : allAsyncTasks) { + DenseSet asyncTaskParentOps; + getAllParentOps(funcOp, asyncTaskParentOps, i); + for (auto op : asyncTaskParentOps) { + addAsyncTaskIds(op, {i}); + } + } + + // Get unlabeled Ops + DenseSet unlabeledOps; + funcOp.walk([&](Operation *op) -> void { + if (isa(op) || isa(op) || + isa(op)) { + return; + } + if (!op->hasAttr("async_task_id")) { + unlabeledOps.insert(op); + } + }); + + // Label Ops using its parentOp + for (auto op : unlabeledOps) { + if (auto parent = op->getParentOp()) { + if (!isa(parent)) { + if (!parent->hasAttr("async_task_id")) { + LLVM_DEBUG({ + LDBG("op and parent: "); + op->dump(); + parent->dump(); + }); + continue; + } + assert(parent->hasAttr("async_task_id")); + auto asyncTasks = getAsyncTaskIds(parent); + setAsyncTaskIds(op, asyncTasks); + unlabeledOps.erase(op); + } + } + } + + // Label Ops using dependency + for (auto op : unlabeledOps) { + labelByUsers(op, allAsyncTasks); + unlabeledOps.erase(op); + } + assert(unlabeledOps.size() == 0); +} + +#ifndef NDEBUG +static bool oneVecCoversTheOther(SmallVector &one, + SmallVector &other) { + // Every element of other appears in one. + for (AsyncTaskId t : other) { + // If t doesn't appear in one, return false. + bool found = false; + for (AsyncTaskId t2 : one) { + if (t2 == t) { + found = true; + break; + } + } + if (!found) + return false; + } + return true; +} + +struct AsyncTaskIdsCompare { + static SmallVector getEmptyKey() { + SmallVector V; + V.push_back(reinterpret_cast(-1)); + return V; + } + + static SmallVector getTombstoneKey() { + SmallVector V; + V.push_back(reinterpret_cast(-2)); + return V; + } + + static unsigned getHashValue(const SmallVector &V) { + return static_cast(llvm::hash_combine_range(V.begin(), V.end())); + } + + static bool isEqual(const SmallVector &LHS, + const SmallVector &RHS) { + return LHS == RHS; + } +}; + +// Make sure the def chain contains the right taskId. +bool verifyTaskId(triton::FuncOp &funcOp, + const llvm::DenseSet& anchorOps) { + bool retCode = true; + DenseSet, AsyncTaskIdsCompare> anchorAsyncTasks; + for (auto anchorOp : anchorOps) { + anchorAsyncTasks.insert(getAsyncTaskIds(anchorOp)); + } + + funcOp.walk([&](Operation *op) { + // Skip control ops + if (llvm::isa(op)) + return; + + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.empty()) { + LLVM_DEBUG({ + LDBG("Op does not have task id"); + op->dump(); + }); + llvm_unreachable("Op does not have task id"); + } + + auto partitionShouldBeUsedSpecified = [](Operation *op) { + if (isa(op)) + return true; + if (isa(op)) + return true; + if (op->hasTrait()) + return true; + return false; + }; + + if (!anchorAsyncTasks.contains(asyncTaskIds)) { + if (partitionShouldBeUsedSpecified(op)) { + LLVM_DEBUG({ + LDBG("async tasks not specified by user"); + op->dump(); + }); + llvm_unreachable("async tasks not specified by user"); + } + } + + assert(!asyncTaskIds.empty() && "Op does not have task id"); + + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (!defOp) + continue; + if (llvm::isa(defOp)) + continue; + auto defTaskIds = getAsyncTaskIds(defOp); + // Make sure defTaskIds cover asyncTaskIds. Call addAsyncTaskIds if + // necessary. + LLVM_DEBUG({ + if (!oneVecCoversTheOther(defTaskIds, asyncTaskIds)) { + // print defOp and op + LDBG("Def op does not cover op"); + LDBG("Def op"); + defOp->dump(); + LDBG("op"); + op->dump(); + } + }); + assert(oneVecCoversTheOther(defTaskIds, asyncTaskIds) && + "defTaskIds should cover asyncTaskIds"); + } + }); + return retCode; +} +#endif + +void backwardPropagateTaskIds(Operation *op, + const llvm::DenseSet &anchors) { + SmallVector queue; + auto asyncTasks = getAsyncTaskIds(op); + for (Value operand : op->getOperands()) { + queue.push_back(operand); + } + + DenseSet seen; + for (auto anchor : anchors) { + if (anchor != op) + for (auto result : anchor->getResults()) + seen.insert(result); + } + + while (!queue.empty()) { + auto value = queue.pop_back_val(); + if (!seen.insert(value).second) { + continue; + } + + // Handle BlockArguments of for loops (i.e. loop carried dependences). + if (auto blockArg = dyn_cast(value)) { + auto parent = blockArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(parent)) { + // Propagate to the control operands. + auto control = + forOp.getOperands().take_front(forOp.getNumControlOperands()); + queue.insert(queue.end(), control.begin(), control.end()); + // Propagate to the initializer. + if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { + queue.push_back(forOp.getTiedLoopInit(blockArg)->get()); + // Propagate to the yield. + auto idx = blockArg.getArgNumber() - forOp.getNumInductionVars(); + queue.push_back(forOp.getBody()->getTerminator()->getOperand(idx)); + addAsyncTaskIds(forOp, asyncTasks); + } + } + continue; + } + + auto op = value.getDefiningOp(); + addAsyncTaskIds(op, asyncTasks); + + // Handle for loops. + if (auto forOp = dyn_cast(op)) { + // Propagate to control operands. + auto control = + forOp.getOperands().take_front(forOp.getNumControlOperands()); + queue.insert(queue.end(), control.begin(), control.end()); + // Propagate to arguments. + unsigned idx = cast(value).getResultNumber(); + queue.push_back(forOp.getOperand(idx + forOp.getNumControlOperands())); + // Propagate to yield. + queue.push_back(forOp.getBody()->getTerminator()->getOperand(idx)); + continue; + } + + // Handle conditionals. + if (auto ifOp = dyn_cast(op)) { + queue.push_back(ifOp.getCondition()); + unsigned idx = cast(value).getResultNumber(); + if (ifOp.elseBlock()) { + queue.push_back(ifOp.elseYield()->getOperand(idx)); + } + queue.push_back(ifOp.thenYield()->getOperand(idx)); + continue; + } + + // Handle normal ops. + for (Value operand : op->getOperands()) { + queue.push_back(operand); + } + } +} + +void backwardPropagateTaskIds(llvm::DenseSet &anchorOps) { + for (Operation *op : anchorOps) { + backwardPropagateTaskIds(op, anchorOps); + } +} + +void populateTaskIdsForControlDependencies( + llvm::DenseSet &anchorOps) { + for (auto op : anchorOps) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (!asyncTaskIds.empty()) { + while (auto parent = op->getParentOp()) { + if (!isa(parent) && !isa(parent)) { + setAsyncTaskIds(parent, asyncTaskIds); + backwardPropagateTaskIds(parent, anchorOps); + op = parent; + } else { + break; + } + } + } + } +} + +class TritonGPUTaskIdPropagatePass + : public impl::TritonGPUTaskIdPropagateBase { +public: + using impl::TritonGPUTaskIdPropagateBase< + TritonGPUTaskIdPropagatePass>::TritonGPUTaskIdPropagateBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + llvm::DenseSet anchorOps; + funcOp.walk([&](mlir::Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + if (!asyncTasks.empty() && + !isa(op)) + anchorOps.insert(op); + }); + + populateTaskIdsForControlDependencies(anchorOps); + + LLVM_DEBUG({ + LDBG("after populateTaskIdsForControlDependencies "); + funcOp->dump(); + }); + + backwardPropagateTaskIds(anchorOps); + + LLVM_DEBUG({ + LDBG("after backwardPropagateTaskIds "); + funcOp->dump(); + }); + + DenseSet allAsyncTasks; + funcOp->walk([&](Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + allAsyncTasks.insert(asyncTasks.begin(), asyncTasks.end()); + }); + SmallVector allAsyncTasksVec(allAsyncTasks.begin(), + allAsyncTasks.end()); + populateUnlabledOpsAtLast(funcOp, allAsyncTasksVec); + + LLVM_DEBUG({ + LDBG("after populateUnlabledOpsAtLast "); + funcOp->dump(); + }); + +#ifndef NDEBUG + verifyTaskId(funcOp, anchorOps); +#endif + } + + void runOnOperation() override { + if (numConsumerGroups == 0) { + getOperation()->walk([&](triton::FuncOp funcOp) { + funcOp.walk([&](mlir::Operation *op) { + auto asyncTasks = getAsyncTaskIds(op); + if (!asyncTasks.empty()) + op->removeAttr("async_task_id"); + }); + }); + return; + } + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp new file mode 100644 index 000000000..2ae27f467 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/WSCodePartition.cpp @@ -0,0 +1,1424 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" +#include +#include + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define GEN_PASS_DEF_TRITONGPUWSCODEPARTITION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +#define DEBUG_TYPE "tritongpu-warp-spec-code-partition" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +std::pair scanRegUsage(ArrayRef opList, + AsyncTaskId asyncTaskId, int regDecProducer, + int regIncConsumer) { + // TODO: scan ops to estimate register usage + if (asyncTaskId == 0) { + // deallocate registers + return {regDecProducer == 0 ? 40 : regDecProducer, false}; + } else { + // allocate registers + return {regIncConsumer == 0 ? 232 : regIncConsumer, true}; + } +} + +// Create IfOp for each ayncTaskId. +DenseMap SpecializeRegion(triton::FuncOp funcOp, + int regDecProducer, + int regIncConsumer) { + MLIRContext *context = funcOp.getContext(); + OpBuilder builder(context); + auto loc = funcOp.getLoc(); + + // Collect original operations + SmallVector opList; + for (auto &block : funcOp.getBody().getBlocks()) { + for (Operation &op : block.getOperations()) + opList.push_back(&op); + } + + // Create GetAsyncTaskIdOp. + Block *lastBlock = &funcOp.getBody().back(); + auto returnOp = llvm::cast(lastBlock->getTerminator()); + builder.setInsertionPoint(returnOp); + Value curAsyncTaskId = builder.create(loc); + + // Resources for each asyncTaskId: builder, IfOp, and IRMapping. + DenseMap> + tasksToBuilders; + DenseMap tasksToIfOp; + DenseMap tasksToIRMappings; + + for (AsyncTaskId asyncTaskId : getNestedAsyncTaskIds(funcOp)) { + // Create IfOp for each asyncTaskId. + Value cond = builder.create( + loc, arith::CmpIPredicate::eq, curAsyncTaskId, + builder.create(loc, asyncTaskId, 32)); + + auto ifOp = builder.create(loc, cond); + tasksToIfOp[asyncTaskId] = ifOp; + setAsyncTaskIds(ifOp, {asyncTaskId}); + + // Create OpBuilderWithAsyncTaskIds for each taskId. + auto taskBuilder = std::make_shared(context); + tasksToBuilders[asyncTaskId] = taskBuilder; + taskBuilder->setAsynTaskIdsFromArray({asyncTaskId}); + + // Decide if this taskId is a producer or a consumer, and create either + // RegAllocOp or RegDeallocOp accordingly. + auto regAlloc = + scanRegUsage(opList, asyncTaskId, regDecProducer, regIncConsumer); + taskBuilder->setInsertionPointToStart(&(ifOp.getThenRegion().front())); + if (regAlloc.second) + taskBuilder->create( + loc, taskBuilder->getI32IntegerAttr(regAlloc.first)); + else + taskBuilder->create( + loc, taskBuilder->getI32IntegerAttr(regAlloc.first)); + + // Set insertion point before yieldOp. + auto yieldOp = ifOp.thenYield(); + setAsyncTaskIds(yieldOp, {asyncTaskId}); + taskBuilder->setInsertionPoint(yieldOp); + } + + // Clone all operations into the corresponding if blocks. If the operation has + // multiple taskIds, it will be cloned for multiple if blocks. + // If the original code has an IfOp, we should only clone its + // body with the right asyncTaskId, instead of cloning the IfOp. + SmallVector cloned; + for (Operation *op : opList) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() == 0) + continue; + cloned.push_back(op); + if (auto ifOp = dyn_cast(op)) { + DenseMap tasksToThisIfOp; + // TODO: handle outputs of this IfOp. + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(op)) { + IRMapping &mapping = tasksToIRMappings[asyncTaskId]; + auto ifOpForTask = tasksToBuilders[asyncTaskId]->create( + loc, mapping.lookup(ifOp.getCondition())); + tasksToThisIfOp[asyncTaskId] = ifOpForTask; + auto newYieldOp = ifOpForTask.thenYield(); + tasksToBuilders[asyncTaskId]->setInsertionPoint(newYieldOp); + } + // Handle thenRegion of this IfOp. + for (Operation &thenOp : ifOp.thenBlock()->without_terminator()) { + LLVM_DEBUG({ + LDBG("specialize thenBlock inside ifOp "); + thenOp.dump(); + }); + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(&thenOp)) { + IRMapping &mapping = tasksToIRMappings[asyncTaskId]; + Operation *newOp = + tasksToBuilders[asyncTaskId]->clone(thenOp, mapping); + for (unsigned i = 0; i < thenOp.getNumResults(); ++i) + mapping.map(thenOp.getResult(i), newOp->getResult(i)); + } + } + if (!ifOp.elseBlock()) + continue; // Done with this IfOp, continue to the next op. + // Handle elseRegion of the IfOp. + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(op)) { + auto newYieldOp = tasksToThisIfOp[asyncTaskId].elseYield(); + tasksToBuilders[asyncTaskId]->setInsertionPoint(newYieldOp); + } + for (Operation &thenOp : ifOp.elseBlock()->without_terminator()) { + LLVM_DEBUG({ + LDBG("specialize elseBlock inside ifOp "); + thenOp.dump(); + }); + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(&thenOp)) { + IRMapping &mapping = tasksToIRMappings[asyncTaskId]; + Operation *newOp = + tasksToBuilders[asyncTaskId]->clone(thenOp, mapping); + for (unsigned i = 0; i < thenOp.getNumResults(); ++i) + mapping.map(thenOp.getResult(i), newOp->getResult(i)); + } + } + } else { + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(op)) { + IRMapping &mapping = tasksToIRMappings[asyncTaskId]; + Operation *newOp = tasksToBuilders[asyncTaskId]->clone(*op, mapping); + for (unsigned i = 0; i < op->getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + } + } + } + + LLVM_DEBUG({ + LDBG("\n\nWith task Id checks"); + funcOp.dump(); + }); + + // Remove original operations that have been cloned in reverse order. + for (auto it = cloned.rbegin(); it != cloned.rend(); ++it) { + Operation *op = *it; + LLVM_DEBUG({ + LDBG("erasing op "); + op->dump(); + }); + // For debugging purposes, check to see if the original op is still in use. + bool hasUse = false; + for (unsigned i = 0; i < op->getNumResults(); ++i) { + for (Operation *user : op->getResult(i).getUsers()) { + hasUse = true; + LLVM_DEBUG({ + LDBG("op has use "); + user->dump(); + }); + } + } + op->erase(); + } + return tasksToIfOp; +} + +struct Channel { +public: + using Relation = std::pair>; + + Channel(int producer, SmallVector &consumers, Operation *src, + Operation *dst, Value srcOperand) + : relation(producer, consumers), srcOp(src), dstOp(dst), + srcOperand(srcOperand) {} + + bool operator==(const Channel &c) { + return relation == c.relation && srcOp == c.srcOp && dstOp == c.dstOp; + } + + Relation relation; // producer task Id, a list of consumer task Ids + Operation *srcOp; + Operation *dstOp; + Value srcOperand; +}; + +// Loads will be in producer warp groups. For now, we only allow a single +// warp group/task for a producer. For each LoadOp, create a channel from it +// to any direct user which belongs to a different taskId. +void collectAsyncChannels(SmallVector> &channels, + triton::FuncOp &funcOp) { + funcOp.walk([&](Operation *op) { + if (isa(op)) { + auto producerTaskIds = getAsyncTaskIds(op); + if (producerTaskIds.empty() || producerTaskIds.size() > 1) { + LLVM_DEBUG({ + LDBG(" ignoring load ops without async task id or with multiple task " + "ids: "); + op->dump(); + }); + return; + } + auto producerTaskId = producerTaskIds.front(); + + for (auto result : op->getResults()) { + if (result.use_empty()) { + continue; + } + for (Operation *userOp : result.getUsers()) { + auto consumerTaskIds = getAsyncTaskIds(userOp); + if (consumerTaskIds.empty()) + continue; + // Remove producer task id from consumerTaskIds. + auto iter = std::remove(consumerTaskIds.begin(), + consumerTaskIds.end(), producerTaskId); + consumerTaskIds.erase(iter, consumerTaskIds.end()); + // Add a channel from the single producer task to consumerTaskIds. + if (consumerTaskIds.size() > 0) { + channels.push_back(std::make_unique( + producerTaskId, consumerTaskIds, op, userOp, result)); + } + } + } + } + }); + + LLVM_DEBUG({ + LDBG("Async channels:"); + for (auto &channel : channels) { + LDBG("producer op: " << channel->relation.first); + channel->srcOp->dump(); + for (auto &asyncTaskId : channel->relation.second) + LDBG("consumer: " << asyncTaskId); + channel->dstOp->dump(); + } + }); +} + +// Update map, which will be keyed by dstOp of the channel. Use mapKeyVec to +// enforce deterministic order for map. +void groupChannels(SmallVector &channels, + DenseMap> &map, + SmallVector &mapKeyVec) { + // Two channels can be combined if + // src1 and src2 are in the same block and + // (dst1 == dst2 or + // (dst1 and dst2 are in the same block, both have a single user, and + // dst1User == dst2User and dst1User is in the same block as dst1)) + auto channelCanBeMerged = [](Channel *c1, Channel *c2) -> bool { + if (c1->srcOp->getBlock() != c2->srcOp->getBlock()) + return false; + Operation *dst1 = c1->dstOp, *dst2 = c2->dstOp; + if (dst1 == dst2) + return true; + if (dst1->getBlock() != dst2->getBlock() || !dst1->hasOneUse() || + !dst2->hasOneUse()) + return false; + Operation *dst1User = *(dst1->getUsers().begin()); + Operation *dst2User = *(dst2->getUsers().begin()); + return dst1User == dst2User && dst1User->getBlock() == dst1->getBlock(); + }; + assert(channels.size() > 0 && "channel size is zero"); + // Compare with existing channels in the map to see if it can be combined. + for (auto *c0 : channels) { + bool merged = false; + for (auto &kv : map) { + if (kv.second.size() > 0 && channelCanBeMerged(c0, kv.second.front())) { + kv.second.push_back(c0); + merged = true; + break; + } + } + if (!merged) { // Create a new entry. + auto *keyOp = c0->dstOp; + if (!map.count(keyOp)) + mapKeyVec.push_back(keyOp); + map[keyOp].push_back(c0); + } + } + + // Reorder channels associated with one entry based on program order of the + // producers. + for (auto &kv : map) { + if (kv.second.size() > 1) { + auto &allOps = kv.second.front()->srcOp->getBlock()->getOperations(); + std::sort( + kv.second.begin(), kv.second.end(), [&](Channel *a, Channel *b) { + auto itrA = + std::find_if(allOps.begin(), allOps.end(), [&](Operation &op) { + Operation *opPointer = &op; + return opPointer == a->srcOp; + }); + auto itrB = + std::find_if(allOps.begin(), allOps.end(), [&](Operation &op) { + Operation *opPointer = &op; + return opPointer == b->srcOp; + }); + assert(itrA != allOps.end() && itrB != allOps.end()); + return std::distance(itrA, itrB) < 0; + }); + } + } +} + +// Reorder producer ops to unblock consumers interleavingly. +void reorderProducerOps(SmallVector &channels) { + if (channels.size() <= 1) + return; + + // Bail out if channels are not in the same block + auto block = channels.front()->srcOp->getBlock(); + for (auto &channel : channels) { + if (channel->srcOp->getBlock() != block) { + return; + } + } + + // Group channels by the first consumer taskId of each channel. Smaller taskId + // has higher priority. + // TODO: consider consumer priority + std::map> groupedProducerOps; + for (auto &channel : channels) { + auto asyncTaskId = channel->relation.second.front(); + groupedProducerOps[asyncTaskId].push_back(channel); + } + + // No need to reorder if all channels are in the same group. + if (groupedProducerOps.size() <= 1) + return; + + // Sort each group by number of consumers. + for (auto &group : groupedProducerOps) { + std::sort(group.second.begin(), group.second.end(), + [&](Channel *a, Channel *b) { + return a->relation.second.size() < b->relation.second.size(); + }); + } + + // Start from the first producer in channels. Iterate through the groups + // which are ordered by the first consumer taskId. Within each group, channels + // are ordered by number of consumers. + Operation *currOp = channels.front()->srcOp; + for (auto &group : groupedProducerOps) { + for (auto &channel : group.second) { + channel->srcOp->moveAfter(currOp); + currOp = channel->srcOp; + } + } + + // Move backward dependency slice close to producer ops. + // Start from the last producer op backwards and move backward slice to + // before each op. This guarantees that the backward slice of each op is + // scheduled as late as possible. + for (auto &group : reverse(groupedProducerOps)) { + for (auto &channel : reverse(group.second)) { + BackwardSliceOptions opt; + opt.omitBlockArguments = true; + SetVector backwardSlice; + getBackwardSlice(channel->srcOp, &backwardSlice, opt); + for (auto &op : backwardSlice) { + if (op->getBlock() == block) + op->moveBefore(channel->srcOp); + } + } + } + + LLVM_DEBUG({ + LDBG("\n"); + LDBG("after reordering producer ops"); + currOp->getParentOfType().dump(); + LDBG("\n"); + }); +} + +bool isInnermostLoop(scf::ForOp forOp) { + for (Operation &nestedOp : forOp.getBody()->getOperations()) { + if (isa(nestedOp)) { + return false; + } + } + return true; +} + +// Add phase and bufferIndex to be used when lowering the producer. +scf::ForOp createNewLoop(scf::ForOp forOp, int numBuffers, + scf::ForOp &parentForOp) { + auto loc = forOp.getLoc(); + Block *body = forOp.getBody(); + + OpBuilderWithAsyncTaskIds builder(forOp.getContext()); + builder.setAsynTaskIdsFromArray(getNestedAsyncTaskIds(forOp)); + builder.setInsertionPoint(forOp); + + Value numBuffersVal = + builder.createWithAsyncTaskIds(loc, numBuffers, 32); + + // Step 1: Append bufferIdx and phase as forOp arguments. + Value phase = + body->insertArgument(body->getNumArguments(), builder.getI1Type(), loc); + Value bufferIdx = + body->insertArgument(body->getNumArguments(), builder.getI32Type(), loc); + + // Step 2: Generate bufferIdx and phase for next iteration: + // nextBufferIdx = bufferIdx + 1 + // nextPhase = ((nextBufferIdx < numBuffers && curPhase) || + // (nextBufferIdx >= numBuffers && curPhase^1)) + // nextBufferIdx = nextBufferIdx >= numBuffers ? 0 : nextBufferIdx + auto yieldOp = llvm::cast(body->getTerminator()); + builder.setInsertionPoint(yieldOp); + Value one = builder.createWithAsyncTaskIds(loc, 1, 32); + Value zero = builder.createWithAsyncTaskIds(loc, 0, 32); + Value _1_1b = builder.createWithAsyncTaskIds(loc, 1, 1); + // nextBufferIdx = bufferIdx + 1 + Value nextBufferIdx = + builder.createWithAsyncTaskIds(loc, bufferIdx, one); + Value bufferGECond = builder.createWithAsyncTaskIds( + loc, arith::CmpIPredicate::uge, nextBufferIdx, numBuffersVal); + Value bufferLTCond = builder.createWithAsyncTaskIds( + loc, arith::CmpIPredicate::ult, nextBufferIdx, numBuffersVal); + if (isInnermostLoop(forOp)) { + // nextBufferIdx >= numBuffers ? nextBufferIdx - numBuffers : nextBufferIdx + Value moduloBufferIdx = builder.createWithAsyncTaskIds( + loc, nextBufferIdx, numBuffersVal); + nextBufferIdx = builder.createWithAsyncTaskIds( + loc, bufferGECond, moduloBufferIdx, nextBufferIdx); + } + + // nextPhase = ((nextBufferIdx < numBuffers && curPhase) || + // (nextBufferIdx >= numBuffers && curPhase^1)) + Value flipPhase = + builder.createWithAsyncTaskIds(loc, phase, _1_1b); + Value cond0 = builder.createWithAsyncTaskIds( + loc, bufferGECond, flipPhase); + Value cond1 = builder.createWithAsyncTaskIds( + loc, bufferLTCond, phase); + Value nextPhase = + builder.createWithAsyncTaskIds(loc, cond0, cond1); + + // Step 3: Add nextBufferIdx and nextPhase to yieldOp. + yieldOp->insertOperands(yieldOp.getNumOperands(), {nextPhase, nextBufferIdx}); + + // Step 4: Create loop arguments for the new ForOp. + SmallVector newLoopArgs; + for (auto operand : forOp.getInitArgs()) + newLoopArgs.push_back(operand); + + builder.setInsertionPoint(forOp); + Value initBufferIdx, initPhase; + zero = builder.createWithAsyncTaskIds(loc, 0, 32); + // Set initial values for bufferIdx and phase. + if (parentForOp) { + // Assume parent ForOp has bufferIdx as the last argument. + initBufferIdx = parentForOp.getBody()->getArguments().back(); + + // numSteps = ((upperBound - lowerBound) + forOpStep - 1) / forOpStep + Value numSteps = builder.createWithAsyncTaskIds( + loc, forOp.getUpperBound(), forOp.getLowerBound()); + numSteps = builder.createWithAsyncTaskIds(loc, numSteps, + forOp.getStep()); + Value one = + builder.createWithAsyncTaskIds(loc, 1, 32); + Value two = + builder.createWithAsyncTaskIds(loc, 2, 32); + numSteps = + builder.createWithAsyncTaskIds(loc, numSteps, one); + numSteps = builder.createWithAsyncTaskIds(loc, numSteps, + forOp.getStep()); + + // initBufferIdx = (parentForOp.bufferIdx * numSteps) % numBuffers + // tmpIdx = parentForOp.bufferIdx * numSteps + // initBufferIdx = tmpIdx - tmpIdx / numBuffers * numBuffers + // initPhase = (tmpIdx / numBuffers) & 1 + initBufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, numSteps); + Value bufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, numBuffersVal); + initBufferIdx = builder.createWithAsyncTaskIds( + loc, initBufferIdx, + builder.createWithAsyncTaskIds(loc, bufferIdx, + numBuffersVal)); + bufferIdx = + builder.createWithAsyncTaskIds(loc, bufferIdx, one); + initPhase = builder.createWithAsyncTaskIds( + loc, builder.getI1Type(), bufferIdx); + } else { + // Set initial phase to false, and initial bufferIdx to 0. + initBufferIdx = zero; + initPhase = builder.createWithAsyncTaskIds(loc, 0, 1); + } + newLoopArgs.append({initPhase, initBufferIdx}); + + // Step 5: Create newForOp and take the region of the original forOp. + auto newForOp = builder.createWithAsyncTaskIds( + loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), + newLoopArgs); + if (forOp->getAttr("tt.loop_schedule")) + newForOp->setAttr("tt.loop_schedule", forOp->getAttr("tt.loop_schedule")); + newForOp.getRegion().takeBody(forOp.getRegion()); + + // Step 6: Replace forOp with newForOp. + for (unsigned i = 0; i < forOp.getNumResults(); ++i) + forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i)); + forOp.erase(); + + return newForOp; +} + +// Find top-level ops which contain at least one channel. If a channel's srcOp +// and dstOp belong to the inner loop, the outer loop will be part of +// asyncTaskOps. +SmallVector +getTaskTopRegion(triton::FuncOp funcOp, + const SmallVector &channels) { + SmallVector asyncTaskOps; + auto isAsyncTaskTopOp = [&](Operation *taskTopOp) -> bool { + for (auto c : channels) { + Operation *producer = c->srcOp, *consumer = c->dstOp; + while (producer && !isa(producer->getParentOp())) { + producer = producer->getParentOp(); + } + while (consumer && !isa(consumer->getParentOp())) { + consumer = consumer->getParentOp(); + } + if (producer == taskTopOp && consumer == taskTopOp) + return true; + } + return false; + }; + for (auto &block : funcOp.getBody().getBlocks()) { + for (Operation &bodyOp : block.getOperations()) { + Operation *op = &bodyOp; + if (op->getNumRegions() <= 0) + continue; + // If this op does not contain both a producer taskId and a consumer + // taskId, continue. + if (getAsyncTaskIds(op).size() == 1) + continue; + if (isAsyncTaskTopOp(op)) + asyncTaskOps.push_back(op); + } + } + return asyncTaskOps; +} + +// For ForOps in taskTopOps, create new ForOp for each by adding phase, +// bufferIdx to the arguments. +void appendBufferIdxArgs(SmallVector &taskTopOps, int numBuffers) { + SmallVector orderedForOps; + for (auto &op : taskTopOps) { + op->walk([&](Operation *subOp) { + if (auto forOp = dyn_cast(subOp)) { + orderedForOps.push_back(forOp); + } + }); + } + + for (auto &origForOp : orderedForOps) { + scf::ForOp parentForOp = origForOp->getParentOfType(); + scf::ForOp newForOp; + // for(...) -> for(..., phase, bufferIdx) + newForOp = createNewLoop(origForOp, numBuffers, parentForOp); + // origForOp is erased in createNewLoop. If origForOp is a top operation + // (i.e in taskTopOps), make sure taskTopOps is updated with the newForOp. + auto asyncTaskLoopForItr = std::find(taskTopOps.begin(), taskTopOps.end(), + origForOp.getOperation()); + if (asyncTaskLoopForItr != taskTopOps.end()) { + // Update taskTopOps. + *asyncTaskLoopForItr = newForOp.getOperation(); + } + } +} + +// Create an allocation to hold the mbarriers. +static Value createBarrierAlloc(triton::FuncOp funcOp, unsigned distance) { + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(funcOp.getContext()); + Location loc = funcOp.getLoc(); + auto context = funcOp.getContext(); + auto barrierCTALayout = + ttg::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = + ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout); + Type barrierMemDescType = tt::MemDescType::get( + {distance}, builder.getI64Type(), barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + Type singleBarrierMemDescType = + tt::MemDescType::get({1}, builder.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value barrierAlloc = builder.create( + loc, barrierMemDescType, Value()); + for (unsigned i = 0; i < distance; i++) { + Value idx = builder.create(loc, i, 32); + Value barrierView = builder.create( + loc, singleBarrierMemDescType, barrierAlloc, idx); + builder.create(funcOp->getLoc(), barrierView, 1); + } + return barrierAlloc; +} + +// map: channels are grouped together. +// Go through each group, check the first channel in the group, create a token +// for each consumer taskId. Return a map that maps each channel + consumer +// taskId to a token. Also update barrierAllocMap that maps each channel + +// consumer taskId to a BarrierAlloc. +DenseMap> +createToken(const DenseMap> &map, + const SmallVector &mapKeyVec, triton::FuncOp funcOp, + int numBuffers, int numConsumerGroups, + DenseMap> &barrierAllocMap) { + DenseMap> ret; + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + for (auto *key : mapKeyVec) { + auto it = map.find(key); + for (auto consumerAsyncTaskId : it->second.front()->relation.second) { + Value v; + if (it->second.front()->srcOp->getParentOfType()) { + v = builder.create(funcOp.getLoc(), numBuffers); + } else { + v = builder.create(funcOp.getLoc(), 1); + } + // Channels in the group share the same set of tokens. + for (auto &c : it->second) + ret[c][consumerAsyncTaskId] = v; + + auto producerOp = it->second.front()->srcOp; + if (isa(producerOp)) { + Value bAlloc = createBarrierAlloc(funcOp, numBuffers); + // Channels in the group share the same set of tokens. + for (auto &c : it->second) { + ret[c][consumerAsyncTaskId] = v; + barrierAllocMap[c][consumerAsyncTaskId] = bAlloc; + } + } + } + } + return ret; +} + +// Create a buffer array for each channel, if the producer is in a ForOp, +// the buffer array will contain numBuffers. +DenseMap createBuffer(const SmallVector &channels, + triton::FuncOp funcOp, int numBuffers, + int numConsumerGroups) { + DenseMap bufferMap; + MLIRContext *context = funcOp.getContext(); + OpBuilder builder(funcOp); + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + for (const auto &c : channels) { + if (auto tensorType = dyn_cast(c->srcOperand.getType())) { + // Get basic information from tensorType + auto order = ttg::getOrder(tensorType.getEncoding()); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::SharedEncodingAttr::get( + context, sliceShape, order, CTALayout, elemType); + auto sliceType = + RankedTensorType::get(sliceShape, elemType, sharedLayout); + + // Get shape, layout and type of the complete buffer + SmallVector bufferShape(sliceShape.begin(), sliceShape.end()); + if (c->srcOp->getParentOfType()) + bufferShape.insert(bufferShape.begin(), numBuffers); + else + bufferShape.insert(bufferShape.begin(), 1); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + auto bufferType = + RankedTensorType::get(bufferShape, elemType, sharedLayout); + Type memdescType = + tt::MemDescType::get(bufferShape, elemType, sharedLayout, + sharedMemorySpace, /*mutableMemory*/ true); + Value buffer; + if (isa(c->srcOp)) { + buffer = + builder.create(funcOp.getLoc(), memdescType); + } else { + buffer = builder.create(funcOp.getLoc(), memdescType, + c->srcOperand); + } + bufferMap[c] = buffer; + } else { + llvm_unreachable("Unexpected result type"); + } + } + return bufferMap; +} + +static Operation *createAsyncCopy(const DenseMap &bufferMap, + Channel *c, Operation *op, + SmallVector &asyncTasksPC, + Value bufferIdx, Value bufferIdxExtract) { + auto loadOp = cast(op); + auto buffer = bufferMap.find(c)->second; + MLIRContext *context = loadOp->getContext(); + OpBuilderWithAsyncTaskIds builder(context); + builder.setInsertionPoint(loadOp->getParentOp()); + builder.setAsynTaskIdsFromArray(asyncTasksPC); + + builder.setInsertionPoint(loadOp); + Value loadResult = loadOp.getResult(); + auto tensorType = dyn_cast(loadResult.getType()); + if (!tensorType) + return nullptr; + // Get basic information from tensorType + auto order = ttg::getOrder(tensorType.getEncoding()); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::SharedEncodingAttr::get(context, sliceShape, order, + CTALayout, elemType); + auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + tt::MemDescType subviewTy = + tt::MemDescType::get(sliceType.getShape(), sliceType.getElementType(), + sliceType.getEncoding(), sharedMemorySpace, + /*mutableMemory=*/true); + Value zero = builder.createWithAsyncTaskIds( + loadOp.getLoc(), 0, 32); + SmallVector copyOffsets(sliceType.getRank() + 1, zero); + copyOffsets[0] = bufferIdx; + builder.setAsyncTaskIdsFromOp(loadOp); + builder.setInsertionPointAfter(loadOp); + auto view = builder.createWithAsyncTaskIds( + loadOp.getLoc(), subviewTy, buffer, copyOffsets); + // Create cp.async + Operation *copy = + builder.createWithAsyncTaskIds( + loadOp.getLoc(), loadOp.getPtr(), view, loadOp.getMask(), + loadOp.getOther(), loadOp.getCache(), loadOp.getEvict(), + loadOp.getIsVolatile()); + + // Extract part. + builder.setAsyncTaskIdsFromValueUsers(loadResult); + builder.setInsertionPoint(c->dstOp); + SmallVector loadOffsets(sliceType.getRank() + 1, zero); + loadOffsets[0] = bufferIdxExtract; + auto viewLoad = builder.createWithAsyncTaskIds( + loadOp.getLoc(), subviewTy, buffer, loadOffsets); + auto sharedLoad = builder.createWithAsyncTaskIds( + loadOp.getLoc(), loadOp.getType(), viewLoad /*,wait->getResult(0)*/); + // Replace all uses of loadResult + loadResult.replaceAllUsesWith(sharedLoad.getResult()); + loadOp.erase(); + return copy; +} + +static int getTMALoadSize(tt::ExperimentalDescriptorLoadOp &tmaLoad) { + auto tensorTy = cast(tmaLoad->getResult(0).getType()); + int loadSize = product(tensorTy.getShape()); + return loadSize * tensorTy.getElementType().getIntOrFloatBitWidth() / 8; +} + +Value getBarrierForPipelineStage(OpBuilderWithAsyncTaskIds &builder, + Value barrierAlloc, Value bufferIdx) { + auto context = barrierAlloc.getContext(); + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + tt::MemDescType barrierTy = tt::MemDescType::get( + {1}, builder.getI64Type(), + cast(barrierAlloc.getType()).getEncoding(), + sharedMemorySpace, + /*mutableMemory=*/true); + + // Create barrierForTMA from barrierAlloc. + return builder.createWithAsyncTaskIds( + barrierAlloc.getLoc(), barrierTy, barrierAlloc, + ArrayRef({bufferIdx})); +} + +Value getBufferForPipelineStage(OpBuilderWithAsyncTaskIds &builder, + Type loadType, Value buffer, Value bufferIdx, + bool mutableMem) { + auto context = buffer.getContext(); + auto tensorType = dyn_cast(loadType); + assert(tensorType); + + auto order = ttg::getOrder(tensorType.getEncoding()); + auto CTALayout = ttg::getCTALayout(tensorType.getEncoding()); + auto elemType = tensorType.getElementType(); + + // Get shape, layout and type of a slice + auto sliceShape = tensorType.getShape(); + auto sharedLayout = ttg::SharedEncodingAttr::get(context, sliceShape, order, + CTALayout, elemType); + auto sliceType = RankedTensorType::get(sliceShape, elemType, sharedLayout); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + tt::MemDescType subviewTy = + tt::MemDescType::get(sliceType.getShape(), sliceType.getElementType(), + sliceType.getEncoding(), sharedMemorySpace, + /*mutableMemOry=*/mutableMem); + + Value zero = builder.createWithAsyncTaskIds( + buffer.getLoc(), 0, 32); + SmallVector copyOffsets(sliceType.getRank() + 1, zero); + copyOffsets[0] = bufferIdx; + + return builder.createWithAsyncTaskIds( + buffer.getLoc(), subviewTy, buffer, copyOffsets); +} + +Operation * +optimizeTMALoads(OpBuilderWithAsyncTaskIds &builder, + SmallVector &tmaLoads, + SmallVector &buffers, Value barrierAlloc, + Value bufferIdx, Value bufferIdxExtract, Value phase, + Operation *headProducer, Operation *headConsumer) { + auto loc = barrierAlloc.getLoc(); + + // Compute the total size of the loads. + int sizeInBytes = 0; + for (auto &tmaLoad : tmaLoads) { + sizeInBytes += getTMALoadSize(tmaLoad); + } + + // For each of the following ops, we will operate on a subview of each value + // according to the pipeline stage. + + // Create a barrier_expect with the appropriate size and insert it before the + // first load. + builder.setInsertionPoint(headProducer); + builder.setAsyncTaskIdsFromOp(headProducer); + auto prodBarrier = + getBarrierForPipelineStage(builder, barrierAlloc, bufferIdx); + auto pred = builder.createWithAsyncTaskIds(loc, 1, 1); + auto expect = builder.createWithAsyncTaskIds( + loc, prodBarrier, sizeInBytes, pred); + + // Convert all the producers to async_tma_copy_global_to_local + Operation *copy = nullptr; + for (auto [tmaLoad, buffer] : zip(tmaLoads, buffers)) { + auto pipelineBuffer = getBufferForPipelineStage(builder, tmaLoad.getType(), + buffer, bufferIdx, true); + copy = builder.createWithAsyncTaskIds( + loc, tmaLoad.getDescPtr(), tmaLoad.getIndices(), prodBarrier, + pipelineBuffer, pred); + } + + // Create a wait_barrier before the first consumer. + builder.setInsertionPoint(headConsumer); + builder.setAsyncTaskIdsFromOp(headConsumer); + auto consBarrier = + getBarrierForPipelineStage(builder, barrierAlloc, bufferIdxExtract); + phase = builder.createWithAsyncTaskIds( + loc, builder.getI32Type(), phase); + auto wait = builder.createWithAsyncTaskIds( + loc, consBarrier, phase); + + // Convert all the consumers to local_load + for (auto [tmaLoad, buffer] : zip(tmaLoads, buffers)) { + auto pipelineBuffer = getBufferForPipelineStage( + builder, tmaLoad.getType(), buffer, bufferIdxExtract, false); + auto sharedLoad = builder.createWithAsyncTaskIds( + loc, tmaLoad.getType(), pipelineBuffer); + + Value loadResult = tmaLoad.getResult(); + tmaLoad.getResult().replaceAllUsesWith(sharedLoad.getResult()); + tmaLoad.erase(); + } + return copy; +} + +// Lower producers for channels. Here channels are grouped in "map". tokenMap +// tracks the set of tokens for each channel. +void buildAsyncComm( + const DenseMap> &map, + const DenseMap> &tokenMap, + const DenseMap> &barrierAllocMap, + const DenseMap &bufferMap, int numBuffers, + int numConsumerGroups) { + + // Find the operation that is along producer's parent chain, and its parent + // is the same op as producer's parent. Here p is producer, and c is consumer. + auto getSameLevelOp = [](Operation *p, Operation *c) -> Operation * { + while (!isa(c)) { + if (c->getParentOp() == p->getParentOp()) { + return c; + } + c = c->getParentOp(); + } + llvm_unreachable("Failed to find consumer's same level Op with producer"); + }; + + auto consumerReleaseHeutistic = [&](Operation *p, Operation *c, + int consumerAsyncTaskId) -> Operation * { + if (c->getBlock() != p->getBlock()) + return getSameLevelOp(p, c); + for (auto it = c->getBlock()->rbegin(); it != c->getBlock()->rend(); ++it) { + if (!it->hasAttr("async_task_id")) + continue; + auto asyncAttr = it->getAttrOfType("async_task_id") + .getValues(); + if (asyncAttr.size() == 1 && asyncAttr[0] == consumerAsyncTaskId) + return &(*it); + } + return nullptr; + }; + + auto getAsyncTasks = [&](Operation *p, Operation *c, + SmallVector &asyncTaskP, + SmallVector &asyncTaskC, + SmallVector &asyncTasksPC) -> void { + asyncTaskP = getNestedAsyncTaskIds(p); + asyncTaskC = getNestedAsyncTaskIds(c); + asyncTasksPC.reserve(asyncTaskP.size() + asyncTaskC.size()); + asyncTasksPC.insert(asyncTasksPC.end(), asyncTaskP.begin(), + asyncTaskP.end()); + asyncTasksPC.insert(asyncTasksPC.end(), asyncTaskC.begin(), + asyncTaskC.end()); + }; + + // Go through each channel group. + for (auto kv : map) { + auto headProducer = kv.second.front()->srcOp; + auto tailProducer = kv.second.back()->srcOp; + auto headConsumer = kv.second.front()->dstOp; + auto tailConsumer = kv.second.back()->dstOp; + // We have one set of tokens for each channel group. + auto tokens = tokenMap.find(kv.second.front())->second; + + SmallVector asyncTaskP, asyncTaskC, asyncTasksPC; + getAsyncTasks(headProducer, headConsumer, asyncTaskP, asyncTaskC, + asyncTasksPC); + OpBuilderWithAsyncTaskIds builder(headProducer->getContext()); + if (auto funcOp = dyn_cast(headProducer->getParentOp())) { + builder.setInsertionPointToStart(&(funcOp.getBody().front())); + } else { + builder.setInsertionPoint(headProducer->getParentOp()); + } + builder.setAsynTaskIdsFromArray(asyncTasksPC); + + Value bufferIdx; + Value phase = Value(); + if (auto forOp = headProducer->getParentOfType()) { + // We already added phase, bufferIdx to the ForOp. + auto tSize = forOp.getBody()->getArguments().size(); + assert(tSize >= 2); + bufferIdx = forOp.getBody()->getArguments().back(); + phase = forOp.getBody()->getArgument(tSize - 2); // next to last argument + } else { + // Producer is not in a ForOp, create phase and bufferIdx here. + bufferIdx = builder.createWithAsyncTaskIds( + headProducer->getLoc(), 0, 32); + phase = builder.createWithAsyncTaskIds( + headProducer->getLoc(), 0, 1); + } + + assert((isa(headProducer)) && + "producer must be a LoadOp or tma LoadOp"); + builder.setAsynTaskIdsFromArray(asyncTaskP); + for (auto token : tokens) { + // Insert ProducerAcquireOp before the producer. + builder.setInsertionPoint(headProducer); + builder.createWithAsyncTaskIds( + headProducer->getLoc(), token.second, bufferIdx, phase); + + // Insert ProducerCommitOp if producer is LoadOp. For TMA, TMA lowering + // will handle the ProducerCommit. + if (isa(headProducer)) { + builder.setInsertionPointAfter(tailProducer); + builder.createWithAsyncTaskIds( + tailProducer->getLoc(), token.second, bufferIdx); + } + } + + for (auto token : tokens) { + builder.setAsynTaskIdsFromArray(token.first); + // Insert ConsumerWaitOp + if (!isa(headProducer)) { + auto consumerWaitPoint = getSameLevelOp(headProducer, headConsumer); + builder.setInsertionPoint(consumerWaitPoint); + builder.createWithAsyncTaskIds( + headConsumer->getLoc(), token.second, bufferIdx, phase); + } + + // Insert ConsumerReleaseOp. + auto consumerReleasePoint = + consumerReleaseHeutistic(tailProducer, tailConsumer, token.first); + builder.setInsertionPointAfter(consumerReleasePoint); + builder.createWithAsyncTaskIds( + consumerReleasePoint->getLoc(), token.second, bufferIdx); + } + + SmallVector tmaLoads; + SmallVector buffers; + // Go through all channels in this channel group. + for (auto &c : kv.second) { + assert( + (isa(c->srcOp)) && + "producer must be a LoadOp or tma LoadOp"); + bool insideLoop = c->srcOp->getParentOfType() != nullptr; + if (isa(c->srcOp)) { + // After createAsyncCopy, c->srcOp/headProducer are no longer valid. + createAsyncCopy(bufferMap, c, c->srcOp, asyncTasksPC, bufferIdx, + bufferIdx); + } else if (auto tmaLoad = + dyn_cast(c->srcOp)) { + tmaLoads.push_back(tmaLoad); + buffers.push_back(bufferMap.find(c)->second); + } + } + + // Optimize TMA loads. + if (tmaLoads.size() > 0) { + auto barrierAllocs = barrierAllocMap.find(kv.second.front())->second; + // TODO: we created one Alloc for each consumer taskId, but here, we + // only use the first Alloc. + auto barrierAlloc = barrierAllocs.begin()->second; + optimizeTMALoads(builder, tmaLoads, buffers, barrierAlloc, bufferIdx, + bufferIdx, phase, headProducer, headConsumer); + } + } +} + +// Collect argument indices that are used by the specific taskId. +static SmallVector collectBlockArgsForTask( + scf::ForOp forOp, int asyncTaskId, + DenseMap &blockArgToYieldOperand) { + DenseSet seen; + // Collect argument indices that can be reached along the definition chain. + // If reaching a BlockArgument, visit the corresponding yield operand. + SetVector argIndices; + std::function dfs = [&](Operation *op) { + if (!seen.insert(op).second) + return; + for (Value operand : op->getOperands()) { + if (auto blockArg = dyn_cast(operand)) { + if (!blockArgToYieldOperand[blockArg]) + continue; + argIndices.insert(blockArg.getArgNumber() - + forOp.getNumInductionVars()); + operand = blockArgToYieldOperand[blockArg]; + } + Operation *depOp = operand.getDefiningOp(); + assert(depOp && "Unexpected Value with no defining op"); + if (depOp->getBlock() != forOp.getBody()) + continue; + assert(hasAsyncTaskId(depOp, asyncTaskId) && "Dependency error"); + dfs(depOp); + } + }; + + // Start from operations that are marked with this asyncTaskId explicitly and + // check dependency with DFS traversal. + forOp.walk([&](Operation *op) { + if (hasAsyncTaskId(op, asyncTaskId) && !isa(op)) + dfs(op); + }); + + SmallVector args(argIndices.begin(), argIndices.end()); + llvm::sort(args); + return args; +} + +DenseMap +createForOpsForEachAsyncTaskId(scf::ForOp forOp) { + // Collect operation list for each asyncTaskId. + DenseMap> opList; + for (Operation &op : forOp.getBody()->without_terminator()) { + auto ids = getAsyncTaskIds(&op); + for (AsyncTaskId asyncTaskId : ids) + opList[asyncTaskId].push_back(&op); + } + + // Prepare blockArgToYieldOperand mapping. + DenseMap blockArgToYieldOperand; + auto yieldOp = llvm::cast(forOp.getBody()->getTerminator()); + assert(yieldOp.getNumOperands() == forOp.getNumRegionIterArgs()); + for (unsigned i = 0; i < forOp.getNumRegionIterArgs(); ++i) + blockArgToYieldOperand[forOp.getRegionIterArg(i)] = yieldOp.getOperand(i); + + auto loc = forOp.getLoc(); + OpBuilderWithAsyncTaskIds builder(forOp.getContext()); + DenseMap asyncTasksToForOp; + + // Create newForOp for each task Id. + for (AsyncTaskId asyncTaskId : getNestedAsyncTaskIds(forOp)) { + auto usedArgs = + collectBlockArgsForTask(forOp, asyncTaskId, blockArgToYieldOperand); + + // Prepare newLoopArgs. + SmallVector newLoopArgs; + for (unsigned argNumber : usedArgs) + newLoopArgs.push_back(forOp.getInitArgs()[argNumber]); + + // Create newForOp. + builder.setAsynTaskIdsFromArray({asyncTaskId}); + builder.setInsertionPoint(forOp); + auto newForOp = builder.createWithAsyncTaskIds( + loc, forOp.getLowerBound(), forOp.getUpperBound(), forOp.getStep(), + newLoopArgs); + if (forOp->getAttr("tt.loop_schedule")) + newForOp->setAttr("tt.loop_schedule", forOp->getAttr("tt.loop_schedule")); + + // Initialize Value mapping from forOp to newForOp + IRMapping mapping; + mapping.map(forOp.getInductionVar(), newForOp.getInductionVar()); + for (unsigned i = 0; i < usedArgs.size(); ++i) { + auto oldArg = forOp.getRegionIterArgs()[usedArgs[i]]; + auto newArg = newForOp.getRegionIterArgs()[i]; + mapping.map(oldArg, newArg); + } + + // Clone all operations with this asyncTaskId to newForOp. + builder.setInsertionPointToStart(newForOp.getBody()); + for (Operation *op : opList[asyncTaskId]) { + Operation *newOp = builder.clone(*op, mapping); + setAsyncTaskIds(newOp, {asyncTaskId}); + for (unsigned i = 0; i < op->getNumResults(); ++i) + mapping.map(op->getResult(i), newOp->getResult(i)); + } + + // Create YieldOp for newForOp. + SmallVector newYieldOperands; + for (unsigned i : usedArgs) { + LDBG("lookup operand " << i); + newYieldOperands.push_back(mapping.lookup(yieldOp.getOperand(i))); + } + bool createNewYield = true; + if (newForOp.getBody()->mightHaveTerminator()) { + auto initialYield = + llvm::cast(newForOp.getBody()->getTerminator()); + if (newYieldOperands.size() == 0) { + setAsyncTaskIds(initialYield, {asyncTaskId}); + createNewYield = false; + } + } + if (createNewYield) { + auto newYieldOp = + builder.create(yieldOp.getLoc(), newYieldOperands); + setAsyncTaskIds(newYieldOp, {asyncTaskId}); + } + + // Replace results of forOp with results of newForOp. + for (unsigned i = 0; i < usedArgs.size(); ++i) { + auto oldResult = forOp.getResult(usedArgs[i]); + auto newResult = newForOp.getResult(i); + oldResult.replaceUsesWithIf(newResult, [&](OpOperand &operand) -> bool { + return hasAsyncTaskId(operand.getOwner(), asyncTaskId); + }); + } + + asyncTasksToForOp[asyncTaskId] = newForOp; + } + + return asyncTasksToForOp; +} + +// Input asyncTaskTopOp can be an IfOp that contains a ForOp. We clone +// the ForOp for each asyncTaskId. +DenseMap +asyncTaskDivision(Operation *asyncTaskTopOp) { + DenseMap asyncTaskTopOpMap; + Operation *mainForOp = asyncTaskTopOp; + if (auto ifOp = dyn_cast(asyncTaskTopOp)) { + // Find the outmost ForOp inside. Assume only a single ForOp. + Operation *nestedFor = nullptr; + asyncTaskTopOp->walk([&](Operation *op) { + if (auto forOp = dyn_cast(op)) { + assert(nestedFor == nullptr); + nestedFor = op; + } + }); + assert(nestedFor && "can't find ForOp in a top-level IfOp"); + mainForOp = nestedFor; + } + asyncTaskTopOp->walk([&](Operation *op) { + auto ids = getAsyncTaskIds(op); + if (op->getNumRegions() > 0 && ids.size() > 1) { + if (auto forOp = dyn_cast(op)) { + // Create a cloned ForOp for each taskId and return the map. + auto forOps = createForOpsForEachAsyncTaskId(forOp); + if (op == mainForOp) { + for (auto kv : forOps) { + auto f = kv.second; + auto id = getAsyncTaskIds(f.getOperation()); + assert(id.size() == 1 && + "generated ForOp doesn't have one and only one asyncTaskId"); + asyncTaskTopOpMap[id.front()] = f.getOperation(); + } + } + // For debugging purposes, check to see if it is safe to erase the + // original ForOp. + bool hasIssue = false; + for (Operation &opT : forOp.getBody()->without_terminator()) { + // Check to see if opT is used in another block. + for (unsigned i = 0; i < opT.getNumResults(); ++i) + for (Operation *user : opT.getResult(i).getUsers()) { + if (user->getBlock() != opT.getBlock()) { + hasIssue = true; + LLVM_DEBUG({ + LDBG("-- op has user in another block"); + opT.dump(); + user->dump(); + }); + } + } + } + if (hasIssue) { + for (Operation &opT : forOp.getBody()->without_terminator()) { + LLVM_DEBUG({ + LDBG("addr " << (&opT) << ": "); + opT.dump(); + }); + } + } + bool hasUse = false; + for (unsigned i = 0; i < op->getNumResults(); ++i) { + for (Operation *user : op->getResult(i).getUsers()) { + hasUse = true; + LLVM_DEBUG({ + LDBG("op has use "); + user->dump(); + }); + } + } + ModuleOp moduleOp = forOp->getParentOfType(); + LLVM_DEBUG({ + LDBG("erase ForOp"); + forOp.dump(); + }); + forOp.erase(); + LDBG("done erasing ForOp"); + } else if (auto ifOp = dyn_cast(op)) { + // The ForOp inside this ifOp will be cloned. + LDBG("IfOp in asyncTaskDivision"); + } else if (auto whileOp = dyn_cast(op)) { + LDBG("WhileOp in asyncTaskDivision"); + } else { + llvm_unreachable("Unexpected Op with regions"); + } + } + }); + assert(asyncTaskTopOpMap.size() > 0 && "AsyncTask division failed"); + return asyncTaskTopOpMap; +} + +void cloneAsyncTaskLoopForEachAsyncTaskId( + SmallVector &asyncTaskTopOps) { + SmallVector newBackBone; + + for (Operation *op : asyncTaskTopOps) { + auto loc = op->getLoc(); + OpBuilderWithAsyncTaskIds builder(op->getContext()); + builder.setInsertionPoint(op); + // Step 1: create a cloned forOp for each taskId based on the original + // ForOp that is in this top-level operation. + DenseMap newAsyncTaskLoops = + asyncTaskDivision(op); + + // Step 2: remove irrelevant Ops from the cloned ForOps. + for (auto kv : newAsyncTaskLoops) { + SmallVector deleteOps; + AsyncTaskId targetId = kv.first; + Operation *newAsyncTaskLoop = kv.second; + newAsyncTaskLoop->walk([&](Operation *subOp) { + auto ids = getAsyncTaskIds(subOp); + if (std::find(ids.begin(), ids.end(), targetId) == ids.end()) { + deleteOps.push_back(subOp); + } + }); + for (auto it = deleteOps.rbegin(); it != deleteOps.rend(); ++it) { + (*it)->erase(); + } + } + } +} + +class TritonGPUWSCodePartitionPass + : public impl::TritonGPUWSCodePartitionBase { +public: + using impl::TritonGPUWSCodePartitionBase< + TritonGPUWSCodePartitionPass>::TritonGPUWSCodePartitionBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + // Disable code partitioning when numBuffers is 0. + if (numBuffers == 0) + return; + + // Step 1: collect all communications between producers and consumers. + SmallVector> channelsOrigin; + collectAsyncChannels(channelsOrigin, funcOp); + SmallVector channels; + for (const auto &c : channelsOrigin) { + channels.push_back(c.get()); + } + if (channels.empty()) { + return; + } + + // Step 2: group channels where each entry of the map is keyed by the dstOp. + DenseMap> map; + SmallVector mapKeyVec; + groupChannels(channels, map, mapKeyVec); + + // Step 3: reorder producer ops and the backward slices of the producer ops. + reorderProducerOps(channels); + + // Step 4: find top-level ops that contain a channel, also create new ForOps + // by adding phase and bufferIdx to the original ForOps, erase the original + // ForOps. + SmallVector asyncTaskTopOps = + getTaskTopRegion(funcOp, channels); + appendBufferIdxArgs(asyncTaskTopOps, numBuffers); + + // Step 5: Create tokens, and buffers. A set of tokens for each group of + // channels and an array of buffers for each channel. + DenseMap> barrierAllocMap; + DenseMap> tokenMap = createToken( + map, mapKeyVec, funcOp, numBuffers, numConsumerGroups, barrierAllocMap); + DenseMap bufferMap = + createBuffer(channels, funcOp, numBuffers, numConsumerGroups); + LLVM_DEBUG({ + LDBG("\n\nafter createBuffer"); + funcOp.dump(); + }); + + // Step 6: add async communication ops (ProducerAcquire etc). Also lower the + // loads. + buildAsyncComm(map, tokenMap, barrierAllocMap, bufferMap, numBuffers, + numConsumerGroups); + LLVM_DEBUG({ + LDBG("\n\nwith SyncOps"); + funcOp.dump(); + }); + + // If loadResult has a single use which is LocalAlloc, we can get rid of + // sharedLoad and replace all uses of LocalAlloc with viewLoad. + DenseMap opsToReplace; + funcOp.walk([&](ttg::LocalAllocOp localAlloc) { + if (auto src = localAlloc.getSrc()) { + if (auto localLoad = dyn_cast(src.getDefiningOp())) { + opsToReplace[localAlloc] = localLoad.getSrc(); + } + } + }); + OpBuilderWithAsyncTaskIds builder(funcOp.getContext()); + for (auto kv : opsToReplace) + replaceUsesAndPropagateType(builder, kv.getFirst(), kv.getSecond()); + LLVM_DEBUG({ + LDBG("\n\nsimplify localLoad + localAlloc"); + funcOp.dump(); + }); + + // Clone taskTopOp, remove irrelevant blockArgument for {forOp, ifOp} + cloneAsyncTaskLoopForEachAsyncTaskId(asyncTaskTopOps); + LLVM_DEBUG({ + LDBG("\n\nwith Loop Split"); + funcOp.dump(); + }); + + auto ret = SpecializeRegion(funcOp, regDecProducer, regIncConsumer); + LLVM_DEBUG({ + LDBG("\n\nwith IfOps"); + funcOp.dump(); + }); + } + + void runOnOperation() override { + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + LLVM_DEBUG({ + LDBG("post pass"); + getOperation()->dump(); + }); + return; + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp b/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp new file mode 100644 index 000000000..9e9f0a6a3 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/WSDataPartition.cpp @@ -0,0 +1,680 @@ +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Verifier.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Transforms/Passes.h" +#include "mlir/Transforms/RegionUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define DEBUG_TYPE "tritongpu-warp-spec-data-partition" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +static bool oneVecCoversTheOther(SmallVector &one, + SmallVector &other) { + // Every element of other appears in one. + for (AsyncTaskId t : other) { + // If t doesn't appear in one, return false. + bool found = false; + for (AsyncTaskId t2 : one) { + if (t2 == t) { + found = true; + break; + } + } + if (!found) + return false; + } + return true; +} + +// Make sure the def chain contains the right taskId. +bool fixTaskId(triton::FuncOp &funcOp) { + bool retCode = true; + funcOp.walk([&](Operation *op) { + auto asyncTaskIds = getAsyncTaskIds(op); + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (!defOp) + continue; + // Do not update loads. + if (isa(defOp)) + continue; + auto defTaskIds = getAsyncTaskIds(defOp); + // Make sure defTaskIds cover asyncTaskIds. Call addAsyncTaskIds if + // necessary. + if (!oneVecCoversTheOther(defTaskIds, asyncTaskIds)) { + retCode = false; + // Const ops with same value but different task ids can be folded. + if (isa(defOp)) { + LLVM_DEBUG({ + LDBG("fixing taskId for"); + defOp->dump(); + }); + addAsyncTaskIds(defOp, asyncTaskIds); + LLVM_DEBUG({ + LDBG("resulting"); + defOp->dump(); + }); + } + } + } + }); + return retCode; +} + +static SmallVector getShape(Value v) { + auto type = v.getType(); + if (auto type = dyn_cast(v.getType())) { + return {type.getShape().begin(), type.getShape().end()}; + } else if (auto type = dyn_cast(v.getType())) { + return {type.getShape().begin(), type.getShape().end()}; + } + return {}; +} + +bool needToSlice(Value v, int dim, int size) { + auto shape = getShape(v); + return shape.size() > dim && shape[dim] > size; +} + +bool getBackwardSliceToPartition(Value root, unsigned dim, int sliceSize, + SetVector &backwardSlice) { + auto newOpInserted = false; + SmallVector queue = {root}; + while (!queue.empty()) { + auto v = queue.back(); + queue.pop_back(); + if (!needToSlice(v, dim, sliceSize)) + continue; + if (auto op = v.getDefiningOp()) { + auto inserted = backwardSlice.insert(op); + newOpInserted |= inserted; + if (inserted) { + if (op->hasTrait() || + isa(op)) { + for (Value operand : op->getOperands()) + queue.push_back(operand); + } else if (auto dotOp = dyn_cast(op)) { + queue.push_back(dim == 0 ? dotOp.getA() : dotOp.getB()); + queue.push_back(dotOp.getC()); + } else { + llvm_unreachable("Unexpected op"); + } + } + } else { + assert(isa(v) && "value is not an operation or block "); + auto bbArg = cast(v); + Operation *bbAargOwner = bbArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(bbAargOwner)) { + // track initial value + auto initArg = forOp.getInitArgs()[bbArg.getArgNumber() - 1]; + queue.push_back(initArg); + // track yield value + auto yieldArg = forOp.getYieldedValues()[bbArg.getArgNumber() - 1]; + queue.push_back(yieldArg); + } + } + } + return newOpInserted; +}; + +bool getForwardSliceToPartition(Value root, unsigned dim, int sliceSize, + SetVector &forwardSlice) { + auto newOpInserted = false; + SmallVector queue = {root}; + llvm::SmallDenseSet seen; + while (!queue.empty()) { + auto v = queue.back(); + queue.pop_back(); + if (!seen.insert(v).second) + continue; + if (!needToSlice(v, dim, sliceSize)) + continue; + getForwardSlice(v, &forwardSlice); + for (Operation *op : forwardSlice) { + if (op->getNumResults() > 0) + seen.insert(op->getResult(0)); + if (auto yieldOp = dyn_cast(op)) { + if (auto forOp = dyn_cast(yieldOp->getParentOp())) { + for (OpOperand &operand : yieldOp->getOpOperands()) { + if (seen.count(operand.get())) { + queue.push_back(forOp->getResult(operand.getOperandNumber())); + forwardSlice.insert(forOp); + newOpInserted = true; + } + } + } + } + } + } + return newOpInserted; +}; + +// Compute a closure of all ops originated from or being dependent on by the +// root op. +void getSliceToPartition(Value root, unsigned dim, int sliceSize, + SetVector &slice) { + auto newOpInserted = false; + while (!newOpInserted) { + newOpInserted |= getBackwardSliceToPartition(root, dim, sliceSize, slice); + SetVector forwardSlice; + newOpInserted |= + getForwardSliceToPartition(root, dim, sliceSize, forwardSlice); + slice.insert(forwardSlice.begin(), forwardSlice.end()); + for (auto op : forwardSlice) { + if (op->hasTrait() || + isa(op)) { + for (OpOperand &operand : op->getOpOperands()) { + newOpInserted |= + getBackwardSliceToPartition(operand.get(), dim, sliceSize, slice); + } + } else if (auto dotOp = dyn_cast(op)) { + newOpInserted |= getBackwardSliceToPartition( + dim == 0 ? dotOp.getA() : dotOp.getB(), dim, sliceSize, slice); + newOpInserted |= + getBackwardSliceToPartition(dotOp.getC(), dim, sliceSize, slice); + } + } + } +} + +struct DataPartitionScheme { + // Which dimension to partition. For dot, dim 0 means along M dimension, 1 + // means along N dimensiont. + unsigned partitionDim = 0; + unsigned numPartitions = 0; + SetVector ops; +}; + +bool computePartitionScheme(triton::FuncOp &funcOp, + DataPartitionScheme &partitionScheme) { + // Do not partition producer tasks + + // Use dot to drive the partition + SetVector dots; + + // check all dot ops that have more than one async task id + funcOp.walk([&](Operation *op) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() > 1) { + if (auto dotWaitOp = dyn_cast(op)) { + dots.insert(dotWaitOp); + } + } + }); + + // Checking if all dots can be partitioned in the same way + int numWarps = + TritonGPUDialect::getNumWarps(funcOp->getParentOfType()); + for (auto dotOp : dots) { + // partition along M first, otherwise along N + RankedTensorType dotType = dotOp.getType(); + LLVM_DEBUG({ + LDBG("Computing partition scheme for"); + dotOp.dump(); + LDBG("\n"); + }); + auto shapePerCTA = getShapePerCTA(dotType); + if (shapePerCTA.size() != 2) { + LDBG("partition not possible: shapePerCTA " << shapePerCTA.size()); + return false; + } + auto CTALayout = getCTALayout(dotType.getEncoding()); + auto asyncTaskIds = getAsyncTaskIds(dotOp); + int sliceSizeM = shapePerCTA[0] / asyncTaskIds.size(); + int sliceSizeN = shapePerCTA[1] / asyncTaskIds.size(); + int partitionDim, partitionSize; + Value partitionOperand; + + if (sliceSizeM >= 64) { + LLVM_DEBUG({ LDBG("partition along M\n"); }); + partitionDim = 0; + partitionSize = sliceSizeM; + partitionOperand = dotOp.getA(); + } else if (sliceSizeN >= 256) { + LLVM_DEBUG({ LDBG("partition along N\n"); }); + partitionDim = 1; + partitionSize = sliceSizeN; + partitionOperand = dotOp.getB(); + } else { + LDBG("partition not possible: " << sliceSizeM << " " << sliceSizeN); + return false; + } + + if (partitionScheme.numPartitions == 0) { + partitionScheme.partitionDim = partitionDim; + partitionScheme.numPartitions = asyncTaskIds.size(); + } else { + if (partitionScheme.partitionDim != partitionDim || + partitionScheme.numPartitions != asyncTaskIds.size()) { + LDBG("partition not possible, in conflict with previous partition\n"); + return false; + } + } + + // Partition the slice closure + SetVector &slice = partitionScheme.ops; + getSliceToPartition(dotOp.getD(), partitionDim, partitionSize, slice); + + LLVM_DEBUG({ + partitionOperand.dump(); + LDBG("\n"); + LDBG(" slice:"); + for (auto &op : slice) { + op->dump(); + } + LDBG("\n"); + }); + + for (auto op : partitionScheme.ops) { + auto opTaskIds = getAsyncTaskIds(op); + // skip check for control flow ops + if (isa(op)) + continue; +#if 0 + if (opTaskIds.size() > partitionScheme.numPartitions) { + LLVM_DEBUG({ + LDBG("partition not possible: numPartitions" << opTaskIds.size() << " " << partitionScheme.numPartitions); + op->dump(); + }); + return false; + } +#endif + } + } + + return !partitionScheme.ops.empty(); +} + +Operation *sliceOp(Value v, int offset, OpBuilderWithAsyncTaskIds &builder, + IRMapping &mappings, IRMapping &reverseMappings, + DataPartitionScheme &partitionScheme); + +Operation *sliceOp(Operation *op, int offset, + OpBuilderWithAsyncTaskIds &builder, IRMapping &mappings, + IRMapping &reverseMappings, + DataPartitionScheme &partitionScheme) { + if (!partitionScheme.ops.contains(op)) + return op; + if (mappings.contains(op)) + return mappings.lookupOrNull(op); + if (reverseMappings.contains(op)) + return op; + + LLVM_DEBUG({ + LDBG("slicing:"); + op->dump(); + LDBG("\n"); + }); + + int dim = partitionScheme.partitionDim; + int numOfPartitions = partitionScheme.numPartitions; + + auto asyncTaskIds = getAsyncTaskIds(op); + SmallVector sliceTaskIds; + if (asyncTaskIds.size() == numOfPartitions) { + // We are slicing the op for consumer only + sliceTaskIds.push_back(asyncTaskIds[offset]); + } else if (asyncTaskIds.size() == 1) { + // We are slicing the op for producer only + sliceTaskIds.push_back(asyncTaskIds.front()); + } else if (asyncTaskIds.size() > numOfPartitions) { + // We are slicing the op for both producer and consumer + sliceTaskIds.push_back(asyncTaskIds.front()); + sliceTaskIds.push_back(asyncTaskIds[offset + 1]); + } else { + llvm_unreachable("Unexpected asyncTaskIds.size()"); + } + + builder.setAsynTaskIdsFromArray(sliceTaskIds); + auto cloneAndSetResultType = [&](Operation *op) { + builder.setInsertionPoint(op); + auto newOp = builder.clone(*op, mappings); + setAsyncTaskIds(newOp, sliceTaskIds); + mappings.map(op, newOp); + reverseMappings.map(newOp, op); + // set result shape + if (!op->getResults().empty()) { + auto v = op->getResult(0); + auto newV = newOp->getResult(0); + if (auto type = dyn_cast(v.getType())) { + SmallVector shape{type.getShape().begin(), + type.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto newType = + MemDescType::get(shape, type.getElementType(), type.getEncoding(), + type.getMemorySpace(), type.getMutableMemory()); + newV.setType(newType); + } else if (auto type = dyn_cast(v.getType())) { + SmallVector shape{type.getShape().begin(), + type.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto newType = RankedTensorType::get(shape, type.getElementType(), + type.getEncoding()); + newV.setType(newType); + } + + mappings.map(v, newV); + reverseMappings.map(newV, v); + } + return newOp; + }; + + // slice operands first + Operation *newOp; + if (op->hasTrait() || + isa( + op)) { + for (Value operand : op->getOperands()) + sliceOp(operand, offset, builder, mappings, reverseMappings, + partitionScheme); + newOp = cloneAndSetResultType(op); + } else if (auto constOp = dyn_cast(op)) { + builder.setInsertionPoint(op); + auto valAttr = cast(constOp.getValueAttr()); + auto valType = cast(valAttr.getType()); + SmallVector shape{valType.getShape().begin(), + valType.getShape().end()}; + int sliceSize = shape[dim] / numOfPartitions; + shape[dim] = sliceSize; + auto newValType = valType.clone(shape); + auto newValAttr = valAttr.resizeSplat(newValType); + newOp = builder.createWithAsyncTaskIds(op->getLoc(), + newValAttr); + // Do not drop original task id as constant folding may lose one constant. + setAsyncTaskIds(newOp, getAsyncTaskIds(op)); + auto v = op->getResult(0); + auto newV = newOp->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + } else if (auto makeRangeOp = dyn_cast(op)) { + builder.setInsertionPoint(op); + int newRangeStart = makeRangeOp.getStart(); + int newRangeEnd = makeRangeOp.getEnd(); + int sliceSize = (newRangeEnd - newRangeStart) / numOfPartitions; + newRangeStart += offset * sliceSize; + newRangeEnd = newRangeStart + sliceSize; + auto v = op->getResult(0); + auto type = cast(v.getType()); + auto newType = RankedTensorType::get({sliceSize}, builder.getI32Type(), + type.getEncoding()); + newOp = builder.createWithAsyncTaskIds( + op->getLoc(), newType, newRangeStart, newRangeEnd); + auto newV = newOp->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + } else if (isa(op)) { + for (Value operand : op->getOperands()) + sliceOp(operand, offset, builder, mappings, reverseMappings, + partitionScheme); + // TODO: slice store base ptr + newOp = cloneAndSetResultType(op); + } else if (isa( + op)) { + SmallVector shape; + Value coordVal; + if (auto loadOp = dyn_cast(op)) { + coordVal = loadOp.getIndices()[dim]; + shape = getShape(loadOp.getResult()); + } else if (auto storeOp = dyn_cast(op)) { + coordVal = storeOp.getIndices()[dim]; + shape = getShape(storeOp.getSrc()); + } + auto newCoordVal = coordVal; + if (offset) { + builder.setInsertionPointAfter(coordVal.getDefiningOp()); + Value offsetVal = builder.createWithAsyncTaskIds( + op->getLoc(), offset * shape[dim] / numOfPartitions, 32); + newCoordVal = builder.createWithAsyncTaskIds( + op->getLoc(), coordVal, offsetVal); + mappings.map(coordVal, newCoordVal); + reverseMappings.map(newCoordVal, coordVal); + } + + newOp = cloneAndSetResultType(op); + if (isa(op)) { + // map load result + auto v = op->getResult(0); + auto newV = newOp->getResult(0); + mappings.map(v, newV); + reverseMappings.map(newV, v); + } + } else if (auto dotOp = dyn_cast(op)) { + // Only hanlde A and accumulator + sliceOp(dim == 0 ? dotOp.getA() : dotOp.getB(), offset, builder, mappings, + reverseMappings, partitionScheme); + sliceOp(dotOp.getC(), offset, builder, mappings, reverseMappings, + partitionScheme); + newOp = cloneAndSetResultType(op); + } else if (auto forOp = dyn_cast(op)) { + // Add new loop arguments + SmallVector newLoopArgs; + for (auto initArg : forOp.getInitArgs()) + newLoopArgs.push_back(initArg); + DenseMap newArgIdices; + for (unsigned i = 0; i < forOp.getInitArgs().size(); i++) { + auto initArg = forOp.getInitArgs()[i]; + auto newInitArgOp = sliceOp(initArg.getDefiningOp(), offset, builder, + mappings, reverseMappings, partitionScheme); + auto newInitArg = newInitArgOp->getResult(0); + if (newInitArg != initArg) { + newLoopArgs.append({newInitArg}); + forOp.getBody()->insertArgument(forOp.getBody()->getNumArguments(), + newInitArg.getType(), forOp.getLoc()); + newArgIdices[i] = newLoopArgs.size() - 1; + } + } + + // Create newForOp and take the region of forOp + builder.setInsertionPoint(op); + auto newForOp = builder.createWithAsyncTaskIds( + forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newLoopArgs); + assert(newForOp.getRegionIterArgs().size() == + newForOp.getInitArgs().size()); + newForOp->setAttrs(forOp->getAttrs()); + partitionScheme.ops.insert(newForOp); + newOp = newForOp; + + // Replace forOp with newForOp + newForOp.getRegion().takeBody(forOp.getRegion()); + for (unsigned i = 0; i < forOp.getNumResults(); ++i) + forOp.getResult(i).replaceAllUsesWith(newForOp.getResult(i)); + op->setAttr("to_be_removed", builder.getUnitAttr()); + + // Map new loop arguments + for (auto argIndex : newArgIdices) { + Value v = newForOp.getResult(argIndex.first); + Value newV = newForOp.getResult(argIndex.second); + mappings.map(v, newV); + reverseMappings.map(newV, v); + + auto regionArg = newForOp.getRegionIterArg(argIndex.first); + auto newRegionArg = newForOp.getRegionIterArg(argIndex.second); + mappings.map(regionArg, newRegionArg); + reverseMappings.map(newRegionArg, regionArg); + } + + } else if (auto yieldOp = dyn_cast(op)) { + int num = yieldOp.getNumOperands(); + for (int i = 0; i < num; i++) { + auto operand = yieldOp.getOperand(i); + sliceOp(operand, offset, builder, mappings, reverseMappings, + partitionScheme); + if (auto newV = mappings.lookupOrNull(operand)) + yieldOp->insertOperands(op->getNumOperands(), newV); + } + newOp = op; + } else if (auto reduceOp = dyn_cast(op)) { + assert(reduceOp.getAxis() != partitionScheme.partitionDim && + "reduce should not happen on the partitioned dimension"); + for (Value operand : op->getOperands()) + sliceOp(operand, offset, builder, mappings, reverseMappings, + partitionScheme); + newOp = cloneAndSetResultType(op); + } else { + llvm_unreachable("unsupported value type"); + } + + LLVM_DEBUG({ + LDBG("resulting"); + newOp->dump(); + LDBG("\n"); + }); + mappings.map(op, newOp); + reverseMappings.map(newOp, op); + return newOp; +} + +Operation *sliceOp(Value v, int offset, OpBuilderWithAsyncTaskIds &builder, + IRMapping &mappings, IRMapping &reverseMappings, + DataPartitionScheme &partitionScheme) { + if (auto op = v.getDefiningOp()) { + return sliceOp(op, offset, builder, mappings, reverseMappings, + partitionScheme); + } else { + assert(isa(v) && "value is not an operation or block "); + auto bbArg = cast(v); + Operation *bbAargOwner = bbArg.getOwner()->getParentOp(); + return sliceOp(bbAargOwner, offset, builder, mappings, reverseMappings, + partitionScheme); + } +} + +void partitionTasks(triton::FuncOp &funcOp) { + + // op -> (partition dim, num of partitions) + DataPartitionScheme partitionScheme; + if (!computePartitionScheme(funcOp, partitionScheme)) + return; + + for (int i = 0; i < partitionScheme.numPartitions; i++) { + OpBuilderWithAsyncTaskIds builder(funcOp.getContext()); + IRMapping mappings, reverseMappings; + + LLVM_DEBUG({ LDBG("partitioning op for task " << i << ":\n"); }); + + // TODO: compute a topological order for partitionScheme.ops and + // slice in that order. + int numOps = partitionScheme.ops.size(); + for (int j = 0; j < numOps; j++) { + auto op = partitionScheme.ops[j]; + sliceOp(op, i, builder, mappings, reverseMappings, partitionScheme); + } + + // clean up + SmallVector opsToDelete; + for (auto op : partitionScheme.ops) { + if (op->hasAttr("to_be_removed")) + opsToDelete.push_back(op); + } + for (auto op : opsToDelete) { + partitionScheme.ops.remove(op); + op->erase(); + } + } + + // clean up + + SmallVector opsToDelete; + for (auto op : partitionScheme.ops) { + if (isa(op)) + continue; + bool notUsed = true; + for (auto result : op->getResults()) { + if (!result.getUsers().empty()) { + notUsed = false; + break; + } + } + if (notUsed) + opsToDelete.push_back(op); + } + + LLVM_DEBUG({ + LDBG("opsToDelete:\n"); + for (auto op : opsToDelete) { + LDBG("op: "); + op->dump(); + } + LDBG("\n"); + }); + for (auto op : opsToDelete) { + partitionScheme.ops.remove(op); + op->erase(); + } + LLVM_DEBUG({ + LDBG("prior to clean up:"); + funcOp.dump(); + }); + + // delete block arguments + RewritePatternSet cleanUpPatterns(funcOp.getContext()); + populateForOpDeadArgumentElimination(cleanUpPatterns); + scf::ForOp::getCanonicalizationPatterns(cleanUpPatterns, funcOp.getContext()); + if (applyPatternsAndFoldGreedily(funcOp, std::move(cleanUpPatterns)) + .failed()) { + llvm_unreachable("failed to clean up"); + // signalPassFailure(); + } + + // Make sure original ops are not used + LLVM_DEBUG({ + LDBG("after partition"); + funcOp.dump(); + LDBG("\n"); + }); + fixTaskId(funcOp); +} + +#define GEN_PASS_DEF_TRITONGPUWSDATAPARTITION +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +class TritonGPUWSDataPartitionPass + : public impl::TritonGPUWSDataPartitionBase { +public: + using impl::TritonGPUWSDataPartitionBase< + TritonGPUWSDataPartitionPass>::TritonGPUWSDataPartitionBase; + + void runOnFuncOp(triton::FuncOp funcOp) { + if (numConsumerGroups == 0) + return; + partitionTasks(funcOp); + } + + void runOnOperation() override { + getOperation()->walk([&](triton::FuncOp funcOp) { runOnFuncOp(funcOp); }); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp b/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp new file mode 100644 index 000000000..c2bf31fc5 --- /dev/null +++ b/lib/Dialect/TritonGPU/Transforms/WSLowering.cpp @@ -0,0 +1,349 @@ +#include "mlir/Transforms/Passes.h" +#include "triton/Dialect/TritonGPU/Transforms/Passes.h" + +#include + +#include "mlir/IR/OperationSupport.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/Triton/IR/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" + +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; +namespace ttng = ::mlir::triton::nvidia_gpu; +namespace mlir { +namespace triton { +namespace gpu { + +#define DEBUG_TYPE "tritongpu-warp-spec-lowering" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +enum class LoadType { + LoadAsyncOp, + LoadTMAOp, +}; + +static Value createThreadIdOp(OpBuilder &builder, Location loc) { + Value threadId = builder.create<::mlir::gpu::ThreadIdOp>( + loc, builder.getIndexType(), ::mlir::gpu::Dimension::x); + auto cast = builder.create( + loc, TypeRange{builder.getIntegerType(32)}, ValueRange{threadId}); + return cast.getResult(0); +} + +// Lower to use GetCanonicalWarpIdOp. +// In Hopper, each task is a warpgroup consisting of 4 warps. +static const int WARPS_PER_TASK = 4; +static const int THREADS_PER_TASK = 128; +void lowerGetAsyncTaskIdOp(Operation *parentOp, int numConsumerGroups) { + DenseSet eraseOps; + parentOp->walk([&](ttng::GetAsyncTaskIdOp op) { + auto loc = op.getLoc(); + OpBuilder builder(op); + Value _4 = builder.create(loc, WARPS_PER_TASK, 32); + Value warpId = builder.create(loc); + Value asyncTaskId = builder.create(loc, warpId, _4); + op.getResult().replaceAllUsesWith(asyncTaskId); + + LLVM_DEBUG({ + LDBG("erasing GetAsyncTask"); + op->dump(); + }); + eraseOps.insert(op); + }); + for (Operation *op : eraseOps) + op->erase(); +} + +//===----------------------------------------------------------------------===// +// Lower token operations +//===----------------------------------------------------------------------===// + +LoadType scanLoadTypes(ttng::CreateTokenOp createTokenOp) { + std::set loadTypes; + createTokenOp->getBlock()->walk([&](Operation *op) { + if (auto asyncCopy = dyn_cast(op)) { + loadTypes.insert(LoadType::LoadAsyncOp); + } else if (auto asyncCopy = + dyn_cast(op)) { + loadTypes.insert(LoadType::LoadTMAOp); + } + }); + assert(loadTypes.size() > 0 && "no async copy in the block"); + assert(loadTypes.size() == 1 && "block contains both async copy and tma"); + return *loadTypes.begin(); +} + +Value getMBarrierPhaseBit(OpBuilder &builder, Operation *op, + bool emptyBarrier) { + auto loc = op->getLoc(); + assert(isa(op) || isa(op)); + Value curPhase; + if (auto acq = dyn_cast(op)) + curPhase = acq.getPhase(); + else if (auto wait = dyn_cast(op)) + curPhase = wait.getPhase(); + if (emptyBarrier) { + // curPhase = curPhase xor True for emptyBarrier. + Value _1_1b = builder.create(loc, 1, 1); + curPhase = builder.create(loc, curPhase, _1_1b); + } + LLVM_DEBUG(curPhase.dump()); + return curPhase; +} + +void processProducerAcquireOp(OpBuilder &builder, ttng::ProducerAcquireOp op, + Value bufferEmpty) { + auto loc = op.getLoc(); + Value phase = getMBarrierPhaseBit(builder, op, true); + auto i32Ty = builder.getIntegerType(32); + phase = builder.create(loc, i32Ty, phase); + auto waitOp = builder.create(loc, bufferEmpty, phase); + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(waitOp, getAsyncTaskIds(op.getOperation())); +} + +void processProducerCommitOp(OpBuilder &builder, ttng::ProducerCommitOp op, + Value bufferFull, LoadType loadType) { + auto loc = op.getLoc(); + int txCnt = 0; + ttng::MBarrierArriveOp arriveOp; + + if (loadType == LoadType::LoadAsyncOp) { + // Each thread arrives. + Value pred = builder.create(loc, 1, 1); + arriveOp = builder.create( + loc, bufferFull, pred, /*remoteCTAId*/ nullptr, /*trackAsyncOp*/ true, + txCnt); + } else { + // Only thread 0 arrives for TMA load. + Value _0 = builder.create(loc, 0, 32); + Value threadId = createThreadIdOp(builder, loc); + Value pred = builder.create(loc, arith::CmpIPredicate::eq, + threadId, _0); + arriveOp = builder.create( + loc, bufferFull, pred, /*remoteCTAId*/ nullptr, /*trackAsyncOp*/ false, + txCnt); + } + + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(arriveOp, getAsyncTaskIds(op.getOperation())); +} + +void processConsumerWaitOp(OpBuilder &builder, ttng::ConsumerWaitOp op, + Value bufferFull) { + auto loc = op.getLoc(); + Value phase = getMBarrierPhaseBit(builder, op, false); + auto i32Ty = builder.getIntegerType(32); + phase = builder.create(loc, i32Ty, phase); + auto waitOp = builder.create(loc, bufferFull, phase); + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(waitOp, getAsyncTaskIds(op.getOperation())); +} + +void processConsumerReleaseOp(OpBuilder &builder, ttng::ConsumerReleaseOp op, + Value bufferEmpty, int numCTAs) { + auto loc = op.getLoc(); + Value _0 = builder.create(loc, 0, 32); + Value _4 = builder.create(loc, 4, 32); + Value _8 = builder.create(loc, 8, 32); + Value _32 = builder.create(loc, 32, 32); + Value _threadPerTask = + builder.create(loc, THREADS_PER_TASK, 32); + + // threadId = threadId % THREADS_PER_TASK + Value threadId = builder.create( + loc, createThreadIdOp(builder, loc), _threadPerTask); + // k = threadId / 8 + Value k = builder.create(loc, threadId, _8); + // row = k / 4 + Value row = builder.create(loc, k, _4); + // col = k % 4 + Value col = builder.create(loc, k, _4); + // remoteCTAId = (col ^ row) * 4 + col + Value remoteCTAId = builder.create( + loc, + Value{builder.create( + loc, Value{builder.create(loc, col, row)}, _4)}, + col); + + // pred0 = threadId % 8 == 0 + Value pred0 = builder.create( + loc, arith::CmpIPredicate::eq, + builder.create(loc, threadId, _8), _0); + // pred1 = remoteCTAId < numCTAs + Value pred1 = builder.create( + loc, arith::CmpIPredicate::ult, remoteCTAId, + builder.create(loc, numCTAs, 32)); + + // pred = pred0 & pred1 + Value pred = builder.create(loc, pred0, pred1); + // bufferEmpty arrive + auto arriveOp = builder.create(loc, bufferEmpty, pred, + remoteCTAId, false, 0); + + assert(op.getOperation()->hasAttr("async_task_id")); + setAsyncTaskIds(arriveOp, getAsyncTaskIds(op.getOperation())); +} + +void lowerTokenOperations(Operation *parentOp, int numCTAs, + int numConsumerGroups) { + SmallVector deprecatedOps; + parentOp->walk([&](ttng::CreateTokenOp createTokenOp) { + LoadType loadType = scanLoadTypes(createTokenOp); + MLIRContext *context = createTokenOp.getContext(); + OpBuilder builder(createTokenOp); + Location loc = createTokenOp.getLoc(); + + Attribute sharedMemorySpace = + triton::gpu::SharedMemorySpaceAttr::get(context); + auto barrierCTALayout = + ttg::CTALayoutAttr::get(context, /*CTAsPerCGA=*/{1}, + /*CTASplitNum=*/{1}, /*CTAOrder=*/{0}); + auto barrierEncoding = + ttg::SharedEncodingAttr::get(context, 1, 1, 1, {0}, barrierCTALayout); + Type barrierMemDescType = + tt::MemDescType::get({createTokenOp.getNum()}, builder.getI64Type(), + barrierEncoding, sharedMemorySpace, + /*mutableMemory=*/true); + Type singleBarrierMemDescType = + tt::MemDescType::get({1}, builder.getI64Type(), barrierEncoding, + sharedMemorySpace, /*mutableMemory=*/true); + Value bufferFullArray = builder.create( + loc, barrierMemDescType, Value()); + Value bufferEmptyArray = builder.create( + loc, barrierMemDescType, Value()); + + for (unsigned i = 0; i < createTokenOp.getNum(); i++) { + Value idx = builder.create(loc, i, 32); + Value barrierFullView = builder.create( + loc, singleBarrierMemDescType, bufferFullArray, idx); + unsigned bufferFullCount = + loadType == LoadType::LoadTMAOp ? 1 : THREADS_PER_TASK; + builder.create(loc, barrierFullView, + bufferFullCount); + + Value barrierEmptyView = builder.create( + loc, singleBarrierMemDescType, bufferEmptyArray, idx); + unsigned bufferEmptyCount = numCTAs; + builder.create(loc, barrierEmptyView, numCTAs); + } + + if (numCTAs == 1) { + builder.create(loc); + } else { + // Make sure that MBarriers are initialized in all CTAs. + builder.create(loc, false); + builder.create(loc); + } + + // Helper function for extracting one index from bufferFullArray. + auto extractBufferFull = [&](Location loc, Value idx) -> Value { + return builder.create( + loc, singleBarrierMemDescType, bufferFullArray, idx); + }; + + // Helper function for extracting one index from bufferEmptyArray. + auto extractBufferEmpty = [&](Location loc, Value idx) -> Value { + return builder.create( + loc, singleBarrierMemDescType, bufferEmptyArray, idx); + }; + + // Process token users: ProducerAcquireOp, ProducerCommitOp, ConsumerWaitOp, + // and ConsumerReleaseOp. + for (Operation *user : createTokenOp.getResult().getUsers()) { + auto loc = user->getLoc(); + builder.setInsertionPoint(user); + if (auto op = dyn_cast(user)) { + Value bufferEmpty = extractBufferEmpty(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferEmpty.getDefiningOp(), getAsyncTaskIds(user)); + processProducerAcquireOp(builder, op, bufferEmpty); + } else if (auto op = dyn_cast(user)) { + Value bufferFull = extractBufferFull(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferFull.getDefiningOp(), getAsyncTaskIds(user)); + processProducerCommitOp(builder, op, bufferFull, loadType); + } else if (auto op = dyn_cast(user)) { + Value bufferFull = extractBufferFull(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferFull.getDefiningOp(), getAsyncTaskIds(user)); + processConsumerWaitOp(builder, op, bufferFull); + } else if (auto op = dyn_cast(user)) { + Value bufferEmpty = extractBufferEmpty(loc, op.getIdx()); + assert(user->hasAttr("async_task_id")); + setAsyncTaskIds(bufferEmpty.getDefiningOp(), getAsyncTaskIds(user)); + processConsumerReleaseOp(builder, op, bufferEmpty, numCTAs); + } else { + llvm_unreachable("Unexpected user of token"); + } + deprecatedOps.push_back(user); + } + + deprecatedOps.push_back(createTokenOp); + }); + for (auto op : deprecatedOps) { + op->erase(); + } + + // Insert a cluster barrier before the kernel exits. Without this barrier, + // mbarrier_remote_arrive will fail if the remote CTA already exits. + if (numCTAs > 1) { + parentOp->walk([&](triton::FuncOp funcOp) { + Block *block = &funcOp.getBody().front(); + auto returnOp = llvm::cast(block->getTerminator()); + OpBuilder builder(returnOp); + auto loc = returnOp.getLoc(); + builder.create(loc, false); + builder.create(loc); + }); + } +} + +#define GEN_PASS_DEF_TRITONGPUWSLOWERING +#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc" + +// This pass lowers WS-specific operations. +class TritonGPUWSLowering + : public impl::TritonGPUWSLoweringBase { +public: + using impl::TritonGPUWSLoweringBase< + TritonGPUWSLowering>::TritonGPUWSLoweringBase; + + void runOnOperation() override { + // Disable WarpSpec if numConsumerGroups is zero. + if (numConsumerGroups == 0) + return; + ModuleOp mod = getOperation(); + int numCTAs = ttg::TritonGPUDialect::getNumCTAs(mod); + + lowerGetAsyncTaskIdOp(mod, numConsumerGroups); + lowerTokenOperations(mod, numCTAs, numConsumerGroups); + + // We assume number of warps per warp group is 4. + // With Warp Spec, the effective warps per CTA is + // number of warp groups * 4, but within each warp group, layout will use + // num_warps of 4, since tensors are not distributed between the groups. + // + // Loads usually happen in one producer warp groups. num_warps of 4 makes + // sense because only the 4 warps from the producer warp group are + // participating in the load. + // + // But at some point (at least when we launch the kernel!) we really do need + // to know that the CTA has 8 or 12 warps in it. Attribute + // "num-warp-groups-per-cta" can be used to calculate the total number of + // warps. + auto builder = OpBuilder::atBlockBegin(mod.getBody()); + mod->setAttr("triton_gpu.num-warp-groups-per-cta", + builder.getI32IntegerAttr(1 + numConsumerGroups)); + } +}; + +} // namespace gpu +} // namespace triton +} // namespace mlir diff --git a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp index 37c69eef8..888e93bb0 100644 --- a/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp @@ -93,6 +93,19 @@ LogicalResult WarpGroupDotWaitOp::inferReturnTypes( return mlir::success(); } +///--- Async related ops --- +void GetAsyncTaskIdOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state) { + build(builder, state, builder.getI32Type()); +} + +void CreateTokenOp::build(::mlir::OpBuilder &builder, + ::mlir::OperationState &state, uint32_t num) { + auto tokenType = TokenType::get(builder.getContext()); + auto resultType = RankedTensorType::get({num}, tokenType); + build(builder, state, resultType, num); +} + static LogicalResult verifyBarrierType(Operation *op, MemDescType barrierType) { if (!barrierType.getElementType().isInteger(64) || barrierType.getShape() != ArrayRef({1})) diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt index 5adebc352..001d96214 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/CMakeLists.txt @@ -2,6 +2,7 @@ add_triton_library(TritonNvidiaGPUTransforms FenceInsertion.cpp PlanCTA.cpp TMALowering.cpp + Utility.cpp DEPENDS TritonNvidiaGPUTransformsIncGen diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp index 0938432c7..c1bf9ca8c 100644 --- a/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/TMALowering.cpp @@ -7,6 +7,7 @@ #include "triton/Dialect/TritonGPU/IR/Attributes.h" #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" #include "triton/Dialect/TritonNvidiaGPU/Transforms/Passes.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" #include @@ -25,7 +26,7 @@ class TMALoadLowering : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExperimentalDescriptorLoadOp op, - PatternRewriter &rewriter) const override { + PatternRewriter &baseRewriter) const override { MLIRContext *ctx = op.getContext(); Attribute sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); auto loc = op.getLoc(); @@ -42,6 +43,7 @@ class TMALoadLowering : public OpRewritePattern { MemDescType memDescType = MemDescType::get(tensorType.getShape(), tensorType.getElementType(), encoding, sharedMemorySpace, /*mutableMemory=*/true); + PatternRewriterWithAsyncTaskIds rewriter(baseRewriter, op); Value alloc = rewriter.create(loc, memDescType); auto barrierCTALayout = CTALayoutAttr::get( /*context=*/tensorType.getContext(), /*CTAsPerCGA=*/{1}, @@ -49,7 +51,7 @@ class TMALoadLowering : public OpRewritePattern { auto barrierEncoding = SharedEncodingAttr::get(tensorType.getContext(), 1, 1, 1, {0}, barrierCTALayout); MemDescType barrierMemDescType = - MemDescType::get({1}, rewriter.getI64Type(), barrierEncoding, + MemDescType::get({1}, baseRewriter.getI64Type(), barrierEncoding, sharedMemorySpace, /*mutableMemory=*/true); Value barrierAlloc = rewriter.create(loc, barrierMemDescType); rewriter.create(loc, barrierAlloc, 1); @@ -91,11 +93,17 @@ class TMAStoreLowering MemDescType memDescType = MemDescType::get(tensorType.getShape(), tensorType.getElementType(), encoding, sharedMemorySpace, /*mutableMemory=*/true); - Value alloc = rewriter.create(loc, memDescType, op.getSrc()); - rewriter.create(loc, false); - rewriter.create( + // If op has allocation.copy, the created LocalAlloc will have it. + auto alloc = rewriter.create(loc, memDescType, op.getSrc()); + auto attrs = op->getAttrs(); + alloc->setAttrs(attrs); + auto fence = rewriter.create(loc, false); + fence->setAttrs(attrs); + auto asyncCopy = rewriter.create( loc, op.getDescPtr(), op.getIndices(), alloc); - rewriter.create(loc, 0); + asyncCopy->setAttrs(attrs); + auto tma_wait = rewriter.create(loc, 0); + tma_wait->setAttrs(attrs); rewriter.eraseOp(op); return success(); } diff --git a/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp b/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp new file mode 100644 index 000000000..83f21019f --- /dev/null +++ b/lib/Dialect/TritonNvidiaGPU/Transforms/Utility.cpp @@ -0,0 +1,162 @@ + +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" + +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "triton/Analysis/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "llvm/Support/Debug.h" +#include + +namespace mlir { + +namespace ttg = triton::gpu; + +namespace { + +bool knownSafeToIgnoreRegion(Operation *op) { + return isa(op); +} + +// Assigns `dependentSet` and returns ok if the analysis is successful. +// We do not support dependency analysis across load/store, thus a failure will +// be returned if encountering such cases. +LogicalResult getDependentPointers(Value ptr, DenseSet &dependentSet, + DenseSet &processedSet) { + // early return if processed + if (!processedSet.insert(ptr).second) + return success(); + + if (auto blockArg = dyn_cast(ptr)) { + if (!blockArg.getOwner()->isEntryBlock()) + return failure(); + auto parentOp = blockArg.getOwner()->getParentOp(); + if (auto forOp = dyn_cast(parentOp)) { + if (blockArg.getArgNumber() >= forOp.getNumInductionVars()) { + if (failed(getDependentPointers(forOp.getTiedLoopInit(blockArg)->get(), + dependentSet, processedSet))) + return failure(); + + unsigned operandIdx = + blockArg.getArgNumber() - forOp.getNumInductionVars(); + return getDependentPointers( + forOp.getBody()->getTerminator()->getOperand(operandIdx), + dependentSet, processedSet); + } + } else if (auto funcOp = dyn_cast(parentOp)) { + dependentSet.insert(ptr); + return success(); + } + // unknown ops, return failure for correctness. + return failure(); + } + + auto definingOp = ptr.getDefiningOp(); + assert(definingOp); + if (auto makeTensorPtrOp = ptr.getDefiningOp()) { + return getDependentPointers(makeTensorPtrOp.getBase(), dependentSet, + processedSet); + } else if (auto advanceOp = ptr.getDefiningOp()) { + return getDependentPointers(advanceOp.getPtr(), dependentSet, processedSet); + } else if (auto addPtrOp = ptr.getDefiningOp()) { + return getDependentPointers(addPtrOp.getPtr(), dependentSet, processedSet); + } else if (auto forOp = ptr.getDefiningOp()) { + unsigned idx = cast(ptr).getResultNumber(); + return getDependentPointers( + forOp.getBody()->getTerminator()->getOperand(idx), dependentSet, + processedSet); + } else if (auto ifOp = ptr.getDefiningOp()) { + unsigned idx = cast(ptr).getResultNumber(); + if (ifOp.elseBlock() && + failed(getDependentPointers(ifOp.elseYield()->getOperand(idx), + dependentSet, processedSet))) + return failure(); + return getDependentPointers(ifOp.thenYield()->getOperand(idx), dependentSet, + processedSet); + } else if (!definingOp->getNumRegions() || + knownSafeToIgnoreRegion(definingOp)) { + for (Value operand : definingOp->getOperands()) + if (failed(getDependentPointers(operand, dependentSet, processedSet))) + return failure(); + return success(); + } + // unknown ops, return failure for correctness. + return failure(); +} + +} // namespace + +//===----------------------------------------------------------------------===// +// Helper functions for async task +//===----------------------------------------------------------------------===// + +SmallVector getAsyncTaskIds(Operation *op) { + SmallVector asyncTaskIds; + if (auto attr = op->getAttrOfType("async_task_id")) + for (AsyncTaskId asyncTaskId : attr.getValues()) + asyncTaskIds.push_back(asyncTaskId); + return asyncTaskIds; +} + +bool hasAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId) { + for (AsyncTaskId candidate : getAsyncTaskIds(op)) + if (candidate == asyncTaskId) + return true; + return false; +} + +void setAsyncTaskIds(Operation *op, ArrayRef asyncTaskIds) { + SmallVector sortedAsyncTaskIds(asyncTaskIds.begin(), asyncTaskIds.end()); + sort(sortedAsyncTaskIds); + auto i32Ty = IntegerType::get(op->getContext(), 32); + auto size = static_cast(sortedAsyncTaskIds.size()); + auto vecTy = VectorType::get(size, i32Ty); + op->setAttr("async_task_id", DenseIntElementsAttr::get(vecTy, sortedAsyncTaskIds)); +} + +SmallVector getNestedAsyncTaskIds(Operation *op) { + SetVector asyncTaskIds; + op->walk([&](Operation *curOp) { + for (AsyncTaskId asyncTaskId : getAsyncTaskIds(curOp)) + asyncTaskIds.insert(asyncTaskId); + }); + SmallVector res(asyncTaskIds.begin(), asyncTaskIds.end()); + llvm::sort(res); + return res; +} + +void addAsyncTaskIds(Operation *op, ArrayRef asyncTasks) { + auto asyncTasksVec = getAsyncTaskIds(op); + DenseSet asyncTasksSet(asyncTasksVec.begin(), asyncTasksVec.end()); + for (int a : asyncTasks) { + if (!asyncTasksSet.contains(a)) { + asyncTasksVec.push_back(a); + } + } + if (asyncTasksVec.size() > 0) { + setAsyncTaskIds(op, asyncTasksVec); + } +} + +void removeAsyncTaskId(Operation *op, AsyncTaskId asyncTaskId) { + auto origAsyncTaskIds = getAsyncTaskIds(op); + auto end = std::remove(origAsyncTaskIds.begin(), origAsyncTaskIds.end(), asyncTaskId); + origAsyncTaskIds.erase(end, origAsyncTaskIds.end()); + if (origAsyncTaskIds.empty()) + op->removeAttr("async_task_id"); + else + setAsyncTaskIds(op, origAsyncTaskIds); +} + +void removeAsyncTaskIds(Operation *op) { + op->removeAttr("async_task_id"); +} +//===----------------------------------------------------------------------===// +// Implementations for general auto WS +//===----------------------------------------------------------------------===// + + +} // namespace mlir diff --git a/python/src/ir.cc b/python/src/ir.cc index 95e48a692..fcd1b623b 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -37,6 +37,15 @@ namespace py = pybind11; using namespace mlir; using namespace triton; +void setAsyncTaskIds(Operation *op, ArrayRef asyncTaskIds) { + SmallVector sortedAsyncTaskIds(asyncTaskIds.begin(), asyncTaskIds.end()); + sort(sortedAsyncTaskIds); + auto i32Ty = IntegerType::get(op->getContext(), 32); + auto size = static_cast(sortedAsyncTaskIds.size()); + auto vecTy = VectorType::get(size, i32Ty); + op->setAttr("async_task_id", DenseIntElementsAttr::get(vecTy, sortedAsyncTaskIds)); +} + // A custom op builder that keeps track of the last location class TritonOpBuilder { public: @@ -95,7 +104,10 @@ class TritonOpBuilder { template OpTy create(Args &&...args) { auto loc = getLastLoc(); - return builder->create(loc, std::forward(args)...); + auto ret = builder->create(loc, std::forward(args)...); + if (asyncTaskIds) + ::setAsyncTaskIds(ret, *asyncTaskIds); + return ret; } // Overload to create or fold a single result operation. @@ -114,9 +126,16 @@ class TritonOpBuilder { return builder->createOrFold(loc, std::forward(args)...); } + void setAsyncTaskIds(std::vector taskIds) { this->asyncTaskIds = taskIds; } + + void unsetAsyncTaskIds() { + this->asyncTaskIds = std::nullopt; + } + private: std::unique_ptr builder; std::unique_ptr lastLoc; + std::optional> asyncTaskIds; bool lineInfoEnabled = !triton::tools::getBoolEnv("TRITON_DISABLE_LINE_INFO"); }; @@ -368,6 +387,7 @@ void init_triton_ir(py::module &&m) { py::class_(m, "attribute", py::module_local()); py::class_(m, "integer_attr", py::module_local()); py::class_(m, "bool_attr", py::module_local()); + py::class_(m, "string_attr", py::module_local()); // Ops py::class_(m, "OpState", py::module_local()) @@ -631,6 +651,12 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, OpBuilder::InsertPoint pt) { self.restoreInsertionPoint(pt); }) + .def("set_async_task_ids", + [](TritonOpBuilder &self, std::vector v) { + self.setAsyncTaskIds(v); + }) + .def("unset_async_task_ids", + [](TritonOpBuilder &self) { self.unsetAsyncTaskIds(); }) // Attr .def("get_bool_attr", [](TritonOpBuilder &self, bool value) { @@ -640,6 +666,10 @@ void init_triton_ir(py::module &&m) { [](TritonOpBuilder &self, int32_t value) { return self.getBuilder().getI32IntegerAttr(value); }) + .def("get_string_attr", + [](TritonOpBuilder &self, const std::string &value) { + return self.getBuilder().getStringAttr(value); + }) // Use arith.ConstantOp to create constants // Constants .def("get_int1", diff --git a/python/src/passes.cc b/python/src/passes.cc index 98d8369d4..026d2f7c7 100644 --- a/python/src/passes.cc +++ b/python/src/passes.cc @@ -68,6 +68,17 @@ void init_triton_passes_ttgpuir(py::module &&m) { createTritonGPUCombineTensorSelectAndIf); ADD_PASS_WRAPPER_0("add_optimize_accumulator_init", createTritonGPUOptimizeAccumulatorInit); + ADD_PASS_OPTION_WRAPPER_1("add_ws_data_partition", + createTritonGPUWSDataPartition, int); + ADD_PASS_OPTION_WRAPPER_1("add_ws_lowering", createTritonGPUWSLowering, int); + ADD_PASS_OPTION_WRAPPER_1("add_taskid_propagate", + createTritonGPUTaskIdPropagate, int); + ADD_PASS_OPTION_WRAPPER_4("add_ws_code_partition", + createTritonGPUWSCodePartition, int, int, int, int); + ADD_PASS_OPTION_WRAPPER_2("add_ping_pong_sync", createTritonGPUPingPongSync, + int, int); + ADD_PASS_OPTION_WRAPPER_1("add_loop_scheduling", + createTritonGPULoopScheduling, int); } void init_triton_passes_convert(py::module &&m) { diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py index 19d09de85..517a2fc4d 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -802,6 +802,20 @@ def visit_UnaryOp(self, node): ast.USub: '__neg__', ast.UAdd: '__pos__', ast.Not: '__not__', ast.Invert: '__invert__' } + def visit_withitem(self, node): + return self.visit(node.context_expr) + + def visit_With(self, node): + assert len(node.items) == 1 + context = node.items[0].context_expr + withitemClass = self.visit(context.func) + if withitemClass == language.async_task: + args = [self.visit(arg) for arg in context.args] + with withitemClass(*args, _builder=self.builder): + self.visit_compound_statement(node.body) + else: + self.visit_compound_statement(node.body) + def visit_While(self, node): with enter_sub_region(self) as sr: liveins, insert_block = sr @@ -904,6 +918,7 @@ def visit_For(self, node): ast.NodeVisitor.generic_visit(self, stmt) return num_stages = None + loop_schedule = None if IteratorClass is language.range: iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments @@ -913,6 +928,7 @@ def visit_For(self, node): ub = iterator.end step = iterator.step num_stages = iterator.num_stages + loop_schedule = iterator.loop_schedule elif IteratorClass is range: # visit iterator arguments # note: only `range` iterator is supported now @@ -986,6 +1002,8 @@ def visit_For(self, node): for_op = self.builder.create_for_op(lb, ub, step, [arg.handle for arg in init_args]) if num_stages is not None: for_op.set_attr("tt.num_stages", self.builder.get_int32_attr(num_stages)) + if loop_schedule is not None: + for_op.set_attr("tt.loop_schedule", self.builder.get_string_attr(loop_schedule.value)) self.scf_stack.append(node) self.builder.set_insertion_point_to_start(for_op.get_body(0)) diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 0a84bd86a..d18701be2 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -32,6 +32,7 @@ arange, associative_scan, assume, + async_task, atomic_add, atomic_and, atomic_cas, diff --git a/python/triton/language/core.py b/python/triton/language/core.py index e16ca2dee..38cd9d899 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -2505,6 +2505,21 @@ def __next__(self): raise RuntimeError("static_range can only be used in @triton.jit'd functions") +class async_task: + """ + Context manager to run code fragments asynchronously. + """ + def __init__(self, task_ids, _builder=None): + self.task_ids = task_ids + self.builder = _builder + + def __enter__(self): + self.builder.set_async_task_ids(self.task_ids) + + def __exit__(self, exc_type, exc_value, traceback): + self.builder.unset_async_task_ids() + + class range: """ Iterator that counts upward forever. @@ -2514,7 +2529,7 @@ class range: @triton.jit def kernel(...): - for i in tl.range(10, num_stages=3): + for i in tl.range(10, num_stages=3, loop_schedule="Default"): ... :note: This is a special iterator used to implement similar semantics to Python's :code:`range` in the context of :code:`triton.jit` functions. In addition, it allows user to pass extra attributes to the compiler. @@ -2528,9 +2543,10 @@ def kernel(...): kernel argument. The kernel argument only pipelines loads that feed into :code:`dot` operations, while this attribute tries to pipeline most (though not all) loads in this loop. + :param loop_schedule: specify a scheduling policy for the loop. """ - def __init__(self, arg1, arg2=None, step=None, num_stages=None): + def __init__(self, arg1, arg2=None, step=None, num_stages=None, loop_schedule=None): if step is None: self.step = constexpr(1) else: @@ -2542,6 +2558,7 @@ def __init__(self, arg1, arg2=None, step=None, num_stages=None): self.start = arg1 self.end = arg2 self.num_stages = num_stages + self.loop_schedule = loop_schedule def __iter__(self): raise RuntimeError("tl.range can only be used in @triton.jit'd functions") diff --git a/python/triton/runtime/autotuner.py b/python/triton/runtime/autotuner.py index 59191a31b..c2d16b820 100644 --- a/python/triton/runtime/autotuner.py +++ b/python/triton/runtime/autotuner.py @@ -35,7 +35,7 @@ def __init__( 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. """ if not configs: - self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1)] + self.configs = [Config({}, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0, reg_inc_consumer=0)] else: self.configs = configs self.key_idx = [arg_names.index(k) for k in key] @@ -153,6 +153,17 @@ def run(self, *args, **kwargs): timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} bench_end = time.time() self.bench_time = bench_end - bench_start + + # __FACEBOOK__ (facebook) begin T203283446 + if os.getenv("TRITON_PRINT_AUTOTUNING_ALL", None) == "1": + print( + f'\nPrinting ALL Multiple Triton autotuning Configs with timings in sorted order for kernel {self.fn}:' + ) + sorted_configs = builtins.sorted(timings, key=timings.get) + for config in sorted_configs: + print(f'Triton autotune config: [{config}]; Triton autotune timing: {timings[config]}') + # __FACEBOOK__ (facebook) end T203283446 + self.cache[key] = builtins.min(timings, key=timings.get) self.pre_hook(args, reset_only=True) self.configs_timings = timings @@ -227,11 +238,15 @@ class Config: function are args. """ - def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, maxnreg=None, pre_hook=None): + def __init__(self, kwargs, num_warps=4, num_stages=2, num_ctas=1, num_buffers_warp_spec=0, num_consumer_groups=0, reg_dec_producer=0, reg_inc_consumer=0, maxnreg=None, pre_hook=None): self.kwargs = kwargs self.num_warps = num_warps self.num_ctas = num_ctas self.num_stages = num_stages + self.num_buffers_warp_spec = num_buffers_warp_spec + self.num_consumer_groups = num_consumer_groups + self.reg_dec_producer = reg_dec_producer + self.reg_inc_consumer = reg_inc_consumer self.maxnreg = maxnreg self.pre_hook = pre_hook @@ -243,6 +258,10 @@ def all_kwargs(self): ("num_warps", self.num_warps), ("num_ctas", self.num_ctas), ("num_stages", self.num_stages), + ("num_buffers_warp_spec", self.num_buffers_warp_spec), + ("num_consumer_groups", self.num_consumer_groups), + ("reg_dec_producer", self.reg_dec_producer), + ("reg_inc_consumer", self.reg_inc_consumer), ("maxnreg", self.maxnreg), ) if v is not None } @@ -255,6 +274,10 @@ def __str__(self): res.append(f"num_warps: {self.num_warps}") res.append(f"num_ctas: {self.num_ctas}") res.append(f"num_stages: {self.num_stages}") + res.append(f"num_buffers_warp_spec: {self.num_buffers_warp_spec}") + res.append(f"num_consumer_groups: {self.num_consumer_groups}") + res.append(f"reg_dec_producer: {self.reg_dec_producer}") + res.append(f"reg_inc_consumer: {self.reg_inc_consumer}") res.append(f"maxnreg: {self.maxnreg}") return ", ".join(res) diff --git a/python/tutorials/10-warp-specialized-matmul.py b/python/tutorials/10-warp-specialized-matmul.py new file mode 100644 index 000000000..ed51de580 --- /dev/null +++ b/python/tutorials/10-warp-specialized-matmul.py @@ -0,0 +1,319 @@ +import os +import sys + +import torch +import triton +import triton.language as tl + +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +HAS_TMA_DESC = "nv_tma_desc_type" in dir(tl) + +if HAS_TMA_DESC: + print( + "TMA benchmarks will be running with experimental grid constant TMA descriptor.", + ) +else: + print( + "TMA benchmarks will be running without grid constant TMA descriptor.", + ) + + +class TmaAutoTuneHelper: + + # duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498 + class KernelParamWrapper: + def __init__(self, desc): + self.desc = desc + + def tma_desc_cpu_ptr(self): + return self.desc.data_ptr() + + TMA_SIZE = 128 + + def __init__(self): + self.fill_1d_tma_descriptor_inner = ( + triton.runtime.driver.active.utils.fill_1d_tma_descriptor + ) + self.fill_2d_tma_descriptor_inner = ( + triton.runtime.driver.active.utils.fill_2d_tma_descriptor + ) + if HAS_TMA_DESC: + self.descriptors = {} + else: + self.cuda_descriptors = {} + + # Call this method outside of the lambda function for grid size + def init_tma_descriptor(self, name): + if HAS_TMA_DESC: + self.descriptors[name] = torch.empty( + TmaAutoTuneHelper.TMA_SIZE, device="cpu", dtype=torch.int8 + ) + else: + self.cuda_descriptors[name] = torch.empty( + TmaAutoTuneHelper.TMA_SIZE, device="cuda", dtype=torch.int8 + ) + + # Call this method inside the lambda function for grid size + def fill_1d_tma_descriptor(self, name, ptr, dim, block_dim, element_size): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_1d_tma_descriptor_inner( + ptr, dim, block_dim, element_size, desc_x.data_ptr() + ) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_1d_tma_descriptor_inner( + ptr, dim, block_dim, element_size, buf_x.data_ptr() + ) + desc_x.copy_(buf_x, non_blocking=True) + + # Call this method inside the lambda function for grid size + def fill_2d_tma_descriptor( + self, name, ptr, dim1, dim0, block_dim1, block_dim0, element_size + ): + if HAS_TMA_DESC: + desc_x = self.descriptors[name] + assert desc_x.data_ptr() % 64 == 0 + self.fill_2d_tma_descriptor_inner( + ptr, dim1, dim0, block_dim1, block_dim0, element_size, desc_x.data_ptr() + ) + else: + desc_x = self.cuda_descriptors[name] + buf_x = torch.empty_like(desc_x, device="cpu", pin_memory=True) + self.fill_2d_tma_descriptor_inner( + ptr, dim1, dim0, block_dim1, block_dim0, element_size, buf_x.data_ptr() + ) + desc_x.copy_(buf_x, non_blocking=True) + + def get_tma_descriptor_kernel_param(self, name): + if HAS_TMA_DESC: + assert self.descriptors[name] is not None + return self.KernelParamWrapper(self.descriptors[name]) + else: + assert self.cuda_descriptors[name] is not None + return self.cuda_descriptors[name] + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + }, + num_stages=2, + num_warps=4, + num_consumer_groups=2, + num_buffers_warp_spec=3, + ), + ], + key=["M", "N", "K"], +) +@triton.jit +def matmul_persistent_tma_ws_cooperative_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # + GROUP_SIZE_M: tl.constexpr, # +): + """Kernel for computing the matmul C = A x B. + A has shape (M, K), B has shape (K, N) and C has shape (M, N) + """ + + num_tiles = tl.cdiv(M, BLOCK_SIZE_M) * tl.cdiv(N, BLOCK_SIZE_N) + for pid in range(tl.program_id(0), num_tiles, tl.num_programs(0)): + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + # See above `Pointer Arithmetic` section for details + offs_am = pid_m * BLOCK_SIZE_M + offs_bn = pid_n * BLOCK_SIZE_N + offs_k = 0 + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + with tl.async_task([0]): + a = tl._experimental_descriptor_load( + a_ptr, + [offs_am, offs_k], + [BLOCK_SIZE_M, BLOCK_SIZE_K], + tl.float16, + ) + b = tl._experimental_descriptor_load( + b_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tl.float16 + ) + + accumulator += tl.dot(a, b) + offs_k += BLOCK_SIZE_K + + c = accumulator.to(tl.float16) + + with tl.async_task([1, 2]): + tl._experimental_descriptor_store(c_ptr, c, [offs_am, offs_bn]) + + +# %% +# We can now create a convenience wrapper function that only takes two input tensors, +# and (1) checks any shape constraint; (2) allocates the output; (3) launches the above kernel. + + +def matmul_persistent_tma_ws_cooperative(a, b): + # Check constraints. + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.dtype == b.dtype, "Incompatible dtypes" + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count + M, K = a.shape + K, N = b.shape + dtype = a.dtype + # Allocates output. + c = torch.empty((M, N), device=a.device, dtype=dtype) + + desc_helper = TmaAutoTuneHelper() + desc_helper.init_tma_descriptor("a") + desc_helper.init_tma_descriptor("b") + desc_helper.init_tma_descriptor("c") + + def grid(META): + nonlocal desc_helper + desc_helper.fill_2d_tma_descriptor( + "a", + a.data_ptr(), + M, + K, + META["BLOCK_SIZE_M"] // 2, + META["BLOCK_SIZE_K"], + a.element_size(), + ) + + desc_helper.fill_2d_tma_descriptor( + "b", + b.data_ptr(), + K, + N, + META["BLOCK_SIZE_K"], + META["BLOCK_SIZE_N"], + b.element_size(), + ) + desc_helper.fill_2d_tma_descriptor( + "c", + c.data_ptr(), + M, + N, + META["BLOCK_SIZE_M"] // 2, + META["BLOCK_SIZE_N"], + c.element_size(), + ) + return ( + min( + NUM_SMS, + triton.cdiv(M, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ), + ) + + desc_a = desc_helper.get_tma_descriptor_kernel_param("a") + desc_b = desc_helper.get_tma_descriptor_kernel_param("b") + desc_c = desc_helper.get_tma_descriptor_kernel_param("c") + matmul_persistent_tma_ws_cooperative_kernel[grid]( + desc_a, + desc_b, + desc_c, # + M, + N, + K, # + ) + return c + + +def aten_matmul(a, b): + return a.mm(b) + + +test_impls = [ + aten_matmul, + matmul_persistent_tma_ws_cooperative, +] + + +impl_map = {fn.__name__: fn for fn in test_impls} + + +def test(): + torch.manual_seed(0) + m = 4 * 11 * 64 + n = 12 * 256 + k = 64 * 4 + a = torch.randn((m, k), device="cuda", dtype=torch.float16) + b = torch.randn((k, n), device="cuda", dtype=torch.float16) + torch_output = torch.matmul(a, b) + rtol = 0 + for fn in test_impls: + triton_output = fn(a, b) + torch.cuda.synchronize() + if torch.allclose(triton_output, torch_output, atol=1e-2, rtol=rtol): + print(f" Torch matches {fn.__name__}") + else: + print(f" Torch DOES NOT match {fn.__name__}") + print("torch output:") + print(torch_output) + print("triton output:") + print(triton_output) + + +x_vals = [(8192, 8192, i) for i in range(256, 8192 + 1, 256)] +configs = [] +configs.append( + triton.testing.Benchmark( + x_names=["M", "N", "K"], # Argument names to use as an x-axis for the plot + x_vals=x_vals, + line_arg="provider", # Argument name whose value corresponds to a different line in the plot + # Possible values for `line_arg` + # Don't compare to cublas for fp8 cases as torch.matmul doesn't support fp8 at the moment. + line_vals=[fn.__name__ for fn in test_impls], + line_names=[fn.__name__ for fn in test_impls], + # styles=[("red", "-"), ("green", "-"), ("blue", "-")], + ylabel="TFLOPS", # Label name for the y-axis + plot_name="matmul-performance-" + + ( + "fp16" + ), # Name for the plot, used also as a file name for saving the plot. + args={}, + ) +) + +@triton.testing.perf_report(configs) +def benchmark(M, N, K, provider): + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) + quantiles = [0.5, 0.2, 0.8] + fn = impl_map[provider] + ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(lambda: fn(a, b), quantiles=quantiles) + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +test() +benchmark.run(show_plots=True, print_data=True) diff --git a/python/tutorials/mm.py b/python/tutorials/mm.py new file mode 100644 index 000000000..2931fdf0b --- /dev/null +++ b/python/tutorials/mm.py @@ -0,0 +1,201 @@ +import torch + +import triton +import triton.language as tl + + +@triton.autotune( + configs=[ + # fmt: off + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=4, num_buffers_warp_spec=3, num_consumer_groups=1), + # fmt: on + ], + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel_ws( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + with tl.async_task([0]): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + with tl.async_task([1]): + tl.store(c_ptrs, c, mask=c_mask) + + +@triton.autotune( + configs=[ + # fmt: off + triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 8}, num_stages=3, num_warps=4), + # fmt: on + ], + key=["M", "N", "K"], +) +@triton.jit +def matmul_kernel( + a_ptr, + b_ptr, + c_ptr, + M, + N, + K, + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0) + accumulator = tl.dot(a, b, accumulator) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + c = accumulator.to(tl.float16) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +def matmul(a, b, ws=True): + assert a.shape[1] == b.shape[0], "Incompatible dimensions" + assert a.is_contiguous(), "Matrix A must be contiguous" + M, K = a.shape + K, N = b.shape + c = torch.empty((M, N), device=a.device, dtype=torch.float16) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) + kernel = matmul_kernel_ws if ws else matmul_kernel + kernel[grid]( + a, + b, + c, + M, + N, + K, + a.stride(0), + a.stride(1), + b.stride(0), + b.stride(1), + c.stride(0), + c.stride(1), + ) + return c + + +def test(): + m, n, k = 8192, 8192, 8192 + a = torch.randn((m, k), device="cuda", dtype=torch.float16) + b = torch.randn((k, n), device="cuda", dtype=torch.float16) + triton_output = matmul(a, b, ws=True) + torch_output = torch.matmul(a, b) + + print("triton:", triton_output) + print(" torch:", torch_output) + + torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0.0) + + +@triton.testing.perf_report( + [ + triton.testing.Benchmark( + x_names=["M", "N", "K"], + x_vals=[128 * i for i in range(28, 33)], + line_arg="provider", + line_vals=["cublas", "triton-warpspec", "triton-multistage"], + line_names=["cuBLAS", "Triton:WarpSpec", "Triton:MultiStage"], + ylabel="TFLOPS", + plot_name="matmul-performance-fp16", + args={}, + ) + ] +) +def benchmark(M, N, K, provider): + a = torch.randn((M, K), device="cuda", dtype=torch.float16) + b = torch.randn((K, N), device="cuda", dtype=torch.float16) + quantiles = [0.5, 0.2, 0.8] + if provider == "cublas": + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: torch.matmul(a, b), quantiles=quantiles + ) + if "triton" in provider: + ws = "warpspec" in provider + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: matmul(a, b, ws=ws), quantiles=quantiles + ) + perf = lambda ms: 2 * M * N * K * 1e-12 / (ms * 1e-3) + return perf(ms), perf(max_ms), perf(min_ms) + + +test() +benchmark.run(show_plots=True, print_data=True) diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index 281550d26..9124977f1 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2607,3 +2607,45 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return %outLHS : tensor<128x64xf32, #blocked1> } } + +// ----- + +#AL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#BL = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#CL = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#C = #triton_gpu.nvidia_mma<{versionMajor = 2, warpsPerCTA = [4, 1]}> +#A_DOT = #triton_gpu.dot_op<{opIdx = 0, parent = #C, kWidth = 2}> +#B_DOT = #triton_gpu.dot_op<{opIdx = 1, parent = #C, kWidth = 2}> + +module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.num-ctas" = 1 : i32} { + // CHECK-LABEL: matmul_add + tt.func @matmul_add(%lb : index, %ub : index, %step : index, %A : !tt.ptr, %B : !tt.ptr, %C : !tt.ptr) { + %a_ptr_init = tt.splat %A : !tt.ptr -> tensor<128x32x!tt.ptr, #AL> + %b_ptr_init = tt.splat %B : !tt.ptr -> tensor<32x128x!tt.ptr, #BL> + %c_ptr_init = tt.splat %C : !tt.ptr -> tensor<128x128x!tt.ptr, #CL> + %c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #CL> + %cst = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C> + %a_off = arith.constant dense<4> : tensor<128x32xi32, #AL> + %b_off = arith.constant dense<4> : tensor<32x128xi32, #BL> + + %100:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL>) { + %a_ = tt.load %a_ptr : tensor<128x32x!tt.ptr, #AL> + %a = triton_gpu.convert_layout %a_ : tensor<128x32xf16, #AL> -> tensor<128x32xf16, #A_DOT> + %b_ = tt.load %b_ptr : tensor<32x128x!tt.ptr, #BL> + %b = triton_gpu.convert_layout %b_ : tensor<32x128xf16, #BL> -> tensor<32x128xf16, #B_DOT> + %c = tt.dot %a, %b, %cst : tensor<128x32xf16, #A_DOT> * tensor<32x128xf16, #B_DOT> -> tensor<128x128xf32, #C> + %t = triton_gpu.convert_layout %c : tensor<128x128xf32, #C> -> tensor<128x128xf32, #CL> + // CHECK: %[[T0:.*]] = tt.dot + // CHECK: arith.addf %{{.*}}, %[[T0]] : tensor<128x128xf32, #mma> + %t2 = arith.addf %prev_c, %t : tensor<128x128xf32, #CL> + %next_a_ptr = tt.addptr %a_ptr, %a_off : tensor<128x32x!tt.ptr, #AL>, tensor<128x32xi32, #AL> + %next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr, #BL>, tensor<32x128xi32, #BL> + // CHECK: scf.yield + scf.yield %next_a_ptr, %next_b_ptr, %t2 : tensor<128x32x!tt.ptr, #AL>, tensor<32x128x!tt.ptr, #BL>, tensor<128x128xf32, #CL> + } + + // CHECK: triton_gpu.convert_layout {{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked + tt.store %c_ptr_init, %100#2 : tensor<128x128x!tt.ptr, #CL> + tt.return + } +} diff --git a/test/TritonGPU/comp-pipeline.mlir b/test/TritonGPU/comp-pipeline.mlir new file mode 100644 index 000000000..492b1d508 --- /dev/null +++ b/test/TritonGPU/comp-pipeline.mlir @@ -0,0 +1,102 @@ +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=4 -debug-only=triton-matmul-loop-pipeline 2>&1 | FileCheck %s --check-prefix=DEBUG +// RUN: triton-opt %s -split-input-file -tritongpu-pipeline=num-stages=4 | FileCheck %s + +// DEBUG: Final coarse schedule: +// DEBUG: Ops in stage 2 +// DEBUG-DAG: triton_nvidia_gpu.wait_barrier +// DEBUG-DAG: triton_nvidia_gpu.warp_group_dot +// DEBUG: Ops in stage 3 +// DEBUG: triton_nvidia_gpu.wait_barrier +// DEBUG: Original loop: + +#blocked = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [8], order = [0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [2, 4], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @_attn_fwd_tma(%arg3: !tt.ptr {tt.divisibility = 16 : i32}, %arg4: !tt.ptr {tt.divisibility = 16 : i32}, %arg5: !tt.ptr {tt.divisibility = 16 : i32}, %arg6: f32, %arg8: !tt.ptr {tt.divisibility = 16 : i32}, %arg9: i32, %arg10: i32, %arg11: i64, %arg14: i32 {tt.divisibility = 16 : i32}, %arg17: i32 {tt.divisibility = 16 : i32}, %arg23: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %c128_i32 = arith.constant 128 : i32 + %c0_i32 = arith.constant 0 : i32 + %cst_2 = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> + %25 = tt.experimental_descriptor_load %arg3[%arg9, %c0_i32] : !tt.ptr -> tensor<128x128xf16, #blocked1> + %26 = triton_gpu.local_alloc %25 : (tensor<128x128xf16, #blocked1>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> + %27 = arith.extsi %arg14 : i32 to i64 + %28 = tt.splat %arg6 : f32 -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %29 = tt.splat %arg6 : f32 -> tensor<128x128xf32, #mma> + %30 = arith.extsi %arg17 : i32 to i64 + // CHECK: tt.experimental_descriptor_load + // CHECK: %[[QLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<128x128xf16 + // CHECK: %[[KLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3x128x128xf16 + // CHECK: %[[VLOC:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3x128x128xf16 + // CHECK: %[[KBAR:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3xi64 + // CHECK: %[[VBAR:.+]] = triton_gpu.local_alloc {{.*}}tt.memdesc<3xi64 + // stage 0 iteration 0 + // CHECK: %[[K0:.+]] = triton_gpu.memdesc_subview %[[KLOC]][%c0_i32, %c0_i32, %c0_i32] + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[K0]] + // stage 0 iteration 1 + // CHECK: %[[K1:.+]] = triton_gpu.memdesc_subview %[[KLOC]][%c1_i32, %c0_i32, %c0_i32] + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[K1]] + // stage 1 iteration 0 + // CHECK: %[[V0:.+]] = triton_gpu.memdesc_subview %[[VLOC]][%c0_i32, %c0_i32, %c0_i32] + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[V0]] + // stage 2 iteration 0 + // CHECK: %[[FIRSTDOT:.+]] = triton_nvidia_gpu.warp_group_dot + // stage 0 iteration 2 + // CHECK: %[[K2:.+]] = triton_gpu.memdesc_subview %[[KLOC]][%c2_i32, %c0_i32, %c0_i32] + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[K2]] + // stage 1 iteration 1 + // CHECK: %[[V1:.+]] = triton_gpu.memdesc_subview %[[VLOC]][%c1_i32, %c0_i32, %c0_i32] + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local{{.*}} %[[V1]] + // CHECK: scf.for {{.*}} %[[ARG:.+]] = %[[FIRSTDOT]] + // CHECK: %[[KBARSUB:.+]] = triton_gpu.memdesc_subview %[[KBAR]][%[[KBARIDX:.+]]] + // CHECK: scf.if + // CHECK: triton_nvidia_gpu.wait_barrier %[[KBARSUB]] + // CHECK: %[[KLOOP:.+]] = triton_gpu.memdesc_subview %[[KLOC]] + // CHECK: tt.trans %[[KLOOP]] + // CHECK: %[[FIRSTDOTLOOP:.+]] = triton_nvidia_gpu.warp_group_dot + // CHECK: %[[WAIT:.+]]:{{[0-9]+}} = triton_nvidia_gpu.warp_group_dot_wait + // CHECK: "tt.reduce"(%[[ARG]]) + // CHECK: %[[VBARSUB:.+]] = triton_gpu.memdesc_subview %[[VBAR]] + // CHECK: triton_nvidia_gpu.wait_barrier %[[VBARSUB]] + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local + // CHECK: triton_nvidia_gpu.async_tma_copy_global_to_local + // CHECK: scf.yield {{.*}}%[[WAIT]]#0 + // arg26 is acc + %31:1 = scf.for %arg24 = %c0_i32 to %arg23 step %c128_i32 iter_args(%arg26 = %cst_2) -> (tensor<128x128xf32, #mma>) : i32 { + %48 = arith.divsi %arg11, %27 : i64 + %49 = arith.trunci %48 : i64 to i32 + %50 = arith.addi %arg24, %49 : i32 + // loads in different stages + %51 = tt.experimental_descriptor_load %arg4[%50, %c0_i32] {loop.stage = 0 : i32, loop.cluster = 1 : i32} : !tt.ptr -> tensor<128x128xf16, #blocked1> + %52 = triton_gpu.local_alloc %51 : (tensor<128x128xf16, #blocked1>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> + %53 = tt.trans %52 {order = array} : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory> + %54 = triton_nvidia_gpu.warp_group_dot %26, %53, %cst_2 {inputPrecision = 0 : i32, loop.stage = 2 : i32, loop.cluster = 0 : i32} : !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<128x128xf16, #shared1, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma> + %55 = "tt.reduce"(%54) <{axis = 1 : i32}> ({ + ^bb0(%arg28: f32 loc(unknown), %arg29: f32 loc(unknown)): + %80 = arith.maxnumf %arg28, %arg29 : f32 + tt.reduce.return %80 : f32 + }) : (tensor<128x128xf32, #mma>) -> tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %56 = arith.mulf %55, %28 : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> + %58 = arith.mulf %54, %29 : tensor<128x128xf32, #mma> + %59 = tt.expand_dims %56 {axis = 1 : i32} : tensor<128xf32, #triton_gpu.slice<{dim = 1, parent = #mma}>> -> tensor<128x1xf32, #mma> + %60 = tt.broadcast %59 : tensor<128x1xf32, #mma> -> tensor<128x128xf32, #mma> + %61 = arith.subf %58, %60 : tensor<128x128xf32, #mma> + %62 = math.exp2 %61 : tensor<128x128xf32, #mma> + %71 = arith.divsi %arg11, %30 : i64 + %72 = arith.extsi %arg24 : i32 to i64 + %73 = arith.addi %71, %72 : i64 + %74 = arith.trunci %73 : i64 to i32 + %75 = tt.experimental_descriptor_load %arg5[%74, %c0_i32] {loop.stage = 1 : i32, loop.cluster = 1 : i32} : !tt.ptr -> tensor<128x128xf16, #blocked1> + %76 = triton_gpu.local_alloc %75 : (tensor<128x128xf16, #blocked1>) -> !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> + %77 = arith.truncf %62 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %78 = triton_gpu.convert_layout %77 : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> + %79 = triton_nvidia_gpu.warp_group_dot %78, %76, %arg26 {inputPrecision = 0 : i32, loop.stage = 3 : i32, loop.cluster = 0 : i32} : tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma}>> * !tt.memdesc<128x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x128xf32, #mma> + scf.yield %79 : tensor<128x128xf32, #mma> + } {tt.divisibility_arg1 = dense<128> : tensor<1xi32>} + %42 = arith.truncf %31#0 : tensor<128x128xf32, #mma> to tensor<128x128xf16, #mma> + %43 = triton_gpu.convert_layout %42 : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #blocked1> + tt.experimental_descriptor_store %arg8[%arg10, %c0_i32], %43 : !tt.ptr, tensor<128x128xf16, #blocked1> + tt.return + } +} diff --git a/test/TritonNvidiaGPU/WarpSpecialization/async_propagate.mlir b/test/TritonNvidiaGPU/WarpSpecialization/async_propagate.mlir new file mode 100644 index 000000000..1cca80d21 --- /dev/null +++ b/test/TritonNvidiaGPU/WarpSpecialization/async_propagate.mlir @@ -0,0 +1,63 @@ +// RUN: triton-opt %s -split-input-file --triton-gpu-taskid-propagate=num-consumer-groups=1 | FileCheck %s + +// CHECK-LABEL: @async_kernel +// CHECK: %0 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 +// CHECK: %5 = tt.splat %arg2 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1024xi32> +// CHECK: %9 = tt.load %8, %6 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> +// CHECK: %10 = tt.splat %arg1 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<1024x!tt.ptr> +// CHECK: tt.store %11, %9 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr> + +module { + tt.func public @async_kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.splat %1 : i32 -> tensor<1024xi32> + %4 = arith.addi %3, %2 : tensor<1024xi32> + %5 = tt.splat %arg2 : i32 -> tensor<1024xi32> + %6 = arith.cmpi slt, %4, %5 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %4 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.load %8, %6 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> + %10 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %11 = tt.addptr %10, %4 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %11, %9 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr> + tt.return + } +} + +// ----- + +// CHECK-LABEL: @two_consumers +// CHECK: tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 +// CHECK: tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} +// CHECK: tt.load {{.*}} {async_task_id = dense<0> : vector<1xi32>} +// CHECK: tt.load {{.*}} {async_task_id = dense<0> : vector<1xi32>} +// CHECK: tt.splat %arg1 {async_task_id = dense<[1, 2]> : vector<2xi32>} +// CHECK: tt.store {{.*}} {async_task_id = dense<1> : vector<1xi32>} +// CHECK: tt.store {{.*}} {async_task_id = dense<2> : vector<1xi32>} + +module { + tt.func public @two_consumers(%arg0: !tt.ptr, %arg1: !tt.ptr) attributes {noinline = false} { + %c1024_i32 = arith.constant 1024 : i32 + %0 = tt.get_program_id x : i32 + %1 = arith.muli %0, %c1024_i32 : i32 + %2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32> + %3 = tt.make_range {end = 2048 : i32, start = 1024 : i32} : tensor<1024xi32> + %4 = tt.splat %1 : i32 -> tensor<1024xi32> + %5 = arith.addi %4, %2 : tensor<1024xi32> + %6 = arith.addi %4, %3 : tensor<1024xi32> + %7 = tt.splat %arg0 : !tt.ptr -> tensor<1024x!tt.ptr> + %8 = tt.addptr %7, %5 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %9 = tt.addptr %7, %6 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %10 = tt.load %8 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> + %11 = tt.load %9 {async_task_id = dense<0> : vector<1xi32>} : tensor<1024x!tt.ptr> + %12 = tt.splat %arg1 : !tt.ptr -> tensor<1024x!tt.ptr> + %13 = tt.addptr %12, %5 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + %14 = tt.addptr %12, %6 {async_task_id = dense<2> : vector<1xi32>} : tensor<1024x!tt.ptr>, tensor<1024xi32> + tt.store %13, %10 {async_task_id = dense<1> : vector<1xi32>} : tensor<1024x!tt.ptr> + tt.store %14, %11 {async_task_id = dense<2> : vector<1xi32>} : tensor<1024x!tt.ptr> + tt.return + } +} diff --git a/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir b/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir new file mode 100644 index 000000000..0461ce39b --- /dev/null +++ b/test/TritonNvidiaGPU/WarpSpecialization/ws_code_partition.mlir @@ -0,0 +1,306 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-code-partition=num-buffers=1 | FileCheck %s + +// CHECK-LABEL: @matmul_kernel_one_consumer +// CHECK: %[[#TASKID:]] = triton_nvidia_gpu.get_async_task_id : i32 +// CHECK: %c0_i32 = arith.constant 0 : i32 +// CHECK: %[[#WG0:]] = arith.cmpi eq, %[[#TASKID]], %c0_i32 : i32 +// CHECK: scf.if %[[#WG0]] +// CHECK: triton_nvidia_gpu.reg_dealloc 40 +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: %c1_i32 = arith.constant 1 : i32 +// CHECK: %[[#WG1:]] = arith.cmpi eq, %[[#TASKID]], %c1_i32 : i32 +// CHECK: scf.if %[[#WG1]] +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_gpu.local_load +// CHECK: triton_gpu.local_load +// CHECK: tt.dot +// CHECK: triton_nvidia_gpu.consumer_release + + +#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_one_consumer(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant {async_task_id = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #blocked> + %c255_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 255 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %cst_0 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<256x128xf16, #blocked1> + %cst_1 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<128x256xf16, #blocked2> + %c8_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 256 : i32 + %cst_2 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<256> : tensor<128x256xi32, #blocked2> + %0 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %1 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %2 = arith.divsi %1, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %3 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %4 = arith.divsi %3, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %5 = arith.muli %4, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %6 = arith.divsi %0, %5 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %7 = arith.muli %6, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.subi %2, %7 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.minsi %8, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.remsi %0, %5 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.remsi %10, %9 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.addi %7, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.divsi %10, %9 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.muli %12, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %16 = tt.make_range {async_task_id = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %17 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %18 = tt.splat %14 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %19 = tt.splat %14 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %20 = arith.addi %18, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %21 = arith.addi %19, %16 {async_task_id = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %22 = tt.splat %arg3 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %23 = arith.remsi %20, %22 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> + %24 = arith.muli %13, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %25 = tt.splat %24 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %26 = arith.addi %25, %17 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %27 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %28 = arith.remsi %26, %27 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %29 = tt.expand_dims %23 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked2}>> -> tensor<128x1xi32, #blocked2> + %30 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked2> + %31 = arith.muli %29, %30 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked2> + %32 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> + %33 = tt.expand_dims %32 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked2}>> -> tensor<1x256xi32, #blocked2> + %34 = tt.broadcast %31 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %35 = tt.broadcast %33 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked2> -> tensor<128x256xi32, #blocked2> + %36 = arith.addi %34, %35 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256xi32, #blocked2> + %37 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x256x!tt.ptr, #blocked2> + %38 = tt.addptr %37, %36 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %39 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %40 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %41 = tt.expand_dims %39 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %42 = tt.expand_dims %40 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<256x1xi32, #blocked1> + %43 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked1> + %44 = arith.muli %41, %43 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked1> + %45 = tt.expand_dims %28 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %46 = tt.broadcast %44 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked1> -> tensor<256x128xi32, #blocked1> + %47 = tt.broadcast %45 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked1> -> tensor<256x128xi32, #blocked1> + %48 = arith.addi %46, %47 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128xi32, #blocked1> + %49 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked1> + %50 = tt.addptr %49, %48 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked1>, tensor<256x128xi32, #blocked1> + %51 = arith.addi %arg5, %c255_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %52 = arith.divsi %51, %c256_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %53 = arith.muli %arg7, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %54 = tt.splat %53 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x128xi32, #blocked1> + %55:3 = scf.for %arg9 = %c0_i32 to %52 step %c1_i32 iter_args(%arg10 = %cst, %arg11 = %38, %arg12 = %50) -> (tensor<128x128xf32, #blocked>, tensor<128x256x!tt.ptr, #blocked2>, tensor<256x128x!tt.ptr, #blocked1>) : i32 { + %74 = arith.muli %arg9, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %75 = arith.subi %arg5, %74 {async_task_id = dense<0> : vector<1xi32>} : i32 + %76 = tt.splat %75 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x256xi32, #blocked2> + %77 = arith.cmpi slt, %33, %76 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked2> + %78 = tt.broadcast %77 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi1, #blocked2> -> tensor<128x256xi1, #blocked2> + %79 = tt.load %arg11, %78, %cst_1 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked2> + %80 = tt.splat %75 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked1> + %81 = arith.cmpi slt, %42, %80 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked1> + %82 = tt.broadcast %81 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi1, #blocked1> -> tensor<256x128xi1, #blocked1> + %83 = tt.load %arg12, %82, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked1> + %84 = triton_gpu.convert_layout %79 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #blocked2> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> + %85 = triton_gpu.convert_layout %83 {async_task_id = dense<1> : vector<1xi32>} : tensor<256x128xf16, #blocked1> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> + %86 = tt.dot %84, %85, %arg10, inputPrecision = tf32 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x128xf32, #blocked> + %87 = tt.addptr %arg11, %cst_2 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked2>, tensor<128x256xi32, #blocked2> + %88 = tt.addptr %arg12, %54 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked1>, tensor<256x128xi32, #blocked1> + scf.yield {async_task_id = dense<[0, 1]> : vector<2xi32>} %86, %87, %88 : tensor<128x128xf32, #blocked>, tensor<128x256x!tt.ptr, #blocked2>, tensor<256x128x!tt.ptr, #blocked1> + } {async_task_id = dense<[0, 1]> : vector<2xi32>} + %56 = arith.truncf %55#0 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf32, #blocked> to tensor<128x128xf16, #blocked> + %57 = tt.expand_dims %21 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %58 = tt.splat %arg8 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %59 = arith.muli %58, %57 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + %60 = tt.splat %arg2 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %61 = tt.addptr %60, %59 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %62 = tt.expand_dims %26 {async_task_id = dense<1> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %63 = tt.broadcast %61 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x128x!tt.ptr, #blocked1> + %64 = tt.broadcast %62 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked1> -> tensor<128x128xi32, #blocked1> + %65 = tt.addptr %63, %64 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked1>, tensor<128x128xi32, #blocked1> + %66 = tt.splat %arg3 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %67 = arith.cmpi slt, %57, %66 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + %68 = tt.splat %arg4 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<1x128xi32, #blocked1> + %69 = arith.cmpi slt, %62, %68 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked1> + %70 = tt.broadcast %67 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi1, #blocked1> -> tensor<128x128xi1, #blocked1> + %71 = tt.broadcast %69 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi1, #blocked1> -> tensor<128x128xi1, #blocked1> + %72 = arith.andi %70, %71 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked1> + %73 = triton_gpu.convert_layout %56 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf16, #blocked> -> tensor<128x128xf16, #blocked1> + tt.store %65, %73, %72 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked1> + tt.return + } +} + +// ----- + + +// CHECK-LABEL: @matmul_kernel_two_consumers +// CHECK: scf.if +// CHECK: triton_nvidia_gpu.reg_dealloc 40 +// CHECK: scf.for +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_nvidia_gpu.producer_acquire +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: triton_nvidia_gpu.producer_commit +// CHECK: scf.if +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_nvidia_gpu.warp_group_dot +// CHECK: triton_nvidia_gpu.consumer_release +// CHECK: triton_nvidia_gpu.consumer_release +// CHECK: scf.if +// CHECK: triton_nvidia_gpu.reg_alloc 232 +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_nvidia_gpu.consumer_wait +// CHECK: triton_nvidia_gpu.warp_group_dot +// CHECK: triton_nvidia_gpu.consumer_release +// CHECK: triton_nvidia_gpu.consumer_release + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel_two_consumers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<64> : tensor<64x64xi32, #blocked> + %c64_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 128 : i32 + %c8_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 8 : i32 + %cst_0 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<64x64xf16, #blocked> + %cst_1 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<64x128xf16, #blocked1> + %c0_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 127 : i32 + %c63_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 63 : i32 + %cst_2 = arith.constant {async_task_id = dense<[1, 2]> : vector<2xi32>} dense<0.000000e+00> : tensor<64x128xf32, #mma> + %0 = tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %1 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %2 = arith.divsi %1, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %3 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %4 = arith.divsi %3, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %5 = arith.muli %4, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %6 = arith.divsi %0, %5 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %7 = arith.muli %6, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %8 = arith.subi %2, %7 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %9 = arith.minsi %8, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %10 = arith.remsi %0, %5 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %11 = arith.remsi %10, %9 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %12 = arith.addi %7, %11 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %13 = arith.divsi %10, %9 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %14 = arith.muli %12, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %15 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %16 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %17 = tt.splat %14 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %18 = tt.splat %14 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %19 = arith.addi %17, %15 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %20 = arith.addi %18, %16 {async_task_id = dense<1> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %21 = tt.splat %arg3 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %22 = arith.remsi %19, %21 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %23 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 128 : i32, start = 64 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %24 = tt.make_range {async_task_id = dense<2> : vector<1xi32>, end = 128 : i32, start = 64 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %25 = arith.addi %17, %23 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %26 = arith.addi %18, %24 {async_task_id = dense<2> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.remsi %25, %21 {async_task_id = dense<0> : vector<1xi32>} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %28 = arith.muli %13, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %29 = tt.make_range {async_task_id = dense<[0, 1, 2]> : vector<3xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %30 = tt.splat %28 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %31 = arith.addi %30, %29 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %32 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %33 = arith.remsi %31, %32 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %34 = tt.expand_dims %22 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %35 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked> + %36 = arith.muli %34, %35 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %37 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %38 = tt.expand_dims %37 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %39 = tt.broadcast %36 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked> + %40 = tt.broadcast %38 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> -> tensor<64x64xi32, #blocked> + %41 = arith.addi %39, %40 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64xi32, #blocked> + %42 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x64x!tt.ptr, #blocked> + %43 = tt.addptr %42, %41 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %44 = tt.expand_dims %27 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<64x1xi32, #blocked> + %45 = arith.muli %44, %35 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> + %46 = tt.broadcast %45 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked> -> tensor<64x64xi32, #blocked> + %47 = arith.addi %46, %40 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64xi32, #blocked> + %48 = tt.addptr %42, %47 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %49 = tt.expand_dims %16 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %50 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %51 = arith.muli %49, %50 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %52 = tt.expand_dims %33 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %53 = tt.broadcast %51 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> -> tensor<64x128xi32, #blocked1> + %54 = tt.broadcast %52 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1> + %55 = arith.addi %53, %54 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128xi32, #blocked1> + %56 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x128x!tt.ptr, #blocked1> + %57 = tt.addptr %56, %55 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + %58 = arith.addi %arg5, %c63_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %59 = arith.divsi %58, %c64_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %60 = tt.expand_dims %37 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %61 = tt.expand_dims %16 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %62 = arith.muli %arg7, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %63 = tt.splat %62 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x128xi32, #blocked1> + %true = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %true_3 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false_4 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %64:5 = scf.for %arg9 = %c0_i32 to %59 step %c1_i32 iter_args(%arg10 = %cst_2, %arg11 = %cst_2, %arg12 = %43, %arg13 = %57, %arg14 = %48) -> (tensor<64x128xf32, #mma>, tensor<64x128xf32, #mma>, tensor<64x64x!tt.ptr, #blocked>, tensor<64x128x!tt.ptr, #blocked1>, tensor<64x64x!tt.ptr, #blocked>) : i32 { + %93 = arith.muli %arg9, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %94 = arith.subi %arg5, %93 {async_task_id = dense<0> : vector<1xi32>} : i32 + %95 = tt.splat %94 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x64xi32, #blocked> + %96 = arith.cmpi slt, %60, %95 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> + %97 = tt.broadcast %96 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi1, #blocked> -> tensor<64x64xi1, #blocked> + %98 = tt.load %arg12, %97, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked> + %99 = triton_gpu.local_alloc %98 {async_task_id = dense<1> : vector<1xi32>} : (tensor<64x64xf16, #blocked>) -> !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> + %100 = tt.splat %94 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %101 = arith.cmpi slt, %61, %100 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %102 = tt.broadcast %101 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %103 = tt.load %arg13, %102, %cst_1 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1> + %104 = triton_gpu.local_alloc %103 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x128xf16, #blocked1>) -> !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> + %105 = tt.load %arg14, %97, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked> + %106 = triton_gpu.local_alloc %105 {async_task_id = dense<2> : vector<1xi32>} : (tensor<64x64xf16, #blocked>) -> !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> + %107 = triton_nvidia_gpu.warp_group_dot %99, %104, %arg10 {async_task_id = dense<1> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf32, #mma> + %108 = triton_nvidia_gpu.warp_group_dot %106, %104, %arg11 {async_task_id = dense<2> : vector<1xi32>, inputPrecision = 0 : i32} : !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x128xf32, #mma> + %109 = tt.addptr %arg12, %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %110 = tt.addptr %arg14, %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<64x64x!tt.ptr, #blocked>, tensor<64x64xi32, #blocked> + %111 = tt.addptr %arg13, %63 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %107, %108, %109, %111, %110 : tensor<64x128xf32, #mma>, tensor<64x128xf32, #mma>, tensor<64x64x!tt.ptr, #blocked>, tensor<64x128x!tt.ptr, #blocked1>, tensor<64x64x!tt.ptr, #blocked> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %65 = arith.truncf %64#0 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf32, #mma> to tensor<64x128xf16, #mma> + %66 = arith.truncf %64#1 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf32, #mma> to tensor<64x128xf16, #mma> + %67 = tt.expand_dims %20 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %68 = tt.splat %arg8 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %69 = arith.muli %68, %67 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %70 = tt.splat %arg2 {async_task_id = dense<[1, 2]> : vector<2xi32>} : !tt.ptr -> tensor<64x1x!tt.ptr, #blocked1> + %71 = tt.addptr %70, %69 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %72 = tt.expand_dims %31 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x128xi32, #blocked1> + %73 = tt.broadcast %71 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x128x!tt.ptr, #blocked1> + %74 = tt.broadcast %72 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi32, #blocked1> -> tensor<64x128xi32, #blocked1> + %75 = tt.addptr %73, %74 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + %76 = tt.expand_dims %26 {async_task_id = dense<2> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %77 = arith.muli %68, %76 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %78 = tt.addptr %70, %77 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1>, tensor<64x1xi32, #blocked1> + %79 = tt.broadcast %78 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1x!tt.ptr, #blocked1> -> tensor<64x128x!tt.ptr, #blocked1> + %80 = tt.addptr %79, %74 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1>, tensor<64x128xi32, #blocked1> + %81 = tt.splat %arg3 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %82 = arith.cmpi slt, %67, %81 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %83 = tt.splat %arg4 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<1x128xi32, #blocked1> + %84 = arith.cmpi slt, %72, %83 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi32, #blocked1> + %85 = tt.broadcast %82 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %86 = tt.broadcast %84 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x128xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %87 = arith.andi %85, %86 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xi1, #blocked1> + %88 = arith.cmpi slt, %76, %81 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %89 = tt.broadcast %88 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x128xi1, #blocked1> + %90 = arith.andi %89, %86 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xi1, #blocked1> + %91 = triton_gpu.convert_layout %65 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128xf16, #mma> -> tensor<64x128xf16, #blocked1> + tt.store %75, %91, %87 {async_task_id = dense<1> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1> + %92 = triton_gpu.convert_layout %66 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128xf16, #mma> -> tensor<64x128xf16, #blocked1> + tt.store %80, %92, %90 {async_task_id = dense<2> : vector<1xi32>} : tensor<64x128x!tt.ptr, #blocked1> + tt.return + } +} diff --git a/test/TritonNvidiaGPU/WarpSpecialization/ws_data_partition.mlir b/test/TritonNvidiaGPU/WarpSpecialization/ws_data_partition.mlir new file mode 100644 index 000000000..3816f5bc4 --- /dev/null +++ b/test/TritonNvidiaGPU/WarpSpecialization/ws_data_partition.mlir @@ -0,0 +1,136 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-data-partition=num-consumer-groups=2 | FileCheck %s + +// CHECK-LABEL: @matmul_persistent_ws_cooperative_kernel +// CHECK: %[[#GA1:]] = tt.load {{.*}} : tensor<64x64x!tt.ptr +// CHECK: %[[#GA2:]] = tt.load {{.*}} : tensor<64x64x!tt.ptr +// CHECK: %[[#LA1:]] = triton_gpu.local_alloc %[[#GA1]] +// CHECK: %[[#LA2:]] = triton_gpu.local_alloc %[[#GA2]] +// CHECK: %[[#GB:]] = tt.load {{.*}} : tensor<64x256x!tt.ptr +// CHECK: %[[#LB:]] = triton_gpu.local_alloc %[[#GB]] +// CHECK: %[[#C1:]] = triton_nvidia_gpu.warp_group_dot %[[#LA1]], %[[#LB]], {{.*}} : !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x256xf32, #mma> +// CHECK: %[[#C2:]] = triton_nvidia_gpu.warp_group_dot %[[#LA2]], %[[#LB]], {{.*}} : !tt.memdesc<64x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<64x256xf32, #mma> +// CHECK: tt.store {{.*}} : tensor<64x256x!tt.ptr, #blocked1> +// CHECK: tt.store {{.*}} : tensor<64x256x!tt.ptr, #blocked1> + + + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 256, 16]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_persistent_ws_cooperative_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %cst = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<64> : tensor<128x64xi32, #blocked> + %c0_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 0 : i32 + %c1_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 1 : i32 + %c255_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 255 : i32 + %c63_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 63 : i32 + %c64_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 64 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 256 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 128 : i32 + %c8_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 8 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} 127 : i32 + %cst_0 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<128x64xf16, #blocked> + %cst_1 = arith.constant {async_task_id = dense<0> : vector<1xi32>} dense<0.000000e+00> : tensor<64x256xf16, #blocked1> + %cst_2 = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} dense<0.000000e+00> : tensor<128x256xf32, #mma> + %0 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %1 = arith.divsi %0, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %2 = arith.addi %arg4, %c255_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %3 = arith.divsi %2, %c256_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %4 = arith.muli %1, %3 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %5 = tt.get_program_id x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %6 = tt.get_num_programs x {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %7 = arith.muli %3, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %8 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %9 = tt.make_range {async_task_id = dense<[1, 2]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %10 = tt.splat %arg3 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %11 = tt.make_range {async_task_id = dense<[0, 1, 2]> : vector<3xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %12 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %13 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked> + %14 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %15 = tt.expand_dims %14 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %16 = tt.broadcast %15 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> -> tensor<128x64xi32, #blocked> + %17 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x64x!tt.ptr, #blocked> + %18 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 64 : i32, start = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %19 = tt.expand_dims %18 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %20 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %21 = arith.muli %19, %20 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %22 = tt.broadcast %21 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %23 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<64x256x!tt.ptr, #blocked1> + %24 = arith.addi %arg5, %c63_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %25 = arith.divsi %24, %c64_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %26 = tt.expand_dims %14 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x64xi32, #blocked> + %27 = tt.expand_dims %18 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<64xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<64x1xi32, #blocked1> + %28 = arith.muli %arg7, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %29 = tt.splat %28 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x256xi32, #blocked1> + %30 = tt.splat %arg8 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %31 = tt.splat %arg2 {async_task_id = dense<[1, 2]> : vector<2xi32>} : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked1> + %32 = tt.splat %arg3 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %33 = tt.splat %arg4 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<1x256xi32, #blocked1> + scf.for %arg9 = %5 to %4 step %6 : i32 { + %34 = arith.divsi %arg9, %7 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %35 = arith.muli %34, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %36 = arith.subi %1, %35 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %37 = arith.minsi %36, %c8_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %38 = arith.remsi %arg9, %7 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %39 = arith.remsi %38, %37 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %40 = arith.addi %35, %39 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %41 = arith.divsi %38, %37 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %42 = arith.muli %40, %c128_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %43 = tt.splat %42 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %44 = tt.splat %42 {async_task_id = dense<[1, 2]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %45 = arith.addi %43, %8 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %46 = arith.addi %44, %9 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %47 = arith.remsi %45, %10 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %48 = arith.muli %41, %c256_i32 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 + %49 = tt.splat %48 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : i32 -> tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %50 = arith.addi %49, %11 {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %51 = arith.remsi %50, %12 {async_task_id = dense<0> : vector<1xi32>} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %52 = tt.expand_dims %47 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %53 = arith.muli %52, %13 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked> + %54 = tt.broadcast %53 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked> -> tensor<128x64xi32, #blocked> + %55 = arith.addi %54, %16 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64xi32, #blocked> + %56 = tt.addptr %17, %55 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %57 = tt.expand_dims %51 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %58 = tt.broadcast %57 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked1> -> tensor<64x256xi32, #blocked1> + %59 = arith.addi %22, %58 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256xi32, #blocked1> + %60 = tt.addptr %23, %59 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + %true = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} true + %false = arith.constant {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} false + %61:3 = scf.for %arg10 = %c0_i32 to %25 step %c1_i32 iter_args(%arg11 = %cst_2, %arg12 = %56, %arg13 = %60) -> (tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1>) : i32 { + %76 = arith.muli %arg10, %c64_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %77 = arith.subi %arg5, %76 {async_task_id = dense<0> : vector<1xi32>} : i32 + %78 = tt.splat %77 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x64xi32, #blocked> + %79 = arith.cmpi slt, %26, %78 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi32, #blocked> + %80 = tt.broadcast %79 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x64xi1, #blocked> -> tensor<128x64xi1, #blocked> + %81 = tt.load %arg12, %80, %cst_0 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64x!tt.ptr, #blocked> + %82 = triton_gpu.local_alloc %81 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<128x64xf16, #blocked>) -> !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> + %83 = tt.splat %77 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<64x1xi32, #blocked1> + %84 = arith.cmpi slt, %27, %83 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi32, #blocked1> + %85 = tt.broadcast %84 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x1xi1, #blocked1> -> tensor<64x256xi1, #blocked1> + %86 = tt.load %arg13, %85, %cst_1 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256x!tt.ptr, #blocked1> + %87 = triton_gpu.local_alloc %86 {async_task_id = dense<[1, 2]> : vector<2xi32>} : (tensor<64x256xf16, #blocked1>) -> !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> + %88 = triton_nvidia_gpu.warp_group_dot %82, %87, %arg11 {async_task_id = dense<[1, 2]> : vector<2xi32>, inputPrecision = 0 : i32} : !tt.memdesc<128x64xf16, #shared, #triton_gpu.shared_memory> * !tt.memdesc<64x256xf16, #shared, #triton_gpu.shared_memory> -> tensor<128x256xf32, #mma> + %89 = tt.addptr %arg12, %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<128x64x!tt.ptr, #blocked>, tensor<128x64xi32, #blocked> + %90 = tt.addptr %arg13, %29 {async_task_id = dense<0> : vector<1xi32>} : tensor<64x256x!tt.ptr, #blocked1>, tensor<64x256xi32, #blocked1> + scf.yield {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} %88, %89, %90 : tensor<128x256xf32, #mma>, tensor<128x64x!tt.ptr, #blocked>, tensor<64x256x!tt.ptr, #blocked1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + %62 = arith.truncf %61#0 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256xf32, #mma> to tensor<128x256xf16, #mma> + %63 = tt.expand_dims %46 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %64 = arith.muli %30, %63 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1xi32, #blocked1> + %65 = tt.addptr %31, %64 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1x!tt.ptr, #blocked1>, tensor<128x1xi32, #blocked1> + %66 = tt.expand_dims %50 {async_task_id = dense<[1, 2]> : vector<2xi32>, axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %67 = tt.broadcast %65 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1x!tt.ptr, #blocked1> -> tensor<128x256x!tt.ptr, #blocked1> + %68 = tt.broadcast %66 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %69 = tt.addptr %67, %68 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %70 = arith.cmpi slt, %63, %32 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1xi32, #blocked1> + %71 = arith.cmpi slt, %66, %33 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x256xi32, #blocked1> + %72 = tt.broadcast %70 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x1xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + %73 = tt.broadcast %71 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + %74 = arith.andi %72, %73 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256xi1, #blocked1> + %75 = triton_gpu.convert_layout %62 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256xf16, #mma> -> tensor<128x256xf16, #blocked1> + tt.store %69, %75, %74 {async_task_id = dense<[1, 2]> : vector<2xi32>} : tensor<128x256x!tt.ptr, #blocked1> + } {async_task_id = dense<[0, 1, 2]> : vector<3xi32>} + tt.return + } +} diff --git a/test/TritonNvidiaGPU/WarpSpecialization/ws_lowering.mlir b/test/TritonNvidiaGPU/WarpSpecialization/ws_lowering.mlir new file mode 100644 index 000000000..de69a59b8 --- /dev/null +++ b/test/TritonNvidiaGPU/WarpSpecialization/ws_lowering.mlir @@ -0,0 +1,237 @@ +// RUN: triton-opt %s -split-input-file --tritongpu-warp-spec-lowering=num-consumer-groups=1 | FileCheck %s + +// CHECK: %[[#PBARRIER:]] = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64 +// CHECK: %[[#CBARRIER:]] = triton_gpu.local_alloc : () -> !tt.memdesc<1xi64 +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#PBARRIER]][%c0_i32] +// CHECK: triton_nvidia_gpu.init_barrier %[[#]], 128 +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#CBARRIER]][%c0_i32] +// CHECK: triton_nvidia_gpu.init_barrier %[[#]], 1 +// CHECK: scf.for +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#CBARRIER]] +// CHECK: triton_nvidia_gpu.wait_barrier %[[#]] +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: triton_gpu.async_copy_global_to_local +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#PBARRIER]] +// CHECK: triton_nvidia_gpu.mbarrier_arrive %[[#]] +// CHECK: scf.for +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#PBARRIER]] +// CHECK: triton_nvidia_gpu.wait_barrier %[[#]] +// CHECK: triton_gpu.local_load +// CHECK: triton_gpu.local_load +// CHECK: tt.dot +// CHECK: %[[#]] = triton_gpu.memdesc_subview %[[#CBARRIER]] +// CHECK: triton_nvidia_gpu.mbarrier_arrive %[[#]] + + + +#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = true}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} { + tt.func public @matmul_kernel(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = triton_gpu.local_alloc : () -> !tt.memdesc<1x128x256xf16, #shared, #triton_gpu.shared_memory, mutable> + %1 = triton_gpu.local_alloc : () -> !tt.memdesc<1x256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + %2 = triton_nvidia_gpu.create_token {num = 1 : i32} : tensor<1x!triton_nvidia_gpu.token> + %3 = triton_nvidia_gpu.get_async_task_id : i32 + %c0_i32 = arith.constant 0 : i32 + %4 = arith.cmpi eq, %3, %c0_i32 : i32 + scf.if %4 { + %c255_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 255 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1_i32_0 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_1 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %cst = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<256x128xf16, #blocked> + %cst_2 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<128x256xf16, #blocked1> + %c8_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 256 : i32 + %cst_3 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<256> : tensor<128x256xi32, #blocked1> + %6 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %7 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.divsi %7, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.divsi %9, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.muli %10, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.divsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.muli %12, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.subi %8, %13 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = arith.minsi %14, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %16 = arith.remsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %17 = arith.remsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %18 = arith.addi %13, %17 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %19 = arith.divsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %20 = arith.muli %18, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %21 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %22 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %23 = tt.splat %20 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %24 = arith.addi %23, %21 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %25 = tt.splat %arg3 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %26 = arith.remsi %24, %25 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.muli %19, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %28 = tt.splat %27 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %29 = arith.addi %28, %22 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %30 = tt.splat %arg4 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %31 = arith.remsi %29, %30 {async_task_id = dense<0> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %32 = tt.expand_dims %26 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> -> tensor<128x1xi32, #blocked1> + %33 = tt.splat %arg6 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked1> + %34 = arith.muli %32, %33 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1> + %35 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> + %36 = tt.expand_dims %35 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 0, parent = #blocked1}>> -> tensor<1x256xi32, #blocked1> + %37 = tt.broadcast %34 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x1xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %38 = tt.broadcast %36 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked1> -> tensor<128x256xi32, #blocked1> + %39 = arith.addi %37, %38 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256xi32, #blocked1> + %40 = tt.splat %arg0 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<128x256x!tt.ptr, #blocked1> + %41 = tt.addptr %40, %39 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %42 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %43 = tt.make_range {async_task_id = dense<0> : vector<1xi32>, end = 256 : i32, start = 0 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %44 = tt.expand_dims %42 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %45 = tt.expand_dims %43 {async_task_id = dense<0> : vector<1xi32>, axis = 1 : i32} : tensor<256xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<256x1xi32, #blocked> + %46 = tt.splat %arg7 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked> + %47 = arith.muli %44, %46 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked> + %48 = tt.expand_dims %31 {async_task_id = dense<0> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %49 = tt.broadcast %47 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked> -> tensor<256x128xi32, #blocked> + %50 = tt.broadcast %48 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<256x128xi32, #blocked> + %51 = arith.addi %49, %50 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128xi32, #blocked> + %52 = tt.splat %arg1 {async_task_id = dense<0> : vector<1xi32>} : !tt.ptr -> tensor<256x128x!tt.ptr, #blocked> + %53 = tt.addptr %52, %51 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked>, tensor<256x128xi32, #blocked> + %54 = arith.addi %arg5, %c255_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %55 = arith.divsi %54, %c256_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %56 = arith.muli %arg7, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %57 = tt.splat %56 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x128xi32, #blocked> + %c1_i32_4 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_5 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %false = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} false + %58:4 = scf.for %arg9 = %c0_i32_1 to %55 step %c1_i32_0 iter_args(%arg10 = %41, %arg11 = %53, %arg12 = %false, %arg13 = %c0_i32_5) -> (tensor<128x256x!tt.ptr, #blocked1>, tensor<256x128x!tt.ptr, #blocked>, i1, i32) : i32 { + %59 = arith.muli %arg9, %c256_i32 {async_task_id = dense<0> : vector<1xi32>} : i32 + %60 = arith.subi %arg5, %59 {async_task_id = dense<0> : vector<1xi32>} : i32 + %61 = tt.splat %60 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<1x256xi32, #blocked1> + %62 = arith.cmpi slt, %36, %61 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi32, #blocked1> + %63 = tt.broadcast %62 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x256xi1, #blocked1> -> tensor<128x256xi1, #blocked1> + triton_nvidia_gpu.producer_acquire %2, %arg13, %false {async_task_id = dense<0> : vector<1xi32>} : tensor<1x!triton_nvidia_gpu.token>, i32, i1 + %c0_i32_6 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 0 : i32 + %c1_i32_7 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 1 : i32 + %64 = triton_gpu.memdesc_subview %0[%arg13, %c0_i32_6, %c0_i32_6] {async_task_id = dense<0> : vector<1xi32>} : !tt.memdesc<1x128x256xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x256xf16, #shared, #triton_gpu.shared_memory, mutable> + %65 = triton_gpu.async_copy_global_to_local %arg10, %64 mask %63 other %cst_2 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked1> -> <128x256xf16, #shared, #triton_gpu.shared_memory, mutable> + %66 = tt.splat %60 {async_task_id = dense<0> : vector<1xi32>} : i32 -> tensor<256x1xi32, #blocked> + %67 = arith.cmpi slt, %45, %66 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi32, #blocked> + %68 = tt.broadcast %67 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x1xi1, #blocked> -> tensor<256x128xi1, #blocked> + %c0_i32_8 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 0 : i32 + %c1_i32_9 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 1 : i32 + %69 = triton_gpu.memdesc_subview %1[%arg13, %c0_i32_8, %c0_i32_8] {async_task_id = dense<0> : vector<1xi32>} : !tt.memdesc<1x256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + %70 = triton_gpu.async_copy_global_to_local %arg11, %69 mask %68 other %cst {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked> -> <256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + triton_nvidia_gpu.producer_commit %2, %arg13 {async_task_id = dense<0> : vector<1xi32>} : tensor<1x!triton_nvidia_gpu.token>, i32 + %71 = tt.addptr %arg10, %cst_3 {async_task_id = dense<0> : vector<1xi32>} : tensor<128x256x!tt.ptr, #blocked1>, tensor<128x256xi32, #blocked1> + %72 = tt.addptr %arg11, %57 {async_task_id = dense<0> : vector<1xi32>} : tensor<256x128x!tt.ptr, #blocked>, tensor<256x128xi32, #blocked> + %c1_i32_10 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 1 : i32 + %c0_i32_11 = arith.constant {async_task_id = dense<0> : vector<1xi32>} 0 : i32 + %true = arith.constant {async_task_id = dense<0> : vector<1xi32>} true + %73 = arith.addi %arg13, %c1_i32_10 {async_task_id = dense<0> : vector<1xi32>} : i32 + %74 = arith.cmpi uge, %73, %c1_i32_4 {async_task_id = dense<0> : vector<1xi32>} : i32 + %75 = arith.cmpi ult, %73, %c1_i32_4 {async_task_id = dense<0> : vector<1xi32>} : i32 + %76 = arith.subi %73, %c1_i32_4 {async_task_id = dense<0> : vector<1xi32>} : i32 + %77 = arith.select %74, %76, %73 {async_task_id = dense<0> : vector<1xi32>} : i32 + %78 = arith.xori %arg12, %true {async_task_id = dense<0> : vector<1xi32>} : i1 + %79 = arith.andi %74, %78 {async_task_id = dense<0> : vector<1xi32>} : i1 + %80 = arith.andi %75, %arg12 {async_task_id = dense<0> : vector<1xi32>} : i1 + %81 = arith.ori %79, %80 {async_task_id = dense<0> : vector<1xi32>} : i1 + scf.yield {async_task_id = dense<0> : vector<1xi32>} %71, %72, %81, %77 : tensor<128x256x!tt.ptr, #blocked1>, tensor<256x128x!tt.ptr, #blocked>, i1, i32 + } {async_task_id = dense<0> : vector<1xi32>} + } {async_task_id = dense<0> : vector<1xi32>} + %c1_i32 = arith.constant 1 : i32 + %5 = arith.cmpi eq, %3, %c1_i32 : i32 + scf.if %5 { + %cst = arith.constant {async_task_id = dense<1> : vector<1xi32>} dense<0.000000e+00> : tensor<128x128xf32, #blocked2> + %c255_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 255 : i32 + %c127_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 127 : i32 + %c1_i32_0 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_1 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %cst_2 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<256x128xf16, #blocked> + %cst_3 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<0.000000e+00> : tensor<128x256xf16, #blocked1> + %c8_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 8 : i32 + %c128_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 128 : i32 + %c256_i32 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 256 : i32 + %cst_4 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} dense<256> : tensor<128x256xi32, #blocked1> + %6 = tt.get_program_id x {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %7 = arith.addi %arg3, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %8 = arith.divsi %7, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %9 = arith.addi %arg4, %c127_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %10 = arith.divsi %9, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %11 = arith.muli %10, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %12 = arith.divsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %13 = arith.muli %12, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %14 = arith.subi %8, %13 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %15 = arith.minsi %14, %c8_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %16 = arith.remsi %6, %11 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %17 = arith.remsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %18 = arith.addi %13, %17 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %19 = arith.divsi %16, %15 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %20 = arith.muli %18, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %21 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %22 = tt.make_range {async_task_id = dense<1> : vector<1xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %23 = tt.make_range {async_task_id = dense<[0, 1]> : vector<2xi32>, end = 128 : i32, start = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %24 = tt.splat %20 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %25 = tt.splat %20 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %26 = arith.addi %24, %21 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %27 = arith.addi %25, %22 {async_task_id = dense<1> : vector<1xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> + %28 = tt.splat %arg3 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %29 = arith.remsi %26, %28 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked1}>> + %30 = arith.muli %19, %c128_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %31 = tt.splat %30 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 -> tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %32 = arith.addi %31, %23 {async_task_id = dense<[0, 1]> : vector<2xi32>} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> + %33 = arith.addi %arg5, %c255_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %34 = arith.divsi %33, %c256_i32 {async_task_id = dense<[0, 1]> : vector<2xi32>} : i32 + %c1_i32_5 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 1 : i32 + %c0_i32_6 = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} 0 : i32 + %false = arith.constant {async_task_id = dense<[0, 1]> : vector<2xi32>} false + %35:3 = scf.for %arg9 = %c0_i32_1 to %34 step %c1_i32_0 iter_args(%arg10 = %cst, %arg11 = %false, %arg12 = %c0_i32_6) -> (tensor<128x128xf32, #blocked2>, i1, i32) : i32 { + %c0_i32_7 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 0 : i32 + %c1_i32_8 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 1 : i32 + %c0_i32_9 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 0 : i32 + %c1_i32_10 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 1 : i32 + triton_nvidia_gpu.consumer_wait %2, %arg12, %false {async_task_id = dense<1> : vector<1xi32>} : tensor<1x!triton_nvidia_gpu.token>, i32, i1 + %54 = triton_gpu.memdesc_subview %0[%arg12, %c0_i32_7, %c0_i32_7] {async_task_id = dense<1> : vector<1xi32>} : !tt.memdesc<1x128x256xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<128x256xf16, #shared, #triton_gpu.shared_memory, mutable> + %55 = triton_gpu.local_load %54 {async_task_id = dense<1> : vector<1xi32>} : !tt.memdesc<128x256xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<128x256xf16, #blocked1> + %56 = triton_gpu.convert_layout %55 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #blocked1> -> tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> + %57 = triton_gpu.memdesc_subview %1[%arg12, %c0_i32_9, %c0_i32_9] {async_task_id = dense<1> : vector<1xi32>} : !tt.memdesc<1x256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> + %58 = triton_gpu.local_load %57 {async_task_id = dense<1> : vector<1xi32>} : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory, mutable> -> tensor<256x128xf16, #blocked> + %59 = triton_gpu.convert_layout %58 {async_task_id = dense<1> : vector<1xi32>} : tensor<256x128xf16, #blocked> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> + %60 = tt.dot %56, %59, %arg10, inputPrecision = tf32 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x256xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked2}>> * tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked2}>> -> tensor<128x128xf32, #blocked2> + triton_nvidia_gpu.consumer_release %2, %arg12 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x!triton_nvidia_gpu.token>, i32 + %c1_i32_11 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 1 : i32 + %c0_i32_12 = arith.constant {async_task_id = dense<1> : vector<1xi32>} 0 : i32 + %true = arith.constant {async_task_id = dense<1> : vector<1xi32>} true + %61 = arith.addi %arg12, %c1_i32_11 {async_task_id = dense<1> : vector<1xi32>} : i32 + %62 = arith.cmpi uge, %61, %c1_i32_5 {async_task_id = dense<1> : vector<1xi32>} : i32 + %63 = arith.cmpi ult, %61, %c1_i32_5 {async_task_id = dense<1> : vector<1xi32>} : i32 + %64 = arith.subi %61, %c1_i32_5 {async_task_id = dense<1> : vector<1xi32>} : i32 + %65 = arith.select %62, %64, %61 {async_task_id = dense<1> : vector<1xi32>} : i32 + %66 = arith.xori %arg11, %true {async_task_id = dense<1> : vector<1xi32>} : i1 + %67 = arith.andi %62, %66 {async_task_id = dense<1> : vector<1xi32>} : i1 + %68 = arith.andi %63, %arg11 {async_task_id = dense<1> : vector<1xi32>} : i1 + %69 = arith.ori %67, %68 {async_task_id = dense<1> : vector<1xi32>} : i1 + scf.yield {async_task_id = dense<1> : vector<1xi32>} %60, %69, %65 : tensor<128x128xf32, #blocked2>, i1, i32 + } {async_task_id = dense<1> : vector<1xi32>} + %36 = arith.truncf %35#0 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf32, #blocked2> to tensor<128x128xf16, #blocked2> + %37 = tt.expand_dims %27 {async_task_id = dense<1> : vector<1xi32>, axis = 1 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %38 = tt.splat %arg8 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked> + %39 = arith.muli %38, %37 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked> + %40 = tt.splat %arg2 {async_task_id = dense<1> : vector<1xi32>} : !tt.ptr -> tensor<128x1x!tt.ptr, #blocked> + %41 = tt.addptr %40, %39 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked>, tensor<128x1xi32, #blocked> + %42 = tt.expand_dims %32 {async_task_id = dense<1> : vector<1xi32>, axis = 0 : i32} : tensor<128xi32, #triton_gpu.slice<{dim = 0, parent = #blocked}>> -> tensor<1x128xi32, #blocked> + %43 = tt.broadcast %41 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1x!tt.ptr, #blocked> -> tensor<128x128x!tt.ptr, #blocked> + %44 = tt.broadcast %42 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked> -> tensor<128x128xi32, #blocked> + %45 = tt.addptr %43, %44 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked>, tensor<128x128xi32, #blocked> + %46 = tt.splat %arg3 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<128x1xi32, #blocked> + %47 = arith.cmpi slt, %37, %46 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi32, #blocked> + %48 = tt.splat %arg4 {async_task_id = dense<1> : vector<1xi32>} : i32 -> tensor<1x128xi32, #blocked> + %49 = arith.cmpi slt, %42, %48 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi32, #blocked> + %50 = tt.broadcast %47 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x1xi1, #blocked> -> tensor<128x128xi1, #blocked> + %51 = tt.broadcast %49 {async_task_id = dense<1> : vector<1xi32>} : tensor<1x128xi1, #blocked> -> tensor<128x128xi1, #blocked> + %52 = arith.andi %50, %51 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xi1, #blocked> + %53 = triton_gpu.convert_layout %36 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128xf16, #blocked2> -> tensor<128x128xf16, #blocked> + tt.store %45, %53, %52 {async_task_id = dense<1> : vector<1xi32>} : tensor<128x128x!tt.ptr, #blocked> + } {async_task_id = dense<1> : vector<1xi32>} + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 8649911a7..626f41a0e 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -93,7 +93,7 @@ struct ConvertTritonAMDGPUToLLVM int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - // Hack: WSMaterialization may have changed the effective number of warps, + // Hack: WSLowering may have changed the effective number of warps, // in a way that isn't reflected in triton_gpu.num-warps. If so, we have to // respect that here. if (Attribute attr = mod->getAttr("triton_gpu.num-warp-groups-per-cta")) { diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index adfde57b0..6d1122fb0 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -88,6 +88,11 @@ class CUDAOptions: num_warps: int = 4 num_ctas: int = 1 num_stages: int = 3 + num_buffers_warp_spec: int = 0 + num_consumer_groups: int = 0 + reg_dec_producer: int = 0 + reg_inc_consumer: int = 0 + partition_style: int = 0 # maxnreg corresponds to the ptx parameter .maxnreg, which controls the # maximum number of 32-bit registers used by one thread. maxnreg: Optional[int] = None @@ -221,7 +226,14 @@ def make_ttgir(mod, metadata, opt, capability): if capability // 10 >= 8: passes.ttgpuir.add_optimize_accumulator_init(pm) passes.ttgpuir.add_combine_tensor_select_and_if(pm) + passes.ttgpuir.add_taskid_propagate(pm, opt.num_consumer_groups) + passes.ttgpuir.add_ws_data_partition(pm, opt.num_consumer_groups) + passes.ttgpuir.add_ws_code_partition(pm, opt.num_buffers_warp_spec, opt.num_consumer_groups, + opt.reg_dec_producer, opt.reg_inc_consumer) + passes.ttgpuir.add_loop_scheduling(pm, opt.num_stages) passes.ttgpuir.add_pipeline(pm, opt.num_stages) + passes.ttgpuir.add_ping_pong_sync(pm, opt.num_consumer_groups, opt.partition_style) + passes.ttgpuir.add_ws_lowering(pm, opt.num_consumer_groups) passes.ttgpuir.add_prefetch(pm) passes.ttgpuir.add_optimize_dot_operands(pm, capability >= 80) passes.ttgpuir.add_remove_layout_conversions(pm) diff --git a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td index 31b2646db..840e0714c 100644 --- a/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td +++ b/third_party/nvidia/include/Dialect/NVGPU/IR/NVGPUOps.td @@ -39,6 +39,7 @@ def NVGPU_WGMMAFenceOp : NVGPU_Op<"wgmma_fence", []> { let assemblyFormat = "attr-dict"; } + def NVGPU_WGMMACommitGroupOp : NVGPU_Op<"wgmma_commit_group", []> { let assemblyFormat = "attr-dict"; } @@ -52,6 +53,32 @@ def NVGPU_WGMMAWaitGroupOp : NVGPU_Op<"wgmma_wait_group", let assemblyFormat = "$input attr-dict `:` type($input)"; } +def MBarrier_ArriveTypeAttr : I32EnumAttr<"MBarriveType", + "mbarrier arrive type, either 'normal', 'expect_tx', 'cp_async'", + [ + I32EnumAttrCase<"normal", 0>, + I32EnumAttrCase<"cp_async", 1>, + I32EnumAttrCase<"expect_tx", 2>, + I32EnumAttrCase<"remote", 3>, + ]>{ + let cppNamespace = "::mlir::triton::nvgpu"; +} + +def NVGPU_MBarrierArriveOp : NVGPU_Op<"mbarrier_arrive", []> { + let arguments = (ins LLVM_PointerShared:$mbarrier, I1:$pred, Optional:$ctaId, MBarrier_ArriveTypeAttr:$arriveType, DefaultValuedAttr:$txCount); + let assemblyFormat = "$mbarrier `,` $pred (`,` $ctaId^)? attr-dict `:` type($mbarrier)"; +} + +def NVGPU_NamedBarrierArriveOp : NVGPU_Op<"bar_arrive", []> { + let arguments = (ins I32:$bar, I32:$numThreads); + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + +def NVGPU_NamedBarrierWaitOp : NVGPU_Op<"bar_wait", []> { + let arguments = (ins I32:$bar, I32:$numThreads); + let assemblyFormat = "$bar `,` $numThreads attr-dict `:` type(operands)"; +} + def WGMMA_LayoutAttr : I32EnumAttr<"WGMMALayout", "wgmma layout, either 'row' or 'col'", [ @@ -112,4 +139,19 @@ def NVGPU_ClusterCTAIdOp : NVGPU_Op<"cluster_id", [Pure]> { let assemblyFormat = "attr-dict"; } +def NVGPU_CanonicalWarpIdOp : NVGPU_Op<"canonical_warp_id", [Pure]> { + let results = (outs I32:$result); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_RegAllocOp : NVGPU_Op<"reg_alloc", []> { + let arguments = (ins I32Attr: $regCount); + let assemblyFormat = "attr-dict"; +} + +def NVGPU_RegDeallocOp : NVGPU_Op<"reg_dealloc", []> { + let arguments = (ins I32Attr: $regCount); + let assemblyFormat = "attr-dict"; +} + #endif diff --git a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp index 8de0efefc..5a461fb72 100644 --- a/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp +++ b/third_party/nvidia/lib/NVGPUToLLVM/NVGPUToLLVMPass.cpp @@ -38,6 +38,28 @@ const std::string Cluster_Cta_Id_Op = "{\n" "mad.lo.u32 a1, a2, a4, a1; \n" "mad.lo.u32 $0, a1, a3, a0; \n" "}"; +const std::string Reg_Alloc_Op = "setmaxnreg.inc.sync.aligned.u32 #regCount;"; +const std::string Reg_Dealloc_Op = "setmaxnreg.dec.sync.aligned.u32 #regCount;"; + +const std::string Named_Barrier_Arrive_Op = "bar.arrive $0, $1;"; +const std::string Named_Barrier_Wait_Op = "bar.sync $0, $1;"; +const std::string Canonical_Warp_Id_Op = + "{\n" + ".reg .u32 a<5>; \n" + "mov.u32 a0, %tid.x; \n" // x + "mov.u32 a1, %tid.y; \n" // y + "mov.u32 a2, %tid.z; \n" // z + "mov.u32 a3, %ntid.x; \n" // nx + "mov.u32 a4, %ntid.y; \n" // ny + "mad.lo.u32 a1, a2, a4, a1; \n" + "mad.lo.u32 a0, a1, a3, a0; \n" + "shr.u32 a0, a0, 5; \n" + ".reg .b32 %tmp<3>; \n" + "mov.u32 %tmp0, -1; \n" + "mov.u32 %tmp1, 31; \n" + "mov.u32 %tmp2, 0; \n" + "shfl.sync.idx.b32 $0, a0, %tmp2, %tmp1, %tmp0; \n" + "}"; bool isNumber(const std::string &s) { return !s.empty() && std::find_if(s.begin(), s.end(), [](unsigned char c) { @@ -278,6 +300,77 @@ class StoreMatrixOpPattern : public OpRewritePattern { } }; +class MBarrierArriveOpPattern : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ttn::MBarrierArriveOp op, + PatternRewriter &rewriter) const override { + return rewriteAsPtxAsm(op, rewriter, getPtxAsm(op), + getOperandsAndConstraints(op)); + } + + OperandsAndConstraints + getOperandsAndConstraints(ttn::MBarrierArriveOp op) const { + OperandsAndConstraints operandsAndTypes; + Value mbarrier = op.getMbarrier(); + Value pred = op.getPred(); + Value ctaId = op.getCtaId(); + auto arriveType = op.getArriveType(); + + switch (arriveType) { + case ttn::MBarriveType::normal: + case ttn::MBarriveType::cp_async: + case ttn::MBarriveType::expect_tx: + operandsAndTypes.push_back({mbarrier, "r"}); + operandsAndTypes.push_back({pred, "b"}); + break; + case ttn::MBarriveType::remote: + operandsAndTypes.push_back({mbarrier, "r"}); + operandsAndTypes.push_back({ctaId, "r"}); + operandsAndTypes.push_back({pred, "b"}); + break; + default: + llvm::errs() << "Unsupported mbarrier arrive type " << arriveType << "\n"; + llvm_unreachable(""); + break; + } + return operandsAndTypes; + } + + std::string getPtxAsm(ttn::MBarrierArriveOp op) const { + Value ctaId = op.getCtaId(); + auto arriveType = op.getArriveType(); + uint32_t txCount = op.getTxCount(); + std::string ptxAsm; + switch (arriveType) { + case ttn::MBarriveType::normal: + ptxAsm = "@$1 mbarrier.arrive.shared.b64 _, [$0];"; + break; + case ttn::MBarriveType::cp_async: + ptxAsm = "@$1 cp.async.mbarrier.arrive.noinc.shared.b64 [$0];"; + break; + case ttn::MBarriveType::expect_tx: + assert(txCount > 0 && "txCount should be valid"); + ptxAsm = "@$1 mbarrier.arrive.expect_tx.shared.b64 _, [$0], " + + std::to_string(txCount) + ";"; + break; + case ttn::MBarriveType::remote: + assert(ctaId && "ctaId should have a valid value"); + ptxAsm = + " { .reg .b32 remAddr32; \n" + " @$2 mapa.shared::cluster.u32 remAddr32, $0, $1; \n" + " @$2 mbarrier.arrive.shared::cluster.b64 _, [remAddr32]; } \n"; + break; + default: + llvm::errs() << "Unsupported mbarrier arrive type " << arriveType << "\n"; + llvm_unreachable(""); + break; + } + return ptxAsm; + } +}; + class WGMMAWaitGroupOpPattern : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -507,17 +600,25 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase { #define POPULATE_NVGPU_OP(SRC_OP, ASM) \ patterns.add>(context, ASM, Constraints(), \ Constraints()); + POPULATE_NVGPU_OP(ttn::RegAllocOp, Reg_Alloc_Op) POPULATE_NVGPU_OP(ttn::WGMMAFenceOp, Wgmma_Fence_Op) POPULATE_NVGPU_OP(ttn::WGMMACommitGroupOp, Wgmma_Commit_Group_Op) POPULATE_NVGPU_OP(ttn::ClusterWaitOp, Cluster_Wait_Op) + POPULATE_NVGPU_OP(ttn::RegDeallocOp, Reg_Dealloc_Op) #undef POPULATE_NVGPU_OP + patterns.add>( + context, Named_Barrier_Arrive_Op, Constraints(), + Constraints({"r", "r"})); + patterns.add>( + context, Named_Barrier_Wait_Op, Constraints(), Constraints({"r", "r"})); patterns.add>( context, Cluster_Cta_Id_Op, Constraints({"=r"}), Constraints()); + patterns.add>( + context, Canonical_Warp_Id_Op, Constraints({"=r"}), Constraints()); - patterns - .add( - context); + patterns.add(context); if (applyPatternsAndFoldGreedily(mod, std::move(patterns)).failed()) signalPassFailure(); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp index 746b910e1..268d1dbf6 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/BarrierOpToLLVM.cpp @@ -55,6 +55,77 @@ struct BarrierOpConversion } }; +// -------------------------------------------------------------------------- +// -- MBarrier related Ops lowering, to be moved to a separate file --------- +// -------------------------------------------------------------------------- +struct MBarrierArriveOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::MBarrierArriveOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::MBarrierArriveOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto mbarrier = LLVM::getSharedMemoryObjectFromStruct( + op.getLoc(), adaptor.getMbarrier(), + typeConverter->convertType(op.getMbarrier().getType().getElementType()), + rewriter); + + bool trackAsyncOp = op.getTrackAsyncOp(); + triton::nvgpu::MBarriveType type = triton::nvgpu::MBarriveType::normal; + uint32_t txCount = op.getTxCount(); + auto remoteCtaId = adaptor.getRemoteCtaId(); + if (trackAsyncOp) { + type = triton::nvgpu::MBarriveType::cp_async; + } else if (remoteCtaId) { + assert(txCount == 0 && + "remote arrive of transaction mbarrier is not implemented yet"); + type = triton::nvgpu::MBarriveType::remote; + } else if (txCount > 0) { + type = triton::nvgpu::MBarriveType::expect_tx; + } + Value pred = adaptor.getPred(); + if (pred == nullptr) { + pred = int_val(/*width*/ 1, 1); + } + rewriter.replaceOpWithNewOp( + op, mbarrier.getBase(), pred, remoteCtaId, type, txCount); + return success(); + } +}; + +struct NamedBarrierArriveOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::NamedBarrierArriveOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::NamedBarrierArriveOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getBar(), adaptor.getNumThreads()); + return success(); + } +}; + +struct NamedBarrierWaitOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::NamedBarrierWaitOp>::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(triton::nvidia_gpu::NamedBarrierWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + rewriter.replaceOpWithNewOp( + op, adaptor.getBar(), adaptor.getNumThreads()); + return success(); + } +}; + struct FenceAsyncSharedOpConversion : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern< @@ -83,8 +154,18 @@ struct InitBarrierOpConversion typeConverter->convertType(op.getAlloc().getType().getElementType()), rewriter); + auto asyncTaskIds = getAsyncTaskIds(op); + int executingThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + executingThreadId = asyncTaskIds[0] * numWarps * warpSize; + } + auto id = getThreadId(rewriter, loc); - auto pred = icmp_eq(id, i32_val(0)); + auto pred = icmp_eq(id, i32_val(executingThreadId)); ::mlir::triton::PTXBuilder ptxBuilder; const std::string ptx = "@$0 mbarrier.init.shared::cta.b64 [$1], " + std::to_string(op.getCount()) + ";"; @@ -112,8 +193,17 @@ struct InvalBarrierOpConversion typeConverter->convertType(op.getAlloc().getType().getElementType()), rewriter); + auto asyncTaskIds = getAsyncTaskIds(op); + int executingThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + executingThreadId = asyncTaskIds[0] * numWarps * warpSize; + } auto id = getThreadId(rewriter, loc); - Value pred = icmp_eq(id, i32_val(0)); + Value pred = icmp_eq(id, i32_val(executingThreadId)); ::mlir::triton::PTXBuilder ptxBuilder; const std::string ptx = "@$0 mbarrier.inval.shared::cta.b64 [$1];"; auto &barSyncOp = *ptxBuilder.create<>(ptx); @@ -140,8 +230,17 @@ struct BarrierExpectConversion typeConverter->convertType(op.getAlloc().getType().getElementType()), rewriter); + auto asyncTaskIds = getAsyncTaskIds(op); + int executingThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + auto mod = op->getParentOfType(); + int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod); + int warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + executingThreadId = asyncTaskIds[0] * numWarps * warpSize; + } auto id = getThreadId(rewriter, loc); - Value pred = icmp_eq(id, i32_val(0)); + Value pred = icmp_eq(id, i32_val(executingThreadId)); pred = and_(pred, adaptor.getPred()); ::mlir::triton::PTXBuilder ptxBuilder; const std::string ptx = @@ -194,6 +293,9 @@ void mlir::triton::NVIDIA::populateBarrierOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); patterns.add(typeConverter, benefit); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 8fb44ce64..37702f1d6 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -505,7 +505,7 @@ struct ConvertLayoutOpConversion auto multiDimRepId = getMultiDimIndex(repId, numReplicates, outOrd); if (repId != 0) { - barrier(); + insertBarrier(rewriter, op); } if (isLayoutMmaV1(srcLayout)) @@ -517,7 +517,7 @@ struct ConvertLayoutOpConversion multiDimRepId, inVec, paddedRepShape, origRepShape, outOrd, vals, smemBase); - barrier(); + insertBarrier(rewriter, op); if (isLayoutMmaV1(dstLayout)) processReplicaForMMAV1(loc, rewriter, /*stNotRd*/ false, dstTy, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp index d950e0157..73e31104c 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -9,6 +9,8 @@ #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" using namespace mlir; using namespace mlir::triton; @@ -490,10 +492,21 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern, } }; -void createBarrier(ConversionPatternRewriter &rewriter, Location loc, +void createBarrier(ConversionPatternRewriter &rewriter, Operation *op, int numCTAs) { + auto loc = op->getLoc(); if (numCTAs == 1) { - barrier(); + auto barrierOp = barrier(); + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() == 1) { + int asyncTaskId = asyncTaskIds[0]; + int barId = asyncTaskId + nameBarrierIdBegin; + assert(barId < nameBarrierIdEnd); + // TODO: Change hard code style of numThreads. + const int numThreads = 128; + barrierOp->setAttr("bar_id", rewriter.getI64IntegerAttr(barId)); + barrierOp->setAttr("num_threads", rewriter.getI64IntegerAttr(numThreads)); + } } else { rewriter.create(loc, false); rewriter.create(loc); @@ -606,7 +619,7 @@ struct AtomicCASOpConversion st(dstOprStore, valOprStore).predicate(mask); auto ASMReturnTy = void_ty(ctx); ptxBuilderStore.launch(rewriter, loc, ASMReturnTy); - createBarrier(rewriter, loc, numCTAs); + createBarrier(rewriter, op, numCTAs); Value ret = load(valueElemTy, atomPtr); rewriter.replaceOp(op, {ret}); } @@ -778,7 +791,7 @@ struct AtomicRMWOpConversion auto *valOpr = ptxBuilderStore.newOperand(old, tyId); storeShared(ptrOpr, valOpr).predicate(rmwMask); ptxBuilderStore.launch(rewriter, loc, void_ty(ctx)); - createBarrier(rewriter, loc, numCTAs); + createBarrier(rewriter, op, numCTAs); Value ret = load(valueElemTy, atomPtr); rewriter.replaceOp(op, {ret}); } @@ -988,6 +1001,13 @@ struct AsyncTMACopyGlobalToLocalOpConversion if (rank > 1) numCopies = ceil(contigDimSizeInByte, 128); + auto asyncTaskIds = getAsyncTaskIds(op); + int firstThreadId = 0; + if (!asyncTaskIds.empty()) { + assert(asyncTaskIds.size() == 1 && "only support single async task"); + firstThreadId = asyncTaskIds[0] * numWarps * warpSize; + } + // The bounding box inner dimension must be less than or equal to the // swizzle size. // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html#group__CUDA__TENSOR__MEMORY_1ga7c7d2aaac9e49294304e755e6f341d7 @@ -997,8 +1017,9 @@ struct AsyncTMACopyGlobalToLocalOpConversion int numWarpsToCopy = std::min(numCopies - copyIdx, numWarps); if (numWarpsToCopy == 1) warpID = i32_val(0); - Value boxPred = - and_(pred, icmp_ult(id, i32_val(numWarpsToCopy * warpSize))); + Value boxPred = and_( + pred, + icmp_ult(id, i32_val(numWarpsToCopy * warpSize + firstThreadId))); ::mlir::triton::PTXBuilder ptxBuilderTMA; Type elemPtrTy = ptr_ty(rewriter.getContext(), 3); Value copyIdxVal = add(warpID, i32_val(copyIdx)); @@ -1037,6 +1058,14 @@ struct AsyncTMACopyGlobalToLocalOpConversion } }; +int getWarpOffset(Operation *op) { + auto asyncTaskIds = getAsyncTaskIds(op); + if (asyncTaskIds.size() > 0) { + return 4 * *std::min_element(asyncTaskIds.begin(), asyncTaskIds.end()); + } + return 0; +} + struct AsyncTMACopyLocalToGlobalOpConversion : public ConvertOpToLLVMPattern< triton::nvidia_gpu::AsyncTMACopyLocalToGlobalOp> { @@ -1082,6 +1111,9 @@ struct AsyncTMACopyLocalToGlobalOpConversion int numWarpsToCopy = std::min(numCopies - copyIdx, numWarps); if (numWarpsToCopy == 1) warpID = i32_val(0); + auto warpOffset = getWarpOffset(op); + warpID = sub(warpID, i32_val(warpOffset)); + id = sub(id, i32_val(warpOffset * warpSize)); Value boxPred = and_(pred, icmp_ult(id, i32_val(numWarpsToCopy * warpSize))); ::mlir::triton::PTXBuilder ptxBuilderTMA; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp index 93ad46971..8bc55e187 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/SPMDOpToLLVM.cpp @@ -1,5 +1,6 @@ #include "PatternTritonGPUOpToLLVM.h" #include "Utility.h" +#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h" namespace { @@ -33,10 +34,23 @@ struct GetNumProgramsOpConversion } }; +struct GetCanonicalWarpIdConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern< + triton::nvidia_gpu::GetCanonicalWarpIdOp>::ConvertOpToLLVMPattern; + LogicalResult + matchAndRewrite(triton::nvidia_gpu::GetCanonicalWarpIdOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOp(op, GetCanonicalWarpId(rewriter, op->getLoc())); + return success(); + } +}; } // namespace void mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, benefit); + patterns.add(typeConverter, benefit); } diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 0a35176ec..1e1e7c488 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -96,6 +96,13 @@ struct ConvertTritonGPUToLLVM int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + // Hack: WSLowering may have changed the effective number of warps, + // in a way that isn't reflected in triton_gpu.num-warps. If so, we have to + // respect that here. + if (Attribute attr = mod->getAttr("triton_gpu.num-warp-groups-per-cta")) { + numWarps *= cast(attr).getInt(); + } + // Allocate shared memory and set barrier ModuleAllocation allocation(mod); ModuleMembarAnalysis membarPass(&allocation, NVIDIA::canSkipBarSync); @@ -175,6 +182,8 @@ struct ConvertTritonGPUToLLVM patterns, benefit); mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo, patterns, benefit); + mlir::triton::populateRegReallocOpToLLVMPatterns(typeConverter, patterns, + benefit); if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) return signalPassFailure();