Skip to content

Commit

Permalink
[warpspec] Add experimental support for warp specialization with user…
Browse files Browse the repository at this point in the history
… annotations

This commit is a squash generated by:
```
git diff --stat b62b221a...9755e293a -- . ':(exclude)python/gemmbench' ':(exclude)python/hstuBench' ':(exclude)third_party/proton'
```

Additional modifications:

- Update README.md with Warp Specialization Support
- Propagate mma layout to following elementwise operations.
  • Loading branch information
bertmaher authored and htyu committed Nov 18, 2024
1 parent f4c48a9 commit 56df264
Show file tree
Hide file tree
Showing 58 changed files with 6,819 additions and 88 deletions.
211 changes: 203 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,211 @@
<div align="center">
<img src="https://lh5.googleusercontent.com/wzQKEsTFkrgNQO9JjhGH5wFvslJr1saLtLaJ_a6Fp_gNENpvt3VG7BmztwngU9hFJaU4CPwGiw1opQtDvTkLrxWRbO_a12Q-pdESWHgtmheIHcPbOL5ZMC4TSiJVe5ty1w=w3517" alt="Triton logo">
</div>

The Triton Conference is happening again on September 17th, 2024 in Fremont (CA)!
# Warp Specialization Support


Warp specialization enhances kernel performance by utilizing an asynchronous execution model, where different parts of the kernel are handled by separate hardware units. The data communication between these units, via shared memory on the H100, operates with high efficiency. With this in mind, we’ve developed a Triton DSL extension that allows users to partition their kernel into asynchronous tasks (which map to warp groups on NVIDIA GPU), which naturally execute concurrently, leveraging the hardware’s multitasking warp scheduler. The following sections provide a breakdown of the compiler features developed to enable warp specialization.


## Asynchronous Tasks

Warp specialization is built on top of the concept of partitioning the user’s program into asynchronous tasks (referred to as "async tasks" or “tasks” in the following sections). Each async task will be executed by a standalone warp group on the supported hardware, to achieve instruction level parallelism. Optimally and automatically partitioning async tasks is quite a challenge for the compiler. As a result, the Triton DSL has been extended to allow users to perform manual partitioning.

The language extension is built around the Python context manager, designed to be simple and intuitive. Such extension is platform-agnostic, i.e., on platforms where warp specialization is not supported, the user annotation will be ignored, with no impact on correctness and performance.

For instance, a warp-specialized GEMM implementation might look like this:

```python
@triton.jit
def matmul_persistent_ws_kernel(
a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
pid_m = pid // num_pid_m
pid_n = pid % num_pid_n
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_K)):
# Use tl.async_task to specify warp-specialized code
with tl.async_task([0]):
a = tl.load(a_ptrs)
b = tl.load(b_ptrs)
acc += tl.dot(a, b)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = acc.to(tl.float16)
c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :]
# Use tl.async_task to specify warp-specialized code
with tl.async_task([1]):
tl.store(c_ptrs, c)
```

By wrapping a code block within the **tl.async_task** statement, the user specifies that the block will be executed by a certain number of warp groups, as defined by the statement. In the example above, the load operations are assigned to task 0, while the store operations are handled by task 1. Operations that are explicitly specified with a task id are known as anchor operations, and they affect the task assignment for the remaining operations.


The non-anchor operations are assigned to a task by the compiler in the following way:

- Control dependencies exclusive to an anchor operation are included in the same task as the anchor operation.
- Data dependencies exclusive to an anchor operation are included in the same task as the anchor operation, unless they are another anchor operation.
- Control or data dependencies shared between tasks are included in all those tasks.

For the GEMM example above, the compiler computes a task scheme and annotates it in the IR using MLIR attributes. To illustrate this more clearly, let's use source code annotations. After task propagation:

