Skip to content

Commit

Permalink
Unify prefill & decode passes (#775)
Browse files Browse the repository at this point in the history
* Unify prefill and decode passes

* dynamic split-fuse

* refactor

* correct input count calculation

* remove unused

* lint

* lint

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build

* fix msvc build
  • Loading branch information
lzhangzz authored Dec 4, 2023
1 parent 2ba9082 commit 7f943a2
Show file tree
Hide file tree
Showing 37 changed files with 2,273 additions and 2,960 deletions.
16 changes: 11 additions & 5 deletions src/turbomind/kernels/bert_preprocess_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ __global__ void getPaddingOffsetAndCuSeqLensKernel(size_t* h_valid_word_num,
if (calculate_cu_seqlens) {
cu_seqlens[batch_size] = total_seq_len;
}
h_valid_word_num[0] = (size_t)total_seq_len;
if (h_valid_word_num) {
h_valid_word_num[0] = (size_t)total_seq_len;
}
}

void invokeGetPaddingOffsetAndCuSeqLens(size_t* h_pinned_token_num,
Expand All @@ -60,15 +62,19 @@ void invokeGetPaddingOffsetAndCuSeqLens(size_t* h_pinned_token_num,
const int max_seq_len,
cudaStream_t stream)
{
h_pinned_token_num[0] = 0;
if (h_pinned_token_num) {
h_pinned_token_num[0] = 0;
}
getPaddingOffsetAndCuSeqLensKernel<<<1, 1, 0, stream>>>(
h_pinned_token_num, tmp_mask_offset, cu_seqlens, sequence_lengths, batch_size, max_seq_len);
if (h_pinned_token_num) {
#ifdef _MSC_VER
cudaStreamSynchronize(stream);
cudaStreamSynchronize(stream);
#else
while (((volatile size_t*)h_pinned_token_num)[0] == 0) {};
while (((volatile size_t*)h_pinned_token_num)[0] == 0) {};
#endif
h_token_num[0] = h_pinned_token_num[0];
h_token_num[0] = h_pinned_token_num[0];
}
sync_check_cuda_error();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,11 @@ struct DecoderMultiHeadAttentionParams {
T* __restrict__ v_bias;

// sequence-level buffers
const int* __restrict__ per_sample_length;
const int* __restrict__ context_length;
const bool* __restrict__ finished;
const float* __restrict__ rope_theta;

// kv cache
void** __restrict__ per_sample_k_cache; // [H, S, D]
void** __restrict__ per_sample_v_cache; // [H, S, D]
size_t layer_offset;

/// cache layout M,[N,H,x,D]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ struct DecoderMultiHeadAttentionKernel {
kv_head_idx_ = head_idx_ / gqa_group_size;
is_gqa_leader_ = head_idx_ % gqa_group_size == 0;

timestep_ = params_.per_sample_length[batch_idx_];
timestep_ = params_.context_length[batch_idx_] - 1;

if (kSplitK && params.max_split_k > 1) {
const int slice_count = (timestep_ + kSliceLen - 1) / kSliceLen;
Expand Down Expand Up @@ -815,7 +815,7 @@ struct DecoderMultiHeadAttentionKernel {
{
const int batch_idx = get_batch_idx();
const int head_idx = get_head_idx();
const int timestep = params.per_sample_length[batch_idx];
const int timestep = params.context_length[batch_idx] - 1;
const int max_split_k = params.max_split_k;
const int slice_count = get_slice_count(timestep);
const int slice_per_split = (slice_count + max_split_k - 1) / max_split_k;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void TestBlocks(thrust::universal_vector<half>& linear, // linear data
std::mt19937 g(rd());
std::shuffle(idxs.begin(), idxs.end(), g);

for (int i = 0; i < idxs.size(); ++i) {
for (size_t i = 0; i < idxs.size(); ++i) {
ptrs[i] = blocks.data().get() + idxs[i] * head_num * block_size * head_dim;
}

Expand Down Expand Up @@ -115,8 +115,8 @@ int main(int argc, char* argv[])
constexpr int KvHeadNum = 32;
constexpr int kBatchSize = 1;
// constexpr int kContextLen = 7306;
constexpr int kContextLen = 1024;
constexpr int kSequenceLen = kContextLen + 1;
constexpr int kSequenceLen = 1024;
constexpr int kContextLen = kSequenceLen + 1;
constexpr int kBlockSz = 128;
constexpr int kTestIter = 10;
constexpr int kMaxSplitK = 1;
Expand All @@ -126,9 +126,10 @@ int main(int argc, char* argv[])
thrust::universal_vector<half> output(kBatchSize * kHeadNum * kHeadDim);
thrust::universal_vector<half> qkv(kBatchSize * (kHeadNum + KvHeadNum * 2) * kHeadDim);
thrust::universal_vector<bool> finished(kBatchSize);
thrust::universal_vector<half> k_cache(kBatchSize * kSequenceLen * KvHeadNum * kHeadDim);
thrust::universal_vector<half> v_cache(kBatchSize * kSequenceLen * KvHeadNum * kHeadDim);
thrust::universal_vector<int> sequence_lengths(kBatchSize);
thrust::universal_vector<half> k_cache(kBatchSize * kContextLen * KvHeadNum * kHeadDim);
thrust::universal_vector<half> v_cache(kBatchSize * kContextLen * KvHeadNum * kHeadDim);
thrust::universal_vector<int> context_length(kBatchSize);
thrust::universal_vector<int> sequence_length(kBatchSize);
thrust::universal_vector<void*> k_cache_ptrs(kBatchSize);
thrust::universal_vector<void*> v_cache_ptrs(kBatchSize);

Expand All @@ -138,23 +139,23 @@ int main(int argc, char* argv[])

rng.GenerateNormal(qkv.data().get(), qkv.size(), 1.f, 0.f);

if (kContextLen) {
rng.GenerateNormal(k_cache.data().get(), kBatchSize * KvHeadNum * kSequenceLen * kHeadDim);
rng.GenerateNormal(v_cache.data().get(), kBatchSize * KvHeadNum * kSequenceLen * kHeadDim);
if (kSequenceLen) {
rng.GenerateNormal(k_cache.data().get(), kBatchSize * KvHeadNum * kContextLen * kHeadDim);
rng.GenerateNormal(v_cache.data().get(), kBatchSize * KvHeadNum * kContextLen * kHeadDim);

cudaMemset2DAsync(k_cache.data().get() + kContextLen * kHeadDim,
sizeof(half) * kSequenceLen * kHeadDim,
cudaMemset2DAsync(k_cache.data().get() + kSequenceLen * kHeadDim,
sizeof(half) * kContextLen * kHeadDim,
0,
sizeof(half) * kHeadDim,
kBatchSize * KvHeadNum);
if constexpr (0) {
for (int b = 0; b < kBatchSize; ++b) {
for (int h = 0; h < KvHeadNum; ++h) {
for (int s = 0; s < kSequenceLen; ++s) {
for (int s = 0; s < kContextLen; ++s) {
for (int d = 0; d < kHeadDim; ++d) {
std::cout << std::setw(7) << std::setprecision(4) << std::fixed
<< (float)k_cache[b * KvHeadNum * kSequenceLen * kHeadDim
+ h * kSequenceLen * kHeadDim + s * kHeadDim + d]
<< (float)k_cache[b * KvHeadNum * kContextLen * kHeadDim
+ h * kContextLen * kHeadDim + s * kHeadDim + d]
<< " ";
}
std::cout << "\n";
Expand All @@ -166,8 +167,8 @@ int main(int argc, char* argv[])
std::exit(0);
}

cudaMemset2DAsync(v_cache.data().get() + kContextLen * kHeadDim,
sizeof(half) * kSequenceLen * kHeadDim,
cudaMemset2DAsync(v_cache.data().get() + kSequenceLen * kHeadDim,
sizeof(half) * kContextLen * kHeadDim,
0,
sizeof(half) * kHeadDim,
kBatchSize * KvHeadNum);
Expand All @@ -193,7 +194,8 @@ int main(int argc, char* argv[])
cudaDeviceSynchronize();

for (int i = 0; i < kBatchSize; ++i) {
sequence_lengths[i] = kContextLen;
sequence_length[i] = kSequenceLen;
context_length[i] = kContextLen;
k_cache_ptrs[i] = k_cache.data().get() + i * k_cache.size() / kBatchSize;
v_cache_ptrs[i] = v_cache.data().get() + i * v_cache.size() / kBatchSize;
k_cache_ref_ptrs[i] = k_cache_ref.data().get() + i * k_cache_ref.size() / kBatchSize;
Expand All @@ -212,19 +214,17 @@ int main(int argc, char* argv[])
params.stride = (kHeadNum + 2 * KvHeadNum) * kHeadDim;

params.batch_size = kBatchSize;
params.max_seq_len = kContextLen + 1;
params.max_seq_len = kSequenceLen;
params.cu_block_cnts = cu_block_cnts.data().get();

printf("%d %d\n", (int)k_ptrs.size(), (int)v_ptrs.size());
params.k_cache_block_ptrs = (void**)k_ptrs.data().get();
params.v_cache_block_ptrs = (void**)v_ptrs.data().get();
params.kv_cache_block_size = kBlockSz;

params.finished = finished.data().get();
params.per_sample_length = sequence_lengths.data().get();
params.per_sample_k_cache = k_cache_ref_ptrs.data().get();
params.per_sample_v_cache = v_cache_ref_ptrs.data().get();
params.layer_offset = 0;
params.finished = finished.data().get();
params.context_length = context_length.data().get();
params.layer_offset = 0;

params.num_heads = kHeadNum;
params.num_kv_heads = KvHeadNum;
Expand All @@ -238,8 +238,16 @@ int main(int argc, char* argv[])
params.partial_M = partial_M.data().get();
params.partial_O = partial_O.data().get();

params.max_split_k = kMaxSplitK;
params.arch = 80;

for (int i = 0; i < kTestIter; ++i) {
mmha_ft_reference(params, cudaStream_t{});
mmha_ft_reference(params,
(half**)k_cache_ref_ptrs.data().get(),
(half**)v_cache_ref_ptrs.data().get(),
sequence_length.data().get(),
kContextLen,
cudaStream_t{});
}

cudaDeviceSynchronize();
Expand All @@ -249,14 +257,7 @@ int main(int argc, char* argv[])
}
std::cout << "---------------------------------------------------\n";

params.out = output.data().get();
params.per_sample_k_cache = k_cache_ptrs.data().get();
params.per_sample_v_cache = v_cache_ptrs.data().get();

params.max_split_k = kMaxSplitK;
params.max_seq_len = kContextLen;

params.arch = 80;
params.out = output.data().get();

std::vector<thrust::universal_vector<half>> outputs;

Expand All @@ -271,30 +272,25 @@ int main(int argc, char* argv[])
}
}

thrust::universal_vector<int> seq_lens(kBatchSize);
for (auto& x : seq_lens) {
x = kContextLen + 1;
}

if (1) {
ConvertBlocksToLinear((const half**)k_ptrs.data().get(),
k_cache.data().get(),
cu_block_cnts.data().get(),
seq_lens.data().get(),
context_length.data().get(),
0,
kBlockSz,
kSequenceLen,
kContextLen,
KvHeadNum,
kHeadDim,
kBatchSize,
0);
ConvertBlocksToLinear((const half**)v_ptrs.data().get(),
v_cache.data().get(),
cu_block_cnts.data().get(),
seq_lens.data().get(),
context_length.data().get(),
0,
kBlockSz,
kSequenceLen,
kContextLen,
KvHeadNum,
kHeadDim,
kBatchSize,
Expand All @@ -316,15 +312,15 @@ int main(int argc, char* argv[])

// [H, S, D]

Compare(k_cache.data().get() + kContextLen * kHeadDim,
k_cache_ref.data().get() + kContextLen * kHeadDim,
kSequenceLen * kHeadDim,
Compare(k_cache.data().get() + kSequenceLen * kHeadDim,
k_cache_ref.data().get() + kSequenceLen * kHeadDim,
kContextLen * kHeadDim,
kHeadDim,
KvHeadNum);

Compare(v_cache.data().get() + kContextLen * kHeadDim,
v_cache_ref.data().get() + kContextLen * kHeadDim,
kSequenceLen * kHeadDim,
Compare(v_cache.data().get() + kSequenceLen * kHeadDim,
v_cache_ref.data().get() + kSequenceLen * kHeadDim,
kContextLen * kHeadDim,
kHeadDim,
KvHeadNum);

Expand Down
24 changes: 17 additions & 7 deletions src/turbomind/kernels/decoder_multihead_attention/test_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,12 @@ struct SATypeConverter<half> {
};

template<typename T>
void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t st)
void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p,
T** per_sample_k_cache,
T** per_sample_v_cache,
const int* sequence_length,
int max_memory_len,
cudaStream_t st)
{
using DataType = typename SATypeConverter<T>::Type;

Expand All @@ -204,18 +209,18 @@ void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t
params.stride = p.stride;
params.finished = (bool*)p.finished;

params.k_cache_per_sample = reinterpret_cast<DataType**>(p.per_sample_k_cache);
params.v_cache_per_sample = reinterpret_cast<DataType**>(p.per_sample_v_cache);
params.k_cache_per_sample = reinterpret_cast<DataType**>(per_sample_k_cache);
params.v_cache_per_sample = reinterpret_cast<DataType**>(per_sample_v_cache);
params.kv_cache_per_sample_offset = p.layer_offset;
params.batch_size = p.batch_size;
params.beam_width = 1;
params.memory_max_len = p.max_seq_len;
params.memory_max_len = max_memory_len;
params.prefix_prompt_lengths = 0;
params.max_prefix_prompt_length = 0;
params.length_per_sample = p.per_sample_length; // max_input_length + current output length
params.length_per_sample = sequence_length; // max_input_length + current output length

for (int i = 0; i < p.batch_size; ++i) {
params.timestep = std::max(p.per_sample_length[i], params.timestep);
params.timestep = std::max(sequence_length[i], params.timestep);
}

std::cout << "timestep = " << params.timestep << "\n";
Expand All @@ -237,6 +242,11 @@ void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& p, cudaStream_t
masked_multihead_attention(params, st);
}

template void mmha_ft_reference(const DecoderMultiHeadAttentionParams<half>& params, cudaStream_t st);
template void mmha_ft_reference(const DecoderMultiHeadAttentionParams<half>& params,
half** per_sample_k_cache,
half** per_sample_v_cache,
const int* sequence_length,
int max_memory_len,
cudaStream_t st);

} // namespace turbomind
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ class RNG {
};

template<typename T>
void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& params, cudaStream_t st);
void mmha_ft_reference(const DecoderMultiHeadAttentionParams<T>& params,
T** per_sample_k_cache,
T** per_sample_v_cache,
const int* sequence_length,
int max_memory_len,
cudaStream_t st);

} // namespace turbomind
16 changes: 11 additions & 5 deletions src/turbomind/models/llama/Barrier.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,11 @@ class Barrier {

class Barrier {
public:
Barrier(unsigned count)
Barrier(unsigned count): count_(count)
{
TM_LOG_INFO("Barrier(%d)", (int)count);
pthread_barrier_init(&barrier_, nullptr, count);
if (count_ > 1) {
pthread_barrier_init(&barrier_, nullptr, count);
}
}

Barrier(const Barrier&) = delete;
Expand All @@ -47,15 +48,20 @@ class Barrier {

void wait()
{
pthread_barrier_wait(&barrier_);
if (count_ > 1) {
pthread_barrier_wait(&barrier_);
}
}

~Barrier()
{
pthread_barrier_destroy(&barrier_);
if (count_ > 1) {
pthread_barrier_destroy(&barrier_);
}
}

private:
const int count_;
pthread_barrier_t barrier_{};
};

Expand Down
7 changes: 2 additions & 5 deletions src/turbomind/models/llama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,13 @@ find_package(CUDAToolkit REQUIRED)
add_library(Llama STATIC
LlamaV2.cc
LlamaBatch.cc
LlamaCacheManager.cc
BlockManager.cc
SequenceManager.cc
LlamaContextDecoder.cc
LlamaContextAttentionLayer.cc
LlamaDecoderSelfAttentionLayer.cc
LlamaDecoder.cc
LlamaWeight.cc
LlamaDecoderLayerWeight.cc
LlamaFfnLayer.cc
unified_decoder.cc
unified_attention_layer.cc
llama_kernels.cu
llama_decoder_kernels.cu
llama_utils.cu)
Expand Down
Loading

0 comments on commit 7f943a2

Please sign in to comment.