Skip to content

Commit

Permalink
eliminate inter_size_ from ffn layer
Browse files Browse the repository at this point in the history
  • Loading branch information
lzhangzz committed Nov 7, 2024
1 parent e0b221e commit cea0524
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 34 deletions.
13 changes: 9 additions & 4 deletions src/turbomind/models/llama/LlamaDenseWeight.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,24 +145,28 @@ struct LlamaFfnWeight {
LlamaFfnWeight(
size_t hidden_dim, size_t inter_size, size_t tp, WeightType weight_type, int group_size, bool fuse_silu_act)
{
inter_size /= tp;

this->inter_size = inter_size;

gating.input_dims = hidden_dim;
gating.output_dims = inter_size / tp;
gating.output_dims = inter_size;
gating.type = weight_type;
gating.group_size = group_size;

intermediate.input_dims = hidden_dim;
intermediate.output_dims = inter_size / tp;
intermediate.output_dims = inter_size;
intermediate.type = weight_type;
intermediate.group_size = group_size;

fused_gating_intermediate.input_dims = hidden_dim;
fused_gating_intermediate.output_dims = inter_size / tp * 2;
fused_gating_intermediate.output_dims = inter_size * 2;
fused_gating_intermediate.type = weight_type;
fused_gating_intermediate.group_size = group_size;

is_fused_silu = fuse_silu_act;

output.input_dims = inter_size / tp;
output.input_dims = inter_size;
output.output_dims = hidden_dim;
output.type = weight_type;
output.group_size = group_size;
Expand All @@ -173,6 +177,7 @@ struct LlamaFfnWeight {
LlamaDenseWeight<T> output;
LlamaDenseWeight<T> fused_gating_intermediate;

int inter_size{};
bool is_fused_silu{};
};

Expand Down
47 changes: 24 additions & 23 deletions src/turbomind/models/llama/LlamaFfnLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ namespace turbomind {

template<typename T>
void LlamaFfnLayer<T>::allocateBuffer(size_t token_num,
int inter_size,
const LlamaDenseWeight<T>* gating,
const LlamaDenseWeight<T>* inter)
{
const size_t sz = token_num * inter_size_;
const size_t sz = token_num * inter_size;

const size_t sz_gate = token_num * gating->lora.r;
const size_t sz_inter = token_num * inter->lora.r;
Expand All @@ -51,24 +52,24 @@ template<typename T>
void LlamaFfnLayer<T>::freeBuffer()
{
if (is_allocate_buffer_) {
// allocator_->free((void**)&inter_buf_);
allocator_->free((void**)&gating_buf_);
is_allocate_buffer_ = false;
}
}

template<typename T>
void LlamaFfnLayer<T>::activation(int token_num, bool is_chunked)
void LlamaFfnLayer<T>::activation(int token_num, int inter_size, bool is_chunked)
{
NvtxScope scope("activation");
if (is_chunked) {
// gate & up are in the SAME buffer
invokeGenericActivation_v2<SiluActivation>(
gating_buf_, gating_buf_ + inter_size_, inter_size_ * 2, token_num, inter_size_, stream_);
gating_buf_, gating_buf_ + inter_size, inter_size * 2, token_num, inter_size, stream_);
sync_check_cuda_error();
}
else {
invokeGenericActivation_v2<SiluActivation>(
gating_buf_, inter_buf_, inter_size_, token_num, inter_size_, stream_);
// gate & up are in separate buffers
invokeGenericActivation_v2<SiluActivation>(gating_buf_, inter_buf_, inter_size, token_num, inter_size, stream_);
sync_check_cuda_error();
}
}
Expand All @@ -88,11 +89,11 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,

NvtxScope scope("ffn");

const size_t num_token = input_tensors->at("ffn_input").shape[0];
const int layer_id = input_tensors->getVal<int>("layer_id");
// LOG(WARNING);
const size_t token_num = input_tensors->at("ffn_input").shape[0];
const int layer_id = input_tensors->getVal<int>("layer_id");
const int inter_size = weights->inter_size;

allocateBuffer(num_token, &weights->gating, &weights->intermediate);
allocateBuffer(token_num, inter_size, &weights->gating, &weights->intermediate);

const T* ffn_input_data = input_tensors->at("ffn_input").getPtr<T>();
T* ffn_output_data = output_tensors->at("ffn_output").getPtr<T>();
Expand All @@ -103,50 +104,50 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,

const auto type = weights->is_fused_silu ? LlamaLinear<T>::kFusedSiluFfn : LlamaLinear<T>::kGemm;

linear_->forward(gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, type);
linear_->forward(gating_buf_, ffn_input_data, token_num, weights->fused_gating_intermediate, type);
sync_check_cuda_error();

if (!weights->is_fused_silu) {
activation(num_token, true);
activation(token_num, inter_size, true);
}

count_and_fix(gating_buf_, num_token * weights->output.input_dims, Concat("w1_w3_silu", layer_id), 3);
count_and_fix(gating_buf_, token_num * weights->output.input_dims, Concat("w1_w3_silu", layer_id), 3);
}
else {
{ // w1(x)
NvtxScope scope("w1");
linear_->forward(gating_buf_, ffn_input_data, num_token, weights->gating, LlamaLinear<T>::kGemm, lora_mask);
linear_->forward(gating_buf_, ffn_input_data, token_num, weights->gating, LlamaLinear<T>::kGemm, lora_mask);
sync_check_cuda_error();
}
count_and_fix(gating_buf_, num_token * weights->gating.output_dims, Concat("w1", layer_id), 3);
count_and_fix(gating_buf_, token_num * weights->gating.output_dims, Concat("w1", layer_id), 3);

{ // w3(x)
NvtxScope scope("w3");
linear_->forward(
inter_buf_, ffn_input_data, num_token, weights->intermediate, LlamaLinear<T>::kGemm, lora_mask);
inter_buf_, ffn_input_data, token_num, weights->intermediate, LlamaLinear<T>::kGemm, lora_mask);
sync_check_cuda_error();
}
count_and_fix(inter_buf_, num_token * weights->intermediate.output_dims, Concat("w3", layer_id), 3);
count_and_fix(inter_buf_, token_num * weights->intermediate.output_dims, Concat("w3", layer_id), 3);

// silu(w1(x)) * w3(x)
activation(num_token, false);
activation(token_num, inter_size, false);

count_and_fix(gating_buf_, num_token * weights->output.input_dims, Concat("act", layer_id), 3);
count_and_fix(gating_buf_, token_num * weights->output.input_dims, Concat("act", layer_id), 3);
}

{ // w2(x)
NvtxScope scope("w2");
const int pitch = (weights->fused_gating_intermediate.kernel && !weights->is_fused_silu) ? inter_size_ * 2 : 0;
const int pitch = (weights->fused_gating_intermediate.kernel && !weights->is_fused_silu) ? inter_size * 2 : 0;
linear_->forward(
ffn_output_data, {gating_buf_, pitch}, num_token, weights->output, LlamaLinear<T>::kGemm, lora_mask);
ffn_output_data, {gating_buf_, pitch}, token_num, weights->output, LlamaLinear<T>::kGemm, lora_mask);
sync_check_cuda_error();
}