```python
@triton.jit
def matmul_persistent_ws_kernel(
a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0) # async_task 0, 1
num_pid_m = tl.cdiv(M, BLOCK_M) # async_task 0, 1
num_pid_n = tl.cdiv(N, BLOCK_N) # async_task 0, 1
pid_m = pid // num_pid_m # async_task 0, 1
pid_n = pid % num_pid_n # async_task 0, 1
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) # async_task 0, 1
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) # async_task 0, 1
offs_k = tl.arange(0, BLOCK_K) # async_task 0
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) # async_task 0
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) # async_task 1
for k in range(0, tl.cdiv(K, BLOCK_K)): # async_task 0, 1
a = tl.load(a_ptrs) # async_task 0
b = tl.load(b_ptrs) # async_task 0
acc += tl.dot(a, b) # async_task 1
a_ptrs += BLOCK_K * stride_ak # async_task 0
b_ptrs += BLOCK_K * stride_bk # async_task 0
c = acc.to(tl.float16) # async_task 1
c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :] # async_task 1
tl.store(c_ptrs, c) # async_task 1
```

## Data Partitioning

To further improve performance, the user may choose to split the same workload across two async tasks This way, when one task is blocked on a heavy computation (e.g., the dot operation), the other group can execute other operations in parallel. This can be easily achieved by annotating the store operation with two tasks:

```python
with tl.async_task([1,2]):
tl.store(c_ptr)
```

The compiler determines how to divide the work between the two tasks to maximize performance. On the H100 GPU, the compiler will, by default, attempt to split the input tensor A along the M dimension so that each consumer computes half of the output tensor independently. This approach is known as cooperative partitioning. If this split is not advantageous—for instance, if it results in a smaller-than-native `wgmma` instruction—the compiler will instead attempt to split along the N dimension.

The transformed code for the above GEMM kernel with a configured tile size [128, 256, 64] will look like below (using source annotations instead of IR for illustration)

```python
@triton.jit
def matmul_persistent_ws_kernel(
a_ptr, b_ptr, c_ptr, M, N, K,
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
pid = tl.program_id(axis=0) # async_task 0, 1, 2
num_pid_m = tl.cdiv(M, BLOCK_M) # async_task 0, 1, 2
num_pid_n = tl.cdiv(N, BLOCK_N) # async_task 0, 1, 2
pid_m = pid // num_pid_m # async_task 0, 1, 2
pid_n = pid % num_pid_n # async_task 0, 1, 2
offs_m_1 = pid_m * BLOCK_M + tl.arange(0, BLOCK_M // 2) # async_task 0, 1, 2
offs_m_2 = pid_m * BLOCK_M + tl.arange(BLOCK_M // 2, BLOCK_M) # async_task 0, 1, 2
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_N) # async_task 0, 1, 2
offs_k = tl.arange(0, BLOCK_K) # async_task 0
a_ptrs_1 = a_ptr + (offs_m_1[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0
a_ptrs_2 = a_ptr + (offs_m_2[:, None] * stride_am + offs_k[None, :] * stride_ak) # async_task 0
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn) # async_task 0
acc_1 = tl.zeros((BLOCK_M // 2, BLOCK_N), dtype=tl.float32) # async_task 1
acc_1 = tl.zeros((BLOCK_M // 2, BLOCK_N), dtype=tl.float32) # async_task 2
for k in range(0, tl.cdiv(K, BLOCK_K)): # async_task 0, 1, 2
a_1 = tl.load(a_ptrs_1) # async_task 0
a_2 = tl.load(a_ptrs_2) # async_task 0
b = tl.load(b_ptrs) # async_task 0
acc_1 += tl.dot(a_1, b) # async_task 1
acc_2 += tl.dot(a_2, b) # async_task 2
a_ptrs_1 += BLOCK_K * stride_ak # async_task 0
a_ptrs_2 += BLOCK_K * stride_ak # async_task 0
b_ptrs += BLOCK_K * stride_bk # async_task 0
c_1 = acc_1.to(tl.float16) # async_task 1
c_2 = acc_2.to(tl.float16) # async_task 2
c_ptrs_1 = c_ptr_1 + stride_cm * offs_m_1[:, None] + stride_cn * offs_n[None, :] # async_task 1
c_ptrs_2 = c_ptr_2 + stride_cm * offs_m_2[:, None] + stride_cn * offs_n[None, :] # async_task 2
tl.store(c_ptrs_1, c_1) # async_task 1
tl.store(c_ptrs_2, c_2) # async_task 2
```

## Code Partitioning

