Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support image_embs input #799

Merged
merged 11 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions lmdeploy/serve/turbomind/triton_models/interactive/config.pbtxt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,31 @@ input [
data_type: TYPE_UINT32
dims: [ -1 ]
},
{
name: "input_embeddings"
data_type: TYPE_INT8
dims: [ -1 ]
optional: true
},
{
name: "embedding_counts"
data_type: TYPE_UINT32
dims: [ 1 ]
reshape: { shape: [ ] }
optional: true
},
{
name: "embedding_begins"
data_type: TYPE_UINT32
dims: [ -1 ]
optional: true
},
{
name: "embedding_ends"
data_type: TYPE_UINT32
dims: [ -1 ]
optional: true
},
{
name: "step"
data_type: TYPE_INT32
Expand Down
43 changes: 43 additions & 0 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,9 @@ async def async_stream_infer(self, *args, **kwargs):
def stream_infer(self,
session_id,
input_ids,
input_embeddings=None,
embedding_begins=None,
embedding_ends=None,
request_output_len: int = 512,
sequence_start: bool = True,
sequence_end: bool = False,
Expand All @@ -476,6 +479,9 @@ def stream_infer(self,
Args:
session_id (int): the id of a session
input_ids (numpy.ndarray): the token ids of a prompt
input_embeddings (List[numpy.ndarray]): embeddings features
embedding_begins (List[int]): the begin offsets of input_embeddings
embedding_ends (List[int]): the end offset of input_embeddings
request_output_len (int): the max number of to-be-generated tokens
sequence_start (bool): indicator for starting a sequence
sequence_end (bool): indicator for ending a sequence
Expand Down Expand Up @@ -544,6 +550,43 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
CORRID=np.array(session_id, dtype=np.uint64),
STOP=_broadcast_np((1 if stop else 0), np.int32))

if input_embeddings is not None:
assert len(input_embeddings) == len(embedding_begins) == len(
embedding_ends)
if isinstance(embedding_begins[0], int):
embedding_begins = [embedding_begins]
embedding_ends = [embedding_ends]
input_embeddings = [input_embeddings]
# convert to lookup table type
if self.tm_model.config.weight_type == 'fp32':
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weight_type may be 'int4'.

Copy link
Collaborator Author

@irexyc irexyc Dec 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int4,lookup table也是half的吧

input_embeddings = [[x.astype(np.float32) for x in y]
for y in input_embeddings]
elif self.tm_model.config.weight_type == 'bf16':
input_embeddings = [[
torch.from_numpy(x).bfloat16().view(torch.half).numpy()
for x in y
] for y in input_embeddings]
else:
input_embeddings = [[x.astype(np.float16) for x in y]
for y in input_embeddings]

embedding_counts = torch.IntTensor(
[len(embs) for embs in input_embeddings])
input_embeddings = [[torch.from_numpy(x).squeeze() for x in y]
for y in input_embeddings]
input_embeddings = [torch.cat(x) for x in input_embeddings]
input_embeddings = pad_sequence(input_embeddings, batch_first=True)
input_embeddings = input_embeddings.reshape(
input_embeddings.shape[0], -1).view(torch.int8)
embedding_begins = [torch.IntTensor(x) for x in embedding_begins]
embedding_begins = pad_sequence(embedding_begins, batch_first=True)
embedding_ends = [torch.IntTensor(x) for x in embedding_ends]
embedding_ends = pad_sequence(embedding_ends, batch_first=True)
inputs['input_embeddings'] = input_embeddings
inputs['embedding_counts'] = embedding_counts
inputs['embedding_begins'] = embedding_begins
inputs['embedding_ends'] = embedding_ends

if ignore_eos:
stop_words = None
bad_words = torch.tensor([[[self.eos_id], [1]]], dtype=torch.int32)
Expand Down
71 changes: 70 additions & 1 deletion src/turbomind/models/llama/LlamaBatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,22 @@ void ClearState(BatchState& s)
s.size = s.active_size = 0;
}

void DropEmbeddings(const Sequence& seq)
{
int seq_len = seq.tokens.size();
int num_emb = seq.input_embeddings.size();
size_t sz = num_emb;
for (; sz >= 1; sz--) {
if (seq.embedding_ends[sz - 1] <= seq_len) {
break;
}
}
// should we keep part of embedding?
seq.input_embeddings.resize(sz);
seq.embedding_begins.resize(sz);
seq.embedding_ends.resize(sz);
}

template<typename T>
void LlamaBatch<T>::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs)
{
Expand Down Expand Up @@ -234,6 +250,7 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
if (step <= seq.tokens.size()) {
seq.tokens.resize(step);
seq.cache_len = std::min(seq.cache_len, step);
DropEmbeddings(seq);
}
else if (rank_ == 0) {
TM_LOG_WARNING(
Expand All @@ -258,6 +275,53 @@ void LlamaBatch<T>::ProcessInferRequests(const Requests& requests)
output_ids = Copy(input_ids, input_length, output_ids);
}

// copy input embeddings
if (r->inputs[rank_].isExist("embedding_counts")) {
int emb_count = r->inputs[rank_].getVal<int>("embedding_counts");
const auto emb_tensor = r->inputs[rank_].at("input_embeddings");
const auto begin_tensor = r->inputs[rank_].at("embedding_begins");
const auto end_tensor = r->inputs[rank_].at("embedding_ends");
irexyc marked this conversation as resolved.
Show resolved Hide resolved
const int* begin = begin_tensor.getPtr<int>();
const int* end = end_tensor.getPtr<int>();

auto check_embeddings = [&]() {
if (emb_count <= 0 || begin_tensor.shape != end_tensor.shape || emb_tensor.shape.size() != 2) {
return false;
}
int emb_len = 0;
for (size_t i = 0; i < emb_count; i++) {
emb_len += (end[i] - begin[i]);
if (begin[i] < 0 || end[i] < 0 || begin[i] >= end[i] || end[i] > input_length
|| emb_len > input_length
|| emb_len * model_->hidden_units_ * sizeof(T) > emb_tensor.shape[1]) {
return false;
}
}
return true;
};

if (!check_embeddings()) {
TM_LOG_WARNING("[ImageFeature] Skip invalid input embeddings, id = %ld, input_length = %d, "
"input embeddings = %s, embedding_counts = %d, begins = %s, ends = %s",
(long)seq.id,
input_length,
emb_tensor.toString().c_str(),
emb_count,
begin_tensor.toString().c_str(),
end_tensor.toString().c_str());
}
else {
char* emb_tensor_ptr = emb_tensor.getPtr<char>();
for (size_t i = 0; i < emb_count; i++) {
size_t count = (end[i] - begin[i]) * model_->hidden_units_ * sizeof(T);
seq.input_embeddings.emplace_back((std::byte*)emb_tensor_ptr, (std::byte*)(emb_tensor_ptr + count));
seq.embedding_begins.emplace_back(begin[i] + seq.tokens.size());
seq.embedding_ends.emplace_back(end[i] + seq.tokens.size());
emb_tensor_ptr += count;
}
}
}

// total context length (history + input)
state.h_context_length[idx] = output_ids - output_ids_base;
state.h_finished[idx] = false;
Expand Down Expand Up @@ -1422,6 +1486,8 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
std::vector<int> decode_indices{};
std::vector<int> decode_lengths{};

std::vector<const Sequence*> sequences;

BatchedCopy batched_copy;
for (int i = first; i < last; ++i) {
input_ids = batched_copy.Add(input_d_ptrs[i], h_input_length_buf_[i], input_ids);
Expand All @@ -1438,6 +1504,7 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
}
decode_indices.push_back(i);
decode_lengths.push_back(h_input_length_buf_[i]);
sequences.push_back(state_->sequences[i]);
max_input_len = std::max(max_input_len, h_input_length_buf_[i]);
}
int token_count = input_ids - context_decoder_ids_buf_;
Expand Down Expand Up @@ -1484,7 +1551,9 @@ bool LlamaBatch<T>::Forward(GenerationState& g, int iter)
pf_batch_size,
max_input_len,
max_context_cnts[p],
max_context_cnts[p]);
max_context_cnts[p],
h_input_length_buf_ + first,
sequences.data());