count_and_fix(ffn_output_data, num_token * weights->output.output_dims, Concat("w2", layer_id), 3);
count_and_fix(ffn_output_data, token_num * weights->output.output_dims, Concat("w2", layer_id), 3);

if (all_reduce_ && tensor_para_.world_size_ > 1) {
NcclGuard nccl_guard(tensor_para_, stream_);
ftNcclAllReduceSum(ffn_output_data, ffn_output_data, num_token * hidden_units_, tensor_para_, stream_);
ftNcclAllReduceSum(ffn_output_data, ffn_output_data, token_num * hidden_units_, tensor_para_, stream_);
sync_check_cuda_error();
}

Expand Down
11 changes: 4 additions & 7 deletions src/turbomind/models/llama/LlamaFfnLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,18 @@

#pragma once

#include "src/turbomind/models/llama/LlamaDecoderLayerWeight.h"
#include "src/turbomind/models/llama/LlamaLinear.h"
#include "src/turbomind/models/llama/context.h"
#include "src/turbomind/utils/custom_ar_comm.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/nccl_utils.h"
#include <functional>

namespace turbomind {

template<typename T>
class LlamaFfnLayer {
public:
LlamaFfnLayer(const ModelParam& model, const NcclParam& tp, const Context<T>& ctx, bool all_reduce):
inter_size_(model.inter_size / tp.world_size_),
hidden_units_(model.hidden_units),
tensor_para_(tp),
stream_(ctx.stream),
Expand All @@ -50,13 +48,12 @@ class LlamaFfnLayer {
void forward(TensorMap* output_tensors, const TensorMap* input_tensors, const LlamaFfnWeight<T>* weights);

private:
void allocateBuffer(size_t token_num, const LlamaDenseWeight<T>*, const LlamaDenseWeight<T>*);
void allocateBuffer(size_t token_num, int inter_size, const LlamaDenseWeight<T>*, const LlamaDenseWeight<T>*);

void freeBuffer();

void activation(int token_num, bool is_chunked);
void activation(int token_num, int inter_size, bool is_chunked);

const size_t inter_size_;
const size_t hidden_units_;
const NcclParam tensor_para_;
cudaStream_t const stream_;
Expand Down

0 comments on commit cea0524

Please sign in to comment.