Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support turbomind head_dim 64 #2715

Merged
merged 6 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lmdeploy/turbomind/deploy/source_model/internvl.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,14 @@ def model_info(self):
scaling_factor = model_arg['rope_scaling'].get('factor', '')
if scaling_type == 'dynamic':
use_dynamic_ntk = 1
addition_kwargs = {}
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
if model_arg['architectures'][0] == 'Qwen2ForCausalLM':
addition_kwargs['attn_bias'] = 1

return dict(num_layer=num_layer,
size_per_head=hidden_units // attn_head_num,
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
rotary_embedding=hidden_units // attn_head_num,
**addition_kwargs,
norm_eps=norm_eps,
hidden_units=hidden_units,
inter_size=inter_size,
Expand Down
2 changes: 2 additions & 0 deletions lmdeploy/turbomind/deploy/source_model/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def model_info(self):
beta_slow = rope_scaling.get('beta_slow', 1.0)

return dict(
size_per_head=hidden_units // attn_head_num,
rotary_embedding=hidden_units // attn_head_num,
num_layer=num_layer,
norm_eps=norm_eps,
head_num=attn_head_num,
Expand Down
8 changes: 4 additions & 4 deletions lmdeploy/turbomind/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ def is_supported(model_path: str):
""" # noqa: E501
import os

def _is_head_dim_128(cfg):
def _is_head_dim_128_64(cfg):
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
num_attn_head = cfg.num_attention_heads
hidden_size = cfg.hidden_size
# turbomind support head_dim=128
return (hidden_size // num_attn_head) == 128
return (hidden_size // num_attn_head) in [128, 64]

support_by_turbomind = False
triton_model_path = os.path.join(model_path, 'triton_models')
Expand All @@ -89,7 +89,7 @@ def _is_head_dim_128(cfg):
elif arch in ['Qwen2ForCausalLM', 'LlamaForCausalLM']:
# the head_dim of qwen2 0.5b and llama3.2-1b is 64, which
# hasn't been supported by turbomind yet
support_by_turbomind = _is_head_dim_128(cfg)
support_by_turbomind = _is_head_dim_128_64(cfg)
elif arch in ('ChatGLMModel', 'ChatGLMForConditionalGeneration'):
# chatglm1/2/3 is not working yet
support_by_turbomind = cfg.num_layers == 40
Expand All @@ -98,6 +98,6 @@ def _is_head_dim_128(cfg):
support_by_turbomind = False
elif arch == 'InternVLChatModel':
# internvl2-4b,internlm2-1b are not working yet
support_by_turbomind = _is_head_dim_128(cfg.llm_config)
support_by_turbomind = _is_head_dim_128_64(cfg.llm_config)

return support_by_turbomind
16 changes: 16 additions & 0 deletions src/turbomind/kernels/attention/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,22 @@ add_library(attention STATIC
codegen/decoding_sm80_128_f16_f16.cu
codegen/decoding_sm80_128_f16_u4.cu
codegen/decoding_sm80_128_f16_u8.cu
codegen/attention_sm70_64_f16.cu
codegen/attention_sm75_64_f16.cu
codegen/attention_sm80_64_bf16.cu
codegen/attention_sm80_64_f16.cu
codegen/decoding_sm70_64_f16_f16.cu
codegen/decoding_sm70_64_f16_u4.cu
codegen/decoding_sm70_64_f16_u8.cu
codegen/decoding_sm75_64_f16_f16.cu
codegen/decoding_sm75_64_f16_u4.cu
codegen/decoding_sm75_64_f16_u8.cu
codegen/decoding_sm80_64_bf16_bf16.cu
codegen/decoding_sm80_64_bf16_u4.cu
codegen/decoding_sm80_64_bf16_u8.cu
codegen/decoding_sm80_64_f16_f16.cu
codegen/decoding_sm80_64_f16_u4.cu
codegen/decoding_sm80_64_f16_u8.cu
)
set_property(TARGET attention PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET attention PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
Expand Down
19 changes: 13 additions & 6 deletions src/turbomind/kernels/attention/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,19 @@ template<class T>
void dispatchAttention(const AttentionParams<T>& params)
{
using namespace attention;
if (params.size_per_head == 128) {

auto dispatch = [&](const auto dim) {
constexpr int kHeadDim = dim;
if (params.arch >= 80) {
using Config = AttentionConfig<arch::Sm80, T, 128, CacheType::kLinear>;
using Config = AttentionConfig<arch::Sm80, T, kHeadDim, CacheType::kLinear>;
return invokeAttention<typename Config::Kernel>(params);
}

if constexpr (!std::is_same_v<T, nv_bfloat16>) {
if (params.arch == 75) {
return invokeAttention<typename AttentionConfig<arch::Sm75, T, 128, CacheType::kLinear>::Kernel>(
return invokeAttention<typename AttentionConfig<arch::Sm75, T, kHeadDim, CacheType::kLinear>::Kernel>(
params);
}
else if (params.arch >= 70) {
return invokeAttention<typename AttentionConfig<arch::Sm70, T, 128, CacheType::kLinear>::Kernel>(
return invokeAttention<typename AttentionConfig<arch::Sm70, T, kHeadDim, CacheType::kLinear>::Kernel>(
params);
}
}
Expand All @@ -38,6 +37,14 @@ void dispatchAttention(const AttentionParams<T>& params)
params.arch);
}
}
FT_CHECK(0);
};

if (params.size_per_head == 64) {
return dispatch(std::integral_constant<int, 64>{});
}
else if (params.size_per_head == 128) {
return dispatch(std::integral_constant<int, 128>{});
}
FT_CHECK(0);
}
Expand Down
16 changes: 16 additions & 0 deletions src/turbomind/kernels/attention/codegen/attention_sm70_64_f16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../attention_config.h"
#include "../attention_template.h"

namespace turbomind {

using namespace attention;

template void invokeAttention<typename AttentionConfig<arch::Sm70, half, 64, CacheType::kLinear>::Kernel>(
const AttentionParams<half>& params);

template void invokeAttention<typename AttentionConfig<arch::Sm70, half, 64, CacheType::kBlock>::Kernel>(
const AttentionParams<half>& params);

} // namespace turbomind
17 changes: 17 additions & 0 deletions src/turbomind/kernels/attention/codegen/attention_sm75_64_f16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../attention_config.h"
#include "../attention_template.h"

namespace turbomind {

using namespace attention;

template void invokeAttention<typename AttentionConfig<arch::Sm75, half, 64, CacheType::kLinear>::Kernel>(
const AttentionParams<half>& params);

// ! register spill
// template void invokeAttention<typename AttentionConfig<arch::Sm75, half, 64, CacheType::kBlock>::Kernel>(
// const AttentionParams<half>& params);

} // namespace turbomind
16 changes: 16 additions & 0 deletions src/turbomind/kernels/attention/codegen/attention_sm80_64_bf16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../attention_config.h"
#include "../attention_template.h"

namespace turbomind {

using namespace attention;

template void invokeAttention<typename AttentionConfig<arch::Sm80, nv_bfloat16, 64, CacheType::kLinear>::Kernel>(
const AttentionParams<nv_bfloat16>& params);

template void invokeAttention<typename AttentionConfig<arch::Sm80, nv_bfloat16, 64, CacheType::kBlock>::Kernel>(
const AttentionParams<nv_bfloat16>& params);

} // namespace turbomind
16 changes: 16 additions & 0 deletions src/turbomind/kernels/attention/codegen/attention_sm80_64_f16.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../attention_config.h"
#include "../attention_template.h"

namespace turbomind {

using namespace attention;

template void invokeAttention<typename AttentionConfig<arch::Sm80, half, 64, CacheType::kLinear>::Kernel>(
const AttentionParams<half>& params);

template void invokeAttention<typename AttentionConfig<arch::Sm80, half, 64, CacheType::kBlock>::Kernel>(
const AttentionParams<half>& params);

} // namespace turbomind
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../decoding_config.h"
#include "../decoding_template.h"

namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm70, half, half, 1, 64>>(const AttentionParams<half>& params);

template bool invokeDecoding<Decoding<arch::Sm70, half, half, 2, 64>>(const AttentionParams<half>& params);

template bool invokeDecoding<Decoding<arch::Sm70, half, half, 3, 64>>(const AttentionParams<half>& params);

} // namespace turbomind
17 changes: 17 additions & 0 deletions src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u4.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../attention_params.h"
#include "../decoding_config.h"
#include "../decoding_template.h"

namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm70, half, uint4_t, 1, 64>>(const AttentionParams<half>& params);

template bool invokeDecoding<Decoding<arch::Sm70, half, uint4_t, 2, 64>>(const AttentionParams<half>& params);

template bool invokeDecoding<Decoding<arch::Sm70, half, uint4_t, 3, 64>>(const AttentionParams<half>& params);

} // namespace turbomind
17 changes: 17 additions & 0 deletions src/turbomind/kernels/attention/codegen/decoding_sm70_64_f16_u8.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../attention_params.h"
#include "../decoding_config.h"
#include "../decoding_template.h"

namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm70, half, uint8_t, 1, 64>>(const AttentionParams<half>& params);

template bool invokeDecoding<Decoding<arch::Sm70, half, uint8_t, 2, 64>>(const AttentionParams<half>& params);

template bool invokeDecoding<Decoding<arch::Sm70, half, uint8_t, 3, 64>>(const AttentionParams<half>& params);

} // namespace turbomind
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../decoding_config.h"
#include "../decoding_template.h"

namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm75, half, half, 8, 64>>(const AttentionParams<half>& params);

template bool invokeDecoding<Decoding<arch::Sm75, half, half, 16, 64>>(const AttentionParams<half>& params);

} // namespace turbomind
14 changes: 14 additions & 0 deletions src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u4.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../decoding_config.h"
#include "../decoding_template.h"

namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm75, half, uint4_t, 8, 64>>(const AttentionParams<half>& params);

template bool invokeDecoding<Decoding<arch::Sm75, half, uint4_t, 16, 64>>(const AttentionParams<half>& params);

} // namespace turbomind
14 changes: 14 additions & 0 deletions src/turbomind/kernels/attention/codegen/decoding_sm75_64_f16_u8.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../decoding_config.h"
#include "../decoding_template.h"

namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm75, half, uint8_t, 8, 64>>(const AttentionParams<half>& params);

template bool invokeDecoding<Decoding<arch::Sm75, half, uint8_t, 16, 64>>(const AttentionParams<half>& params);

} // namespace turbomind
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../decoding_config.h"
#include "../decoding_template.h"

namespace turbomind {

using namespace attention;

template bool
invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 1, 64>>(const AttentionParams<nv_bfloat16>& params);

template bool
invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 2, 64>>(const AttentionParams<nv_bfloat16>& params);

template bool
invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 8, 64>>(const AttentionParams<nv_bfloat16>& params);

template bool
invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, nv_bfloat16, 16, 64>>(const AttentionParams<nv_bfloat16>& params);

} // namespace turbomind
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../decoding_config.h"
#include "../decoding_template.h"

namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, uint4_t, 8, 64>>(const AttentionParams<nv_bfloat16>&);

template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, uint4_t, 16, 64>>(const AttentionParams<nv_bfloat16>&);

} // namespace turbomind
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../decoding_config.h"
#include "../decoding_template.h"

namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, uint8_t, 8, 64>>(const AttentionParams<nv_bfloat16>&);

template bool invokeDecoding<Decoding<arch::Sm80, nv_bfloat16, uint8_t, 16, 64>>(const AttentionParams<nv_bfloat16>&);

} // namespace turbomind
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../decoding_config.h"
#include "../decoding_template.h"

namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm80, half, half, 1, 64>>(const AttentionParams<half>& params);

template bool invokeDecoding<Decoding<arch::Sm80, half, half, 2, 64>>(const AttentionParams<half>& params);

template bool invokeDecoding<Decoding<arch::Sm80, half, half, 8, 64>>(const AttentionParams<half>& params);

template bool invokeDecoding<Decoding<arch::Sm80, half, half, 16, 64>>(const AttentionParams<half>& params);

} // namespace turbomind
14 changes: 14 additions & 0 deletions src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u4.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../decoding_config.h"
#include "../decoding_template.h"

namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm80, half, uint4_t, 8, 64>>(const AttentionParams<half>&);

template bool invokeDecoding<Decoding<arch::Sm80, half, uint4_t, 16, 64>>(const AttentionParams<half>&);

} // namespace turbomind
14 changes: 14 additions & 0 deletions src/turbomind/kernels/attention/codegen/decoding_sm80_64_f16_u8.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// Copyright (c) OpenMMLab. All rights reserved.

#include "../decoding_config.h"
#include "../decoding_template.h"

namespace turbomind {

using namespace attention;

template bool invokeDecoding<Decoding<arch::Sm80, half, uint8_t, 8, 64>>(const AttentionParams<half>&);

template bool invokeDecoding<Decoding<arch::Sm80, half, uint8_t, 16, 64>>(const AttentionParams<half>&);

} // namespace turbomind
Loading
Loading