From cea0524c74a678797bd61c26a5e367174dffa167 Mon Sep 17 00:00:00 2001 From: Li Zhang Date: Thu, 7 Nov 2024 07:32:53 +0000 Subject: [PATCH] eliminate `inter_size_` from ffn layer --- src/turbomind/models/llama/LlamaDenseWeight.h | 13 +++-- src/turbomind/models/llama/LlamaFfnLayer.cc | 47 ++++++++++--------- src/turbomind/models/llama/LlamaFfnLayer.h | 11 ++--- 3 files changed, 37 insertions(+), 34 deletions(-) diff --git a/src/turbomind/models/llama/LlamaDenseWeight.h b/src/turbomind/models/llama/LlamaDenseWeight.h index 382d0dfc65..21c1c666c3 100644 --- a/src/turbomind/models/llama/LlamaDenseWeight.h +++ b/src/turbomind/models/llama/LlamaDenseWeight.h @@ -145,24 +145,28 @@ struct LlamaFfnWeight { LlamaFfnWeight( size_t hidden_dim, size_t inter_size, size_t tp, WeightType weight_type, int group_size, bool fuse_silu_act) { + inter_size /= tp; + + this->inter_size = inter_size; + gating.input_dims = hidden_dim; - gating.output_dims = inter_size / tp; + gating.output_dims = inter_size; gating.type = weight_type; gating.group_size = group_size; intermediate.input_dims = hidden_dim; - intermediate.output_dims = inter_size / tp; + intermediate.output_dims = inter_size; intermediate.type = weight_type; intermediate.group_size = group_size; fused_gating_intermediate.input_dims = hidden_dim; - fused_gating_intermediate.output_dims = inter_size / tp * 2; + fused_gating_intermediate.output_dims = inter_size * 2; fused_gating_intermediate.type = weight_type; fused_gating_intermediate.group_size = group_size; is_fused_silu = fuse_silu_act; - output.input_dims = inter_size / tp; + output.input_dims = inter_size; output.output_dims = hidden_dim; output.type = weight_type; output.group_size = group_size; @@ -173,6 +177,7 @@ struct LlamaFfnWeight { LlamaDenseWeight output; LlamaDenseWeight fused_gating_intermediate; + int inter_size{}; bool is_fused_silu{}; }; diff --git a/src/turbomind/models/llama/LlamaFfnLayer.cc b/src/turbomind/models/llama/LlamaFfnLayer.cc index f9ee0c4ad4..8cce207203 100644 --- a/src/turbomind/models/llama/LlamaFfnLayer.cc +++ b/src/turbomind/models/llama/LlamaFfnLayer.cc @@ -28,10 +28,11 @@ namespace turbomind { template void LlamaFfnLayer::allocateBuffer(size_t token_num, + int inter_size, const LlamaDenseWeight* gating, const LlamaDenseWeight* inter) { - const size_t sz = token_num * inter_size_; + const size_t sz = token_num * inter_size; const size_t sz_gate = token_num * gating->lora.r; const size_t sz_inter = token_num * inter->lora.r; @@ -51,24 +52,24 @@ template void LlamaFfnLayer::freeBuffer() { if (is_allocate_buffer_) { - // allocator_->free((void**)&inter_buf_); allocator_->free((void**)&gating_buf_); is_allocate_buffer_ = false; } } template -void LlamaFfnLayer::activation(int token_num, bool is_chunked) +void LlamaFfnLayer::activation(int token_num, int inter_size, bool is_chunked) { NvtxScope scope("activation"); if (is_chunked) { + // gate & up are in the SAME buffer invokeGenericActivation_v2( - gating_buf_, gating_buf_ + inter_size_, inter_size_ * 2, token_num, inter_size_, stream_); + gating_buf_, gating_buf_ + inter_size, inter_size * 2, token_num, inter_size, stream_); sync_check_cuda_error(); } else { - invokeGenericActivation_v2( - gating_buf_, inter_buf_, inter_size_, token_num, inter_size_, stream_); + // gate & up are in separate buffers + invokeGenericActivation_v2(gating_buf_, inter_buf_, inter_size, token_num, inter_size, stream_); sync_check_cuda_error(); } } @@ -88,11 +89,11 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, NvtxScope scope("ffn"); - const size_t num_token = input_tensors->at("ffn_input").shape[0]; - const int layer_id = input_tensors->getVal("layer_id"); - // LOG(WARNING); + const size_t token_num = input_tensors->at("ffn_input").shape[0]; + const int layer_id = input_tensors->getVal("layer_id"); + const int inter_size = weights->inter_size; - allocateBuffer(num_token, &weights->gating, &weights->intermediate); + allocateBuffer(token_num, inter_size, &weights->gating, &weights->intermediate); const T* ffn_input_data = input_tensors->at("ffn_input").getPtr(); T* ffn_output_data = output_tensors->at("ffn_output").getPtr(); @@ -103,50 +104,50 @@ void LlamaFfnLayer::forward(TensorMap* output_tensors, const auto type = weights->is_fused_silu ? LlamaLinear::kFusedSiluFfn : LlamaLinear::kGemm; - linear_->forward(gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, type); + linear_->forward(gating_buf_, ffn_input_data, token_num, weights->fused_gating_intermediate, type); sync_check_cuda_error(); if (!weights->is_fused_silu) { - activation(num_token, true); + activation(token_num, inter_size, true); } - count_and_fix(gating_buf_, num_token * weights->output.input_dims, Concat("w1_w3_silu", layer_id), 3); + count_and_fix(gating_buf_, token_num * weights->output.input_dims, Concat("w1_w3_silu", layer_id), 3); } else { { // w1(x) NvtxScope scope("w1"); - linear_->forward(gating_buf_, ffn_input_data, num_token, weights->gating, LlamaLinear::kGemm, lora_mask); + linear_->forward(gating_buf_, ffn_input_data, token_num, weights->gating, LlamaLinear::kGemm, lora_mask); sync_check_cuda_error(); } - count_and_fix(gating_buf_, num_token * weights->gating.output_dims, Concat("w1", layer_id), 3); + count_and_fix(gating_buf_, token_num * weights->gating.output_dims, Concat("w1", layer_id), 3); { // w3(x) NvtxScope scope("w3"); linear_->forward( - inter_buf_, ffn_input_data, num_token, weights->intermediate, LlamaLinear::kGemm, lora_mask); + inter_buf_, ffn_input_data, token_num, weights->intermediate, LlamaLinear::kGemm, lora_mask); sync_check_cuda_error(); } - count_and_fix(inter_buf_, num_token * weights->intermediate.output_dims, Concat("w3", layer_id), 3); + count_and_fix(inter_buf_, token_num * weights->intermediate.output_dims, Concat("w3", layer_id), 3); // silu(w1(x)) * w3(x) - activation(num_token, false); + activation(token_num, inter_size, false); - count_and_fix(gating_buf_, num_token * weights->output.input_dims, Concat("act", layer_id), 3); + count_and_fix(gating_buf_, token_num * weights->output.input_dims, Concat("act", layer_id), 3); } { // w2(x) NvtxScope scope("w2"); - const int pitch = (weights->fused_gating_intermediate.kernel && !weights->is_fused_silu) ? inter_size_ * 2 : 0; + const int pitch = (weights->fused_gating_intermediate.kernel && !weights->is_fused_silu) ? inter_size * 2 : 0; linear_->forward( - ffn_output_data, {gating_buf_, pitch}, num_token, weights->output, LlamaLinear::kGemm, lora_mask); + ffn_output_data, {gating_buf_, pitch}, token_num, weights->output, LlamaLinear::kGemm, lora_mask); sync_check_cuda_error(); } - count_and_fix(ffn_output_data, num_token * weights->output.output_dims, Concat("w2", layer_id), 3); + count_and_fix(ffn_output_data, token_num * weights->output.output_dims, Concat("w2", layer_id), 3); if (all_reduce_ && tensor_para_.world_size_ > 1) { NcclGuard nccl_guard(tensor_para_, stream_); - ftNcclAllReduceSum(ffn_output_data, ffn_output_data, num_token * hidden_units_, tensor_para_, stream_); + ftNcclAllReduceSum(ffn_output_data, ffn_output_data, token_num * hidden_units_, tensor_para_, stream_); sync_check_cuda_error(); } diff --git a/src/turbomind/models/llama/LlamaFfnLayer.h b/src/turbomind/models/llama/LlamaFfnLayer.h index 75ced5f9ac..2daca2cc95 100644 --- a/src/turbomind/models/llama/LlamaFfnLayer.h +++ b/src/turbomind/models/llama/LlamaFfnLayer.h @@ -19,12 +19,11 @@ #pragma once -#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h" #include "src/turbomind/models/llama/LlamaLinear.h" #include "src/turbomind/models/llama/context.h" -#include "src/turbomind/utils/custom_ar_comm.h" +#include "src/turbomind/models/llama/llama_params.h" +#include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/nccl_utils.h" -#include namespace turbomind { @@ -32,7 +31,6 @@ template class LlamaFfnLayer { public: LlamaFfnLayer(const ModelParam& model, const NcclParam& tp, const Context& ctx, bool all_reduce): - inter_size_(model.inter_size / tp.world_size_), hidden_units_(model.hidden_units), tensor_para_(tp), stream_(ctx.stream), @@ -50,13 +48,12 @@ class LlamaFfnLayer { void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaFfnWeight* weights); private: - void allocateBuffer(size_t token_num, const LlamaDenseWeight*, const LlamaDenseWeight*); + void allocateBuffer(size_t token_num, int inter_size, const LlamaDenseWeight*, const LlamaDenseWeight*); void freeBuffer(); - void activation(int token_num, bool is_chunked); + void activation(int token_num, int inter_size, bool is_chunked); - const size_t inter_size_; const size_t hidden_units_; const NcclParam tensor_para_; cudaStream_t const stream_;