We assume all operations are already marked with a list of taskIds. We first find all communications required between warp groups. Each communication starts from a load operation with a single taskId, and ends at a direct user of the load which belongs to a different taskId. For `ForOps` containing a communication channel, we add additional arguments: `phase` and `bufferIndex`.

We introduce a tuning configuration: `num_buffers_warp_spec`. For each communication channel, if it is within a `forOp`, we use an array of buffers in SMEM to save the results, and size of the array is determined by `num_buffers_warp_spec`. We also use an array of barriers for each communication channel that is inside a `ForOp`. At this pass, four new operations are introduced to correctly synchronize between the producer and the consumer: `ProducerAcquireOp`, `ProducerCommitOp`, `ConsumerWaitOp`, and `ConsumerReleaseOp`. Each of the four new ops take a token, a buffer Index. `ProducerAcquire` and `ConsumerWait` take an additional phase operand.


For `ForOps` with multiple task Ids, we clone one copy for each taskId, each copy contains the operations with the specific taskId. In the end, we create multiple `IfOps`, one for each possible taskId. We go through the body of the function, clone the op for each attached task Id and put the cloned op in the right `IfOp`.

To adjust register usage, we introduce two new ops: `RegAllocOp` and `RegDeallocOp`, both taking an integer operand. For each warp group, we decide to insert either `RegAllocOp` or `RegDeallocOp`. The current heuristic is simple: if the task Id is 0, we add `RegDeallocOp`, otherwise we use `RegAllocOp`. The amount of register adjustment can be tuned via `reg_dec_producer` and `reg_inc_consumer`.

This pass also lowers `loadOp`s to `AsyncTMACopyGlobalToLocalOp` or `AsyncCopyGlobalToLocalOp`, so the communication can be expressed via SMEM. For TMA, the producer will become
`ProducerAcquire` -> `barrier_expect` -> `AsyncTMACopyGlobalToLocalOp`, and the consumer will contain `wait_barrier` -> ops -> `ConsumerRelease`. For non-TMA loads, the producer will become `ProducerAcquire` -> `AsyncCopyGlobalToLocalOp` -> `ProducerCommitOp`, and the consumer will contain `ConsumerWaitOp` -> ops -> `ConsumerRelease`.


# Performance

## Flash Attention

```
(Batch, Heads, SeqLen, Dhead) triton_tutorial_flash_v2_ws-tflops triton_tutorial_flash_v2_tma_ws-tflops triton_tutorial_flash_v2-tflops
------------------------------- ------------------------------------ ---------------------------------------- ---------------------------------
(8, 16, 8192, 128) 548.783 561.539 482.664
```

Benchmarking instructions:
```
git clone //github.com/pytorch-labs/tritonbench.git
cd tritonbench
TORCH_CUDA_ARCH_LIST=9.0a python run.py --op flash_attention --only triton_tutorial_flash_v2_ws,triton_tutorial_flash_v2_tma_ws,triton_tutorial_flash_v2 --num-inputs 1 --seq-len 13 --metrics tflops --batch 8 --n-heads 16 --d-head 128
```

## FP8 Rowwise GEMM with FP8_FAST_ACC=False


With warp specialization:

```
(M, N, K) _triton-tflops cutlass-tflops cutlass-speedup cutlass-accuracy cublas_12.1-tflops cublas_12.1-speedup cublas_12.1-accuracy
------------------ ---------------- ---------------- ----------------- ------------------ -------------------- --------------------- ----------------------
(8192, 8192, 8192) 1028.33 380.182 0.369707 1 1033.47 1.00499 1
```

Without warp specialization:

```
(M, N, K) _triton-tflops cutlass-tflops cutlass-speedup cutlass-accuracy cublas_12.1-tflops cublas_12.1-speedup cublas_12.1-accuracy
------------------ ---------------- ---------------- ----------------- ------------------ -------------------- --------------------- ----------------------
(8192, 8192, 8192) 919.595 384.423 0.418035 1 1035.43 1.12596 1
```


Benchmarking instructions:

