Skip to content

Commit

Permalink
Add flashattention2 (#196)
Browse files Browse the repository at this point in the history
* first

* fix causal mask

* disable flash attention2 on sm70

* fix 2

* update readme

* clang-format

* disable ft2 on windows

* fix lint

* fix build

* fix build

* fix long kv seq

* fix lint

* sync copy output

---------

Co-authored-by: grimoire <[email protected]>
Co-authored-by: irexyc <[email protected]>
  • Loading branch information
3 people authored Aug 29, 2023
1 parent d4d609b commit 452822a
Show file tree
Hide file tree
Showing 24 changed files with 1,923 additions and 134 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ include(FetchContent)
FetchContent_Declare(
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git
GIT_TAG cc85b64cf676c45f98a17e3a47c0aafcf817f088
GIT_TAG 6f47420213f757831fae65c686aa471749fa8d60
)

set(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
Expand Down Expand Up @@ -312,6 +312,7 @@ add_library(transformer-shared SHARED
$<TARGET_OBJECTS:BaseSamplingLayer>
$<TARGET_OBJECTS:DynamicDecodeLayer>
$<TARGET_OBJECTS:llama_fmha>
$<TARGET_OBJECTS:flash_attention2>
$<TARGET_OBJECTS:Llama>
$<TARGET_OBJECTS:LlamaTritonBackend>
$<TARGET_OBJECTS:gemm_s4_f16>
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ______________________________________________________________________

## News 🎉

- \[2023/08\] TurboMind supports flash-attention2.
- \[2023/08\] TurboMind supports Qwen-7B, dynamic NTK-RoPE scaling and dynamic logN scaling
- \[2023/08\] TurboMind supports Windows (tp=1)
- \[2023/08\] TurboMind supports 4-bit inference, 2.4x faster than FP16, the fastest open-source implementation🚀. Check [this](./docs/en/w4a16.md) guide for detailed info
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ ______________________________________________________________________

## 更新 🎉

- \[2023/08\] TurboMind 支持 flash-attention2
- \[2023/08\] TurboMind 支持 Qwen-7B,动态NTK-RoPE缩放,动态logN缩放
- \[2023/08\] TurboMind 支持 Windows (tp=1)
- \[2023/08\] TurboMind 支持 4-bit 推理,速度是 FP16 的 2.4 倍,是目前最快的开源实现🚀。部署方式请看[这里](./docs/zh_cn/w4a16.md)
Expand Down
5 changes: 5 additions & 0 deletions src/turbomind/models/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,5 +41,10 @@ target_link_libraries(Llama PUBLIC CUDA::cudart
logger
llama_fmha)

if (NOT MSVC)
add_subdirectory(flash_attention2)
target_link_libraries(Llama PUBLIC flash_attention2)
endif()

add_executable(llama_gemm llama_gemm.cc)
target_link_libraries(llama_gemm PUBLIC CUDA::cudart gpt_gemm_func memory_utils cuda_utils logger)
1 change: 1 addition & 0 deletions src/turbomind/models/llama/LlamaContextAttentionLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
{
//////////////////////////////////////////////
// flash attention
// flash attention 2 only support half inputs
using AttentionOp = FlashAttentionOp<T>;
using Layout = typename AttentionOp::AttentionLayout;
Layout layout_q{
Expand Down
15 changes: 15 additions & 0 deletions src/turbomind/models/llama/flash_attention2/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@

cmake_minimum_required(VERSION 3.8)
project(flash_attention2)

add_library(${PROJECT_NAME} STATIC
flash_api.cpp
flash_fwd_hdim32_fp16_sm80.cu
flash_fwd_hdim64_fp16_sm80.cu
flash_fwd_hdim128_fp16_sm80.cu
flash_fwd_hdim256_fp16_sm80.cu
)
target_include_directories(${PROJECT_NAME} PRIVATE ${CUTLASS_DIR} / include)
target_link_libraries(${PROJECT_NAME} PRIVATE nvidia::cutlass::cutlass)
set_property(TARGET ${PROJECT_NAME} PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET ${PROJECT_NAME} PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
7 changes: 7 additions & 0 deletions src/turbomind/models/llama/flash_attention2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#Flash Attention 2

This is flash attention2 implementation modified from https://github.com/Dao-AILab/flash-attention

- remove dropout
- remove backward
- cutlass 3.1.0
45 changes: 45 additions & 0 deletions src/turbomind/models/llama/flash_attention2/block_info.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/

#pragma once

namespace flash {

////////////////////////////////////////////////////////////////////////////////////////////////////

template<bool Varlen = true>
struct BlockInfo {

template<typename Params>
__device__ BlockInfo(const Params& params, const int bidb):
sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]),
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q :
params.cu_seqlens_q[bidb + 1] - sum_s_q),
actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k :
params.cu_seqlens_k[bidb + 1] - sum_s_k)
{
}

template<typename index_t>
inline __device__ index_t q_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const
{
return sum_s_q == -1 ? bidb * batch_stride : uint32_t(sum_s_q) * row_stride;
}

template<typename index_t>
inline __device__ index_t k_offset(const index_t batch_stride, const index_t row_stride, const int bidb) const
{
return sum_s_k == -1 ? bidb * batch_stride : uint32_t(sum_s_k) * row_stride;
}

const int sum_s_q;
const int sum_s_k;
const int actual_seqlen_q;
const int actual_seqlen_k;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace flash
88 changes: 88 additions & 0 deletions src/turbomind/models/llama/flash_attention2/flash.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// modify from: https://github.com/Dao-AILab/flash-attention

#pragma once

#include <cuda.h>
#include <cuda_runtime.h>
#include <vector>

constexpr int TOTAL_DIM = 0;
constexpr int H_DIM = 1;
constexpr int D_DIM = 2;

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Qkv_params {
using index_t = uint32_t;
// The QKV matrices.
void* __restrict__ q_ptr;
void* __restrict__ k_ptr;
void* __restrict__ v_ptr;

// batched ptr inputs.
void** __restrict__ k_batched_ptr = nullptr;
void** __restrict__ v_batched_ptr = nullptr;
int k_batched_offset = 0;
int v_batched_offset = 0;

// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
index_t k_batch_stride;
index_t v_batch_stride;
index_t q_row_stride;
index_t k_row_stride;
index_t v_row_stride;
index_t q_head_stride;
index_t k_head_stride;
index_t v_head_stride;

// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_fwd_params: public Qkv_params {

// The O matrix (output).
void* __restrict__ o_ptr;

// The stride between rows of O.
index_t o_batch_stride;
index_t o_row_stride;
index_t o_head_stride;

// The pointer to the P matrix.
void* __restrict__ p_ptr;

// The dimensions.
int b, seqlen_q, seqlen_k, d, seqlen_q_rounded, seqlen_k_rounded;

// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;

// array of length b+1 holding starting offset of each sequence.
int* __restrict__ cu_seqlens_q;
int* __restrict__ cu_seqlens_k;

void* __restrict__ blockmask;

bool is_bf16;
bool is_causal;

// enable output seqlen
bool q_enable_seqlen;
bool o_enable_seqlen;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<typename T, int Headdim>
void run_mha_fwd_(Flash_fwd_params& params, cudaStream_t stream);
164 changes: 164 additions & 0 deletions src/turbomind/models/llama/flash_attention2/flash_api.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/******************************************************************************
* Copyright (c) 2023, Tri Dao.
******************************************************************************/
// modify from: https://github.com/Dao-AILab/flash-attention

#include "flash.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "static_switch.h"
#include <cuda_runtime.h>
#include <cutlass/numeric_types.h>
#include <math.h>

void run_mha_fwd(Flash_fwd_params& params, cudaStream_t stream)
{
FP16_SWITCH(true,
[&] { FWD_HEADDIM_SWITCH(params.d, [&] { run_mha_fwd_<elem_type, kHeadDim>(params, stream); }); });
}

namespace turbomind {

static constexpr int FMHA_VERSION = 2;

template<typename T>
class FlashAttentionOpImpl<T, FMHA_VERSION> {

public:
using AttentionLayout = BaseAttentionLayout<T>;
using Params = BaseAttentionParams<T>;

public:
FlashAttentionOpImpl(int batch_size, int head_num, int key_len, int seq_len, int size_per_head);
~FlashAttentionOpImpl();

int get_workspace_size() const;

void operator()(Params& params, cudaStream_t st) const;

private:
class impl;
std::unique_ptr<impl> pimpl;
};

template<typename T>
class FlashAttentionOpImpl<T, FMHA_VERSION>::impl {

private:
using scalar_t =
typename std::conditional_t<std::is_same<half, typename std::decay<T>::type>::value, cutlass::half_t, T>;
using Params = typename FlashAttentionOpImpl<T, FMHA_VERSION>::Params;

int batch_size_;
int head_num_;
int key_len_;
int seq_len_;
int size_per_head_;

public:
impl(int batch_size, int head_num, int key_len, int seq_len, int size_per_head):
batch_size_(batch_size),
head_num_(head_num),
key_len_(key_len),
seq_len_(seq_len),
size_per_head_(size_per_head)
{
}

~impl() {}

int get_workspace_size() const
{
return 0;
}

void operator()(Params& params, cudaStream_t st) const
{
const float qk_scale = static_cast<float>(1.f / sqrtf(size_per_head_ * 1.f));
Flash_fwd_params fwd_params;
memset(&fwd_params, 0, sizeof(fwd_params));

fwd_params.q_ptr = reinterpret_cast<void*>(params.query);
fwd_params.k_ptr = reinterpret_cast<void*>(params.key);
fwd_params.v_ptr = reinterpret_cast<void*>(params.val);

fwd_params.k_batched_ptr = reinterpret_cast<void**>(params.layout_k.batch_seqs);
fwd_params.v_batched_ptr = reinterpret_cast<void**>(params.layout_v.batch_seqs);
fwd_params.k_batched_offset = params.layout_k.batch_seqs_offset;
fwd_params.v_batched_offset = params.layout_v.batch_seqs_offset;

fwd_params.q_batch_stride = params.layout_q.stride_batch;
fwd_params.k_batch_stride = params.layout_k.stride_batch;
fwd_params.v_batch_stride = params.layout_v.stride_batch;
fwd_params.q_row_stride = params.layout_q.stride_seq;
fwd_params.k_row_stride = params.layout_k.stride_seq;
fwd_params.v_row_stride = params.layout_v.stride_seq;
fwd_params.q_head_stride = params.layout_q.stride_head;
fwd_params.v_head_stride = params.layout_v.stride_head;
fwd_params.k_head_stride = params.layout_k.stride_head;

fwd_params.h = head_num_;
fwd_params.h_k = head_num_ / params.group_size;
fwd_params.h_h_k_ratio = params.group_size;

fwd_params.o_ptr = reinterpret_cast<void*>(params.attn_out);

fwd_params.o_batch_stride = params.layout_o.stride_batch;
fwd_params.o_row_stride = params.layout_o.stride_seq;
fwd_params.o_head_stride = params.layout_o.stride_head;

fwd_params.p_ptr = nullptr;

fwd_params.b = batch_size_;
fwd_params.seqlen_q = seq_len_;
fwd_params.seqlen_k = key_len_;
fwd_params.d = size_per_head_;
fwd_params.seqlen_q_rounded = 0;
fwd_params.seqlen_k_rounded = 0;

fwd_params.scale_softmax = qk_scale;
fwd_params.scale_softmax_log2 = qk_scale * M_LOG2E;

fwd_params.cu_seqlens_q = params.cu_seqlens_q;
fwd_params.cu_seqlens_k = params.cu_seqlens_k;

fwd_params.blockmask = reinterpret_cast<void*>(params.mask);

fwd_params.is_bf16 = false;
fwd_params.is_causal = true;

fwd_params.q_enable_seqlen = params.layout_q.use_seqlens;
fwd_params.o_enable_seqlen = params.layout_o.use_seqlens;

run_mha_fwd(fwd_params, st);
}
};

template<typename T>
FlashAttentionOpImpl<T, FMHA_VERSION>::FlashAttentionOpImpl(
int batch_size, int head_num, int key_len, int seq_len, int size_per_head):
pimpl{std::make_unique<FlashAttentionOpImpl<T, FMHA_VERSION>::impl>(
batch_size, head_num, key_len, seq_len, size_per_head)}
{
}

template<typename T>
FlashAttentionOpImpl<T, FMHA_VERSION>::~FlashAttentionOpImpl()
{
}

template<typename T>
int FlashAttentionOpImpl<T, FMHA_VERSION>::get_workspace_size() const
{
return pimpl->get_workspace_size();
}

template<typename T>
void FlashAttentionOpImpl<T, FMHA_VERSION>::operator()(Params& params, cudaStream_t st) const
{
pimpl->operator()(params, st);
}

template class FlashAttentionOpImpl<float, FMHA_VERSION>;
template class FlashAttentionOpImpl<half, FMHA_VERSION>;

} // namespace turbomind
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) 2023, Tri Dao.

// Splitting the different head dimensions to different files to speed up compilation.

#include "flash_fwd_launch_template.h"

template<>
void run_mha_fwd_<cutlass::half_t, 128>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim128<cutlass::half_t>(params, stream);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// Copyright (c) 2023, Tri Dao.

// Splitting the different head dimensions to different files to speed up compilation.

#include "flash_fwd_launch_template.h"

template<>
void run_mha_fwd_<cutlass::half_t, 256>(Flash_fwd_params& params, cudaStream_t stream)
{
run_mha_fwd_hdim256<cutlass::half_t>(params, stream);
}
Loading

0 comments on commit 452822a

Please sign in to comment.