diff --git a/lmdeploy/pytorch/backends/default/rotary_embedding.py b/lmdeploy/pytorch/backends/default/rotary_embedding.py index bc209be5e1..3cecbefbc2 100644 --- a/lmdeploy/pytorch/backends/default/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/default/rotary_embedding.py @@ -232,9 +232,12 @@ def __init__(self, self.register_buffer('inv_freq', inv_freq, persistent=False) # get mscale - self.mscale = float( - yarn_get_mscale(self.scaling_factor, self.mscale) / - yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)) + if yarn_params.attention_factor is not None: + self.mscale = yarn_params.attention_factor + else: + self.mscale = float( + yarn_get_mscale(self.scaling_factor, self.mscale) / + yarn_get_mscale(self.scaling_factor, self.mscale_all_dim)) if self.mscale == 1.0: self.mscale = None @@ -334,10 +337,10 @@ def build( return LlamaDynamicNTKScalingRotaryEmbedding( dim, base, scaling_factor, max_position_embeddings) elif emb_type == RopeType.Llama3: - return Llama3RotaryEmbeddingImpl(dim, base, scaling_factor, - llama3_params.low_freq_factor, - llama3_params.high_freq_factor, - max_position_embeddings) + return Llama3RotaryEmbeddingImpl( + dim, base, scaling_factor, llama3_params.low_freq_factor, + llama3_params.high_freq_factor, + llama3_params.original_max_position_embeddings) elif emb_type == RopeType.Yarn: return YarnRotaryEmbeddingImpl(dim, base, diff --git a/lmdeploy/pytorch/backends/rotary_embedding.py b/lmdeploy/pytorch/backends/rotary_embedding.py index 6fa6abbdf9..5718d822f0 100644 --- a/lmdeploy/pytorch/backends/rotary_embedding.py +++ b/lmdeploy/pytorch/backends/rotary_embedding.py @@ -22,6 +22,7 @@ class YarnParameters: beta_slow: float = 1 mscale: int = 1 mscale_all_dim: int = 0 + attention_factor: int = None @dataclass @@ -39,6 +40,7 @@ class Llama3Parameters: """llama3 rope parameters.""" low_freq_factor: float = 1.0 high_freq_factor: float = 4.0 + original_max_position_embeddings: int = 8192 class RotaryEmbeddingImpl(ABC): diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py index 5ffa06665b..de6a7a58e1 100644 --- a/lmdeploy/pytorch/models/qwen2.py +++ b/lmdeploy/pytorch/models/qwen2.py @@ -6,8 +6,9 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType, - SiluAndMul, build_rotary_embedding) +from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, + SiluAndMul, build_rotary_embedding, + build_rotary_params) from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear, build_qkv_proj, build_rowwise_linear) from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight @@ -245,7 +246,8 @@ def __init__(self, device=device) # build rotary embedding - emb_type = RopeType.LinearScaling + # emb_type = RopeType.LinearScaling + rope_params = build_rotary_params(config) rope_dim = config.hidden_size // config.num_attention_heads rope_max_pos_emb = config.max_position_embeddings rope_base = config.rope_theta @@ -253,7 +255,7 @@ def __init__(self, rope_dim, rope_max_pos_emb, rope_base, - emb_type=emb_type, + **rope_params, ) def forward( diff --git a/lmdeploy/pytorch/nn/__init__.py b/lmdeploy/pytorch/nn/__init__.py index 2b90f40298..63df9a5ae9 100644 --- a/lmdeploy/pytorch/nn/__init__.py +++ b/lmdeploy/pytorch/nn/__init__.py @@ -8,3 +8,4 @@ from .rotary_embedding import RopeType # noqa: F401 from .rotary_embedding import YarnParameters # noqa: F401 from .rotary_embedding import build_rotary_embedding # noqa: F401 +from .rotary_embedding import build_rotary_params # noqa: F401 diff --git a/lmdeploy/pytorch/nn/rotary_embedding.py b/lmdeploy/pytorch/nn/rotary_embedding.py index 35a7de7144..43eb1f806d 100644 --- a/lmdeploy/pytorch/nn/rotary_embedding.py +++ b/lmdeploy/pytorch/nn/rotary_embedding.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from torch import Tensor, nn +from transformers import PretrainedConfig from ..backends import OpType, get_backend from ..backends.rotary_embedding import (Llama3Parameters, @@ -7,6 +8,84 @@ YarnParameters) +def _get_default_rope_parameters(config: PretrainedConfig): + """get default rope parameters.""" + return dict(emb_type=RopeType.Default, scaling_factor=1.0) + + +def _get_linear_scaling_rope_parameters(config: PretrainedConfig): + """get linear rope parameters.""" + rope_scaling = config.rope_scaling + scaling_factor = rope_scaling['factor'] + return dict(emb_type=RopeType.LinearScaling, scaling_factor=scaling_factor) + + +def _get_dynamic_ntk_parameters(config: PretrainedConfig): + """get dynamic ntk parameters.""" + rope_scaling = config.rope_scaling + scaling_factor = rope_scaling['factor'] + return dict(emb_type=RopeType.DynamicNTKScaling, + scaling_factor=scaling_factor) + + +def _get_yarn_parameters(config: PretrainedConfig): + """get yarn parameters.""" + rope_scaling = config.rope_scaling + scaling_factor = rope_scaling['factor'] + params = YarnParameters() + params.attention_factor = rope_scaling.get('attention_factor', + params.attention_factor) + params.beta_fast = rope_scaling.get('beta_fast', params.beta_fast) + params.beta_slow = rope_scaling.get('beta_slow', params.beta_slow) + return dict(emb_type=RopeType.Yarn, + scaling_factor=scaling_factor, + yarn_params=params) + + +def _get_longrope_parameters(config: PretrainedConfig): + """get longrope parameters.""" + rope_scaling = config.rope_scaling + params = LongRoPEScalingParameters() + scaling_factor = rope_scaling['factor'] + params.long_factor = rope_scaling.long_factor + params.short_factor = rope_scaling.long_factor + params.original_max_position_embeddings = rope_scaling.get( + 'original_max_position_embeddings', config.max_position_embeddings) + return dict(emb_type=RopeType.LongRoPEScaling, + scaling_factor=scaling_factor, + longrope_params=params) + + +def _get_llama3_parameters(config: PretrainedConfig): + """get llama rope parameters.""" + rope_scaling = config.rope_scaling + params = Llama3Parameters() + scaling_factor = rope_scaling['factor'] + params.low_freq_factor = rope_scaling['low_freq_factor'] + params.high_freq_factor = rope_scaling['high_freq_factor'] + params.original_max_position_embeddings = rope_scaling.get( + 'original_max_position_embeddings', + params.original_max_position_embeddings) + return dict(emb_type=RopeType.Llama3, + scaling_factor=scaling_factor, + llama3_params=params) + + +def build_rotary_params(config: PretrainedConfig): + """get scaling_factor rotary params, and emb_type.""" + params = dict(emb_type=RopeType.Default) + if config.rope_scaling is not None: + rope_type_str = config.rope_scaling.get('rope_type', 'default') + build_funcs = dict(default=_get_default_rope_parameters, + linear=_get_linear_scaling_rope_parameters, + dynamic=_get_dynamic_ntk_parameters, + yarn=_get_yarn_parameters, + longrope=_get_longrope_parameters, + llama3=_get_llama3_parameters) + params.update(build_funcs[rope_type_str](config)) + return params + + def build_rotary_embedding(dim: int, max_position_embeddings: int = 2048, base: int = 10000, diff --git a/lmdeploy/turbomind/deploy/config.py b/lmdeploy/turbomind/deploy/config.py index 6652650949..7e8ebf7b47 100644 --- a/lmdeploy/turbomind/deploy/config.py +++ b/lmdeploy/turbomind/deploy/config.py @@ -63,6 +63,7 @@ def verify(self): class AttentionConfig: rotary_embedding: int = 128 rope_theta: float = 10000.0 + attention_factor: float = None max_position_embeddings: int = 0 original_max_position_embeddings: int = 0 rope_scaling_type: str = '' @@ -70,6 +71,8 @@ class AttentionConfig: use_dynamic_ntk: int = 0 low_freq_factor: float = 1.0 high_freq_factor: float = 1.0 + beta_fast: float = 32.0 + beta_slow: float = 1.0 use_logn_attn: int = 0 cache_block_seq_len: int = 64 diff --git a/lmdeploy/turbomind/deploy/source_model/llama.py b/lmdeploy/turbomind/deploy/source_model/llama.py index d61d1906e1..8e19fa8d87 100644 --- a/lmdeploy/turbomind/deploy/source_model/llama.py +++ b/lmdeploy/turbomind/deploy/source_model/llama.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import json +import math import os.path as osp import re @@ -157,25 +158,35 @@ def model_info(self): scaling_type = '' low_freq_factor = 1.0 high_freq_factor = 1.0 + attention_factor = -1.0 + beta_fast = 32.0 + beta_slow = 1.0 original_max_position_embeddings = 0 if isinstance(rope_scaling, dict): - llama2_scaling_type = model_arg['rope_scaling'].get('type', '') - llama3_scaling_type = model_arg['rope_scaling'].get( - 'rope_type', '') - scaling_factor = model_arg['rope_scaling'].get('factor', '') - low_freq_factor = model_arg['rope_scaling'].get( - 'low_freq_factor', 1.0) - high_freq_factor = model_arg['rope_scaling'].get( - 'high_freq_factor', 1.0) - original_max_position_embeddings = model_arg[ - 'rope_scaling'].get('original_max_position_embeddings', 0) + llama2_scaling_type = rope_scaling.get('type', '') + llama3_scaling_type = rope_scaling.get('rope_type', '') if llama2_scaling_type and llama3_scaling_type: raise ValueError( f'Ambiguous rope_scaling in config: {model_arg}') scaling_type = llama2_scaling_type if llama2_scaling_type \ else llama3_scaling_type + scaling_factor = rope_scaling.get('factor', 0.0) if scaling_type == 'dynamic': use_dynamic_ntk = 1 + elif scaling_type == 'llama3': + low_freq_factor = rope_scaling.get('low_freq_factor', 1.0) + high_freq_factor = rope_scaling.get( + 'high_freq_factor', 1.0) + original_max_position_embeddings = model_arg[ + 'rope_scaling'].get('original_max_position_embeddings', + 0) + elif scaling_type == 'yarn': + attention_factor = rope_scaling.get( + 'attention_factor', None) + if attention_factor is None: + attention_factor = 0.1 * math.log(scaling_factor) + 1.0 + beta_fast = rope_scaling.get('beta_fast', 32.0) + beta_slow = rope_scaling.get('beta_slow', 1.0) return dict( num_layer=num_layer, @@ -192,4 +203,7 @@ def model_info(self): rope_scaling_type=scaling_type, rope_scaling_factor=scaling_factor, low_freq_factor=low_freq_factor, - high_freq_factor=high_freq_factor) + high_freq_factor=high_freq_factor, + attention_factor=attention_factor, + beta_fast=beta_fast, + beta_slow=beta_slow) diff --git a/src/turbomind/kernels/attention/attention_params.h b/src/turbomind/kernels/attention/attention_params.h index 8e0e52195d..b6dfaa596c 100644 --- a/src/turbomind/kernels/attention/attention_params.h +++ b/src/turbomind/kernels/attention/attention_params.h @@ -59,12 +59,18 @@ struct AttentionParams { // rotary embedding int rotary_embedding_dim; float rotary_embedding_base; + float rope_scaling_factor; + float attention_scaling; int max_position_embeddings; float rope_ti_scale; // used for linear RoPE scaling // the following 3 parameters are used by llama3 float llama3_inv_scaling_factor; float llama3_alpha; float llama3_beta; + // the following are use by yarn + float yarn_ramp_inv_factor_div_2; + float yarn_ramp_inv_factor_mul_min; + float yarn_inv_scaling_factor; // log(n) attention bool use_logn_attn; diff --git a/src/turbomind/kernels/attention/attention_universal.h b/src/turbomind/kernels/attention/attention_universal.h index 352cc14725..5fb583bd1f 100644 --- a/src/turbomind/kernels/attention/attention_universal.h +++ b/src/turbomind/kernels/attention/attention_universal.h @@ -231,9 +231,14 @@ struct AttentionUniversal { params.rotary_embedding_dim, rope_base, params.rope_ti_scale, + params.rope_scaling_factor, params.llama3_inv_scaling_factor, params.llama3_alpha, params.llama3_beta, + params.yarn_ramp_inv_factor_div_2, + params.yarn_ramp_inv_factor_mul_min, + params.yarn_inv_scaling_factor, + params.attention_scaling, std::integral_constant{}); PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu index 1edb92f374..9f28a17b83 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.cu +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.cu @@ -23,9 +23,14 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, const float* rope_base, int rope_dim, float rope_ti_scale, + float rope_scaling_factor, float llama3_inv_scaling_factor, float llama3_alpha, float llama3_beta, + float yarn_ramp_inv_factor_div_2, + float yarn_ramp_inv_factor_mul_min, + float yarn_inv_scaling_factor, + float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -128,9 +133,14 @@ __global__ void __launch_bounds__(128) ProcessKV_v2(char** blocks, rope_dim, base, rope_ti_scale, + rope_scaling_factor, llama3_inv_scaling_factor, llama3_alpha, llama3_beta, + yarn_ramp_inv_factor_div_2, + yarn_ramp_inv_factor_mul_min, + yarn_inv_scaling_factor, + attention_scaling, std::integral_constant{}); PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { @@ -204,9 +214,14 @@ void invokeProcessKV_v2(char** blocks, const float* rope_base, int rope_dim, float rope_ti_scale, + float rope_scaling_factor, float llama3_inv_scaling_factor, float llama3_1_alpha, float llama3_1_beta, + float yarn_ramp_inv_factor_div_2, + float yarn_ramp_inv_factor_mul_min, + float yarn_inv_scaling_factor, + float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -245,9 +260,14 @@ void invokeProcessKV_v2(char** blocks, rope_base, rope_dim, rope_ti_scale, + rope_scaling_factor, llama3_inv_scaling_factor, llama3_1_alpha, llama3_1_beta, + yarn_ramp_inv_factor_div_2, + yarn_ramp_inv_factor_mul_min, + yarn_inv_scaling_factor, + attention_scaling, stride_b, stride_c, stride_h, @@ -279,9 +299,14 @@ void invokeProcessKV_v2(char** blocks, const float* rope_base, \ int rope_dim, \ float rope_ti_scale, \ + float rope_scaling_factor, \ float llama3_inv_scaling_factor, \ float llama3_1_alpha, \ float llama3_1_beta, \ + float yarn_ramp_inv_factor_div_2, \ + float yarn_ramp_inv_factor_mul_min, \ + float yarn_inv_scaling_factor, \ + float attention_scaling, \ int64_t stride_b, \ int64_t stride_c, \ int64_t stride_h, \ @@ -309,9 +334,14 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, const float* rope_base, int rope_dim, float rope_ti_scale, + float rope_scaling_factor, float llama3_inv_scaling_factor, float llama3_alpha, float llama3_beta, + float yarn_ramp_inv_factor_div_2, + float yarn_ramp_inv_factor_mul_min, + float yarn_inv_scaling_factor, + float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -397,9 +427,14 @@ __global__ void __launch_bounds__(128) flattenKV_v2(T* k, rope_dim, base, rope_ti_scale, + rope_scaling_factor, llama3_inv_scaling_factor, llama3_alpha, llama3_beta, + yarn_ramp_inv_factor_div_2, + yarn_ramp_inv_factor_mul_min, + yarn_inv_scaling_factor, + attention_scaling, std::integral_constant{}); PRAGMA_UNROLL for (int s = 0; s < ITER_S; ++s) { @@ -434,9 +469,14 @@ void invokeFlattenKV_v2(T* k, const float* rope_base, int rope_dim, float rope_ti_scale, + float rope_scaling_factor, float llama3_inv_scaling_factor, float llama3_alpha, float llama3_beta, + float yarn_ramp_inv_factor_div_2, + float yarn_ramp_inv_factor_mul_min, + float yarn_inv_scaling_factor, + float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -472,9 +512,14 @@ void invokeFlattenKV_v2(T* k, rope_base, rope_dim, rope_ti_scale, + rope_scaling_factor, llama3_inv_scaling_factor, llama3_alpha, llama3_beta, + yarn_ramp_inv_factor_div_2, + yarn_ramp_inv_factor_mul_min, + yarn_inv_scaling_factor, + attention_scaling, stride_b, stride_c, stride_h, @@ -503,9 +548,14 @@ void invokeFlattenKV_v2(T* k, const float* rope_base, \ int rope_dim, \ float rope_ti_scale, \ + float rope_scaling_factor, \ float llama3_inv_scaling_factor, \ float llama3_alpha, \ float llama3_beta, \ + float yarn_ramp_inv_factor_div_2, \ + float yarn_ramp_inv_factor_mul_min, \ + float yarn_inv_scaling_factor, \ + float attention_scaling, \ int64_t stride_b, \ int64_t stride_c, \ int64_t stride_h, \ diff --git a/src/turbomind/kernels/attention/kv_cache_utils_v2.h b/src/turbomind/kernels/attention/kv_cache_utils_v2.h index 74ba7fafb0..fe45ad7be7 100644 --- a/src/turbomind/kernels/attention/kv_cache_utils_v2.h +++ b/src/turbomind/kernels/attention/kv_cache_utils_v2.h @@ -19,9 +19,14 @@ void invokeProcessKV_v2(char** blocks, const float* rope_base, int rope_dim, float rope_ti_scale, + float rope_scaling_factor, float llama3_inv_scaling_factor, float llama3_alpha, float llama3_beta, + float yarn_ramp_inv_factor_div_2, + float yarn_ramp_inv_factor_mul_min, + float yarn_inv_scaling_factor, + float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -49,9 +54,14 @@ void invokeProcessKV_v2_(const AttentionParams& params) params.rope_theta, params.rotary_embedding_dim, params.rope_ti_scale, + params.rope_scaling_factor, params.llama3_inv_scaling_factor, params.llama3_alpha, params.llama3_beta, + params.yarn_ramp_inv_factor_div_2, + params.yarn_ramp_inv_factor_mul_min, + params.yarn_inv_scaling_factor, + params.attention_scaling, 0, // stride b params.stride / params.size_per_head, // stride c 1, // stride h @@ -75,9 +85,14 @@ void invokeFlattenKV_v2(T* k, const float* rope_base, int rope_dim, float rope_ti_scale, + float rope_scaling_factor, float llama3_inv_scaling_factor, float llama3_alpha, float llama3_beta, + float yarn_ramp_inv_factor_div_2, + float yarn_ramp_inv_factor_mul_min, + float yarn_inv_scaling_factor, + float attention_scaling, int64_t stride_b, int64_t stride_c, int64_t stride_h, @@ -104,9 +119,14 @@ void invokeFlattenKV_v2_(const AttentionParams& params, int sum_k_len) nullptr, // params.rope_theta, params.rotary_embedding_dim, params.rope_ti_scale, + params.rope_scaling_factor, params.llama3_inv_scaling_factor, params.llama3_alpha, params.llama3_beta, + params.yarn_ramp_inv_factor_div_2, + params.yarn_ramp_inv_factor_mul_min, + params.yarn_inv_scaling_factor, + params.attention_scaling, 0, 1, 2 * sum_k_len, diff --git a/src/turbomind/kernels/attention/rotary_embedding.h b/src/turbomind/kernels/attention/rotary_embedding.h index 8bc54ad268..8e09da22cd 100644 --- a/src/turbomind/kernels/attention/rotary_embedding.h +++ b/src/turbomind/kernels/attention/rotary_embedding.h @@ -74,17 +74,24 @@ struct FastRoPE { Array inv_freq_; bool is_valid_; + float attention_scaling_; __device__ FastRoPE(int idx, D dims, float base, float ti_scale, + float factor, float llama3_inv_scaling_factor, float llama3_alpha, float llama3_beta, + float yarn_ramp_inv_factor_div_2, + float yarn_ramp_inv_factor_mul_min, + float yarn_inv_scaling_factor, + float attention_scaling, std::integral_constant) { - is_valid_ = idx < dims; + is_valid_ = idx < dims; + attention_scaling_ = attention_scaling; /// TODO: Take this away from device code const float scale_factor = -log2f(base) / dims; PRAGMA_UNROLL @@ -110,6 +117,15 @@ struct FastRoPE { inv_freq_[i / 2] = (1 - smooth) * freq * llama3_inv_scaling_factor + smooth * freq; } } + if (yarn_ramp_inv_factor_div_2) { + PRAGMA_UNROLL + for (int i = 0; i < N; i += 2) { + auto freq = inv_freq_[i / 2]; + float alpha = (idx + i) * yarn_ramp_inv_factor_div_2 - yarn_ramp_inv_factor_mul_min; + alpha = fmaxf(0.f, fminf(1.f, alpha)); + inv_freq_[i / 2] = freq - freq * alpha * yarn_inv_scaling_factor; + } + } } template @@ -119,6 +135,8 @@ struct FastRoPE { for (int i = 0; i < N; i += 2) { float c, s; sincosf(timestep * inv_freq_[i / 2], &s, &c); + s *= attention_scaling_; + c *= attention_scaling_; float tmp0 = c * (float)x[i] - s * (float)x[i + 1]; float tmp1 = c * (float)x[i + 1] + s * (float)x[i]; if (is_valid_) { diff --git a/src/turbomind/kernels/attention/test_attention.cu b/src/turbomind/kernels/attention/test_attention.cu index 4496b8b4a1..c6d7b40637 100644 --- a/src/turbomind/kernels/attention/test_attention.cu +++ b/src/turbomind/kernels/attention/test_attention.cu @@ -150,7 +150,12 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, rope_dim, 1., 0., + 0., + 1.0, 1.0, + 0.0, + 0.0, + 0.0, 1.0, 2 * head_num * seq_len, 0, @@ -179,8 +184,13 @@ void TestBlocks(const thrust::universal_vector& k_cache, // [B, H, S, rope_dim, 1., 0., + 0., 1.0, 1.0, + 0.0, + 0.0, + 0.0, + 1.0, 2 * head_num * seq_len, 0, seq_len, @@ -538,7 +548,12 @@ int test_attention() kRoPEDim, 1., 0., + 0., + 1.0, 1.0, + 0.0, + 0.0, + 0.0, 1.0, KvHeadNum * kContextLen, 0, diff --git a/src/turbomind/models/llama/llama_params.h b/src/turbomind/models/llama/llama_params.h index 4cb9e27e13..1c039ca66a 100644 --- a/src/turbomind/models/llama/llama_params.h +++ b/src/turbomind/models/llama/llama_params.h @@ -45,6 +45,9 @@ struct AttentionParam { float rope_scaling_factor; float low_freq_factor; float high_freq_factor; + float attention_factor; + float beta_fast; + float beta_slow; bool use_dynamic_ntk; bool use_logn_attn; int cache_block_seq_len; diff --git a/src/turbomind/models/llama/unified_attention_layer.cc b/src/turbomind/models/llama/unified_attention_layer.cc index f38151a1a5..2f99b0c2ce 100644 --- a/src/turbomind/models/llama/unified_attention_layer.cc +++ b/src/turbomind/models/llama/unified_attention_layer.cc @@ -296,6 +296,8 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa params.rotary_embedding_dim = param_.rotary_embedding_dim; params.rotary_embedding_base = param_.rotary_embedding_base; params.max_position_embeddings = param_.max_position_embeddings; + params.rope_scaling_factor = param_.rope_scaling_factor; + params.attention_scaling = 1.0; params.rope_ti_scale = 1.f; if (param_.rope_scaling_type == "linear") { params.rope_ti_scale /= param_.rope_scaling_factor; @@ -307,6 +309,34 @@ inline void UnifiedAttentionLayer::forward(TensorMap* outputs, const TensorMa params.llama3_alpha = param_.original_max_position_embeddings / (2 * PI) * inv_diff_freq_factor; params.llama3_beta = param_.low_freq_factor * inv_diff_freq_factor; } + if (param_.rope_scaling_type == "yarn") { + const double PI = 3.14159265358979323846; + auto find_correction_dim = [&](float num_rotations) { + return (param_.rotary_embedding_dim + * std::log(param_.max_position_embeddings / (num_rotations * 2 * PI))) + / (2 * std::log(param_.rotary_embedding_base)); + }; + auto find_correction_range = [&](float low_rot, float high_rot, float& low, float& high) { + low = std::floor(find_correction_dim(low_rot)); + high = std::ceil(find_correction_dim(high_rot)); + low = std::max(low, 0.f); + high = std::min(high, param_.rotary_embedding_dim - 1.f); + }; + float low, high; + find_correction_range(param_.beta_fast, param_.beta_slow, low, high); + if (low == high) { + high += 0.01f; + } + params.yarn_ramp_inv_factor_div_2 = 1.0 / (high - low) / 2.0; + params.yarn_ramp_inv_factor_mul_min = 1.0 / (high - low) * low; + params.yarn_inv_scaling_factor = (1 - 1.0 / param_.rope_scaling_factor); + if (param_.attention_factor < 0) { + params.attention_scaling = 0.1 * std::log(param_.rope_scaling_factor) + 1.0; + } + else { + params.attention_scaling = param_.attention_factor; + } + } params.use_logn_attn = param_.use_logn_attn; diff --git a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc index 44f73370da..5fbd4287a8 100644 --- a/src/turbomind/triton_backend/llama/LlamaTritonModel.cc +++ b/src/turbomind/triton_backend/llama/LlamaTritonModel.cc @@ -265,6 +265,9 @@ LlamaTritonModel::LlamaTritonModel(size_t tensor_para_size, // rotary embedding parameters attn_param_.rotary_embedding_dim = attention_reader["rotary_embedding"].as(); attn_param_.rotary_embedding_base = attention_reader["rope_theta"].as(10000.0f); + attn_param_.attention_factor = attention_reader["attention_factor"].as(-1.f); + attn_param_.beta_fast = attention_reader["beta_fast"].as(32.f); + attn_param_.beta_slow = attention_reader["beta_slow"].as(1.f); attn_param_.rope_scaling_type = attention_reader["rope_scaling_type"].as(""); attn_param_.rope_scaling_factor = attention_reader["rope_scaling_factor"].as(0.f); attn_param_.low_freq_factor = attention_reader["low_freq_factor"].as(1.0);