if (iter == 0) {
// compute logits of inputs if requested
Expand Down
75 changes: 53 additions & 22 deletions src/turbomind/models/llama/LlamaV2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,28 +166,56 @@ void LlamaV2<T>::embeddingLookup(T* embeddings, const int* token_ids_buf, int ba
}

template<typename T>
void LlamaV2<T>::forwardUnified(T* out,
T* decoder_output,
T* decoder_input,
void** k_block_ptrs,
void** v_block_ptrs,
const int* input_ids,
const int* cu_block_cnts,
const float* rope_theta,
const bool* dc_finished,
const int* pf_input_length,
const int* pf_context_length,
T** pf_tmp_k_ptrs,
T** pf_tmp_v_ptrs,
size_t token_num,
int dc_batch_size,
int dc_step,
int dc_sum_seq_len,
int dc_max_seq_len,
int pf_batch_size,
int pf_max_input_len,
int pf_max_context_len,
int pf_session_len)
void LlamaV2<T>::updateEmbedding(T* decoder_input, const int bsz, const int* h_input_length, const Sequence** sequences)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);

for (int i = 0; i < bsz; i++) {
const auto& seq = *sequences[i];
const auto& embeddings = seq.input_embeddings;
const auto& begins = seq.embedding_begins;
const auto& ends = seq.embedding_ends;
for (int j = embeddings.size() - 1; j >= 0; j--) {
if (ends[j] <= seq.cache_len) {
break;
}
int off_dst = std::max(0, begins[j] - seq.cache_len);
int off_src = std::max(0, seq.cache_len - begins[j]);
size_t byte_size = (ends[j] - begins[j]) * hidden_units_ * sizeof(T);
T* dst_ptr = decoder_input + off_dst * hidden_units_;
auto src_ptr = embeddings[j].data() + off_src * hidden_units_ * sizeof(T);
cudaMemcpyAsync(dst_ptr, src_ptr, byte_size, cudaMemcpyDefault, stream_);
}
decoder_input += h_input_length[i] * hidden_units_;
}
sync_check_cuda_error();
}