```
git clone //github.com/pytorch-labs/tritonbench.git
cd tritonbench
git checkout hoy/fp8gemmWS
python install.py --fbgemm
python run.py --op fp8_gemm_rowwise --m 8192 --n 8192 --k 8192 --no_fp8_fast_accum --warp_specialization
python run.py --op fp8_gemm_rowwise --m 8192 --n 8192 --k 8192 --no_fp8_fast_accum
```



If you are interested in attending, please fill up [this form](https://docs.google.com/forms/d/e/1FAIpQLSecHC1lkalcm0h3JDUbspekDX5bmBvMxgVTLaK3e-61bzDDbg/viewform).


| **`Documentation`** | **`Nightly Wheels`** |
|-------------------- | -------------------- |
| [![Documentation](https://github.com/triton-lang/triton/actions/workflows/documentation.yml/badge.svg)](https://triton-lang.org/) | [![Wheels](https://github.com/triton-lang/triton/actions/workflows/wheels.yml/badge.svg?branch=release/2.0.x)](https://github.com/triton-lang/triton/actions/workflows/wheels.yml) |

# General information about Triton

# Triton

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateRegReallocOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
PatternBenefit benefit);

} // namespace triton
} // namespace mlir

Expand Down
42 changes: 42 additions & 0 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
#include "triton/Dialect/TritonNvidiaGPU/Transforms/Utility.h"
#include "triton/Tools/LinearLayout.h"
#include "triton/Tools/StrUtil.h"
#include "triton/Tools/Sys/GetEnv.hpp"
Expand Down Expand Up @@ -144,6 +145,20 @@ using namespace mlir::triton;
namespace mlir {
namespace triton {

static inline void insertBarrier(PatternRewriter &rewriter, Operation *op) {
auto barrierOp = rewriter.create<mlir::gpu::BarrierOp>(op->getLoc());
auto asyncTaskIds = getAsyncTaskIds(op);
if (asyncTaskIds.size() == 1) {
int asyncTaskId = asyncTaskIds[0];
int barId = asyncTaskId + nameBarrierIdBegin;
assert(barId < nameBarrierIdEnd);
// TODO: Change hard code style of numThreads.
const int numThreads = 128;
barrierOp->setAttr("bar_id", rewriter.getI64IntegerAttr(barId));
barrierOp->setAttr("num_threads", rewriter.getI64IntegerAttr(numThreads));
}
}

// Delinearize supposing order is [0, 1, .. , n]
template <typename T>
llvm::SmallVector<T> getMultiDimIndexImpl(T linearIndex,
Expand Down Expand Up @@ -371,6 +386,20 @@ inline Value getStackPointer(RewriterBase &rewriter,
return funcOp.getArgument(funcOp.getNumArguments() - 1);
}

static Operation *getWarpGroupId(Operation *op) {
auto funcOp = op->getParentOfType<FunctionOpInterface>();
Operation *getWarpId = nullptr;
funcOp.walk([&](Operation *op) -> void {
if (isa<mlir::triton::nvidia_gpu::GetCanonicalWarpIdOp>(op)) {
assert(getWarpId == nullptr);
getWarpId = op;
}
});
assert(getWarpId);
getWarpId->dump();
return getWarpId;
}

inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
Operation *op) {
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
Expand All @@ -381,6 +410,19 @@ inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
.getValue()
.getZExtValue();
Value offVal = i32_val(offset);
if (op->hasAttr("allocation.copy")) {
auto copy = cast<IntegerAttr>(op->getAttr("allocation.copy")).getValue().getZExtValue();
if (copy != 1) {
Operation *getWarpId = getWarpGroupId(op);
Value warpsPerWG = i32_val(4);
Value wgId = udiv(getWarpId->getResult(0), warpsPerWG);
// (wgId - 1) * allocation.size + offset
auto singleSize = cast<IntegerAttr>(op->getAttr("allocation.size")).getValue().getZExtValue();
Value sub1 = sub(wgId, i32_val(1));
Value temp = mul(sub1, i32_val(singleSize));
offVal = add(temp, offVal);
}
}
Value base = gep(ptrTy, i8_ty, LLVM::getStackPointer(rewriter, func), offVal);
return base;
}
Expand Down
Loading

0 comments on commit 56df264

Please sign in to comment.