Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support eos_token list in turbomind #3044

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion lmdeploy/turbomind/deploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ModelConfig:
norm_eps: float = None
attn_bias: int = 0
start_id: int = None
end_id: int = None
end_id: List[int] = None
size_per_head: int = 128
group_size: int = 64
weight_type: str = None
Expand Down
13 changes: 12 additions & 1 deletion lmdeploy/turbomind/deploy/target_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,18 @@ def update_model_config(self):
"""Update `self.model_config` according to the input_model's
`tokenizer_info` and `model_info`"""
_, bos_id, eos_id = self.input_model_tokenizer_info

try:
from transformers import GenerationConfig
cfg = GenerationConfig.from_pretrained(self.input_model.model_path)
if isinstance(cfg.eos_token_id, int):
eos_id = [cfg.eos_token_id]
elif isinstance(cfg.eos_token_id, list):
eos_id = cfg.eos_token_id
elif cfg.eos_token_id is None:
eos_id = [eos_id]
except OSError:
if isinstance(eos_id, int):
eos_id = [eos_id]
final_cfg = config_to_dict(self.model_config)
final_cfg.update(dict(start_id=bos_id, end_id=eos_id))
final_cfg.update(self.input_model_info)
Expand Down
9 changes: 5 additions & 4 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self,
pass

self.session_len = self.config.session_len
self.eos_id = self.tokenizer.eos_token_id
self.eos_id = self.config.model_config.end_id