template<typename T>
void LlamaV2<T>::forwardUnified(T* out,
T* decoder_output,
T* decoder_input,
void** k_block_ptrs,
void** v_block_ptrs,
const int* input_ids,
const int* cu_block_cnts,
const float* rope_theta,
const bool* dc_finished,
const int* pf_input_length,
const int* pf_context_length,
T** pf_tmp_k_ptrs,
T** pf_tmp_v_ptrs,
size_t token_num,
int dc_batch_size,
int dc_step,
int dc_sum_seq_len,
int dc_max_seq_len,
int pf_batch_size,
int pf_max_input_len,
int pf_max_context_len,
int pf_session_len,
const int* h_input_length,
const Sequence** sequences)
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);

Expand All @@ -203,6 +231,9 @@ void LlamaV2<T>::forwardUnified(T* out,
1,
hidden_units_,
stream_);

updateEmbedding(decoder_input, dc_batch_size + pf_batch_size, h_input_length, sequences);

sync_check_cuda_error();

const auto dtype = getTensorType<T>();
Expand Down
48 changes: 26 additions & 22 deletions src/turbomind/models/llama/LlamaV2.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,28 +107,32 @@ class LlamaV2 {

void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);

void forwardUnified(T* out,
T* decoder_output,
T* decoder_input,
void** k_block_ptrs,
void** v_block_ptrs,
const int* input_ids,
const int* cu_block_cnts,
const float* rope_theta,
const bool* dc_finished,
const int* pf_input_length,
const int* pf_context_length,
T** pf_tmp_k_ptrs,
T** pf_tmp_v_ptrs,
size_t token_num,
int dc_batch_size,
int dc_step,
int dc_sum_seq_len,
int dc_max_seq_len,
int pf_batch_size,
int pf_max_input_len,
int pf_max_context_len,
int pf_session_len);
void updateEmbedding(T* decoder_input, const int bsz, const int* h_input_length, const Sequence** sequences);

void forwardUnified(T* out,
T* decoder_output,
T* decoder_input,
void** k_block_ptrs,
void** v_block_ptrs,
const int* input_ids,
const int* cu_block_cnts,
const float* rope_theta,
const bool* dc_finished,
const int* pf_input_length,
const int* pf_context_length,
T** pf_tmp_k_ptrs,
T** pf_tmp_v_ptrs,
size_t token_num,
int dc_batch_size,
int dc_step,
int dc_sum_seq_len,
int dc_max_seq_len,
int pf_batch_size,
int pf_max_input_len,
int pf_max_context_len,
int pf_session_len,
const int* h_input_length,
const Sequence** sequences);

void postDecodeEmbedding(float* logits, float* local_logits, const T* decoder_output, int batch_size);

Expand Down
5 changes: 5 additions & 0 deletions src/turbomind/models/llama/SequenceManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ struct Sequence {

mutable float rope_theta = 0.f;

// embedding data
mutable std::vector<std::vector<std::byte>> input_embeddings;
mutable std::vector<int> embedding_begins;
mutable std::vector<int> embedding_ends;

explicit Sequence(uint64_t _id): id(_id) {}

friend std::ostream& operator<<(std::ostream& os, const Sequence& seq);
Expand Down
Loading