-
Notifications
You must be signed in to change notification settings - Fork 894
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[enhancement] support llama #575
base: main
Are you sure you want to change the base?
Conversation
This implement for llama is very meaningful and do you test the performance of this ? How fast can this be when compares with vanilla transformers api? |
I've been super busy lately, don't quite have the time for performance comparison, hopefully someone will do the favor and compare FT with transformers api. :-) |
Does this implement int8 (or even 4bit) by any chance? |
src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h
Outdated
Show resolved
Hide resolved
…tance.h Co-authored-by: Bram Wasti <[email protected]>
Some updates:
|
what's the parameters for kernel-autotuning for llama model? |
FasterTransformer doesn't seem to support int4 at all right now. I would be interested in helping with int8 though, that should enable the 65B model to run tensor-parallel on my 2x A6000 GPUs. |
+1 happy to contribute to this |
@void-main Have compared with ggml's llama.cpp with cuBlas support? |
num_layer = 32 | ||
rotary_embedding = 128 | ||
vocab_size = 32000 | ||
start_id = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add layernorm_eps=1e-06 will help new beginers.
Hi. I've recently tested this implementation on blip2_vicuna_instruct. It utilizes vit_qformer's embedding as a prefix_soft_embedding, which will be fed into vicuna with prompt's token_ids. According to my test result, I found that: For example, pytorch output:
FT output:
Does anyone has experience in using fasterTransformer's prefix soft prompt feature. What problem might cause this issue. Counld it be a usage mistake? I need some hits to debug it. Thanks in advance! [EDITED]: issue solved |
@void-main Hi, I'm also in Beijing and I'm a developer in AI inference. could I have your wechat? |
sure, try send me an email. :-) |
I found that the results of rotary embedding is different for FT and huggingface. Has anyone met similar problems? |
@void-main Hello,i found a bug that after multiple (thousands of) batch(20) inference, some batches may output randomly. But if the triton service is restarted, it can be inferred normally. When the batch size is equal 5, I haven't found it yet. Prompt Mixed Chinese and English
Compile based on Fastertransformer Backend Device: V100/ A100 4gpu |
Another problem is When batch inference is used, the results generated by the same prompt are different. paramters: top_k=1, random_seed=1, output_len=500
|
size_t rotary_embedding_dim_; | ||
float layernorm_eps_; | ||
|
||
static constexpr bool neox_rotary_style_ = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be false.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@prnake I test it and only true
output normal value
You are right, the model should use the basic type of rotary. |
First of all, thx for your Implement of ft LlaMa.@void-main I push a PR to support Int8 and share context. Anyone can help me to check it? |
Hi @CN-COTER , thanks for the contribution! really appreciate it! |
support int8 & share context
upload code: llama int8&share_context triton backend
fix bug: ft-llama-int8 output is incorrect
fix llama
My start_ids.csv: decode out: It looks like there's a problem with the out and decode out. Please give me some suggestions. |
Email has been received.This is an automatic reply, confirming that your email was received.Thank you.
|
Same issue. Have you find the solution? |
@void-main Please give some suggestions. Thank you! |
Will llama-2 70b arch be supported in the future? @void-main Thanks |
Implement LlaMa as requested in issue #506 .
Steps to use
first convert llama-7b-hf weights from huggingface with
huggingface_llama_convert.py
:python3 huggingface_llama_convert.py -saved_dir=/path/to/export/folder/ -in_file=/path/to/llama-7b-hf -infer_gpu_num=1 -weight_data_type=fp16 -model_name=llama_7b
next, compile and run
llama_example
.Test case
start_ids.csv:
[0, 18637, 29892, 526, 366, 1136, 455, 2470, 29973, 1815, 366, 5193, 304, 592, 29973]
out:
[0,18637,29892,526,366,1136,455,2470,29973,1815,366,5193,304,592,29973,18637,29892,526,366,1136,455,2470,29973,1815,366,5193,304,592,29973,18637,29892,526,366,1136,455,2470,29973,1815,366,5193,304,592,29973,18637,29892,526,366]