From e56913312f3e43ba0110c5caa391d4fb35fa2069 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Fri, 15 Mar 2024 13:15:54 -0400 Subject: [PATCH 1/7] Adding new rocm triton flash attention kernel Co-authored-by: Vinayak Gokhale --- Dockerfile.rocm | 14 + .../layers/attention/attention.py | 2 +- .../layers/attention/backends/flash_attn.py | 40 +- .../attention/ops/flash_attention_triton.py | 541 ++++++++++++++++++ 4 files changed, 586 insertions(+), 11 deletions(-) create mode 100644 vllm/model_executor/layers/attention/ops/flash_attention_triton.py diff --git a/Dockerfile.rocm b/Dockerfile.rocm index a45265d79a6ac..a7640f6841ad9 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -26,6 +26,9 @@ ARG BUILD_FA="1" # whether to build cupy on rocm ARG BUILD_CUPY="1" +# whether to build triton on rocm +ARG BUILD_TRITON="1" + # Install some basic utilities RUN apt-get update && apt-get install python3 python3-pip -y @@ -95,6 +98,17 @@ RUN if [ "$BUILD_CUPY" = "1" ]; then \ && cd ..; \ fi +# build triton +RUN if [ "$BUILD_TRITON" = "1"]; then \ + mkdir -p libs \ + && cd libs \ + && pip uninstall -y triton \ + && git clone https://github.com/ROCmSoftwarePlatform/triton.git + && cd triton/python \ + && pip3 install -e . \ + && cd ../..; \ + fi + COPY ./ /app/vllm RUN python3 -m pip install --upgrade pip diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 830e82e10f7ad..39b66fb1fb2db 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -30,7 +30,7 @@ def __init__( sliding_window: Optional[int] = None, ) -> None: super().__init__() - if (not is_hip() and torch.cuda.get_device_capability()[0] >= 8 and + if (torch.cuda.get_device_capability()[0] >= 8 and torch.get_default_dtype() in (torch.float16, torch.bfloat16)): # Ampere or later NVIDIA GPUs. # NOTE(woosuk): FlashAttention does not support FP32. diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index 512f4e49c7eb2..c7543e2d54d12 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -2,12 +2,21 @@ from typing import List, Optional # NOTE(woosuk): This imports flash_attn under vllm/thirdparty_files/. -from flash_attn import flash_attn_func +from vllm.utils import is_hip +try: + from flash_attn import flash_attn_func +except ImportError: + if is_hip(): + pass + else: + raise + import torch from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention.ops.paged_attn import ( PagedAttentionImpl) +from vllm.model_executor.layers.attention.ops.flash_attention_triton import attention class FlashAttentionBackend: @@ -86,15 +95,26 @@ def forward( query = query.unflatten(0, (batch_size, seq_len)) key = key.unflatten(0, (batch_size, seq_len)) value = value.unflatten(0, (batch_size, seq_len)) - output = flash_attn_func( - query, - key, - value, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) + if is_hip(): + output, _ = attention( + query, + key, + value, + None, + input_metadata, + True, + self.scale, + ) + else: + output = flash_attn_func( + query, + key, + value, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) else: # prefix-enabled attention output = PagedAttentionImpl.forward_prefix( diff --git a/vllm/model_executor/layers/attention/ops/flash_attention_triton.py b/vllm/model_executor/layers/attention/ops/flash_attention_triton.py new file mode 100644 index 0000000000000..37c15e0e6fa36 --- /dev/null +++ b/vllm/model_executor/layers/attention/ops/flash_attention_triton.py @@ -0,0 +1,541 @@ +#!/usr/bin/env python +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import torch +import triton +import triton.language as tl + +torch_dtype:tl.constexpr = torch.float16 + +TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') +if TORCH_HAS_FP8E5: + torch_dtype:tl.constexpr = torch.float8_e5m2fnuz + +@triton.jit +def cdiv_fn(x,y): + return (x + y - 1) // y + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_keep = rng_output > dropout_p + return rng_keep + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0,1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + +@triton.jit +def _attn_fwd_inner( + acc, l_i, m_i, q, + K_block_ptr, V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr +): + # loop over k, v, and update accumulator + for start_n in range (block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn(K_block_ptr, PADDED_HEAD, MASK_STEPS and (n_extra_tokens != 0), "zero") + if PRE_LOAD_V: + v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None,:] + mask = size_n < boundary_m[:,None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += (bias * 1.44269504089) + m_ij = tl.maximum(m_i, tl.max(qk,1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + if RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty)) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty)) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. Check why. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + ], + key=['hq', 'hk', 'IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], +) +@triton.jit +def attn_fwd( + Q, K, V, bias, sm_scale, L, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + stride_bz, stride_bh, stride_bm, stride_bn, + cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, encoded_softmax, + hq, hk, + ACTUAL_BLOCK_DMODEL:tl.constexpr, + MAX_SEQLENS_Q:tl.constexpr, MAX_SEQLENS_K:tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, + BLOCK_N + ) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1)) + l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # We store inf to LSE, not -inf because in the bwd pass, we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. + l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here too? + return + + is_mqa = hq != hk + off_h_k = off_h_q % hk if is_mqa else off_h_q + need_padding = False + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + need_padding = True + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + need_padding = True + n_extra_tokens = seqlen_k % BLOCK_N + padded_head = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + + # Compute pointers for all the tensors used in this kernel. + q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. In + # this case, we return an invalid pointer so indicate the mask is not valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0) + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, padded_head, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, bias_ptr, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks*BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks*BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks*BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, + IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head + ) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # This is a > check because mask being 0 blocks the store. + l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0,1)) + +def check_args(q, k, v, o, max_seqlens): + assert q.dim() == k.dim() and q.dim() == v.dim() + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + # TODO: Fix assert to check head size <=256 once supported + assert head_size <= 128 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None): + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + check_args(q, k, v, o, metadata.max_seq_len) + + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + unpadded_head_dims = {32, 64, 128} + if head_size not in unpadded_head_dims: + padded_d_model = None + for i in unpadded_head_dims: + if i > head_size: + padded_d_model = i + break + assert padded_d_model is not None + else: + padded_d_model = head_size + + + grid = lambda META: ( + triton.cdiv(metadata.max_seq_len, META['BLOCK_M']), + nheads_q, + batch + ) + + encoded_softmax = None + + M = torch.empty((batch, nheads_q, metadata.max_seq_len), device=q.device, dtype=torch.float32) + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = (bias.stride(0), bias.stride(1), + bias.stride(2), bias.stride(3)) + else: + bias_strides = (0,0,0,0) + + attn_fwd[grid]( + q, k, v, bias, sm_scale, M, o, + *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, + None, None, + dropout_p=0.0, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + hq=nheads_q, hk=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=metadata.max_seq_len, + MAX_SEQLENS_K=metadata.max_seq_len, + IS_CAUSAL=causal, + VARLEN=False, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False + ) + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = causal + ctx.dropout_p = 0.0 + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = False + return o, encoded_softmax + +attention = _attention.apply From 0e63661b0cdac0c60857f7cc277819c5b88ed2f6 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Tue, 19 Mar 2024 11:49:21 -0400 Subject: [PATCH 2/7] Small fix on dockerfile --- Dockerfile.rocm | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index a7640f6841ad9..080e5b04d28bc 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -99,11 +99,11 @@ RUN if [ "$BUILD_CUPY" = "1" ]; then \ fi # build triton -RUN if [ "$BUILD_TRITON" = "1"]; then \ +RUN if [ "$BUILD_TRITON" = "1" ]; then \ mkdir -p libs \ && cd libs \ && pip uninstall -y triton \ - && git clone https://github.com/ROCmSoftwarePlatform/triton.git + && git clone https://github.com/ROCmSoftwarePlatform/triton.git \ && cd triton/python \ && pip3 install -e . \ && cd ../..; \ From d4cb905dfec5abdbbe9f585ff4ee9da59efedbfe Mon Sep 17 00:00:00 2001 From: jpvillam Date: Tue, 19 Mar 2024 19:41:43 -0400 Subject: [PATCH 3/7] Rebase updates and PR review changes Added Flag for controlling triton vs default flow. More small changes to dockerfile --- Dockerfile.rocm | 2 +- .../layers/attention/attention.py | 47 ++++++++++++------- .../layers/attention/backends/flash_attn.py | 36 +++++++++----- .../attention/ops/flash_attention_triton.py | 33 ++++++------- 4 files changed, 70 insertions(+), 48 deletions(-) diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 080e5b04d28bc..e7f52307a6aa2 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -105,7 +105,7 @@ RUN if [ "$BUILD_TRITON" = "1" ]; then \ && pip uninstall -y triton \ && git clone https://github.com/ROCmSoftwarePlatform/triton.git \ && cd triton/python \ - && pip3 install -e . \ + && pip3 install . \ && cd ../..; \ fi diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 4b63b9eaf59a7..89b5816f7a47a 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.model_executor.input_metadata import InputMetadata from vllm.utils import is_hip +import os logger = init_logger(__name__) @@ -34,11 +35,12 @@ def __init__( sliding_window: Optional[int] = None, ) -> None: super().__init__() - if _use_flash_attn(): + if use_triton := _use_flash_attn(): from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501 self.backend = FlashAttentionBackend(num_heads, head_size, scale, num_kv_heads, alibi_slopes, - sliding_window) + sliding_window, + use_triton == 2) else: from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend # noqa: E501 self.backend = XFormersBackend(num_heads, head_size, scale, @@ -59,26 +61,37 @@ def forward( @lru_cache(maxsize=1) -def _use_flash_attn() -> bool: - try: - import flash_attn # noqa: F401 - except ImportError: - logger.info("flash_attn is not found. Using xformers backend.") - return False - - if is_hip(): - # AMD GPUs. - return False - if torch.cuda.get_device_capability()[0] < 8: +def _use_flash_attn() -> int: + """Returns if and which flash attention to use. + + Returns: + int: 0 for none, 1 for default implementation, 2 for triton implementation. + """ + if not (os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') and is_hip()): + # AMD GPUs can use flash_attn package or triton impl. + try: + import flash_attn # noqa: F401 + except ImportError: + logger.info("flash_attn is not found. Using xformers backend.") + return 0 + + if (not is_hip()) and torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info("flash_attn is not supported on Turing or older GPUs. " "Using xformers backend.") - return False + return 0 + + if is_hip() and torch.cuda.get_device_capability()[0] != 9: + # not Instinct series GPUs. + logger.info("flash_atten is not supported on NAVI GPUs. " + "Using xformers backend.") + return 0 + if torch.get_default_dtype() not in (torch.float16, torch.bfloat16): logger.info( "flash_attn only supports torch.float16 or torch.bfloat16. " "Using xformers backend.") - return False + return 0 - logger.info("Using flash_attn backend.") - return True + logger.info(f"Using {'Triton' if os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') else ''} flash_attn backend.") + return 2 if os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') else 1 diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index c2d7b5acc467e..726b42cad9e3f 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -8,7 +8,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention.ops.paged_attn import ( PagedAttentionImpl) -from vllm.model_executor.layers.attention.ops.flash_attention_triton import attention +from vllm.model_executor.layers.attention.ops.flash_attention_triton import triton_attention class FlashAttentionBackend: @@ -21,6 +21,7 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + use_triton: Optional[bool] = False, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -30,6 +31,7 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.use_triton = use_triton assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -87,8 +89,8 @@ def forward( query = query.unflatten(0, (batch_size, seq_len)) key = key.unflatten(0, (batch_size, seq_len)) value = value.unflatten(0, (batch_size, seq_len)) - if is_hip(): - output, _ = attention( + if self.use_triton: + output, _ = triton_attention( query, key, value, @@ -98,15 +100,25 @@ def forward( self.scale, ) else: - output = flash_attn_func( - query, - key, - value, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) + if is_hip(): + #XXX: window_size and alibi_slopes not supported + output = flash_attn_func( + query, + key, + value, + softmax_scale=self.scale, + causal=True, + ) + else: + output = flash_attn_func( + query, + key, + value, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) else: # prefix-enabled attention output = PagedAttentionImpl.forward_prefix( diff --git a/vllm/model_executor/layers/attention/ops/flash_attention_triton.py b/vllm/model_executor/layers/attention/ops/flash_attention_triton.py index 37c15e0e6fa36..80962e4cf9d9a 100644 --- a/vllm/model_executor/layers/attention/ops/flash_attention_triton.py +++ b/vllm/model_executor/layers/attention/ops/flash_attention_triton.py @@ -251,12 +251,12 @@ def attn_fwd( ) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) # We still need to write 0s to the result - tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1)) - l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + #tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1)) + #l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m # We store inf to LSE, not -inf because in the bwd pass, we subtract this # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. - l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) - tl.store(l_ptrs, l) + #l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + #tl.store(l_ptrs, l) # TODO: Should dropout and return encoded softmax be handled here too? return @@ -417,17 +417,17 @@ def attn_fwd( z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE - l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + #l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. # This is only true for the last M block. For others, overflow_size will be -ve - overflow_size = end_m_idx - seqlen_q - if overflow_size > 0: - boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) - # This is a > check because mask being 0 blocks the store. - l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) - tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) - else: - tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + #overflow_size = end_m_idx - seqlen_q + #if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + #else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh @@ -494,8 +494,6 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None): encoded_softmax = None - M = torch.empty((batch, nheads_q, metadata.max_seq_len), device=q.device, dtype=torch.float32) - # Seed the RNG so we get reproducible results for testing. philox_seed = 0x1BF52 philox_offset = 0x1D4B42 @@ -507,7 +505,7 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None): bias_strides = (0,0,0,0) attn_fwd[grid]( - q, k, v, bias, sm_scale, M, o, + q, k, v, bias, sm_scale, None, o, *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, None, None, dropout_p=0.0, @@ -526,7 +524,6 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None): RETURN_ENCODED_SOFTMAX=False ) - ctx.save_for_backward(q, k, v, o, M) ctx.grid = grid ctx.sm_scale = sm_scale ctx.BLOCK_DMODEL = head_size @@ -538,4 +535,4 @@ def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None): ctx.return_encoded_softmax = False return o, encoded_softmax -attention = _attention.apply +triton_attention = _attention.apply From 1fff99a0fa4e2be0af8bf335ad1fb6e3c156edf5 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Fri, 22 Mar 2024 16:48:37 +0000 Subject: [PATCH 4/7] Added interleaving for MQA for triton kernel --- .../layers/attention/backends/flash_attn.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index 726b42cad9e3f..c5fba494fdbce 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -44,6 +44,15 @@ def __init__( self.sliding_window = ((self.sliding_window, self.sliding_window) if self.sliding_window is not None else (-1, -1)) + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + tokens, n_kv_heads, head_dim = x.shape + return ( + x[:, :, None, :] + .expand(tokens, n_kv_heads, n_rep, head_dim) + .reshape(tokens, n_kv_heads * n_rep, head_dim) + ) + def forward( self, query: torch.Tensor, @@ -85,6 +94,11 @@ def forward( # Prompt run. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): + if self.num_kv_heads != self.num_heads: + # Interleave for MQA + key = self.repeat_kv(key, self.num_queries_per_kv) + value = self.repeat_kv(value, self.num_queries_per_kv) + # normal attention query = query.unflatten(0, (batch_size, seq_len)) key = key.unflatten(0, (batch_size, seq_len)) From a3bc6c83179e326a8c78ab0d9e83f9d4b420cfaf Mon Sep 17 00:00:00 2001 From: jpvillam Date: Fri, 22 Mar 2024 15:07:02 -0400 Subject: [PATCH 5/7] Only if VLLM_USE_FLASH_ATTN_TRITON is to 1 or true --- vllm/model_executor/layers/attention/attention.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 89b5816f7a47a..a554911b30b5d 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -67,7 +67,8 @@ def _use_flash_attn() -> int: Returns: int: 0 for none, 1 for default implementation, 2 for triton implementation. """ - if not (os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') and is_hip()): + use_flash_attn_triton = os.environ.get('VLLM_USE_FLASH_ATTN_TRITON', "False").lower() in ("true", "1") + if not ( use_flash_attn_triton and is_hip()): # AMD GPUs can use flash_attn package or triton impl. try: import flash_attn # noqa: F401 @@ -93,5 +94,5 @@ def _use_flash_attn() -> int: "Using xformers backend.") return 0 - logger.info(f"Using {'Triton' if os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') else ''} flash_attn backend.") - return 2 if os.environ.get('VLLM_USE_FLASH_ATTN_TRITON') else 1 + logger.info(f"Using {'Triton' if use_flash_attn_triton else ''} flash_attn backend.") + return 2 if use_flash_attn_triton else 1 From 8cd9653b1646c2ccd5cc22e46bc00d9287a058c9 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Fri, 22 Mar 2024 15:56:35 -0400 Subject: [PATCH 6/7] Make triton the default FA --- vllm/model_executor/layers/attention/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index a554911b30b5d..252a8192cb052 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -67,7 +67,7 @@ def _use_flash_attn() -> int: Returns: int: 0 for none, 1 for default implementation, 2 for triton implementation. """ - use_flash_attn_triton = os.environ.get('VLLM_USE_FLASH_ATTN_TRITON', "False").lower() in ("true", "1") + use_flash_attn_triton = os.environ.get('VLLM_USE_FLASH_ATTN_TRITON', "True").lower() in ("true", "1") if not ( use_flash_attn_triton and is_hip()): # AMD GPUs can use flash_attn package or triton impl. try: From 52af554a1d59eef2ecf742a4873a041f2d330fb4 Mon Sep 17 00:00:00 2001 From: jpvillam Date: Fri, 22 Mar 2024 15:59:58 -0400 Subject: [PATCH 7/7] Make workaround only applicable to triton path --- vllm/model_executor/layers/attention/backends/flash_attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index c5fba494fdbce..036013afc0850 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -94,8 +94,8 @@ def forward( # Prompt run. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): - if self.num_kv_heads != self.num_heads: - # Interleave for MQA + if self.use_triton and (self.num_kv_heads != self.num_heads): + # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv)