diff --git a/kernels/__init__.py b/kernels/__init__.py index dd492d3..8335b20 100644 --- a/kernels/__init__.py +++ b/kernels/__init__.py @@ -2,6 +2,7 @@ from . import blocksparse from .cross_entropy import _cross_entropy, cross_entropy from .flash_attention import attention +from .sage_attention import sageattn from .matmul import _matmul, get_higher_dtype, matmul __all__ = [ @@ -11,5 +12,6 @@ "_matmul", "matmul", "attention", + "sageattn", "get_higher_dtype", ] diff --git a/kernels/sage_attention.py b/kernels/sage_attention.py new file mode 100644 index 0000000..f99d974 --- /dev/null +++ b/kernels/sage_attention.py @@ -0,0 +1,287 @@ +import torch, math +import triton +import triton.language as tl + +@triton.jit +def quant_per_block_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + sm_scale, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) + off_h = tl.program_id(1) + off_b = tl.program_id(2) + offs_n = off_blk * BLK + tl.arange(0, BLK) + offs_k = tl.arange(0, C) + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + x *= sm_scale + scale = tl.max(tl.abs(x)) / 127. + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_block_int8(q, k, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ, 1), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK, 1), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b) + quant_per_block_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + sm_scale=(sm_scale * 1.44269504), + C=head_dim, BLK=BLKQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b) + quant_per_block_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + sm_scale=1.0, + C=head_dim, BLK=BLKK + ) + + return q_int8, q_scale, k_int8, k_scale + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, + K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, + ): + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_scale_ptr += lo // BLOCK_N + K_ptrs += stride_kn * lo + V_ptrs += stride_vn * lo + elif STAGE == 3: + lo, hi = 0, kv_len + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_mask = offs_n[None, :] < (kv_len - start_n) + k = tl.load(K_ptrs, mask = k_mask) + k_scale = tl.load(K_scale_ptr) + qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale + + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + v = tl.load(V_ptrs, mask = offs_n[:, None] < (kv_len - start_n)) + p = p.to(tl.float16) + acc += tl.dot(p, v, out_dtype=tl.float16) + m_i = m_ij + K_ptrs += BLOCK_N * stride_kn + K_scale_ptr += 1 + V_ptrs += BLOCK_N * stride_vn + return acc, l_i, m_i + +@triton.jit +def _attn_fwd(Q, K, V, Q_scale, K_scale, Out, + stride_qz, stride_qh, stride_qn, + stride_kz, stride_kh, stride_kn, + stride_vz, stride_vh, stride_vn, + stride_oz, stride_oh, stride_on, + qo_len, kv_len, H:tl.constexpr, num_kv_groups:tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr + ): + start_m = tl.program_id(0) + off_z = tl.program_id(2).to(tl.int64) + off_h = tl.program_id(1).to(tl.int64) + q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M) + k_scale_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * tl.cdiv(kv_len, BLOCK_N) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, HEAD_DIM) + Q_ptrs = Q + (off_z * stride_qz + off_h * stride_qh) + offs_m[:, None] * stride_qn + offs_k[None, :] + Q_scale_ptr = Q_scale + q_scale_offset + start_m + K_ptrs = K + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + offs_n[None, :] * stride_kn + offs_k[:, None] + K_scale_ptr = K_scale + k_scale_offset + V_ptrs = V + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + offs_n[:, None] * stride_vn + offs_k[None, :] + O_block_ptr = Out + (off_z * stride_oz + off_h * stride_oh) + offs_m[:, None] * stride_on + offs_k[None, :] + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + q = tl.load(Q_ptrs, mask = offs_m[:, None] < qo_len) + q_scale = tl.load(Q_scale_ptr) + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + BLOCK_M, HEAD_DIM, BLOCK_N, + 4 - STAGE, offs_m, offs_n + ) + if STAGE != 1: + acc, l_i, _ = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + BLOCK_M, HEAD_DIM, BLOCK_N, + 2, offs_m, offs_n + ) + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < qo_len)) + + +def forward(q, k, v, q_scale, k_scale, is_casual, tensor_layout="HND", output_dtype=torch.float16): + BLOCK_M = 128 + BLOCK_N = 64 + stage = 3 if is_casual else 1 + + o = torch.empty(q.shape, dtype=output_dtype, device=q.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2) + stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1) + stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1) + else: + raise ValueError(f"tensor_layout {tensor_layout} not supported") + + assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention" + + HEAD_DIM_K = head_dim + num_kv_groups = h_qo // h_kv + + grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b ) + _attn_fwd[grid]( + q, k, v, q_scale, k_scale, o, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_v, stride_h_v, stride_seq_v, + stride_bz_o, stride_h_o, stride_seq_o, + qo_len, kv_len, + h_qo, num_kv_groups, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, + STAGE=stage, + num_warps=4 if head_dim == 64 else 8, + num_stages=4) + return o + + +from typing import Any, List, Literal, Optional, Tuple, Union + +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str ="HND", + is_causal=False, + sm_scale: Optional[float] = None, + smooth_k: bool =True, + **kwargs: Any, +) -> torch.Tensor: + """ + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + + dtype = q.dtype + headdim = q.size(-1) + assert headdim in [64, 128], "headdim should be in [64, 128]." + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous." + seq_dim = 1 if tensor_layout == "NHD" else 2 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + k -= km + else: + km = None + if dtype == torch.bfloat16: + v = v.to(torch.float16) + + q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k, sm_scale=sm_scale, tensor_layout=tensor_layout) + o = forward(q_int8, k_int8, v, q_scale, k_scale, is_causal = is_causal, tensor_layout=tensor_layout, output_dtype=dtype) + + return o + diff --git a/test/test_sage_attention.py b/test/test_sage_attention.py new file mode 100644 index 0000000..716c6bf --- /dev/null +++ b/test/test_sage_attention.py @@ -0,0 +1,349 @@ +import torch, math +import triton +import triton.language as tl +import torch.nn.functional as F + +@triton.jit +def quant_per_block_int8_kernel(Input, Output, Scale, L, + stride_iz, stride_ih, stride_in, + stride_oz, stride_oh, stride_on, + stride_sz, stride_sh, + sm_scale, + C: tl.constexpr, BLK: tl.constexpr): + off_blk = tl.program_id(0) + off_h = tl.program_id(1) + off_b = tl.program_id(2) + offs_n = off_blk * BLK + tl.arange(0, BLK) + offs_k = tl.arange(0, C) + input_ptrs = Input + off_b * stride_iz + off_h * stride_ih + offs_n[:, None] * stride_in + offs_k[None, :] + output_ptrs = Output + off_b * stride_oz + off_h * stride_oh + offs_n[:, None] * stride_on + offs_k[None, :] + scale_ptrs = Scale + off_b * stride_sz + off_h * stride_sh + off_blk + x = tl.load(input_ptrs, mask=offs_n[:, None] < L) + x = x.to(tl.float32) + x *= sm_scale + scale = tl.max(tl.abs(x)) / 127. + x_int8 = x / scale + x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1) + x_int8 = x_int8.to(tl.int8) + tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L) + tl.store(scale_ptrs, scale) + +def per_block_int8(q, k, BLKQ=128, BLKK=64, sm_scale=None, tensor_layout="HND"): + q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device) + k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device) + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(1), q_int8.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(1), k_int8.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_qo, stride_h_qo, stride_seq_qo = q_int8.stride(0), q_int8.stride(2), q_int8.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_ko, stride_h_ko, stride_seq_ko = k_int8.stride(0), k_int8.stride(2), k_int8.stride(1) + else: + raise ValueError(f"Unknown tensor layout: {tensor_layout}") + + q_scale = torch.empty((b, h_qo, (qo_len + BLKQ - 1) // BLKQ, 1), device=q.device, dtype=torch.float32) + k_scale = torch.empty((b, h_kv, (kv_len + BLKK - 1) // BLKK, 1), device=q.device, dtype=torch.float32) + + if sm_scale is None: + sm_scale = head_dim**-0.5 + + grid = ((qo_len + BLKQ - 1) // BLKQ, h_qo, b) + quant_per_block_int8_kernel[grid]( + q, q_int8, q_scale, qo_len, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_qo, stride_h_qo, stride_seq_qo, + q_scale.stride(0), q_scale.stride(1), + sm_scale=(sm_scale * 1.44269504), + C=head_dim, BLK=BLKQ + ) + + grid = ((kv_len + BLKK - 1) // BLKK, h_kv, b) + quant_per_block_int8_kernel[grid]( + k, k_int8, k_scale, kv_len, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_ko, stride_h_ko, stride_seq_ko, + k_scale.stride(0), k_scale.stride(1), + sm_scale=1.0, + C=head_dim, BLK=BLKK + ) + + return q_int8, q_scale, k_int8, k_scale + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, + K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, + STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, + ): + if STAGE == 1: + lo, hi = 0, start_m * BLOCK_M + elif STAGE == 2: + lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M + lo = tl.multiple_of(lo, BLOCK_M) + K_scale_ptr += lo // BLOCK_N + K_ptrs += stride_kn * lo + V_ptrs += stride_vn * lo + elif STAGE == 3: + lo, hi = 0, kv_len + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_mask = offs_n[None, :] < (kv_len - start_n) + k = tl.load(K_ptrs, mask = k_mask) + k_scale = tl.load(K_scale_ptr) + qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale + + if STAGE == 2: + mask = offs_m[:, None] >= (start_n + offs_n[None, :]) + qk = qk + tl.where(mask, 0, -1.0e6) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + else: + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + v = tl.load(V_ptrs, mask = offs_n[:, None] < (kv_len - start_n)) + p = p.to(tl.float16) + acc += tl.dot(p, v, out_dtype=tl.float16) + m_i = m_ij + K_ptrs += BLOCK_N * stride_kn + K_scale_ptr += 1 + V_ptrs += BLOCK_N * stride_vn + return acc, l_i, m_i + +@triton.jit +def _attn_fwd(Q, K, V, Q_scale, K_scale, Out, + stride_qz, stride_qh, stride_qn, + stride_kz, stride_kh, stride_kn, + stride_vz, stride_vh, stride_vn, + stride_oz, stride_oh, stride_on, + qo_len, kv_len, H:tl.constexpr, num_kv_groups:tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + STAGE: tl.constexpr + ): + start_m = tl.program_id(0) + off_z = tl.program_id(2).to(tl.int64) + off_h = tl.program_id(1).to(tl.int64) + q_scale_offset = (off_z * H + off_h) * tl.cdiv(qo_len, BLOCK_M) + k_scale_offset = (off_z * (H // num_kv_groups) + off_h // num_kv_groups) * tl.cdiv(kv_len, BLOCK_N) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, HEAD_DIM) + Q_ptrs = Q + (off_z * stride_qz + off_h * stride_qh) + offs_m[:, None] * stride_qn + offs_k[None, :] + Q_scale_ptr = Q_scale + q_scale_offset + start_m + K_ptrs = K + (off_z * stride_kz + (off_h // num_kv_groups) * stride_kh) + offs_n[None, :] * stride_kn + offs_k[:, None] + K_scale_ptr = K_scale + k_scale_offset + V_ptrs = V + (off_z * stride_vz + (off_h // num_kv_groups) * stride_vh) + offs_n[:, None] * stride_vn + offs_k[None, :] + O_block_ptr = Out + (off_z * stride_oz + off_h * stride_oh) + offs_m[:, None] * stride_on + offs_k[None, :] + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32) + q = tl.load(Q_ptrs, mask = offs_m[:, None] < qo_len) + q_scale = tl.load(Q_scale_ptr) + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + BLOCK_M, HEAD_DIM, BLOCK_N, + 4 - STAGE, offs_m, offs_n + ) + if STAGE != 1: + acc, l_i, _ = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, kv_len, K_ptrs, K_scale_ptr, V_ptrs, stride_kn, stride_vn, + start_m, + BLOCK_M, HEAD_DIM, BLOCK_N, + 2, offs_m, offs_n + ) + acc = acc / l_i[:, None] + tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < qo_len)) + + +def forward(q, k, v, q_scale, k_scale, is_casual, tensor_layout="HND", output_dtype=torch.float16): + BLOCK_M = 128 + BLOCK_N = 64 + stage = 3 if is_casual else 1 + + o = torch.empty(q.shape, dtype=output_dtype, device=q.device) + + if tensor_layout == "HND": + b, h_qo, qo_len, head_dim = q.shape + _, h_kv, kv_len, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(1), q.stride(2) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(1), k.stride(2) + stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(1), v.stride(2) + stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(1), o.stride(2) + elif tensor_layout == "NHD": + b, qo_len, h_qo, head_dim = q.shape + _, kv_len, h_kv, _ = k.shape + stride_bz_q, stride_h_q, stride_seq_q = q.stride(0), q.stride(2), q.stride(1) + stride_bz_k, stride_h_k, stride_seq_k = k.stride(0), k.stride(2), k.stride(1) + stride_bz_v, stride_h_v, stride_seq_v = v.stride(0), v.stride(2), v.stride(1) + stride_bz_o, stride_h_o, stride_seq_o = o.stride(0), o.stride(2), o.stride(1) + else: + raise ValueError(f"tensor_layout {tensor_layout} not supported") + + assert qo_len == kv_len, "qo_len and kv_len must be equal for causal attention" + + HEAD_DIM_K = head_dim + num_kv_groups = h_qo // h_kv + + grid = (triton.cdiv(qo_len, BLOCK_M), h_qo, b ) + _attn_fwd[grid]( + q, k, v, q_scale, k_scale, o, + stride_bz_q, stride_h_q, stride_seq_q, + stride_bz_k, stride_h_k, stride_seq_k, + stride_bz_v, stride_h_v, stride_seq_v, + stride_bz_o, stride_h_o, stride_seq_o, + qo_len, kv_len, + h_qo, num_kv_groups, + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, + STAGE=stage, + num_warps=4 if head_dim == 64 else 8, + num_stages=4) + return o + + +from typing import Any, List, Literal, Optional, Tuple, Union +def sageattn( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + tensor_layout: str ="HND", + is_causal=False, + sm_scale: Optional[float] = None, + smooth_k: bool =True, + **kwargs: Any, +) -> torch.Tensor: + """ + Parameters + ---------- + q : torch.Tensor + The query tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + k : torch.Tensor + The key tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + v : torch.Tensor + The value tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_kv_heads, kv_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, kv_len, num_kv_heads, head_dim]``. + tensor_layout : str + The tensor layout, either "HND" or "NHD". + Default: "HND". + is_causal : bool + Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len. + Default: False. + sm_scale : Optional[float] + The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``. + smooth_k : bool + Whether to smooth the key tensor by subtracting the mean along the sequence dimension. + Default: True. + Returns + ------- + torch.Tensor + The output tensor. Shape: + - If `tensor_layout` is "HND": ``[batch_size, num_qo_heads, qo_len, head_dim]``. + - If `tensor_layout` is "NHD": ``[batch_size, qo_len, num_qo_heads, head_dim]``. + Note + ---- + - ``num_qo_heads`` must be divisible by ``num_kv_heads``. + - The tensors `q`, `k`, and `v` must have the dtype ``torch.float16`` or ``torch.bfloat16`` + - All tensors must be on the same cuda device. + """ + + dtype = q.dtype + headdim = q.size(-1) + assert headdim in [64, 128], "headdim should be in [64, 128]." + + # assert last dim is contiguous + assert q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1, "Last dim of qkv must be contiguous." + seq_dim = 1 if tensor_layout == "NHD" else 2 + + if smooth_k: + km = k.mean(dim=seq_dim, keepdim=True) + k -= km + else: + km = None + if dtype == torch.bfloat16: + v = v.to(torch.float16) + + q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k, sm_scale=sm_scale, tensor_layout=tensor_layout) + o = forward(q_int8, k_int8, v, q_scale, k_scale, is_causal = is_causal, tensor_layout=tensor_layout, output_dtype=dtype) + + return o + + +def precision_metric(quant_o, fa2_o): + x, xx = quant_o.float(), fa2_o.float() + sim = F.cosine_similarity(x.reshape(1, -1), xx.reshape(1, -1)).item() + l1 = ( (x - xx).abs().sum() / xx.abs().sum() ).item() + rmse = torch.sqrt(torch.mean((x -xx) ** 2)).item() + print(f'Cossim: {sim:.6f}, L1: {l1:.6f}, RMSE:{rmse:.6f}\n') + +def test_accuracy(N_CTX, dtype, causal, tensor_layout="HND"): + q = (torch.empty((BATCH, N_HEADS, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((BATCH, N_HEADS, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((BATCH, N_HEADS, N_CTX, HEAD_DIM), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + torch_stand = torch.nn.functional.scaled_dot_product_attention(q, k, v,is_causal=causal) + q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k, tensor_layout=tensor_layout) + quant_o = forward(q_int8, k_int8, v, q_scale, k_scale, causal) + print(torch.allclose(torch_stand, quant_o, atol=1e-2, rtol=1e-3)) + precision_metric(quant_o, torch_stand) + +BATCH, N_HEADS, HEAD_DIM = 4, 32, 64 +configs = [] +for causal in [True, False]: + configs.append( + triton.testing.Benchmark( + x_names=["N_CTX"], + x_vals=[2**i for i in range(10, 16)], + line_arg="provider", + line_vals=(["triton-sageattn"]), + line_names=(["triton-sageattn"]), + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"SageAttention-batch{BATCH}-head{N_HEADS}-d{HEAD_DIM}-causal={causal}", + args={ + "H": N_HEADS, + "BATCH": BATCH, + "HEAD_DIM": HEAD_DIM, + "causal": causal, + }, + )) +@triton.testing.perf_report(configs) +def bench_sage_attention(BATCH, H, N_CTX, HEAD_DIM, causal, provider, device="cuda"): + warmup = 25 + rep = 100 + dtype = torch.float16 + tensor_layout = "HND" + if "triton" in provider: + q = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True).contiguous() + k = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True).contiguous() + v = torch.randn((BATCH, H, N_CTX, HEAD_DIM), dtype=dtype, device=device, requires_grad=True).contiguous() + q_int8, q_scale, k_int8, k_scale = per_block_int8(q, k, tensor_layout=tensor_layout) + fn = lambda: forward(q_int8, k_int8, v, q_scale, k_scale, causal) + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * HEAD_DIM + total_flops = 2 * flops_per_matmul + if causal: + total_flops *= 0.5 + return total_flops / ms * 1e-9 + + +if __name__ == "__main__": + test_accuracy(4000, torch.float16, False, ) + bench_sage_attention.run(print_data=True) + \ No newline at end of file