def _create_weight(self, model_comm):
"""Allocate weight buffer, load params if from_workspace."""
Expand Down Expand Up @@ -531,11 +531,12 @@ def prepare_inputs(self,
bad_words.extend(gen_config.bad_token_ids)
if gen_config.ignore_eos:
stop_words = None
bad_words.append(self.eos_id)
bad_words.extend(self.eos_id)
else:
stop_words = gen_config.stop_token_ids or []
if self.eos_id not in stop_words:
stop_words.append(self.eos_id)
for eos_id in self.eos_id:
if eos_id not in stop_words:
stop_words.append(eos_id)
stop_words = _construct_stop_or_bad_words(stop_words)
bad_words = _construct_stop_or_bad_words(bad_words)

Expand Down
27 changes: 16 additions & 11 deletions src/turbomind/kernels/sampling_penalty_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -579,51 +579,56 @@ template void invokeBatchApplyRepetitionPenalty(half* logits,
#endif
template<typename T>
__global__ void batchApplyMinLengthPenalty(T* logits,
const int batch_size,
const int* min_lengths,
const int* end_ids,
const int* __restrict__ end_ids,
const int end_ids_len,
const int* sequence_lengths,
const int max_input_length,
const int vocab_size_padded)
{
int bid = threadIdx.x + blockIdx.x * blockDim.x; // batch index
if (bid >= batch_size) {
return;
}
// In decoder, sequence_lengths means length of sequence that has kv cache already computed
if (sequence_lengths[bid] + 1 < min_lengths[bid]) {
T mask_val = (std::is_same<T, half>::value) ? -65504.0f : -FLT_MAX;
logits[bid * vocab_size_padded + end_ids[bid]] = mask_val;
T mask_val = (std::is_same<T, half>::value) ? -65504.0f : -FLT_MAX;
int end_id = __ldg(end_ids + blockIdx.y);
logits[bid * vocab_size_padded + end_id] = mask_val;
}
}

template<typename T>
void invokeMinLengthPenalty(T* logits,
const int* min_lengths,
const int* end_ids,
const int end_ids_len,
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
const int* sequnece_lengths,
const int max_input_length,
const int batch_size,
const int vocab_size_padded,
cudaStream_t stream)

{
const int block_size = min(batch_size, 1024);
const int grid_size = (batch_size + block_size - 1) / block_size;
batchApplyMinLengthPenalty<<<grid_size, block_size, 0, stream>>>(
logits, min_lengths, end_ids, sequnece_lengths, max_input_length, vocab_size_padded);
const dim3 block(std::min(batch_size, 1024));
const dim3 grid((batch_size + block.x - 1) / block.x, end_ids_len);
batchApplyMinLengthPenalty<<<grid, block, 0, stream>>>(
logits, batch_size, min_lengths, end_ids, end_ids_len, sequnece_lengths, vocab_size_padded);
}

template void invokeMinLengthPenalty(float* logits,
const int* min_lengths,
const int* end_ids,
const int end_ids_len,
const int* sequnece_lengths,
const int max_input_length,
const int batch_size,
const int vocab_size_padded,
cudaStream_t stream);
#if 0
template void invokeMinLengthPenalty(half* logits,
const int* min_lengths,
const int* end_ids,
const int end_ids_len,
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
const int* sequnece_lengths,
const int max_input_length,
const int batch_size,
const int vocab_size_padded,
cudaStream_t stream);
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/kernels/sampling_penalty_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ template<typename T>
void invokeMinLengthPenalty(T* logits,
const int* min_lengths,
const int* end_ids,
const int end_ids_len,
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
const int* sequnece_lengths,
const int max_input_length,
const int batch_size,
const int vocab_size_padded,
cudaStream_t stream);
Expand Down
5 changes: 4 additions & 1 deletion src/turbomind/layers/sampling_layers/LogitsProcessorLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,11 +163,14 @@ void LogitsProcessorLayer<T>::forward(TensorMap* output_tensors, TensorMap* inpu
});
if (invoke_min_length_penalty) {
FT_CHECK_WITH_INFO(input_tensors->isExist("end_id"), "Need end_id to apply min length penlaty");
const Tensor end_id = input_tensors->at("end_id");
FT_CHECK(end_id.shape.size() == 1);
const size_t end_id_len = end_id.shape[0];
lvhan028 marked this conversation as resolved.
Show resolved Hide resolved
invokeMinLengthPenalty(logits,
min_lengths_buf_,
input_tensors->getPtr<const int>("end_id"),
end_id_len,
output_tensors->getPtr<const int>("sequence_length"),
max_input_length,
batch_size,
args_.vocab_size_padded,
stream_);
Expand Down
13 changes: 7 additions & 6 deletions src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -777,8 +777,8 @@ void LlamaBatch<T>::AllocatePersistantBuffer(size_t max_batch_size, int cache_bl
d_curand_state_ =
(curandState_t*)allocator_->reMalloc(d_curand_state_, sizeof(curandState_t) * max_batch_size, true, false);

d_end_ids_buf_ = (int*)allocator_->reMalloc(d_end_ids_buf_, sizeof(int) * max_batch_size, false);
h_end_ids_buf_ = (int*)allocator_->reMalloc(h_end_ids_buf_, sizeof(int) * max_batch_size, false, true);
d_end_ids_buf_ = (int*)allocator_->reMalloc(d_end_ids_buf_, sizeof(int) * model_->end_id_.size(), false);
h_end_ids_buf_ = (int*)allocator_->reMalloc(h_end_ids_buf_, sizeof(int) * model_->end_id_.size(), false, true);

sampling_params_ = {
{"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_},
Expand Down Expand Up @@ -1151,9 +1151,10 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g)
}

// init for eos
std::fill_n(h_end_ids_buf_, batch_size, model_->end_id_);
Copy(h_end_ids_buf_, batch_size, d_end_ids_buf_);
inputs.insert({"end_id", {MEMORY_GPU, TYPE_INT32, {(size_t)batch_size}, d_end_ids_buf_}});
size_t eos_size = model_->end_id_.size();
std::copy_n(model_->end_id_.begin(), eos_size, h_end_ids_buf_);
Copy(h_end_ids_buf_, eos_size, d_end_ids_buf_);
inputs.insert({"end_id", {MEMORY_GPU, TYPE_INT32, {(size_t)eos_size}, d_end_ids_buf_}});

inputs_ = std::move(inputs);

Expand Down Expand Up @@ -1563,7 +1564,7 @@ void LlamaBatch<T>::InternalThreadEntry()
check_cuda_error(cudaSetDevice(device_id_));

// Initialize `AnomalyHandler`
AnomalyHandler::instance().Init(rank_, model_->vocab_size_padded_, model_->end_id_, max_batch_size_, stream_);
AnomalyHandler::instance().Init(rank_, model_->vocab_size_padded_, model_->end_id_[0], max_batch_size_, stream_);

// auto& request_queue = shared_state_->request_queue;
auto& infer_reqs = shared_state_->infer_reqs;
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ void LlamaV2<T>::dynamicDecode(int* token_ids,
{"sequence_limit_length", {MEMORY_GPU, TYPE_UINT32, {batch_size}, seq_limit_len}},
{"input_lengths", {MEMORY_GPU, TYPE_INT32, {batch_size, 1}, context_length}},
{"ite", {MEMORY_CPU, TYPE_UINT32, {1}, &ite}},
{"end_id", {MEMORY_GPU, TYPE_INT32, {batch_size}, end_ids}},
{"end_id", {MEMORY_GPU, TYPE_INT32, {end_id_.size()}, end_ids}},
{"local_batch_size", {MEMORY_CPU, TYPE_INT32, {1}, &local_batch_size}},
};

Expand Down
24 changes: 12 additions & 12 deletions src/turbomind/models/llama/LlamaV2.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,18 +105,18 @@ class LlamaV2 {
const AttentionParam attn_param_;
const LoraParam lora_param_;

const size_t head_num_;
const size_t size_per_head_;
const size_t hidden_units_;
const size_t layer_num_;
const size_t vocab_size_;
const size_t vocab_size_padded_;
const float rmsnorm_eps_;
const int start_id_;
const int end_id_;
const NcclParam tensor_para_;
const size_t local_head_num_;
const size_t local_kv_head_num_;
const size_t head_num_;
const size_t size_per_head_;
const size_t hidden_units_;
const size_t layer_num_;
const size_t vocab_size_;
const size_t vocab_size_padded_;
const float rmsnorm_eps_;
const int start_id_;
const std::vector<int> end_id_;
const NcclParam tensor_para_;
const size_t local_head_num_;
const size_t local_kv_head_num_;

const std::shared_ptr<LlamaWeight<T>> weights_{};

Expand Down
32 changes: 16 additions & 16 deletions src/turbomind/models/llama/llama_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,22 @@ struct MLAParam {
};

struct ModelParam {
size_t head_num;
size_t head_dim;
size_t kv_head_num;
size_t hidden_units;
size_t layer_num;
size_t vocab_size;
size_t embedding_size;
float norm_eps;
int quant_policy;
bool attn_bias;
WeightType weight_type;
int group_size;
int start_id;
int end_id;
MLAParam mla;
int tune_layer_num;
size_t head_num;
size_t head_dim;
size_t kv_head_num;
size_t hidden_units;
size_t layer_num;
size_t vocab_size;
size_t embedding_size;
float norm_eps;
int quant_policy;
bool attn_bias;
WeightType weight_type;
int group_size;
int start_id;
std::vector<int> end_id;
MLAParam mla;
int tune_layer_num;

std::vector<int> inter_size;
};
Expand Down
2 changes: 1 addition & 1 deletion src/turbomind/triton_backend/llama/LlamaTritonModel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t ten
model_param_.embedding_size = model_reader["embedding_size"].as<int>();
model_param_.norm_eps = model_reader["norm_eps"].as<float>();
model_param_.start_id = model_reader["start_id"].as<int>();
model_param_.end_id = model_reader["end_id"].as<int>();
model_param_.end_id = model_reader["end_id"].as<std::vector<int>>();
model_param_.tune_layer_num = model_reader["tune_layer_num"].as<int>(1);
model_param_.mla.q_lora_rank = model_reader["q_lora_rank"].as<int>();
model_param_.mla.kv_lora_rank = model_reader["kv_lora_rank"].as<int>();
Expand Down
Loading