diff --git a/.github/scripts/eval_chat_config.py b/.github/scripts/eval_chat_config.py
index e2463c0f39..74ae7a8968 100644
--- a/.github/scripts/eval_chat_config.py
+++ b/.github/scripts/eval_chat_config.py
@@ -1,7 +1,7 @@
from copy import deepcopy
from mmengine.config import read_base
-from opencompass.models import TurboMindModel, TurboMindModelwithChatTemplate
+from opencompass.models import TurboMindModelwithChatTemplate
with read_base():
# choose a list of datasets
@@ -84,6 +84,8 @@
models as hf_mistral_chat_7b # noqa: F401, E501
from opencompass.configs.models.mistral.hf_mixtral_8x7b_instruct_v0_1 import \
models as hf_mixtral_chat_8x7b # noqa: F401, E501
+ from opencompass.configs.models.qwen2_5.lmdeploy_qwen2_5_7b_instruct import \
+ models as lmdeploy_qwen2_5_7b_instruct # noqa: F401, E501
from opencompass.configs.models.qwen.hf_qwen1_5_7b_chat import \
models as hf_qwen1_5_chat_7b # noqa: F401, E501
from opencompass.configs.models.qwen.hf_qwen1_5_moe_a2_7b_chat import \
@@ -146,10 +148,8 @@
turbomind_internlm2_5_7b_chat_4bits = deepcopy(*lmdeploy_internlm2_5_7b_chat)
turbomind_internlm2_5_7b_chat_kvint4 = deepcopy(*lmdeploy_internlm2_5_7b_chat)
turbomind_internlm2_5_7b_chat_kvint8 = deepcopy(*lmdeploy_internlm2_5_7b_chat)
-turbomind_internlm2_5_7b_chat_batch1 = deepcopy(*lmdeploy_internlm2_5_7b_chat)
-turbomind_internlm2_5_7b_chat_batch1_4bits = deepcopy(
- *lmdeploy_internlm2_5_7b_chat)
pytorch_internlm2_5_7b_chat = deepcopy(*lmdeploy_internlm2_5_7b_chat)
+pytorch_internlm2_5_7b_chat_w8a8 = deepcopy(*lmdeploy_internlm2_5_7b_chat)
# ===== Configs for internlm/internlm2_5_20b_chat =====
turbomind_internlm2_5_20b_chat = deepcopy(*lmdeploy_internlm2_5_20b_chat)
@@ -181,26 +181,6 @@
turbomind_qwen_7b_chat_kvint8 = deepcopy(*lmdeploy_qwen_7b_chat)
pytorch_qwen_7b_chat = deepcopy(*lmdeploy_qwen_7b_chat)
-# ===== Configs for meta-llama/Llama-2-7b-chat-hf =====
-turbomind_llama2_7b_chat = dict(type=TurboMindModel,
- abbr='tb_llama2_chat_7b',
- path='meta-llama/Llama-2-7b-chat-hf',
- engine_config=dict(session_len=MAX_SESSION_LEN,
- max_batch_size=128),
- gen_config=dict(top_k=1,
- top_p=0.8,
- temperature=1.0,
- max_new_tokens=MAX_NEW_TOKENS),
- max_out_len=MAX_NEW_TOKENS,
- max_seq_len=MAX_SESSION_LEN,
- batch_size=128,
- meta_template=llama2_meta_template,
- run_cfg=dict(num_gpus=1),
- end_str='[INST]')
-turbomind_llama2_7b_chat_4bits = deepcopy(turbomind_llama2_7b_chat)
-turbomind_llama2_7b_chat_kvint4 = deepcopy(turbomind_llama2_7b_chat)
-turbomind_llama2_7b_chat_kvint8 = deepcopy(turbomind_llama2_7b_chat)
-
# ===== Configs for meta-llama/Meta-Llama-3-8B-Instruct =====
turbomind_llama3_8b_instruct = deepcopy(*lmdeploy_llama3_8b_instruct)
turbomind_llama3_8b_instruct_4bits = deepcopy(*lmdeploy_llama3_8b_instruct)
@@ -218,6 +198,7 @@
turbomind_llama3_1_8b_instruct_kvint8 = deepcopy(
turbomind_llama3_1_8b_instruct)
pytorch_llama3_1_8b_instruct = deepcopy(turbomind_llama3_1_8b_instruct)
+pytorch_llama3_1_8b_instruct_w8a8 = deepcopy(turbomind_llama3_1_8b_instruct)
# ===== Configs for Qwen/Qwen2-7B-Instruct =====
turbomind_qwen2_7b_instruct = deepcopy(*lmdeploy_qwen2_7b_instruct)
@@ -225,17 +206,36 @@
turbomind_qwen2_7b_instruct_kvint4 = deepcopy(*lmdeploy_qwen2_7b_instruct)
turbomind_qwen2_7b_instruct_kvint8 = deepcopy(*lmdeploy_qwen2_7b_instruct)
pytorch_qwen2_7b_instruct = deepcopy(*lmdeploy_qwen2_7b_instruct)
+pytorch_qwen2_7b_instruct_w8a8 = deepcopy(*lmdeploy_qwen2_7b_instruct)
+
+# ===== Configs for Qwen/Qwen25-7B-Instruct =====
+turbomind_qwen2_5_7b_instruct = deepcopy(*lmdeploy_qwen2_5_7b_instruct)
+turbomind_qwen2_5_7b_instruct_4bits = deepcopy(*lmdeploy_qwen2_5_7b_instruct)
+turbomind_qwen2_5_7b_instruct_kvint4 = deepcopy(*lmdeploy_qwen2_5_7b_instruct)
+turbomind_qwen2_5_7b_instruct_kvint8 = deepcopy(*lmdeploy_qwen2_5_7b_instruct)
+pytorch_qwen2_5_7b_instruct = deepcopy(*lmdeploy_qwen2_5_7b_instruct)
+pytorch_qwen2_5_7b_instruct_w8a8 = deepcopy(*lmdeploy_qwen2_5_7b_instruct)
+
+# ===== Configs for meta-llama/Llama-2-7b-chat-hf =====
+turbomind_llama2_7b_chat = deepcopy(*lmdeploy_llama2_7b_chat)
+turbomind_llama2_7b_chat_4bits = deepcopy(*lmdeploy_llama2_7b_chat)
+turbomind_llama2_7b_chat_kvint4 = deepcopy(*lmdeploy_llama2_7b_chat)
+turbomind_llama2_7b_chat_kvint8 = deepcopy(*lmdeploy_llama2_7b_chat)
for model in [v for k, v in locals().items() if k.startswith('turbomind_')]:
- model['engine_config']['max_batch_size'] = 128
+ model['engine_config']['max_batch_size'] = 1
model['gen_config']['do_sample'] = False
- model['batch_size'] = 128
+ model['batch_size'] = 100
for model in [v for k, v in locals().items() if k.endswith('_4bits')]:
model['engine_config']['model_format'] = 'awq'
model['abbr'] = model['abbr'] + '_4bits'
model['path'] = model['path'] + '-inner-4bits'
+for model in [v for k, v in locals().items() if k.endswith('_w8a8')]:
+ model['abbr'] = model['abbr'] + '_w8a8'
+ model['path'] = model['path'] + '-inner-w8a8'
+
for model in [v for k, v in locals().items() if k.endswith('_kvint4')]:
model['engine_config']['quant_policy'] = 4
model['abbr'] = model['abbr'] + '_kvint4'
@@ -247,24 +247,19 @@
for model in [v for k, v in locals().items() if k.startswith('pytorch_')]:
model['abbr'] = model['abbr'].replace('turbomind', 'pytorch')
model['backend'] = 'pytorch'
- model['engine_config']['max_batch_size'] = 64
- model['gen_config']['do_sample'] = False
- model['batch_size'] = 64
-
-for model in [v for k, v in locals().items() if '_batch1' in k]:
- model['abbr'] = model['abbr'] + '_batch1'
model['engine_config']['max_batch_size'] = 1
- model['batch_size'] = 1
+ model['gen_config']['do_sample'] = False
+ model['batch_size'] = 100
basic_pytorch_chat_tp1 = dict(type=TurboMindModelwithChatTemplate,
engine_config=dict(session_len=MAX_SESSION_LEN,
- max_batch_size=64,
+ max_batch_size=1,
tp=1),
gen_config=dict(do_sample=False,
max_new_tokens=MAX_NEW_TOKENS),
max_out_len=MAX_NEW_TOKENS,
max_seq_len=MAX_SESSION_LEN,
- batch_size=64,
+ batch_size=100,
run_cfg=dict(num_gpus=1))
# ===== Configs for Qwen/Qwen1.5-MoE-A2.7B-Chat =====
@@ -277,6 +272,13 @@
pytorch_gemma_2_9b_it['abbr'] = 'pytorch_gemma_2_9b_it'
pytorch_gemma_2_9b_it['path'] = 'google/gemma-2-9b-it'
+# ===== Configs for google/gemma2-27b-it =====
+pytorch_gemma_2_27b_it = deepcopy(basic_pytorch_chat_tp1)
+pytorch_gemma_2_27b_it['abbr'] = 'pytorch_gemma_2_27b_it'
+pytorch_gemma_2_27b_it['path'] = 'google/gemma-2-27b-it'
+pytorch_gemma_2_27b_it['run_cfg']['num_gpus'] = 2
+pytorch_gemma_2_27b_it['engine_config']['tp'] = 2
+
race_datasets = [race_datasets[1]]
# Summarizer
diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml
index bd3876f9ed..e75f728783 100644
--- a/.github/workflows/benchmark.yml
+++ b/.github/workflows/benchmark.yml
@@ -88,7 +88,7 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Clone repository
- uses: actions/checkout@v3
+ uses: actions/checkout@v2
if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}
with:
repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}
@@ -105,10 +105,8 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r /nvme/qa_test_models/offline_pkg/requirements.txt
- name: Install lmdeploy
if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}
@@ -148,9 +146,15 @@ jobs:
needs: [benchmark]
timeout-minutes: 5
runs-on: [self-hosted, linux-a100]
+ container:
+ image: openmmlab/lmdeploy:latest-cu11
+ options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never"
+ volumes:
+ - /nvme/qa_test_models:/nvme/qa_test_models
+ - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Clone repository
- uses: actions/checkout@v3
+ uses: actions/checkout@v2
with:
repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}
ref: ${{github.event.inputs.repo_ref || 'main'}}
diff --git a/.github/workflows/daily_ete_test.yml b/.github/workflows/daily_ete_test.yml
index dbacfc32f5..d6299e163a 100644
--- a/.github/workflows/daily_ete_test.yml
+++ b/.github/workflows/daily_ete_test.yml
@@ -130,7 +130,7 @@ jobs:
needs: download_pkgs
if: ${{!cancelled() && (github.event_name == 'schedule' || contains(fromJSON(github.event.inputs.regression_func), 'quant') )}}
runs-on: [self-hosted, linux-a100]
- timeout-minutes: 120
+ timeout-minutes: 150
env:
PYTHONPATH: /nvme/qa_test_models/offline_pkg/LLaVA
MODELSCOPE_CACHE: /root/modelscope_hub
@@ -149,15 +149,14 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Copy repository and Artifacts
- run: cp -r ${{env.TEST_CODE_PATH}}/. .
+ run: |
+ cp -r ${{env.TEST_CODE_PATH}}/. .
- name: Install lmdeploy - dependency
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -166,7 +165,6 @@ jobs:
pip install ${{env.DEEPSEEK_VL}} --no-deps
- name: Check env
run: |
- pip install transformers
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -244,20 +242,20 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Copy repository and Artifacts
- run: cp -r ${{env.TEST_CODE_PATH}}/. .
+ run: |
+ cp -r ${{env.TEST_CODE_PATH}}/. .
- name: Install lmdeploy - dependency
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
python3 -m pip install lmdeploy-*.whl --no-deps
python3 -m pip install -r requirements/test.txt
+ rm -rf ${{env.DEEPSEEK_VL}}/build
pip install ${{env.DEEPSEEK_VL}} --no-deps
- name: Check env
run: |
@@ -286,6 +284,8 @@ jobs:
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true
pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
+ pytest autotest/tools/chat/test_command_chat_hf_${{matrix.backend}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
+ mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
- name: Test lmdeploy - pipeline
continue-on-error: true
if: matrix.function == 'pipeline'
@@ -294,6 +294,8 @@ jobs:
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true
pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
+ pytest autotest/tools/pipeline/test_pipeline_chat_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
+ mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
- name: Test lmdeploy - restful
continue-on-error: true
if: matrix.function == 'restful'
@@ -302,6 +304,8 @@ jobs:
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true
pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
+ pytest autotest/tools/restful/test_restful_chat_hf_${{matrix.backend}}_${{matrix.model}}.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
+ mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
- name: Test lmdeploy - restful workspace
continue-on-error: true
if: matrix.backend == 'turbomind' && matrix.model == 'llm' && matrix.function == 'restful'
@@ -310,6 +314,8 @@ jobs:
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S') || true
pytest autotest/tools/restful/test_restful_chat_workspace.py -m 'gpu_num_2 and not pr_test' -n 4 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
+ pytest autotest/tools/restful/test_restful_chat_workspace.py -m 'gpu_num_4 and not pr_test' -n 2 --alluredir=${{env.REPORT_DIR}} ${{env.COV_PARAM}} || true
+ mv .coverage ${{env.REPORT_DIR}}/.coverage.$(date +'%Y%m%d%H%M%S')
- name: Test lmdeploy - local testcase
if: matrix.backend == 'turbomind' && matrix.model == 'llm' && matrix.function == 'local_case'
run: |
@@ -344,15 +350,14 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Copy repository and Artifacts
- run: cp -r ${{env.TEST_CODE_PATH}}/. .
+ run: |
+ cp -r ${{env.TEST_CODE_PATH}}/. .
- name: Install lmdeploy - dependency
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -436,15 +441,14 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Copy repository and Artifacts
- run: cp -r ${{env.TEST_CODE_PATH}}/. .
+ run: |
+ cp -r ${{env.TEST_CODE_PATH}}/. .
- name: Install lmdeploy - dependency
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -497,15 +501,14 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Copy repository and Artifacts
- run: cp -r ${{env.TEST_CODE_PATH}}/. .
+ run: |
+ cp -r ${{env.TEST_CODE_PATH}}/. .
- name: Install lmdeploy - dependency
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -560,15 +563,14 @@ jobs:
- /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
steps:
- name: Copy repository and Artifacts
- run: cp -r ${{env.TEST_CODE_PATH}}/. .
+ run: |
+ cp -r ${{env.TEST_CODE_PATH}}/. .
- name: Install lmdeploy - dependency
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install -e /root/packages/AutoAWQ_kernels
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -600,7 +602,7 @@ jobs:
run: |
export LMDEPLOY_DIR=$(pwd)
- python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_7b_chat_batch1, turbomind_internlm2_5_7b_chat_batch1_4bits, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it]" "[*race_datasets, *gsm8k_datasets, *ifeval_datasets]" /root/evaluation-reports/${{ github.run_id }} chat true
+ python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat_w8a8, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct_w8a8, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, turbomind_qwen2_5_7b_instruct, pytorch_qwen2_5_7b_instruct, pytorch_qwen2_5_7b_instruct_w8a8, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it, pytorch_gemma_2_27b_it]" "[*race_datasets, *gsm8k_datasets, *ifeval_datasets]" /root/evaluation-reports/${{ github.run_id }} chat true
- name: Evaluate base models
if: matrix.evaluate_type == 'base'
run: |
@@ -622,11 +624,17 @@ jobs:
needs: [test_benchmark]
timeout-minutes: 5
runs-on: [self-hosted, linux-a100]
+ container:
+ image: openmmlab/lmdeploy:latest-cu11
+ options: "--gpus=all --ipc=host --user root -e PIP_CACHE_DIR=/root/.cache/pip -e NVIDIA_DISABLE_REQUIRE=1 --pull never"
+ volumes:
+ - /nvme/qa_test_models:/nvme/qa_test_models
+ - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime:ro
env:
BENCHMARK_REPORT_DIR: /nvme/qa_test_models/benchmark-reports/${{ github.run_id }}
steps:
- name: Clone repository
- uses: actions/checkout@v3
+ uses: actions/checkout@v2
with:
repository: ${{ github.event.inputs.repo_org || 'InternLM/lmdeploy' }}
ref: ${{github.event.inputs.repo_ref || 'main'}}
diff --git a/.github/workflows/daily_ete_test_v100.yml b/.github/workflows/daily_ete_test_v100.yml
index 8a662b85f5..343cfdea50 100644
--- a/.github/workflows/daily_ete_test_v100.yml
+++ b/.github/workflows/daily_ete_test_v100.yml
@@ -158,8 +158,7 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -167,7 +166,6 @@ jobs:
python3 -m pip install -r requirements/test.txt
- name: Check env
run: |
- pip install triton==3.0.0
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -245,8 +243,7 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -254,7 +251,6 @@ jobs:
python3 -m pip install -r requirements/test.txt
- name: Check env
run: |
- pip install triton==3.0.0
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -345,8 +341,7 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -354,7 +349,6 @@ jobs:
python3 -m pip install -r requirements/test.txt
- name: Check env
run: |
- pip install triton==3.0.0
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -437,8 +431,7 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -446,7 +439,6 @@ jobs:
python3 -m pip install -r requirements/test.txt
- name: Check env
run: |
- pip install triton==3.0.0
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -498,8 +490,7 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -507,7 +498,6 @@ jobs:
python3 -m pip install -r requirements/test.txt
- name: Check env
run: |
- pip install triton==3.0.0
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -560,8 +550,7 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/autoawq-0.2.6-cp310-cp310-manylinux2014_x86_64.whl --no-deps
- python3 -m pip install /root/packages/xformers-0.0.28.post1-cp310-cp310-manylinux_2_28_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r ${{env.OFFLINE_REQUIREMENTS}}
- name: Install lmdeploy
run: |
@@ -575,7 +564,6 @@ jobs:
echo "OPENCOMPASS_DIR=$(pwd)" >> $GITHUB_ENV
- name: Check env
run: |
- pip install triton==3.0.0
pip uninstall -y nvidia-nccl-cu11
python3 -m pip list
lmdeploy check_env
@@ -593,13 +581,13 @@ jobs:
run: |
export LMDEPLOY_DIR=$(pwd)
- python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_7b_chat_batch1, turbomind_internlm2_5_7b_chat_batch1_4bits, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it]" "[*race_datasets, *gsm8k_datasets, *ifeval_datasets]" /root/evaluation-reports/${{ github.run_id }} chat true
+ python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it]" "[*race_datasets, *gsm8k_datasets, *ifeval_datasets]" /root/evaluation-reports/${{ github.run_id }} chat true
- name: Evaluate base models
if: matrix.evaluate_type == 'base'
run: |
export LMDEPLOY_DIR=$(pwd)
- python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_5_7b, turbomind_qwen2_5_14b, turbomind_internlm2_5_7b_batch1]" "[*race_datasets, *gsm8k_datasets, *gpqa_datasets, *winogrande_datasets]" /root/evaluation-reports/${{ github.run_id }} base true
+ python3 .github/scripts/action_tools.py evaluate "[turbomind_internlm2_5_7b, turbomind_qwen2_5_14b]" "[*race_datasets, *gsm8k_datasets, *gpqa_datasets, *winogrande_datasets]" /root/evaluation-reports/${{ github.run_id }} base true
- name: Clear workspace
if: always()
run: |
diff --git a/.github/workflows/evaluate.yml b/.github/workflows/evaluate.yml
index b6ab89f595..dbfff04fe2 100644
--- a/.github/workflows/evaluate.yml
+++ b/.github/workflows/evaluate.yml
@@ -17,7 +17,7 @@ on:
required: true
description: 'Tested TurboMind models list. eg. [internlm_chat_7b,internlm_chat_7b_w8a16]'
type: string
- default: '[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_7b_chat_batch1, turbomind_internlm2_5_7b_chat_batch1_4bits, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it, turbomind_internlm2_chat_7b_4bits, turbomind_internlm2_chat_7b_kvint4, turbomind_internlm2_chat_7b_kvint8, turbomind_internlm2_5_7b_chat_4bits, turbomind_internlm2_5_7b_chat_kvint4, turbomind_internlm2_5_7b_chat_kvint8, turbomind_internlm2_5_20b_chat_4bits, turbomind_internlm2_5_20b_chat_kvint4, turbomind_internlm2_5_20b_chat_kvint8, turbomind_qwen1_5_7b_chat_4bits, turbomind_qwen1_5_7b_chat_kvint4, turbomind_qwen1_5_7b_chat_kvint8, turbomind_llama2_7b_chat_4bits, turbomind_llama2_7b_chat_kvint4, turbomind_llama2_7b_chat_kvint8, turbomind_llama3_8b_instruct_4bits, turbomind_llama3_8b_instruct_kvint4, turbomind_llama3_8b_instruct_kvint8, turbomind_llama3_1_8b_instruct_4bits, turbomind_llama3_1_8b_instruct_kvint4, turbomind_llama3_1_8b_instruct_kvint8, turbomind_qwen2_7b_instruct_4bits, turbomind_qwen2_7b_instruct_kvint8]'
+ default: '[turbomind_internlm2_chat_7b, pytorch_internlm2_chat_7b, turbomind_internlm2_5_7b_chat, pytorch_internlm2_5_7b_chat, turbomind_internlm2_5_20b_chat, pytorch_internlm2_5_20b_chat, turbomind_qwen1_5_7b_chat, pytorch_qwen1_5_7b_chat, turbomind_llama2_7b_chat, turbomind_llama3_8b_instruct, pytorch_llama3_8b_instruct, turbomind_llama3_1_8b_instruct, pytorch_llama3_1_8b_instruct, turbomind_qwen2_7b_instruct, pytorch_qwen2_7b_instruct, turbomind_qwen2_5_7b_instruct, pytorch_qwen2_5_7b_instruct, pytorch_qwen1_5_moe_2_7b_chat, pytorch_gemma_2_9b_it, pytorch_gemma_2_27b_it, turbomind_internlm2_chat_7b_kvint4, turbomind_internlm2_chat_7b_kvint8, turbomind_internlm2_5_7b_chat_4bits, turbomind_internlm2_5_7b_chat_kvint4, turbomind_internlm2_5_7b_chat_kvint8, pytorch_internlm2_5_7b_chat_w8a8, turbomind_internlm2_5_20b_chat_4bits, turbomind_internlm2_5_20b_chat_kvint4, turbomind_internlm2_5_20b_chat_kvint8, turbomind_qwen1_5_7b_chat_4bits, turbomind_qwen1_5_7b_chat_kvint4, turbomind_qwen1_5_7b_chat_kvint8, turbomind_llama2_7b_chat_4bits, turbomind_llama2_7b_chat_kvint4, turbomind_llama2_7b_chat_kvint8, turbomind_llama3_8b_instruct_4bits, turbomind_llama3_8b_instruct_kvint4, turbomind_llama3_8b_instruct_kvint8, turbomind_llama3_1_8b_instruct_4bits, turbomind_llama3_1_8b_instruct_kvint4, turbomind_llama3_1_8b_instruct_kvint8, pytorch_llama3_1_8b_instruct_w8a8, turbomind_qwen2_7b_instruct_4bits, turbomind_qwen2_7b_instruct_kvint8, turbomind_qwen2_5_7b_instruct_4bits, turbomind_qwen2_5_7b_instruct_kvint8, pytorch_qwen2_5_7b_instruct_w8a8]'
chat_datasets:
required: true
description: 'Tested datasets list. eg. [*bbh_datasets,*ceval_datasets,*cmmlu_datasets,*GaokaoBench_datasets,*gpqa_datasets,*gsm8k_datasets,*hellaswag_datasets,*humaneval_datasets,*ifeval_datasets,*math_datasets,*sanitized_mbpp_datasets,*mmlu_datasets,*nq_datasets,*race_datasets,*TheoremQA_datasets,*triviaqa_datasets,*winogrande_datasets,*crowspairs_datasets]'
@@ -25,7 +25,7 @@ on:
default: '[*mmlu_datasets, *gsm8k_datasets, *ifeval_datasets]'
base_models:
required: true
- description: 'Tested TurboMind models list. eg. [turbomind_internlm2_5_7b, turbomind_qwen2_7b, turbomind_internlm2_5_7b_batch1]'
+ description: 'Tested TurboMind models list. eg. [turbomind_internlm2_5_7b, turbomind_qwen2_7b]'
type: string
default: '[turbomind_internlm2_5_7b, turbomind_internlm2_5_7b_4bits, turbomind_internlm2_5_7b_batch1, turbomind_internlm2_5_7b_batch1_4bits, turbomind_qwen2_7b, turbomind_qwen2_5_7b, turbomind_qwen2_5_14b]'
baes_datasets:
@@ -133,8 +133,10 @@ jobs:
run: |
# manually install flash attn
# the install packeage from. https://github.com/Dao-AILab/flash-attention/releases
- python3 -m pip install /root/packages/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp310-cp310-linux_x86_64.whl
- python3 -m pip install /root/packages/xformers-0.0.27+cu118-cp310-cp310-manylinux2014_x86_64.whl --no-deps
+ python3 -m pip install /root/packages/flash_attn-*.whl
+ python3 -m pip install -e /root/packages/AutoAWQ_kernels
+ python3 -m pip install /root/packages/autoawq-*.whl --no-deps
+ python3 -m pip install /root/packages/xformers-*.whl --no-deps
python3 -m pip install -r /root/models/offline_pkg/requirements.txt
- name: Install lmdeploy
if: ${{github.event_name == 'schedule' || !inputs.offline_mode}}
diff --git a/autotest/config.yaml b/autotest/config.yaml
index b4fd4e1712..d92e32a595 100644
--- a/autotest/config.yaml
+++ b/autotest/config.yaml
@@ -17,15 +17,21 @@ tp_config:
Meta-Llama-3-1-70B-Instruct: 4
internlm2_5-7b-chat-1m: 4
Qwen2-7B-Instruct-GPTQ-Int4: 2
- InternVL2-40B: 2
+ InternVL2-26B: 2
+ InternVL2-40B: 4
+ InternVL2_5-26B: 2
+ InternVL2_5-38B: 4
MiniCPM-V-2_6: 2
Qwen2.5-72B-Instruct: 4
+ gemma-2-27b-it: 2
+ DeepSeek-V2-Lite-Chat: 2
turbomind_chat_model:
- meta-llama/Llama-3.2-1B-Instruct
- meta-llama/Llama-3.2-3B-Instruct
- meta-llama/Meta-Llama-3-1-8B-Instruct
- meta-llama/Meta-Llama-3-1-8B-Instruct-AWQ
+ - meta-llama/Meta-Llama-3-1-70B-Instruct
- meta-llama/Meta-Llama-3-8B-Instruct
- meta-llama/Llama-2-7b-chat-hf
- internlm/internlm2_5-7b-chat
@@ -35,6 +41,10 @@ turbomind_chat_model:
- internlm/internlm-chat-20b
- internlm/internlm-xcomposer2-4khd-7b
- internlm/internlm-xcomposer2d5-7b
+ - OpenGVLab/InternVL2_5-1B
+ - OpenGVLab/InternVL2_5-8B
+ - OpenGVLab/InternVL2_5-26B
+ - OpenGVLab/InternVL2_5-38B
- OpenGVLab/InternVL2-1B
- OpenGVLab/InternVL2-2B
- OpenGVLab/InternVL2-8B
@@ -42,6 +52,7 @@ turbomind_chat_model:
- OpenGVLab/InternVL2-40B
- OpenGVLab/InternVL-Chat-V1-5
- OpenGVLab/Mini-InternVL-Chat-2B-V1-5
+ - OpenGVLab/InternVL2-Llama3-76B-AWQ
- Qwen/Qwen2-7B-Instruct
- Qwen/Qwen2-7B-Instruct-AWQ
- Qwen/Qwen2-1.5B-Instruct
@@ -51,6 +62,7 @@ turbomind_chat_model:
- Qwen/Qwen-VL-Chat
- Qwen/Qwen2.5-0.5B-Instruct
- Qwen/Qwen2.5-7B-Instruct
+ - Qwen/Qwen2.5-72B-Instruct
- Qwen/Qwen2-7B-Instruct-GPTQ-Int4
- Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4
- mistralai/Mistral-7B-Instruct-v0.3
@@ -69,10 +81,12 @@ turbomind_chat_model:
- THUDM/glm-4-9b-chat
- openbmb/MiniCPM-Llama3-V-2_5
- openbmb/MiniCPM-V-2_6
+ - allenai/Molmo-7B-D-0924
pytorch_chat_model:
- meta-llama/Meta-Llama-3-8B-Instruct
- meta-llama/Meta-Llama-3-1-8B-Instruct
+ - meta-llama/Meta-Llama-3-1-70B-Instruct
- meta-llama/Llama-3.2-1B-Instruct
- meta-llama/Llama-3.2-3B-Instruct
- meta-llama/Llama-3.2-11B-Vision-Instruct
@@ -81,6 +95,10 @@ pytorch_chat_model:
- internlm/internlm2_5-20b-chat
- internlm/internlm2-chat-20b
- internlm/internlm-chat-20b
+ - OpenGVLab/InternVL2_5-1B
+ - OpenGVLab/InternVL2_5-8B
+ - OpenGVLab/InternVL2_5-26B
+ - OpenGVLab/InternVL2_5-38B
- OpenGVLab/InternVL2-1B
- OpenGVLab/InternVL2-2B
- OpenGVLab/InternVL2-4B
@@ -92,10 +110,11 @@ pytorch_chat_model:
- baichuan-inc/Baichuan2-7B-Chat
- baichuan-inc/Baichuan2-13B-Chat
- 01-ai/Yi-6B-Chat
- - liuhaotian/llava-v1.5-13b
- - liuhaotian/llava-v1.6-vicuna-7b
- Qwen/Qwen2-7B-Instruct
- Qwen/Qwen2-1.5B-Instruct
+ - Qwen/Qwen2.5-0.5B-Instruct
+ - Qwen/Qwen2.5-7B-Instruct
+ - Qwen/Qwen2.5-72B-Instruct
- Qwen/Qwen1.5-7B-Chat
- Qwen/Qwen1.5-MoE-A2.7B-Chat
- Qwen/Qwen2-VL-2B-Instruct
@@ -104,6 +123,7 @@ pytorch_chat_model:
- mistralai/Mixtral-8x7B-Instruct-v0.1
- google/gemma-7b-it
- google/gemma-2-9b-it
+ - google/gemma-2-27b-it
- deepseek-ai/deepseek-moe-16b-chat
- deepseek-ai/deepseek-coder-1.3b-instruct
- deepseek-ai/DeepSeek-V2-Lite-Chat
@@ -111,6 +131,7 @@ pytorch_chat_model:
- THUDM/cogvlm2-llama3-chinese-chat-19B
- THUDM/glm-4v-9b
- THUDM/glm-4-9b-chat
+ - openbmb/MiniCPM-V-2_6
- microsoft/Phi-3-mini-4k-instruct
- microsoft/Phi-3-vision-128k-instruct
@@ -122,11 +143,16 @@ turbomind_vl_model:
- deepseek-ai/deepseek-vl-1.3b-chat
- OpenGVLab/InternVL-Chat-V1-5
- OpenGVLab/Mini-InternVL-Chat-2B-V1-5
+ - OpenGVLab/InternVL2_5-1B
+ - OpenGVLab/InternVL2_5-8B
+ - OpenGVLab/InternVL2_5-26B
+ - OpenGVLab/InternVL2_5-38B
- OpenGVLab/InternVL2-1B
- OpenGVLab/InternVL2-2B
- OpenGVLab/InternVL2-8B
- OpenGVLab/InternVL2-26B
- OpenGVLab/InternVL2-40B
+ - OpenGVLab/InternVL2-Llama3-76B-AWQ
- internlm/internlm-xcomposer2d5-7b
- internlm/internlm-xcomposer2-4khd-7b
- openbmb/MiniCPM-Llama3-V-2_5
@@ -136,6 +162,10 @@ pytorch_vl_model:
- meta-llama/Llama-3.2-11B-Vision-Instruct
- OpenGVLab/InternVL-Chat-V1-5
- OpenGVLab/Mini-InternVL-Chat-2B-V1-5
+ - OpenGVLab/InternVL2_5-1B
+ - OpenGVLab/InternVL2_5-8B
+ - OpenGVLab/InternVL2_5-26B
+ - OpenGVLab/InternVL2_5-38B
- OpenGVLab/InternVL2-1B
- OpenGVLab/InternVL2-2B
- OpenGVLab/InternVL2-4B
@@ -148,6 +178,7 @@ pytorch_vl_model:
- THUDM/cogvlm-chat-hf
- THUDM/cogvlm2-llama3-chinese-chat-19B
- THUDM/glm-4v-9b
+ - openbmb/MiniCPM-V-2_6
- microsoft/Phi-3-vision-128k-instruct
- microsoft/Phi-3.5-vision-instruct
@@ -166,6 +197,8 @@ pytorch_base_model:
turbomind_quatization:
no_awq:
+ - meta-llama/Meta-Llama-3-1-70B-Instruct
+ - Qwen/Qwen2.5-72B-Instruct
- Qwen/Qwen1.5-MoE-A2.7B-Chat
- Qwen/Qwen2-VL-2B-Instruct
- Qwen/Qwen2-VL-7B-Instruct
@@ -174,18 +207,28 @@ turbomind_quatization:
- deepseek-ai/deepseek-coder-1.3b-instruct
- deepseek-ai/DeepSeek-V2-Lite-Chat
- codellama/CodeLlama-7b-Instruct-hf
+ - allenai/Molmo-7B-D-0924
gptq:
- internlm/internlm2_5-7b-chat
no_kvint4:
+ - meta-llama/Llama-3.2-1B-Instruct
+ - OpenGVLab/InternVL2-1B
+ - OpenGVLab/InternVL2_5-1B
- openbmb/MiniCPM-V-2_6
- Qwen/Qwen2-7B-Instruct
- Qwen/Qwen2-7B-Instruct-AWQ
- Qwen/Qwen2-1.5B-Instruct
- Qwen/Qwen2.5-0.5B-Instruct
- Qwen/Qwen2.5-7B-Instruct
+ - Qwen/Qwen2.5-72B-Instruct
- Qwen/Qwen2-7B-Instruct-GPTQ-Int4
+ - allenai/Molmo-7B-D-0924
no_kvint8:
+ - deepseek-ai/DeepSeek-V2-Chat
+ no_converted:
- deepseek-ai/DeepSeek-V2-Lite-Chat
+ - Qwen/Qwen2.5-72B-Instruct
+ - meta-llama/Meta-Llama-3-1-70B-Instruct
pytorch_quatization:
awq:
@@ -200,23 +243,39 @@ pytorch_quatization:
- Qwen/Qwen1.5-7B-Chat
- Qwen/Qwen2-7B-Instruct
- Qwen/Qwen2-1.5B-Instruct
- - microsoft/Phi-3-mini-4k-instruct
+ - Qwen/Qwen2.5-7B-Instruct
- Qwen/Qwen2-VL-2B-Instruct
- Qwen/Qwen2-VL-7B-Instruct
+ - microsoft/Phi-3-mini-4k-instruct
w8a8:
- meta-llama/Meta-Llama-3-8B-Instruct
+ - meta-llama/Llama-3.2-1B-Instruct
- meta-llama/Llama-2-7b-chat-hf
- internlm/internlm2-chat-20b
- internlm/internlm2_5-7b-chat
- internlm/internlm2_5-20b-chat
- 01-ai/Yi-6B-Chat
+ - mistralai/Mistral-7B-Instruct-v0.3
+ - Qwen/Qwen1.5-7B-Chat
+ - Qwen/Qwen2-7B-Instruct
+ - Qwen/Qwen2-1.5B-Instruct
+ - Qwen/Qwen2.5-7B-Instruct
+ - microsoft/Phi-3-mini-4k-instruct
- internlm/internlm2_5-20b
- internlm/internlm2_5-7b
+ - meta-llama/Meta-Llama-3-1-8B-Instruct
no_kvint4:
+ - meta-llama/Llama-3.2-1B-Instruct
- OpenGVLab/InternVL2-1B
- OpenGVLab/InternVL2-4B
+ - OpenGVLab/InternVL2_5-1B
- Qwen/Qwen2-7B-Instruct
+ - Qwen/Qwen2-7B-Instruct-AWQ
- Qwen/Qwen2-1.5B-Instruct
+ - Qwen/Qwen2.5-0.5B-Instruct
+ - Qwen/Qwen2.5-7B-Instruct
+ - Qwen/Qwen2.5-72B-Instruct
+ - Qwen/Qwen2-7B-Instruct-GPTQ-Int4
- Qwen/Qwen2-VL-2B-Instruct
- Qwen/Qwen2-VL-7B-Instruct
- deepseek-ai/DeepSeek-V2-Lite-Chat
@@ -247,3 +306,4 @@ benchmark_model:
- mistralai/Mistral-7B-Instruct-v0.3
- mistralai/Mixtral-8x7B-Instruct-v0.1
- deepseek-ai/DeepSeek-V2-Lite-Chat
+ - allenai/Molmo-7B-D-0924
diff --git a/autotest/interface/pipeline/test_pipeline_func.py b/autotest/interface/pipeline/test_pipeline_func.py
index 87a0719bcb..0696684890 100644
--- a/autotest/interface/pipeline/test_pipeline_func.py
+++ b/autotest/interface/pipeline/test_pipeline_func.py
@@ -408,7 +408,7 @@ def run_pipeline_testcase(config, model, backend, file_name):
result = True
for i in range(2):
result &= response[i].finish_reason == 'length'
- result &= response[i].session_id == i
+ result &= response[i].index == i
save_pipeline_common_log(config, file_name, result, response)
del pipe
torch.cuda.empty_cache()
diff --git a/autotest/prompt_case.yaml b/autotest/prompt_case.yaml
index 9a5ed45724..468f3e49d6 100644
--- a/autotest/prompt_case.yaml
+++ b/autotest/prompt_case.yaml
@@ -54,6 +54,7 @@ emoji_case:
- 好
- '!'
- u1f44d
+ - 🌟
traditional_chinese_case:
- 介紹澳門景點,使用繁體:
- contain:
diff --git a/autotest/pytest.ini b/autotest/pytest.ini
index 4c963d5bbd..69dc47fa58 100644
--- a/autotest/pytest.ini
+++ b/autotest/pytest.ini
@@ -5,4 +5,4 @@ python_functions = test_* # test function
pytest_runtest_call.tryfirst = True
filterwarnings = ignore::UserWarning
reruns = 2
-reruns_delay = 10
+reruns_delay = 1
diff --git a/autotest/tools/chat/test_command_chat_hf_pytorch.py b/autotest/tools/chat/test_command_chat_hf_pytorch.py
index 1ae3be338b..e6986ec614 100644
--- a/autotest/tools/chat/test_command_chat_hf_pytorch.py
+++ b/autotest/tools/chat/test_command_chat_hf_pytorch.py
@@ -51,6 +51,27 @@ def test_hf_pytorch_chat_tp2(config, model, cli_case_config, worker_id):
assert result, msg
+@pytest.mark.order(10)
+@pytest.mark.usefixtures('cli_case_config')
+@pytest.mark.hf_pytorch_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model', get_torch_model_list(tp_num=4))
+def test_hf_pytorch_chat_tp4(config, model, cli_case_config, worker_id):
+ usercase = 'chat_testcase'
+ result, chat_log, msg = hf_command_line_test(
+ config,
+ usercase,
+ cli_case_config.get(usercase),
+ model,
+ 'pytorch',
+ cuda_prefix=get_cuda_prefix_by_workerid(worker_id, tp_num=4))
+ if chat_log is not None:
+ allure.attach.file(chat_log,
+ attachment_type=allure.attachment_type.TEXT)
+
+ assert result, msg
+
+
@pytest.mark.order(10)
@pytest.mark.usefixtures('cli_case_config')
@pytest.mark.hf_turbomind_chat
diff --git a/autotest/tools/chat/test_command_chat_hf_turbomind.py b/autotest/tools/chat/test_command_chat_hf_turbomind.py
index 2f13898fec..935a21ee86 100644
--- a/autotest/tools/chat/test_command_chat_hf_turbomind.py
+++ b/autotest/tools/chat/test_command_chat_hf_turbomind.py
@@ -53,6 +53,28 @@ def test_hf_turbomind_chat_tp2(config, model, cli_case_config, worker_id):
assert result, msg
+@pytest.mark.order(10)
+@pytest.mark.usefixtures('cli_case_config')
+@pytest.mark.hf_turbomind_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=4))
+def test_hf_turbomind_chat_tp4(config, model, cli_case_config, worker_id):
+ usercase = 'chat_testcase'
+ result, chat_log, msg = hf_command_line_test(
+ config,
+ usercase,
+ cli_case_config.get(usercase),
+ model,
+ 'turbomind',
+ cuda_prefix=get_cuda_prefix_by_workerid(worker_id, tp_num=4))
+
+ if chat_log is not None:
+ allure.attach.file(chat_log,
+ attachment_type=allure.attachment_type.TEXT)
+
+ assert result, msg
+
+
@pytest.mark.order(10)
@pytest.mark.usefixtures('cli_case_config')
@pytest.mark.hf_turbomind_chat
diff --git a/autotest/tools/chat/test_command_chat_workspace.py b/autotest/tools/chat/test_command_chat_workspace.py
index a16d4e32f6..415a1c528c 100644
--- a/autotest/tools/chat/test_command_chat_workspace.py
+++ b/autotest/tools/chat/test_command_chat_workspace.py
@@ -9,7 +9,8 @@
@pytest.mark.usefixtures('cli_case_config')
@pytest.mark.command_chat
@pytest.mark.gpu_num_1
-@pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=1))
+@pytest.mark.parametrize('model',
+ get_turbomind_model_list(tp_num=1, is_converted=True))
def test_workspace_chat_tp1(config, cli_case_config, model, worker_id):
usercase = 'chat_testcase'
# cannot convert with rop-scale params, so the case should be skipped
@@ -32,7 +33,8 @@ def test_workspace_chat_tp1(config, cli_case_config, model, worker_id):
@pytest.mark.usefixtures('cli_case_config')
@pytest.mark.command_chat
@pytest.mark.gpu_num_2
-@pytest.mark.parametrize('model', get_turbomind_model_list(tp_num=2))
+@pytest.mark.parametrize('model',
+ get_turbomind_model_list(tp_num=2, is_converted=True))
def test_workspace_chat_tp2(config, cli_case_config, model, worker_id):
usercase = 'chat_testcase'
result, chat_log, msg = command_line_test(
@@ -54,7 +56,8 @@ def test_workspace_chat_tp2(config, cli_case_config, model, worker_id):
@pytest.mark.gpu_num_1
@pytest.mark.parametrize('model',
get_turbomind_model_list(tp_num=1,
- model_type='base_model'))
+ model_type='base_model',
+ is_converted=True))
def test_workspace_base_tp1(config, cli_case_config, model, worker_id):
usercase = 'base_testcase'
result, chat_log, msg = command_line_test(
@@ -76,7 +79,8 @@ def test_workspace_base_tp1(config, cli_case_config, model, worker_id):
@pytest.mark.gpu_num_2
@pytest.mark.parametrize('model',
get_turbomind_model_list(tp_num=2,
- model_type='base_model'))
+ model_type='base_model',
+ is_converted=True))
def test_workspace_base_tp2(config, cli_case_config, model, worker_id):
usercase = 'base_testcase'
result, chat_log, msg = command_line_test(
diff --git a/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py b/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py
index 58674fa173..c0348ec500 100644
--- a/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py
+++ b/autotest/tools/pipeline/test_pipeline_chat_pytorch_llm.py
@@ -56,6 +56,32 @@ def test_pipeline_chat_pytorch_tp2(config, common_case_config, model,
worker_id)
+@pytest.mark.order(6)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.pipeline_chat_pytorch
+@pytest.mark.gpu_num_4
+@pytest.mark.flaky(reruns=0)
+@pytest.mark.parametrize('model',
+ get_torch_model_list(tp_num=4, exclude_dup=True))
+def test_pipeline_chat_pytorch_tp4(config, common_case_config, model,
+ worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_chat_test,
+ args=(config, common_case_config, model,
+ 'pytorch', worker_id))
+ p.start()
+ p.join()
+
+ # assert script
+ assert_pipeline_chat_log(config, common_case_config, model, 'pytorch',
+ worker_id)
+
+
@pytest.mark.order(6)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.pipeline_chat
@@ -109,6 +135,34 @@ def test_pipeline_chat_kvint4_tp2(config, common_case_config, model,
'pytorch-kvint', worker_id)
+@pytest.mark.order(6)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.flaky(reruns=0)
+@pytest.mark.parametrize('model',
+ get_torch_model_list(tp_num=4,
+ quant_policy=4,
+ exclude_dup=True))
+def test_pipeline_chat_kvint4_tp4(config, common_case_config, model,
+ worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_chat_test,
+ args=(config, common_case_config, model,
+ 'pytorch-kvint', worker_id, {
+ 'quant_policy': 4
+ }))
+ p.start()
+ p.join()
+ assert_pipeline_chat_log(config, common_case_config, model,
+ 'pytorch-kvint', worker_id)
+
+
@pytest.mark.order(6)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.pipeline_chat
@@ -162,6 +216,34 @@ def test_pipeline_chat_kvint8_tp2(config, common_case_config, model,
'pytorch-kvint', worker_id)
+@pytest.mark.order(6)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.flaky(reruns=0)
+@pytest.mark.parametrize('model',
+ get_torch_model_list(tp_num=4,
+ quant_policy=8,
+ exclude_dup=True))
+def test_pipeline_chat_kvint8_tp4(config, common_case_config, model,
+ worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_chat_test,
+ args=(config, common_case_config, model,
+ 'pytorch-kvint', worker_id, {
+ 'quant_policy': 8
+ }))
+ p.start()
+ p.join()
+ assert_pipeline_chat_log(config, common_case_config, model,
+ 'pytorch-kvint', worker_id)
+
+
@pytest.mark.order(6)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.pipeline_chat_pytorch
diff --git a/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py b/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py
index 8403ced94f..8735b8e937 100644
--- a/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py
+++ b/autotest/tools/pipeline/test_pipeline_chat_pytorch_mllm.py
@@ -34,6 +34,27 @@ def test_pipeline_chat_tp2(config, model, worker_id):
if 'gw' in worker_id:
os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
tp_num=2)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_vl_chat_test,
+ args=(config, model, BACKEND, worker_id))
+ p.start()
+ p.join()
+ assert_pipeline_vl_chat_log(config, model, worker_id)
+
+
+@pytest.mark.order(6)
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model',
+ get_torch_model_list(tp_num=4, model_type='vl_model'))
+def test_pipeline_chat_tp4(config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
spawn_context = get_context('spawn')
p = spawn_context.Process(target=run_pipeline_vl_chat_test,
args=(config, model, BACKEND, worker_id))
@@ -71,6 +92,29 @@ def test_pipeline_chat_kvint4_tp2(config, model, worker_id):
if 'gw' in worker_id:
os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
tp_num=2)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_vl_chat_test,
+ args=(config, model, BACKEND, worker_id, 4))
+ p.start()
+ p.join()
+ assert_pipeline_vl_chat_log(config, model, worker_id)
+
+
+@pytest.mark.order(6)
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model',
+ get_torch_model_list(tp_num=4,
+ quant_policy=4,
+ model_type='vl_model'))
+def test_pipeline_chat_kvint4_tp4(config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
spawn_context = get_context('spawn')
p = spawn_context.Process(target=run_pipeline_vl_chat_test,
args=(config, model, BACKEND, worker_id, 4))
@@ -108,6 +152,29 @@ def test_pipeline_chat_kvint8_tp2(config, model, worker_id):
if 'gw' in worker_id:
os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
tp_num=2)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_vl_chat_test,
+ args=(config, model, BACKEND, worker_id, 8))
+ p.start()
+ p.join()
+ assert_pipeline_vl_chat_log(config, model, worker_id)
+
+
+@pytest.mark.order(6)
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model',
+ get_torch_model_list(tp_num=4,
+ quant_policy=8,
+ model_type='vl_model'))
+def test_pipeline_chat_kvint8_tp4(config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
spawn_context = get_context('spawn')
p = spawn_context.Process(target=run_pipeline_vl_chat_test,
args=(config, model, BACKEND, worker_id, 8))
diff --git a/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py b/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py
index d1865175cf..58eab0de76 100644
--- a/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py
+++ b/autotest/tools/pipeline/test_pipeline_chat_turbomind_llm.py
@@ -48,6 +48,28 @@ def test_pipeline_chat_tp2(config, common_case_config, model, worker_id):
worker_id)
+@pytest.mark.order(6)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.flaky(reruns=0)
+@pytest.mark.parametrize('model', get_all_model_list(tp_num=4))
+def test_pipeline_chat_tp4(config, common_case_config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_chat_test,
+ args=(config, common_case_config, model,
+ 'turbomind', worker_id))
+ p.start()
+ p.join()
+ assert_pipeline_chat_log(config, common_case_config, model, 'turbomind',
+ worker_id)
+
+
@pytest.mark.order(6)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.pipeline_chat
@@ -95,6 +117,31 @@ def test_pipeline_chat_kvint4_tp2(config, common_case_config, model,
'turbomind-kvint', worker_id)
+@pytest.mark.order(6)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.flaky(reruns=0)
+@pytest.mark.parametrize('model', get_all_model_list(tp_num=4, quant_policy=4))
+def test_pipeline_chat_kvint4_tp4(config, common_case_config, model,
+ worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_chat_test,
+ args=(config, common_case_config, model,
+ 'turbomind-kvint', worker_id, {
+ 'quant_policy': 4
+ }))
+ p.start()
+ p.join()
+ assert_pipeline_chat_log(config, common_case_config, model,
+ 'turbomind-kvint', worker_id)
+
+
@pytest.mark.order(6)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.pipeline_chat
@@ -142,6 +189,31 @@ def test_pipeline_chat_kvint8_tp2(config, common_case_config, model,
'turbomind-kvint', worker_id)
+@pytest.mark.order(6)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.flaky(reruns=0)
+@pytest.mark.parametrize('model', get_all_model_list(tp_num=4, quant_policy=8))
+def test_pipeline_chat_kvint8_tp4(config, common_case_config, model,
+ worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_chat_test,
+ args=(config, common_case_config, model,
+ 'turbomind-kvint', worker_id, {
+ 'quant_policy': 8
+ }))
+ p.start()
+ p.join()
+ assert_pipeline_chat_log(config, common_case_config, model,
+ 'turbomind-kvint', worker_id)
+
+
@pytest.mark.order(6)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.pipeline_chat
diff --git a/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py b/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py
index 8c845fa77a..c62bfc5e8e 100644
--- a/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py
+++ b/autotest/tools/pipeline/test_pipeline_chat_turbomind_mllm.py
@@ -34,6 +34,27 @@ def test_pipeline_chat_tp2(config, model, worker_id):
if 'gw' in worker_id:
os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
tp_num=2)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_vl_chat_test,
+ args=(config, model, BACKEND, worker_id))
+ p.start()
+ p.join()
+ assert_pipeline_vl_chat_log(config, model, worker_id)
+
+
+@pytest.mark.order(6)
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model',
+ get_all_model_list(tp_num=4, model_type='vl_model'))
+def test_pipeline_chat_tp4(config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
spawn_context = get_context('spawn')
p = spawn_context.Process(target=run_pipeline_vl_chat_test,
args=(config, model, BACKEND, worker_id))
@@ -71,6 +92,29 @@ def test_pipeline_chat_kvint4_tp2(config, model, worker_id):
if 'gw' in worker_id:
os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
tp_num=2)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_vl_chat_test,
+ args=(config, model, BACKEND, worker_id, 4))
+ p.start()
+ p.join()
+ assert_pipeline_vl_chat_log(config, model, worker_id)
+
+
+@pytest.mark.order(6)
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model',
+ get_all_model_list(tp_num=4,
+ quant_policy=4,
+ model_type='vl_model'))
+def test_pipeline_chat_kvint4_tp4(config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
spawn_context = get_context('spawn')
p = spawn_context.Process(target=run_pipeline_vl_chat_test,
args=(config, model, BACKEND, worker_id, 4))
@@ -108,6 +152,29 @@ def test_pipeline_chat_kvint8_tp2(config, model, worker_id):
if 'gw' in worker_id:
os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
tp_num=2)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
+ spawn_context = get_context('spawn')
+ p = spawn_context.Process(target=run_pipeline_vl_chat_test,
+ args=(config, model, BACKEND, worker_id, 8))
+ p.start()
+ p.join()
+ assert_pipeline_vl_chat_log(config, model, worker_id)
+
+
+@pytest.mark.order(6)
+@pytest.mark.pipeline_chat
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('model',
+ get_all_model_list(tp_num=4,
+ quant_policy=8,
+ model_type='vl_model'))
+def test_pipeline_chat_kvint8_tp4(config, model, worker_id):
+ if 'gw' in worker_id:
+ os.environ['CUDA_VISIBLE_DEVICES'] = get_cuda_id_by_workerid(worker_id,
+ tp_num=4)
+ os.environ['MASTER_PORT'] = str(
+ int(worker_id.replace('gw', '')) + 29500)
spawn_context = get_context('spawn')
p = spawn_context.Process(target=run_pipeline_vl_chat_test,
args=(config, model, BACKEND, worker_id, 8))
diff --git a/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py b/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py
index fc95e288ca..bc0ea3996a 100644
--- a/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py
+++ b/autotest/tools/restful/test_restful_chat_hf_pytorch_llm.py
@@ -60,6 +60,23 @@ def test_restful_chat_tp2(config, common_case_config, worker_id):
port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.restful_api_pytorch
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getModelList(tp_num=4),
+ indirect=True)
+def test_restful_chat_tp4(config, common_case_config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_all_step(config, common_case_config)
+ else:
+ run_all_step(config,
+ common_case_config,
+ worker_id=worker_id,
+ port=DEFAULT_PORT + get_workerid(worker_id))
+
+
def getKvintModelList(tp_num, quant_policy):
return [{
'model': item,
@@ -104,6 +121,23 @@ def test_restful_chat_kvint4_tp2(config, common_case_config, worker_id):
port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.restful_api
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=4),
+ indirect=True)
+def test_restful_chat_kvint4_tp4(config, common_case_config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_all_step(config, common_case_config)
+ else:
+ run_all_step(config,
+ common_case_config,
+ worker_id=worker_id,
+ port=DEFAULT_PORT + get_workerid(worker_id))
+
+
@pytest.mark.order(7)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.restful_api
@@ -138,6 +172,23 @@ def test_restful_chat_kvint8_tp2(config, common_case_config, worker_id):
port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.restful_api
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=8),
+ indirect=True)
+def test_restful_chat_kvint8_tp4(config, common_case_config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_all_step(config, common_case_config)
+ else:
+ run_all_step(config,
+ common_case_config,
+ worker_id=worker_id,
+ port=DEFAULT_PORT + get_workerid(worker_id))
+
+
@pytest.mark.order(7)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.restful_api
diff --git a/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py b/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py
index bf20c45e6e..cc85d35d09 100644
--- a/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py
+++ b/autotest/tools/restful/test_restful_chat_hf_pytorch_mllm.py
@@ -53,6 +53,19 @@ def test_restful_chat_tp2(config, worker_id):
run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.restful_api_vl
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getModelList(tp_num=4),
+ indirect=True)
+def test_restful_chat_tp4(config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_vl_testcase(config)
+ else:
+ run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+
+
def getKvintModelList(tp_num, quant_policy: int = None):
return [{
'model': item,
@@ -89,6 +102,19 @@ def test_restful_chat_kvint4_tp2(config, worker_id):
run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.restful_api_vl
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=4),
+ indirect=True)
+def test_restful_chat_kvint4_tp4(config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_vl_testcase(config)
+ else:
+ run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+
+
@pytest.mark.order(7)
@pytest.mark.restful_api_vl
@pytest.mark.gpu_num_1
@@ -113,3 +139,16 @@ def test_restful_chat_kvint8_tp2(config, worker_id):
run_vl_testcase(config)
else:
run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+
+
+@pytest.mark.order(7)
+@pytest.mark.restful_api_vl
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=8),
+ indirect=True)
+def test_restful_chat_kvint8_tp4(config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_vl_testcase(config)
+ else:
+ run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
diff --git a/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py b/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py
index 1c9131b32e..435ffc4ae3 100644
--- a/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py
+++ b/autotest/tools/restful/test_restful_chat_hf_turbomind_llm.py
@@ -60,6 +60,23 @@ def test_restful_chat_tp2(config, common_case_config, worker_id):
port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.restful_api
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getModelList(tp_num=4),
+ indirect=True)
+def test_restful_chat_tp4(config, common_case_config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_all_step(config, common_case_config)
+ else:
+ run_all_step(config,
+ common_case_config,
+ worker_id=worker_id,
+ port=DEFAULT_PORT + get_workerid(worker_id))
+
+
def getKvintModelList(tp_num, quant_policy):
return [{
'model': item,
@@ -103,6 +120,23 @@ def test_restful_chat_kvint4_tp2(config, common_case_config, worker_id):
port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.restful_api
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=4),
+ indirect=True)
+def test_restful_chat_kvint4_tp4(config, common_case_config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_all_step(config, common_case_config)
+ else:
+ run_all_step(config,
+ common_case_config,
+ worker_id=worker_id,
+ port=DEFAULT_PORT + get_workerid(worker_id))
+
+
@pytest.mark.order(7)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.restful_api
@@ -137,6 +171,23 @@ def test_restful_chat_kvint8_tp2(config, common_case_config, worker_id):
port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.usefixtures('common_case_config')
+@pytest.mark.restful_api
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=8),
+ indirect=True)
+def test_restful_chat_kvint8_tp4(config, common_case_config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_all_step(config, common_case_config)
+ else:
+ run_all_step(config,
+ common_case_config,
+ worker_id=worker_id,
+ port=DEFAULT_PORT + get_workerid(worker_id))
+
+
@pytest.mark.order(7)
@pytest.mark.usefixtures('common_case_config')
@pytest.mark.restful_api
diff --git a/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py b/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py
index 641f2f760f..bbb8718366 100644
--- a/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py
+++ b/autotest/tools/restful/test_restful_chat_hf_turbomind_mllm.py
@@ -53,6 +53,19 @@ def test_restful_chat_tp2(config, worker_id):
run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.restful_api_vl
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getModelList(tp_num=4),
+ indirect=True)
+def test_restful_chat_tp4(config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_vl_testcase(config)
+ else:
+ run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+
+
def getKvintModelList(tp_num, quant_policy: int = None):
return [{
'model': item,
@@ -89,6 +102,19 @@ def test_restful_chat_kvint4_tp2(config, worker_id):
run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+@pytest.mark.order(7)
+@pytest.mark.restful_api_vl
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=4),
+ indirect=True)
+def test_restful_chat_kvint4_tp4(config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_vl_testcase(config)
+ else:
+ run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+
+
@pytest.mark.order(7)
@pytest.mark.restful_api_vl
@pytest.mark.gpu_num_1
@@ -113,3 +139,16 @@ def test_restful_chat_kvint8_tp2(config, worker_id):
run_vl_testcase(config)
else:
run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
+
+
+@pytest.mark.order(7)
+@pytest.mark.restful_api_vl
+@pytest.mark.gpu_num_4
+@pytest.mark.parametrize('prepare_environment',
+ getKvintModelList(tp_num=4, quant_policy=8),
+ indirect=True)
+def test_restful_chat_kvint8_tp4(config, worker_id):
+ if get_workerid(worker_id) is None:
+ run_vl_testcase(config)
+ else:
+ run_vl_testcase(config, port=DEFAULT_PORT + get_workerid(worker_id))
diff --git a/autotest/tools/restful/test_restful_chat_workspace.py b/autotest/tools/restful/test_restful_chat_workspace.py
index 798a43d7b0..cf69007cca 100644
--- a/autotest/tools/restful/test_restful_chat_workspace.py
+++ b/autotest/tools/restful/test_restful_chat_workspace.py
@@ -23,8 +23,7 @@ def getModelList(tp_num):
'model': item,
'cuda_prefix': None,
'tp_num': tp_num
- } for item in get_turbomind_model_list(tp_num)
- if item not in ('deepseek-ai/deepseek-coder-1.3b-instruct')]
+ } for item in get_turbomind_model_list(tp_num, is_converted=True)]
@pytest.mark.order(7)
diff --git a/autotest/utils/config_utils.py b/autotest/utils/config_utils.py
index dd8db1ccc4..87d5d73f10 100644
--- a/autotest/utils/config_utils.py
+++ b/autotest/utils/config_utils.py
@@ -9,17 +9,33 @@
def get_turbomind_model_list(tp_num: int = None,
model_type: str = 'chat_model',
- quant_policy: int = None):
+ quant_policy: int = None,
+ is_converted: bool = False):
config = get_config()
if quant_policy is None:
- case_list = copy.deepcopy(config.get('turbomind_' + model_type))
+ if is_converted:
+ case_list = [
+ x for x in copy.deepcopy(config.get('turbomind_' + model_type))
+ if x not in config.get('turbomind_quatization').get(
+ 'no_converted')
+ ]
+ else:
+ case_list = copy.deepcopy(config.get('turbomind_' + model_type))
else:
- case_list = [
- x for x in config.get('turbomind_' + model_type)
- if x not in config.get('turbomind_quatization').get(
- 'no_kvint' + str(quant_policy))
- ]
+ if is_converted:
+ case_list = [
+ x for x in config.get('turbomind_' + model_type)
+ if x not in config.get('turbomind_quatization').get(
+ 'no_kvint' + str(quant_policy) and x not in config.get(
+ 'turbomind_quatization').get('no_converted'))
+ ]
+ else:
+ case_list = [
+ x for x in config.get('turbomind_' + model_type)
+ if x not in config.get('turbomind_quatization').get(
+ 'no_kvint' + str(quant_policy))
+ ]
quatization_case_config = config.get('turbomind_quatization')
for key in config.get('turbomind_' + model_type):
@@ -202,6 +218,7 @@ def get_benchmark_model_list(tp_num,
else:
case_list_base = config.get('benchmark_model')
quatization_case_config = config.get('turbomind_quatization')
+ pytorch_quatization_case_config = config.get('pytorch_quatization')
case_list = copy.deepcopy(case_list_base)
for key in case_list_base:
@@ -210,6 +227,12 @@ def get_benchmark_model_list(tp_num,
'no_awq') and not is_quantization_model(key):
case_list.append(key + '-inner-4bits')
+ for key in case_list_base:
+ if key in config.get('pytorch_chat_model'
+ ) and key in pytorch_quatization_case_config.get(
+ 'w8a8') and not is_quantization_model(key):
+ case_list.append(key + '-inner-w8a8')
+
model_list = [
item for item in case_list if get_tp_num(config, item) == tp_num
]
@@ -228,15 +251,18 @@ def get_benchmark_model_list(tp_num,
'backend': 'pytorch',
'tp_num': tp_num
} for item in model_list if '4bits' not in item and (
- item in config.get('pytorch_chat_model') or tp_num == 4)]
+ item.replace('-inner-w8a8', '') in config.get('pytorch_chat_model')
+ or tp_num == 4)]
for kvint in kvint_list:
result += [{
'model': item,
'backend': 'turbomind',
'quant_policy': kvint,
'tp_num': tp_num
- } for item in model_list if item.replace('-inner-4bits', '') in
- config.get('turbomind_chat_model')]
+ } for item in model_list if item.replace(
+ '-inner-4bits', '') in config.get('turbomind_chat_model')
+ and item.replace('-inner-4bits', '') not in
+ quatization_case_config.get('no_kvint' + str(kvint))]
return result
diff --git a/autotest/utils/pipeline_chat.py b/autotest/utils/pipeline_chat.py
index 023e4ac142..5dcb358319 100644
--- a/autotest/utils/pipeline_chat.py
+++ b/autotest/utils/pipeline_chat.py
@@ -235,7 +235,7 @@ def assert_pipeline_single_stream_return(output, logprobs_num: int = 0):
def assert_pipeline_batch_stream_return(output, size: int = 1):
for i in range(size):
- output_list = [item for item in output if item.session_id == i]
+ output_list = [item for item in output if item.index == i]
result, msg = assert_pipeline_single_stream_return(output_list)
if not result:
return result, msg
@@ -249,7 +249,7 @@ def assert_pipeline_single_element(output,
result = True
result &= output.generate_token_len > 0
result &= output.input_token_len > 0
- result &= output.session_id >= 0
+ result &= output.index >= 0
if is_last:
result &= len(output.text) >= 0
result &= output.finish_reason in ['stop', 'length']
@@ -277,14 +277,14 @@ def assert_pipeline_single_element(output,
return result
-PIC1 = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg' # noqa E501
-PIC2 = 'https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg' # noqa E501
-PIC_BEIJING = 'https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Beijing_Small.jpeg' # noqa E501
-PIC_CHONGQING = 'https://raw.githubusercontent.com/QwenLM/Qwen-VL/master/assets/mm_tutorial/Chongqing_Small.jpeg' # noqa E501
-PIC_REDPANDA = 'https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image1.jpg' # noqa E501
-PIC_PANDA = 'https://raw.githubusercontent.com/OpenGVLab/InternVL/main/internvl_chat/examples/image2.jpg' # noqa E501
-DESC = 'What are the similarities and differences between these two images.' # noqa E501
-DESC_ZH = '两张图有什么相同和不同的地方.' # noqa E501
+PIC1 = 'tiger.jpeg'
+PIC2 = 'human-pose.jpg'
+PIC_BEIJING = 'Beijing_Small.jpeg'
+PIC_CHONGQING = 'Chongqing_Small.jpeg'
+PIC_REDPANDA = 'redpanda.jpg'
+PIC_PANDA = 'panda.jpg'
+DESC = 'What are the similarities and differences between these two images.'
+DESC_ZH = '两张图有什么相同和不同的地方.'
def run_pipeline_vl_chat_test(config,
@@ -296,6 +296,7 @@ def run_pipeline_vl_chat_test(config,
tp = get_tp_num(config, model_case)
model_path = config.get('model_path')
hf_path = model_path + '/' + model_case
+ resource_path = config.get('resource_path')
if 'pytorch' in backend:
backend_config = PytorchEngineConfig(tp=tp, session_len=8192)
@@ -320,7 +321,7 @@ def run_pipeline_vl_chat_test(config,
'pipeline_vl_chat_' + model_case.split('/')[1] + worker_id + '.log')
file = open(pipeline_chat_log, 'w')
- image = load_image(PIC1)
+ image = load_image(f'{resource_path}/{PIC1}')
if 'deepseek' in model_case:
prompt = f'describe this image{IMAGE_TOKEN}'
@@ -352,7 +353,7 @@ def run_pipeline_vl_chat_test(config,
}, {
'type': 'image_url',
'image_url': {
- 'url': PIC1
+ 'url': f'{resource_path}/{PIC1}'
}
}]
}]
@@ -362,7 +363,7 @@ def run_pipeline_vl_chat_test(config,
', reason: OpenAI format example: tiger not in ' +
response.text + '\n')
- image_urls = [PIC2, PIC1]
+ image_urls = [f'{resource_path}/{PIC2}', f'{resource_path}/{PIC1}']
images = [load_image(img_url) for img_url in image_urls]
response = pipe((prompt, images))
result = 'tiger' in response.text.lower() or 'ski' in response.text.lower(
@@ -371,7 +372,7 @@ def run_pipeline_vl_chat_test(config,
', reason: Multi-images example: tiger or ski not in ' +
response.text + '\n')
- image_urls = [PIC2, PIC1]
+ image_urls = [f'{resource_path}/{PIC2}', f'{resource_path}/{PIC1}']
prompts = [(prompt, load_image(img_url)) for img_url in image_urls]
response = pipe(prompts)
result = ('ski' in response[0].text.lower()
@@ -382,7 +383,7 @@ def run_pipeline_vl_chat_test(config,
', reason: Batch example: ski or tiger not in ' +
str(response) + '\n')
- image = load_image(PIC2)
+ image = load_image(f'{resource_path}/{PIC2}')
sess = pipe.chat((prompt, image))
result = 'ski' in sess.response.text.lower(
) or '滑雪' in sess.response.text.lower()
@@ -397,12 +398,12 @@ def run_pipeline_vl_chat_test(config,
sess.response.text + '\n')
if 'internvl' in model_case.lower():
- internvl_vl_testcase(config, pipe, file)
- internvl_vl_testcase(config, pipe, file, 'cn')
+ internvl_vl_testcase(pipe, file, resource_path)
+ internvl_vl_testcase(pipe, file, resource_path, 'cn')
if 'minicpm' in model_case.lower():
- MiniCPM_vl_testcase(config, pipe, file)
+ MiniCPM_vl_testcase(pipe, file, resource_path)
if 'qwen' in model_case.lower():
- Qwen_vl_testcase(config, pipe, file)
+ Qwen_vl_testcase(pipe, file, resource_path)
file.close()
@@ -410,7 +411,7 @@ def run_pipeline_vl_chat_test(config,
torch.cuda.empty_cache()
-def internvl_vl_testcase(config, pipe, file, lang='en'):
+def internvl_vl_testcase(pipe, file, resource_path, lang='en'):
if lang == 'cn':
description = DESC_ZH
else:
@@ -422,9 +423,11 @@ def internvl_vl_testcase(config, pipe, file, lang='en'):
dict(type='text',
text=f'{IMAGE_TOKEN}{IMAGE_TOKEN}\n{description}'),
dict(type='image_url',
- image_url=dict(max_dynamic_patch=12, url=PIC_REDPANDA)),
+ image_url=dict(max_dynamic_patch=12,
+ url=f'{resource_path}/{PIC_REDPANDA}')),
dict(type='image_url',
- image_url=dict(max_dynamic_patch=12, url=PIC_PANDA))
+ image_url=dict(max_dynamic_patch=12,
+ url=f'{resource_path}/{PIC_PANDA}'))
])
]
response = pipe(messages)
@@ -452,9 +455,11 @@ def internvl_vl_testcase(config, pipe, file, lang='en'):
+ # noqa E251,E501
description),
dict(type='image_url',
- image_url=dict(max_dynamic_patch=12, url=PIC_REDPANDA)),
+ image_url=dict(max_dynamic_patch=12,
+ url=f'{resource_path}/{PIC_REDPANDA}')),
dict(type='image_url',
- image_url=dict(max_dynamic_patch=12, url=PIC_PANDA))
+ image_url=dict(max_dynamic_patch=12,
+ url=f'{resource_path}/{PIC_PANDA}'))
])
]
response = pipe(messages)
@@ -501,8 +506,7 @@ def load_video(video_path, bound=None, num_segments=32):
imgs.append(img)
return imgs
- resource_path = config.get('resource_path')
- video_path = resource_path + '/red-panda.mp4'
+ video_path = f'{resource_path}/red-panda.mp4'
imgs = load_video(video_path, num_segments=8)
question = ''
@@ -546,14 +550,16 @@ def load_video(video_path, bound=None, num_segments=32):
response.text + '\n')
-def llava_vl_testcase(config, pipe, file):
+def llava_vl_testcase(pipe, file, resource_path):
# multi-image multi-round conversation, combined images
messages = [
dict(role='user',
content=[
dict(type='text', text='Describe the two images in detail.'),
- dict(type='image_url', image_url=dict(url=PIC_BEIJING)),
- dict(type='image_url', image_url=dict(url=PIC_CHONGQING))
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/{PIC_BEIJING}')),
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/{PIC_CHONGQING}'))
])
]
response = pipe(messages)
@@ -575,16 +581,18 @@ def llava_vl_testcase(config, pipe, file):
response.text + '\n')
-def MiniCPM_vl_testcase(config, pipe, file):
+def MiniCPM_vl_testcase(pipe, file, resource_path):
# Chat with multiple images
messages = [
dict(role='user',
content=[
dict(type='text', text='Describe the two images in detail.'),
dict(type='image_url',
- image_url=dict(max_slice_nums=9, url=PIC_REDPANDA)),
+ image_url=dict(max_slice_nums=9,
+ url=f'{resource_path}/{PIC_REDPANDA}')),
dict(type='image_url',
- image_url=dict(max_slice_nums=9, url=PIC_PANDA))
+ image_url=dict(max_slice_nums=9,
+ url=f'{resource_path}/{PIC_PANDA}'))
])
]
response = pipe(messages)
@@ -602,27 +610,27 @@ def MiniCPM_vl_testcase(config, pipe, file):
response.text + '\n')
# In-context few-shot learning
- EXAMPLE1 = 'https://github.com/user-attachments/assets/405d9147-95f6-4f78-8879-606a0aed6707' # noqa E251,E501
- EXAMPLE2 = 'https://github.com/user-attachments/assets/9f2c6ed9-2aa5-4189-9c4f-0b9753024ba1' # noqa E251,E501
- EXAMPLE3 = 'https://github.com/user-attachments/assets/f335b507-1957-4c22-84ae-ed69ff79df38' # noqa E251,E501
question = 'production date'
messages = [
dict(role='user',
content=[
dict(type='text', text=question),
- dict(type='image_url', image_url=dict(url=EXAMPLE1)),
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/data1.jpeg')),
]),
dict(role='assistant', content='2021.08.29'),
dict(role='user',
content=[
dict(type='text', text=question),
- dict(type='image_url', image_url=dict(url=EXAMPLE2)),
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/data2.jpeg')),
]),
dict(role='assistant', content='1999.05.15'),
dict(role='user',
content=[
dict(type='text', text=question),
- dict(type='image_url', image_url=dict(url=EXAMPLE3)),
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/data3.jpeg')),
])
]
response = pipe(messages)
@@ -651,8 +659,7 @@ def uniform_sample(length, n):
print('num frames:', len(frames))
return frames
- resource_path = config.get('resource_path')
- video_path = resource_path + '/red-panda.mp4'
+ video_path = f'{resource_path}red-panda.mp4'
frames = encode_video(video_path)
question = 'Describe the video'
@@ -675,14 +682,16 @@ def uniform_sample(length, n):
'\n')
-def Qwen_vl_testcase(config, pipe, file):
+def Qwen_vl_testcase(pipe, file, resource_path):
# multi-image multi-round conversation, combined images
messages = [
dict(role='user',
content=[
dict(type='text', text='Describe the two images in detail.'),
- dict(type='image_url', image_url=dict(url=PIC_BEIJING)),
- dict(type='image_url', image_url=dict(url=PIC_CHONGQING))
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/{PIC_BEIJING}')),
+ dict(type='image_url',
+ image_url=dict(url=f'{resource_path}/{PIC_CHONGQING}'))
])
]
response = pipe(messages)
@@ -713,11 +722,11 @@ def Qwen_vl_testcase(config, pipe, file):
dict(type='image_url',
image_url=dict(min_pixels=min_pixels,
max_pixels=max_pixels,
- url=PIC_BEIJING)),
+ url=f'{resource_path}/{PIC_BEIJING}')),
dict(type='image_url',
image_url=dict(min_pixels=min_pixels,
max_pixels=max_pixels,
- url=PIC_CHONGQING))
+ url=f'{resource_path}/{PIC_CHONGQING}'))
])
]
response = pipe(messages)
diff --git a/benchmark/profile_generation.py b/benchmark/profile_generation.py
index 952de5d9f7..6c33b8bc4b 100644
--- a/benchmark/profile_generation.py
+++ b/benchmark/profile_generation.py
@@ -1,11 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
+import asyncio
import csv
import os
import time
from dataclasses import dataclass
-from queue import Queue
-from threading import Thread
from typing import List, Union
import numpy as np
@@ -24,8 +23,9 @@
os.environ['TM_LOG_LEVEL'] = 'ERROR'
-def infer(model, session_id: int, input_ids: List,
- gen_config: GenerationConfig, test_round: int, que: Queue):
+async def infer(model, session_id: int, input_ids: List,
+ gen_config: GenerationConfig, test_round: int,
+ que: asyncio.Queue):
if session_id == 1:
pbar = tqdm(total=test_round)
chatbot = model.create_instance()
@@ -47,12 +47,12 @@ def infer(model, session_id: int, input_ids: List,
The time elapsing in this iteration `now-prev` is set to the latency of first token of
the 5 tokens, i.e. `token_latency_stats[0]`, and `token_latency_stats[1:4]` is set 0`
""" # noqa: E501
- for outputs in chatbot.stream_infer(session_id,
- input_ids,
- gen_config=gen_config,
- sequence_start=True,
- sequence_end=True,
- stream_output=True):
+ async for outputs in chatbot.async_stream_infer(session_id,
+ input_ids,
+ gen_config=gen_config,
+ sequence_start=True,
+ sequence_end=True,
+ stream_output=True):
n_token = outputs.num_token
now = time.perf_counter()
if n_prev_token != n_token:
@@ -61,7 +61,7 @@ def infer(model, session_id: int, input_ids: List,
prev = now
# for pytorch engine to restart a session
if hasattr(chatbot, 'end'):
- chatbot.end(session_id)
+ await chatbot.async_end(session_id)
if session_id == 1:
pbar.update(1)
@@ -69,39 +69,42 @@ def infer(model, session_id: int, input_ids: List,
f'Error. session_id({session_id}) request {output_seqlen} ' \
f'tokens, but generate {n_token} tokens'
stats.append(token_latency_stats[:output_seqlen])
- que.put((session_id, stats))
+ await que.put((session_id, stats))
def warmup(model, concurrency: int, input_ids: List[int], warmup_round: int,
- gen_config: GenerationConfig):
+ gen_config: GenerationConfig, event_loop: asyncio.BaseEventLoop):
if not warmup_round:
return
print('start to warmup ...')
- def _infer(model, session_id):
+ async def _infer(model, session_id):
chatbot = model.create_instance()
for _ in range(warmup_round):
- for _ in chatbot.stream_infer(session_id,
- input_ids=input_ids,
- sequence_start=True,
- sequence_end=True,
- ignore_eos=True,
- gen_config=gen_config):
+ async for _ in chatbot.async_stream_infer(session_id,
+ input_ids=input_ids,
+ sequence_start=True,
+ sequence_end=True,
+ ignore_eos=True,
+ gen_config=gen_config):
continue
# for pytorch engine to restart a session
if hasattr(chatbot, 'end'):
- chatbot.end(session_id)
+ await chatbot.async_end(session_id)
_start = time.perf_counter()
- procs = []
+
+ # start threads
+ tasks = []
for i in range(concurrency):
- proc = Thread(target=_infer, args=(model, i + 1), daemon=True)
- procs.append(proc)
- proc.start()
+ task = _infer(model, i + 1)
+ tasks.append(task)
+
+ async def _gather_tasks(tasks):
+ return await asyncio.gather(*tasks)
- for proc in procs:
- proc.join()
+ event_loop.run_until_complete(_gather_tasks(tasks))
_end = time.perf_counter()
print(f'end warmup, elapsed time: {round(_end - _start, 2)}s')
@@ -125,31 +128,34 @@ def profile_throughput(model_path: str, concurrency: int, input_seqlen: int,
from lmdeploy.pytorch.engine import Engine
tm_model = Engine(model_path, engine_config)
+ event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(event_loop)
+
# make up a dummy `input_ids` with the length of `input_seqlen` exactly
assert input_seqlen > 0, 'input_seqlen should > 0'
input_ids = np.random.randint(low=0, high=101, size=input_seqlen).tolist()
- warmup(tm_model, concurrency, input_ids, warmup_round, gen_config)
+ warmup(tm_model, concurrency, input_ids, warmup_round, gen_config,
+ event_loop)
- que = Queue()
- procs = []
+ que = asyncio.Queue()
_start = time.perf_counter()
+ tasks = []
for i in range(concurrency):
- proc = Thread(target=infer,
- args=(tm_model, i + 1, input_ids, gen_config, test_round,
- que))
- procs.append(proc)
- proc.start()
+ task = infer(tm_model, i + 1, input_ids, gen_config, test_round, que)
+ tasks.append(task)
+
+ async def _gather_tasks(tasks):
+ return await asyncio.gather(*tasks)
- for proc in procs:
- proc.join()
+ event_loop.run_until_complete(_gather_tasks(tasks))
_end = time.perf_counter()
elapsed_time = _end - _start
token_latency_stats = []
while not que.empty():
- _, _stats = que.get()
+ _, _stats = que.get_nowait()
token_latency_stats += _stats
# The shape is [concurrency*test_round, output_seqlen]
@@ -426,7 +432,6 @@ def main():
block_size=args.cache_block_seq_len,
session_len=session_len,
tp=args.tp,
- thread_safe=True,
eager_mode=args.eager_mode,
enable_prefix_caching=args.enable_prefix_caching,
dtype=args.dtype,
diff --git a/benchmark/profile_pipeline_api.py b/benchmark/profile_pipeline_api.py
index 764f78399c..be06d32ee2 100644
--- a/benchmark/profile_pipeline_api.py
+++ b/benchmark/profile_pipeline_api.py
@@ -1,11 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
-import csv
import json
import os
import random
-import time
-from collections import OrderedDict
from typing import List, Tuple
from tqdm import tqdm
@@ -14,6 +11,10 @@
from lmdeploy import (GenerationConfig, PytorchEngineConfig,
TurbomindEngineConfig, pipeline)
from lmdeploy.cli.utils import ArgumentHelper, DefaultsAndTypesHelpFormatter
+from lmdeploy.profiler import Profiler, Session
+from lmdeploy.utils import get_logger
+
+logger = get_logger('lmdeploy')
def sample_requests(dataset_path: str, num_requests: int,
@@ -66,91 +67,70 @@ def __init__(self, model_path: str, engine_config, csv: str):
self.csv = csv
- def process_request(self, requests, concurrency, temperature, top_p, top_k,
- stream_output):
+ def process_request(self, requests, profiler: Profiler, temperature, top_p,
+ top_k, stream_output):
- stats = OrderedDict(
- (session_id, None) for session_id in range(len(requests)))
prompts = [prompt for prompt, _, _ in requests]
gen_configs = [
GenerationConfig(temperature=temperature,
top_p=top_p,
top_k=top_k,
ignore_eos=True,
+ do_sample=False,
max_new_tokens=output_len)
for _, _, output_len in requests
]
- start = time.perf_counter()
+ sess: List[Session] = []
+ for _, input_len, output_len in requests:
+ sess.append(profiler.new_session(input_len, output_len))
+
+ def _to_status(finish_reason):
+ if finish_reason == 'length':
+ return Session.SUCCESS
+ else:
+ return Session.FAIL
+
+ profiler.start()
+
+ for s in sess:
+ s.tick(0)
+
if stream_output:
pbar = tqdm(total=len(requests))
for output in self.pipe.stream_infer(prompts,
gen_configs,
do_preprocess=False):
- session_id = output.session_id
+ index = output.index
n_token = output.generate_token_len
finish_reason = output.finish_reason
- stats[session_id] = (n_token, finish_reason)
+ sess[index].tick(n_token)
if finish_reason is not None:
+ sess[index].finish(_to_status(finish_reason))
pbar.update(1)
+ pbar.close()
else:
for output in self.pipe(prompts,
gen_configs,
do_preprocess=False,
use_tqdm=True):
- session_id = output.session_id
+ index = output.index
n_token = output.generate_token_len
finish_reason = output.finish_reason
- stats[session_id] = (n_token, finish_reason)
-
- elapsed_time = time.perf_counter() - start
-
- completion_tokens = 0
- for session_id, (n_token, finish_reason) in stats.items():
- assert finish_reason == 'length', \
- f'unexpected finish_reason of session_id={session_id}, ' \
- f'prompt={requests[session_id][0]}'
- assert n_token - 1 <= requests[session_id][-1] <= n_token, \
- f'request to generate {requests[session_id][-1]} tokens, ' \
- f'but got {n_token} tokens'
- completion_tokens += n_token
-
- prompt_tokens = 0
- for _, input_len, _ in requests:
- prompt_tokens += input_len
-
- completion_token_throughput = completion_tokens / elapsed_time
- total_token_throughput = (prompt_tokens +
- completion_tokens) / elapsed_time
- rps = len(requests) / elapsed_time
- rpm = rps * 60
-
- print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n'
- f'elapsed_time: {elapsed_time:.3f}s\n')
-
- print(
- f'number of prompts: {len(requests)}\n'
- f'number of prompt tokens: {prompt_tokens:.0f}\n'
- f'number of completion tokens: {completion_tokens:.0f}\n'
- f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa
- f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa
- f'RPS (request per second): {rps:.3f} req/s\n'
- f'RPM (request per minute): {rpm:.3f} req/min\n'
- f'{"-" * 50}\n')
-
- if self.csv:
- with open(self.csv, 'w') as csvfile:
- writer = csv.writer(csvfile)
- writer.writerow([
- 'batch', 'num_promts', 'RPS', 'RPM',
- 'throughput(out tok/s)', 'throughput(total tok/s)'
- ])
- writer.writerow([
- concurrency,
- len(requests), f'{rps:.3f}', f'{rpm:.3f}',
- f'{completion_token_throughput:.3f}',
- f'{total_token_throughput:.3f}'
- ])
+ sess[index].tick(n_token)
+ sess[index].finish(_to_status(finish_reason))
+
+ profiler.finish()
+
+ # report first failure
+ for i, s in enumerate(sess):
+ if s.status != Session.SUCCESS or s.ns[-1] < s.req_output_len:
+ logger.error(
+ f'Request {i} failed with {s.ns[-1]}/{s.req_output_len} tokens generated' # noqa: E501
+ )
+ logger.error(f'Prompt: {prompts[i]}')
+ logger.warning('Got failed requests, metrics may be invalid')
+ break
def parse_args():
@@ -252,13 +232,25 @@ def main():
requests = sample_requests(args.dataset, args.num_prompts,
engine.tokenizer)
+ profiler = Profiler(args.stream_output, [50, 75, 95, 99])
+
engine.process_request(requests,
+ profiler,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
- concurrency=args.concurrency,
stream_output=args.stream_output)
+ hyperparams = [('Concurrency', args.concurrency),
+ ('Stream output', str(args.stream_output).lower())]
+
+ profiler.compute_metrics()
+ profiler.summarize(title='Profile Pipeline API', hyperparams=hyperparams)
+
+ if args.csv:
+ profiler.save_csv(args.csv, (('batch', args.concurrency),
+ ('num_prompts', args.num_prompts)))
+
if __name__ == '__main__':
main()
diff --git a/benchmark/profile_throughput.py b/benchmark/profile_throughput.py
index 291b1be9b8..2e4d2a3b8c 100644
--- a/benchmark/profile_throughput.py
+++ b/benchmark/profile_throughput.py
@@ -1,20 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import asyncio
-import csv
import json
import os
import random
-import time
from queue import Queue
from typing import List, Tuple, Union
-import numpy as np
from tqdm import tqdm
from lmdeploy.cli.utils import ArgumentHelper, DefaultsAndTypesHelpFormatter
from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig,
TurbomindEngineConfig)
+from lmdeploy.profiler import Profiler, Session
from lmdeploy.pytorch.engine import EngineInstance
from lmdeploy.tokenizer import DetokenizeState, Tokenizer
from lmdeploy.utils import get_logger
@@ -71,7 +69,7 @@ class Engine:
def __init__(self, model_path: str,
engine_config: Union[PytorchEngineConfig,
- TurbomindEngineConfig], csv: str):
+ TurbomindEngineConfig]):
if isinstance(engine_config, TurbomindEngineConfig):
from lmdeploy.turbomind import TurboMind
tm_model = TurboMind.from_pretrained(model_path,
@@ -83,166 +81,104 @@ def __init__(self, model_path: str,
self.tm_model = tm_model
self.tokenizer = tm_model.tokenizer
- self.csv = csv
self.pbar = None
- async def _inference(self, req_queue: Queue, res_queue: Queue,
- session_id: int, temperature: float, top_p: float,
- top_k: int, stream_output: bool):
+ async def _inference(self, req_queue: Queue, session_id: int,
+ temperature: float, top_p: float, top_k: int,
+ stream_output: bool, skip_tokenize: bool,
+ skip_detokenize: bool):
model_inst = self.tm_model.create_instance()
- stats = []
- # get each generated token's latency
- per_token_latency_stats = []
- for prompt, input_seqlen, output_seqlen in iter(
- req_queue.get_nowait, [None, None, None]):
- _per_token_latency_stats = [0] * (output_seqlen + 1)
- prev = time.perf_counter()
- n_prev_token = 0
-
- input_ids = self.tokenizer(prompt).input_ids
+ sess: Session = None
+ for prompt, _, output_seqlen, cancel_after, sess in iter(
+ req_queue.get_nowait, None):
+
+ sess.tick(0)
+
+ if skip_tokenize:
+ input_ids = prompt
+ else:
+ input_ids = self.tokenizer(prompt).input_ids
+
state = DetokenizeState(len(input_ids))
- async for outputs in model_inst.async_stream_infer(
- session_id,
- input_ids=input_ids,
- gen_config=GenerationConfig(max_new_tokens=output_seqlen,
- temperature=temperature,
- top_p=top_p,
- top_k=top_k,
- ignore_eos=True),
- sequence_start=True,
- sequence_end=True,
- stream_output=stream_output):
- res, n_token = input_ids + outputs.token_ids, outputs.num_token
- _, state = self.tokenizer.detokenize_incrementally(res, state)
- now = time.perf_counter()
- if n_prev_token != n_token:
- _per_token_latency_stats[n_prev_token] = np.round(
- now - prev, 3)
- n_prev_token = n_token
- prev = now
+ prev_len = 0
+ token_ids = input_ids.copy()
+
+ generator = model_inst.async_stream_infer(
+ session_id,
+ input_ids=input_ids,
+ gen_config=GenerationConfig(max_new_tokens=output_seqlen,
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ ignore_eos=True),
+ sequence_start=True,
+ sequence_end=True,
+ stream_output=stream_output)
+ try:
+ async for outputs in generator:
+ n_token = outputs.num_token
+ if n_token > prev_len:
+ token_ids += outputs.token_ids[prev_len - n_token:]
+ if not skip_detokenize:
+ _, state = self.tokenizer.detokenize_incrementally(
+ token_ids, state)
+ sess.tick(n_token)
+ prev_len = n_token
+ if n_token > cancel_after:
+ break
+ sess.finish(Session.SUCCESS)
+ finally:
+ await generator.aclose()
+
# for pytorch engine to restart a session
if isinstance(model_inst, EngineInstance):
await model_inst.async_end(session_id)
- assert output_seqlen <= n_token <= output_seqlen + 1, \
- f'Error. session_id({session_id}) request {output_seqlen} ' \
- f'tokens, but generate {n_token} tokens.\n' \
- f'prompt: {prompt}'
-
- first_token_latency = _per_token_latency_stats[0]
- completion_tokens = n_token
- total_tokens = n_token + input_seqlen
- stats.append([
- first_token_latency, completion_tokens, output_seqlen,
- total_tokens
- ])
- # skip the first token latency
- per_token_latency_stats.append(_per_token_latency_stats[1:])
+
self.pbar.update(1)
- res_queue.put_nowait((session_id, stats, per_token_latency_stats))
- def process_request(self, requests, concurrency, temperature, top_p, top_k,
- stream_output):
- res_queue = Queue()
+ def process_request(self, requests, profiler: Profiler, concurrency,
+ temperature, top_p, top_k, stream_output,
+ skip_tokenize, skip_detokenize, cancel_rate):
req_queue = Queue()
- self.pbar = tqdm(total=len(requests))
-
# feed request to q
- for req in requests:
+ for prompt, input_len, output_len in requests:
+ cancel_after = output_len + 1
+ if cancel_rate > 0:
+ if random.random() < cancel_rate:
+ cancel_after = random.randint(0, cancel_after)
+ sess = profiler.new_session(input_len, output_len)
+ req = [prompt, input_len, output_len, cancel_after, sess]
+ if skip_tokenize:
+ req[0] = self.tokenizer.encode(prompt)
req_queue.put(req)
for i in range(concurrency):
- req_queue.put([None, None, None])
-
- start = time.time()
-
- event_loop = asyncio.new_event_loop()
- asyncio.set_event_loop(event_loop)
+ req_queue.put(None)
# start threads
tasks = []
for i in range(concurrency):
- task = self._inference(req_queue, res_queue, i, temperature, top_p,
- top_k, stream_output)
+ task = self._inference(req_queue, i, temperature, top_p, top_k,
+ stream_output, skip_tokenize,
+ skip_detokenize)
tasks.append(task)
async def _gather_tasks(tasks):
return await asyncio.gather(*tasks)
- event_loop.run_until_complete(_gather_tasks(tasks))
-
- elapsed_time = time.time() - start
-
- stats = []
- per_token_latency_stats = []
- while not res_queue.empty():
- session_id, _stats, _per_token_latency_stats = res_queue.get()
- stats.append(np.array(_stats))
- per_token_latency_stats += [
- item for sublist in _per_token_latency_stats
- for item in sublist
- ]
- stats = np.concatenate(stats).reshape(-1, 4)
-
- first_token_latency_min = np.min(stats[:, 0], axis=0)
- first_token_latency_max = np.max(stats[:, 0], axis=0)
- first_token_latency_ave = np.mean(stats[:, 0], axis=0)
- completion_tokens = np.sum(stats[:, 1], axis=0)
- total_tokens = np.sum(stats[:, 3], axis=0)
- prompt_tokens = total_tokens - completion_tokens
- completion_token_throughput = completion_tokens / elapsed_time
- total_token_throughput = total_tokens / elapsed_time
- rps = len(requests) / elapsed_time
- rpm = rps * 60
-
- per_token_latency_stats.sort()
- percentiles = [
- np.round(
- per_token_latency_stats[int(percent *
- len(per_token_latency_stats))], 3)
- for percent in [0.5, 0.75, 0.95, 0.99]
- ]
-
- print(f'\n{"-" * 50}\nconcurrency: {concurrency}\n'
- f'elapsed_time: {elapsed_time:.3f}s\n')
- if stream_output:
- print(f'first token latency(s)(min, max, ave): '
- f'{first_token_latency_min:.3f}, '
- f'{first_token_latency_max:.3f}, '
- f'{first_token_latency_ave:.3f}')
- print(f'per-token latency(s) percentile(50, 75, 95, 99): '
- f'{percentiles}\n')
- print(
- f'number of prompt tokens: {prompt_tokens:.0f}\n'
- f'number of completion tokens: {completion_tokens:.0f}\n'
- f'token throughput (completion token): {completion_token_throughput:.3f} token/s\n' # noqa
- f'token throughput (prompt + completion token): {total_token_throughput:.3f} token/s\n' # noqa
- f'RPS (request per second): {rps:.3f} req/s\n'
- f'RPM (request per minute): {rpm:.3f} req/min\n'
- f'{"-" * 50}\n')
-
- if self.csv:
- with open(self.csv, 'w') as csvfile:
- writer = csv.writer(csvfile)
- writer.writerow([
- 'batch', 'num_promts', 'RPS', 'RPM', 'FTL(ave)(s)',
- 'FTL(min)(s)', 'FTL(max)(s)', '50%(s)', '75%(s)', '95%(s)',
- '99%(s)', 'throughput(out tok/s)',
- 'throughput(total tok/s)'
- ])
- writer.writerow([
- concurrency,
- len(requests), f'{rps:.3f}', f'{rpm:.3f}',
- f'{first_token_latency_ave:.3f}' if stream_output else '-',
- f'{first_token_latency_min:.3f}' if stream_output else '-',
- f'{first_token_latency_max:.3f}' if stream_output else '-',
- f'{percentiles[0]:.3f}' if stream_output else '-',
- f'{percentiles[1]:.3f}' if stream_output else '-',
- f'{percentiles[2]:.3f}' if stream_output else '-',
- f'{percentiles[3]:.3f}' if stream_output else '-',
- f'{completion_token_throughput:.3f}',
- f'{total_token_throughput:.3f}'
- ])
+ self.pbar = tqdm(total=len(requests))
+
+ event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(event_loop)
+
+ profiler.start()
+
+ asyncio.run(_gather_tasks(tasks))
+
+ profiler.finish()
+
+ self.pbar.close()
def parse_args():
@@ -266,6 +202,20 @@ def parse_args():
type=int,
help='Number of prompts to process',
default=5000)
+ parser.add_argument('--no-stream-output',
+ action='store_true',
+ help='Use stream output')
+ parser.add_argument('--skip-tokenize',
+ action='store_true',
+ help='Pre-tokenize input prompts before starting')
+ parser.add_argument('--skip-detokenize',
+ action='store_true',
+ help='Skip detokenizing output tokens')
+ parser.add_argument('--cancel-rate',
+ type=float,
+ help='Possibility of a request being canceled',
+ default=0)
+ parser.add_argument('--use-uvloop', action='store_true')
parser.add_argument('--csv',
type=str,
help='Where to save the result.',
@@ -340,19 +290,42 @@ def main():
dtype=args.dtype,
)
- engine = Engine(args.model_path, engine_config, csv=args.csv)
+ if args.use_uvloop:
+ import uvloop
+ asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
+
+ engine = Engine(args.model_path, engine_config)
requests = sample_requests(args.dataset, args.num_prompts,
engine.tokenizer)
+ stream_output = not args.no_stream_output
+
+ profiler = Profiler(stream_output, [50, 75, 95, 99])
+
engine.process_request(
requests,
+ profiler,
temperature=args.temperature,
top_p=args.top_p,
top_k=args.top_k,
concurrency=args.concurrency
if args.concurrency < args.num_prompts else args.num_prompts,
- stream_output=True)
+ stream_output=not args.no_stream_output,
+ skip_tokenize=args.skip_tokenize,
+ skip_detokenize=args.skip_detokenize,
+ cancel_rate=args.cancel_rate)
+
+ hyperparams = [('Concurrency', args.concurrency),
+ ('Cancel rate', args.cancel_rate),
+ ('Stream output', str(stream_output).lower()),
+ ('Skip tokenize', str(args.skip_tokenize).lower()),
+ ('Skip detokenize', str(args.skip_detokenize).lower())]
+ profiler.compute_metrics()
+ profiler.summarize(title='Profile Throughput', hyperparams=hyperparams)
+ if args.csv:
+ profiler.save_csv(args.csv, (('batch', args.concurrency),
+ ('num_prompts', args.num_prompts)))
if __name__ == '__main__':
diff --git a/docker/Dockerfile_aarch64_ascend b/docker/Dockerfile_aarch64_ascend
index ecc2d1334e..5ed842061c 100644
--- a/docker/Dockerfile_aarch64_ascend
+++ b/docker/Dockerfile_aarch64_ascend
@@ -121,5 +121,4 @@ COPY --from=copy_temp /tmp /opt/lmdeploy
WORKDIR /opt/lmdeploy
RUN --mount=type=cache,target=/root/.cache/pip \
- sed -i '/triton/d' requirements/runtime.txt && \
LMDEPLOY_TARGET_DEVICE=ascend pip3 install -v --no-build-isolation -e .
diff --git a/docs/en/advance/pytorch_multithread.md b/docs/en/advance/pytorch_multithread.md
new file mode 100644
index 0000000000..446e0fa769
--- /dev/null
+++ b/docs/en/advance/pytorch_multithread.md
@@ -0,0 +1,78 @@
+# PyTorchEngine Multithread
+
+We have removed `thread_safe` mode from PytorchEngine since [PR2907](https://github.com/InternLM/lmdeploy/pull/2907). We encourage users to achieve high concurrency by using **service API** or **coroutines** whenever possible, for example:
+
+```python
+import asyncio
+from lmdeploy import pipeline, PytorchEngineConfig
+
+event_loop = asyncio.new_event_loop()
+asyncio.set_event_loop(event_loop)
+
+model_path = 'Llama-3.2-1B-Instruct'
+pipe = pipeline(model_path, backend_config=PytorchEngineConfig())
+
+async def _gather_output():
+ tasks = [
+ pipe.async_batch_infer('Hakuna Matata'),
+ pipe.async_batch_infer('giraffes are heartless creatures'),
+ ]
+ return await asyncio.gather(*tasks)
+
+output = asyncio.run(_gather_output())
+print(output[0].text)
+print(output[1].text)
+```
+
+If you do need multithreading, it would be easy to warp it like below:
+
+```python
+import threading
+from queue import Queue
+import asyncio
+from lmdeploy import pipeline, PytorchEngineConfig
+
+model_path = 'Llama-3.2-1B-Instruct'
+
+
+async def _batch_infer(inque: Queue, outque: Queue, pipe):
+ while True:
+ if inque.empty():
+ await asyncio.sleep(0)
+ continue
+
+ input = inque.get_nowait()
+ output = await pipe.async_batch_infer(input)
+ outque.put(output)
+
+
+def server(inques, outques):
+ event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(event_loop)
+ pipe = pipeline(model_path, backend_config=PytorchEngineConfig())
+ for inque, outque in zip(inques, outques):
+ event_loop.create_task(_batch_infer(inque, outque, pipe))
+ event_loop.run_forever()
+
+def client(inque, outque, message):
+ inque.put(message)
+ print(outque.get().text)
+
+
+inques = [Queue(), Queue()]
+outques = [Queue(), Queue()]
+
+t_server = threading.Thread(target=server, args=(inques, outques))
+t_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata'))
+t_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures'))
+
+t_server.start()
+t_client0.start()
+t_client1.start()
+
+t_client0.join()
+t_client1.join()
+```
+
+> \[!WARNING\]
+> This is NOT recommended, as multithreading introduces additional overhead, leading to unstable inference performance.
diff --git a/docs/en/get_started/ascend/get_started.md b/docs/en/get_started/ascend/get_started.md
index d104477ca1..7da28b5512 100644
--- a/docs/en/get_started/ascend/get_started.md
+++ b/docs/en/get_started/ascend/get_started.md
@@ -18,7 +18,7 @@ cd lmdeploy
### Environment Preparation
-The Docker version is supposed to be no less than `18.03`. And `Ascend Docker Runtime` should be installed by following [the official guide](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/clusterscheduling/clusterschedulingig/.clusterschedulingig/dlug_installation_012.html).
+The Docker version is supposed to be no less than `18.09`. And `Ascend Docker Runtime` should be installed by following [the official guide](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/clusterscheduling/clusterschedulingig/.clusterschedulingig/dlug_installation_012.html).
> \[!CAUTION\]
> If error message `libascend_hal.so: cannot open shared object file` shows, that means **Ascend Docker Runtime** is not installed correctly!
diff --git a/docs/en/get_started/installation.md b/docs/en/get_started/installation.md
index c00111c2ab..8877d510cc 100644
--- a/docs/en/get_started/installation.md
+++ b/docs/en/get_started/installation.md
@@ -23,7 +23,7 @@ pip install lmdeploy
The default prebuilt package is compiled on **CUDA 12**. If CUDA 11+ (>=11.3) is required, you can install lmdeploy by:
```shell
-export LMDEPLOY_VERSION=0.6.4
+export LMDEPLOY_VERSION=0.6.5
export PYTHON_VERSION=38
pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118
```
diff --git a/docs/en/index.rst b/docs/en/index.rst
index 5d49e01c86..54a36c22c8 100644
--- a/docs/en/index.rst
+++ b/docs/en/index.rst
@@ -103,6 +103,7 @@ Documentation
advance/chat_template.md
advance/debug_turbomind.md
advance/structed_output.md
+ advance/pytorch_multithread.md
.. toctree::
:maxdepth: 1
diff --git a/docs/en/llm/pipeline.md b/docs/en/llm/pipeline.md
index 887e2e0a3e..5ab9d04c1c 100644
--- a/docs/en/llm/pipeline.md
+++ b/docs/en/llm/pipeline.md
@@ -6,7 +6,7 @@ You can overview the detailed pipeline API in [this](https://lmdeploy.readthedoc
## Usage
-- **An example using default parameters:**
+### A 'Hello, world' example
```python
from lmdeploy import pipeline
@@ -40,7 +40,7 @@ There have been alterations to the strategy for setting the k/v cache ratio thro
The allocation strategy for k/v cache is changed to reserve space from the **GPU free memory** proportionally. The ratio `TurbomindEngineConfig.cache_max_entry_count` has been adjusted to 0.8 by default. If OOM error happens, similar to the method mentioned above, please consider reducing the ratio value to decrease the memory usage of the k/v cache.
-- **An example showing how to set tensor parallel num**:
+### Set tensor parallelism
```python
from lmdeploy import pipeline, TurbomindEngineConfig
@@ -52,7 +52,7 @@ response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
print(response)
```
-- **An example for setting sampling parameters:**
+### Set sampling parameters
```python
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
@@ -69,7 +69,7 @@ response = pipe(['Hi, pls intro yourself', 'Shanghai is'],
print(response)
```
-- **An example for OpenAI format prompt input:**
+### Apply OpenAI format prompt
```python
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
@@ -93,7 +93,7 @@ response = pipe(prompts,
print(response)
```
-- **An example for streaming mode:**
+### Apply streaming output
```python
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
@@ -116,31 +116,60 @@ for item in pipe.stream_infer(prompts, gen_config=gen_config):
print(item)
```
-- **An example to cauculate logits & ppl:**
+### Get logits for generated tokens
+
+```python
+from lmdeploy import pipeline, GenerationConfig
+
+pipe = pipeline('internlm/internlm2_5-7b-chat')
+
+gen_config=GenerationConfig(output_logits='generation'
+ max_new_tokens=10)
+response = pipe(['Hi, pls intro yourself', 'Shanghai is'],
+ gen_config=gen_config)
+logits = [x.logits for x in response]
+```
+
+### Get last layer's hidden states for generated tokens
+
+```python
+from lmdeploy import pipeline, GenerationConfig
+
+pipe = pipeline('internlm/internlm2_5-7b-chat')
+
+gen_config=GenerationConfig(output_last_hidden_state='generation',
+ max_new_tokens=10)
+response = pipe(['Hi, pls intro yourself', 'Shanghai is'],
+ gen_config=gen_config)
+hidden_states = [x.last_hidden_state for x in response]
+```
+
+### Calculate ppl
```python
from transformers import AutoTokenizer
from lmdeploy import pipeline
-model_repoid_or_path='internlm/internlm2_5-7b-chat'
+
+
+model_repoid_or_path = 'internlm/internlm2_5-7b-chat'
pipe = pipeline(model_repoid_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)
-
-# logits
messages = [
{"role": "user", "content": "Hello, how are you?"},
]
input_ids = tokenizer.apply_chat_template(messages)
-logits = pipe.get_logits(input_ids)
-# ppl
+# ppl is a list of float numbers
ppl = pipe.get_ppl(input_ids)
+print(ppl)
```
```{note}
-get_ppl returns the cross entropy loss without applying the exponential operation afterwards
+- When input_ids is too long, an OOM (Out Of Memory) error may occur. Please apply it with caution
+- get_ppl returns the cross entropy loss without applying the exponential operation afterwards
```
-- **Below is an example for pytorch backend. Please install triton first.**
+### Use PyTorchEngine
```shell
pip install triton>=2.1.0
@@ -167,7 +196,7 @@ response = pipe(prompts, gen_config=gen_config)
print(response)
```
-- **An example for lora.**
+### Inference with LoRA
```python
from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig
diff --git a/docs/en/multi_modal/vl_pipeline.md b/docs/en/multi_modal/vl_pipeline.md
index 9632c9e6df..1b41e46601 100644
--- a/docs/en/multi_modal/vl_pipeline.md
+++ b/docs/en/multi_modal/vl_pipeline.md
@@ -4,7 +4,7 @@ LMDeploy abstracts the complex inference process of multi-modal Vision-Language
The supported models are listed [here](../supported_models/supported_models.md). We genuinely invite the community to contribute new VLM support to LMDeploy. Your involvement is truly appreciated.
-This article showcases the VLM pipeline using the [liuhaotian/llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) model as a case study.
+This article showcases the VLM pipeline using the [OpenGVLab/InternVL2_5-8B](https://huggingface.co/OpenGVLab/InternVL2_5-8B) model as a case study.
You'll learn about the simplest ways to leverage the pipeline and how to gradually unlock more advanced features by adjusting engine parameters and generation arguments, such as tensor parallelism, context window sizing, random sampling, and chat template customization.
Moreover, we will provide practical inference examples tailored to scenarios with multiple images, batch prompts etc.
@@ -16,7 +16,7 @@ Using the pipeline interface to infer other VLM models is similar, with the main
from lmdeploy import pipeline
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b')
+pipe = pipeline('OpenGVLab/InternVL2_5-8B')
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image))
@@ -30,7 +30,7 @@ In the above example, the inference prompt is a tuple structure consisting of (p
```python
from lmdeploy import pipeline
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b')
+pipe = pipeline('OpenGVLab/InternVL2_5-8B')
prompts = [
{
@@ -53,7 +53,7 @@ Tensor paramllelism can be activated by setting the engine parameter `tp`
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
+pipe = pipeline('OpenGVLab/InternVL2_5-8B',
backend_config=TurbomindEngineConfig(tp=2))
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
@@ -69,7 +69,7 @@ When creating the pipeline, you can customize the size of the context window by
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
+pipe = pipeline('OpenGVLab/InternVL2_5-8B',
backend_config=TurbomindEngineConfig(session_len=8192))
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
@@ -85,7 +85,7 @@ You can change the default sampling parameters of pipeline by passing `Generatio
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
+pipe = pipeline('OpenGVLab/InternVL2_5-8B',
backend_config=TurbomindEngineConfig(tp=2, session_len=8192))
gen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.6)
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
@@ -139,22 +139,19 @@ response = pipe(('describe this image', image))
print(response)
```
-### Calculate logits
-
-We provide support for custom inputs. Users can utilize 'prepare_inputs' to understand how the inputs are organized.
+### Output logits for generated tokens
```python
-from lmdeploy import pipeline, TurbomindEngineConfig
+from lmdeploy import pipeline, GenerationConfig
from lmdeploy.vl import load_image
-pipe = pipeline('internlm/internlm-xcomposer2-7b', backend_config=TurbomindEngineConfig(cache_max_entry_count=0.5))
+pipe = pipeline('OpenGVLab/InternVL2_5-8B')
-# logits
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
-inputs = pipe.prepare_inputs(('describe this image', image))
-input_ids = inputs['input_ids']
-embeddings = inputs['input_embeddings']
-embedding_ranges = inputs['input_embedding_ranges']
-logits = pipe.get_logits(input_ids, embeddings, embedding_ranges)
+
+response = pipe(('describe this image', image),
+ gen_config=GenerationConfig(output_logits='generation'))
+logits = response.logits
+print(logits)
```
## Multi-images inference
@@ -165,7 +162,7 @@ When dealing with multiple images, you can put them all in one list. Keep in min
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
+pipe = pipeline('OpenGVLab/InternVL2_5-8B',
backend_config=TurbomindEngineConfig(session_len=8192))
image_urls=[
@@ -186,7 +183,7 @@ Conducting inference with batch prompts is quite straightforward; just place the
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
+pipe = pipeline('OpenGVLab/InternVL2_5-8B',
backend_config=TurbomindEngineConfig(session_len=8192))
image_urls=[
@@ -206,7 +203,7 @@ There are two ways to do the multi-turn conversations with the pipeline. One is
from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
+pipe = pipeline('OpenGVLab/InternVL2_5-8B',
backend_config=TurbomindEngineConfig(session_len=8192))
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg')
diff --git a/docs/en/quantization/w8a8.md b/docs/en/quantization/w8a8.md
index 1b1726bd5f..5cdb48f764 100644
--- a/docs/en/quantization/w8a8.md
+++ b/docs/en/quantization/w8a8.md
@@ -1,55 +1,74 @@
# SmoothQuant
-LMDeploy provides functions for quantization and inference of large language models using 8-bit integers.
+LMDeploy provides functions for quantization and inference of large language models using 8-bit integers(INT8). For GPUs such as Nvidia H100, lmdeploy also supports 8-bit floating point(FP8).
-Before starting inference, ensure that lmdeploy and openai/triton are correctly installed. Execute the following commands to install these:
+And the following NVIDIA GPUs are available for INT8/FP8 inference respectively:
+
+- INT8
+ - V100(sm70): V100
+ - Turing(sm75): 20 series, T4
+ - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100
+ - Ada Lovelace(sm89): 40 series
+ - Hopper(sm90): H100
+- FP8
+ - Ada Lovelace(sm89): 40 series
+ - Hopper(sm90): H100
+
+First of all, run the following command to install lmdeploy:
```shell
-pip install lmdeploy
-pip install triton>=2.1.0
+pip install lmdeploy[all]
```
-## 8-bit Weight Model Inference
+## 8-bit Weight Quantization
-For performing 8-bit weight model inference, you can directly download the pre-quantized 8-bit weight models from LMDeploy's [model zoo](https://huggingface.co/lmdeploy). For instance, the 8-bit Internlm-chat-7B model is available for direct download from the model zoo:
+Performing 8-bit weight quantization involves three steps:
-```shell
-git-lfs install
-git clone https://huggingface.co/lmdeploy/internlm-chat-7b-w8 (coming soon)
-```
+1. **Smooth Weights**: Start by smoothing the weights of the Language Model (LLM). This process makes the weights more amenable to quantizing.
+2. **Replace Modules**: Locate DecoderLayers and replace the modules RSMNorm and nn.Linear with QRSMNorm and QLinear modules respectively. These 'Q' modules are available in the lmdeploy/pytorch/models/q_modules.py file.
+3. **Save the Quantized Model**: Once you've made the necessary replacements, save the new quantized model.
-Alternatively, you can manually convert original 16-bit weights into 8-bit by referring to the content under the ["8bit Weight Quantization"](#8bit-weight-quantization) section. Save them in the internlm-chat-7b-w8 directory, using the command below:
+lmdeploy provides `lmdeploy lite smooth_quant` command to accomplish all three tasks detailed above. Note that the argument `--quant-dtype` is used to determine if you are doing int8 or fp8 weight quantization. To get more info about usage of the cli, run `lmdeploy lite smooth_quant --help`
-```shell
-lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8
-```
+Here are two examples:
-Afterwards, use the following command to interact with the model via the terminal:
+- int8
-```shell
-lmdeploy chat ./internlm-chat-7b-w8 --backend pytorch
-```
+ ```shell
+ lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8 --quant-dtype int8
+ ```
-## Launching gradio service
+- fp8
-Coming soon...
+ ```shell
+ lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8 --quant-dtype fp8
+ ```
-## Inference Speed
+## Inference
-Coming soon...
+Trying the following codes, you can perform the batched offline inference with the quantized model:
-## 8bit Weight Quantization
+```python
+from lmdeploy import pipeline, PytorchEngineConfig
-Performing 8bit weight quantization involves three steps:
+engine_config = PytorchEngineConfig(tp=1)
+pipe = pipeline("internlm2_5-7b-chat-int8", backend_config=engine_config)
+response = pipe(["Hi, pls intro yourself", "Shanghai is"])
+print(response)
+```
-1. **Smooth Weights**: Start by smoothing the weights of the Language Model (LLM). This process makes the weights more amenable to quantizing.
-2. **Replace Modules**: Locate DecoderLayers and replace the modules RSMNorm and nn.Linear with QRSMNorm and QLinear modules respectively. These 'Q' modules are available in the lmdeploy/pytorch/models/q_modules.py file.
-3. **Save the Quantized Model**: Once you've made the necessary replacements, save the new quantized model.
+## Service
+
+LMDeploy's `api_server` enables models to be easily packed into services with a single command. The provided RESTful APIs are compatible with OpenAI's interfaces. Below are an example of service startup:
+
+```shell
+lmdeploy serve api_server ./internlm2_5-7b-chat-int8 --backend pytorch
+```
-The script `lmdeploy/lite/apis/smooth_quant.py` accomplishes all three tasks detailed above. For example, you can obtain the model weights of the quantized Internlm-chat-7B model by running the following command:
+The default port of `api_server` is `23333`. After the server is launched, you can communicate with server on terminal through `api_client`:
```shell
-lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8
+lmdeploy serve api_client http://0.0.0.0:23333
```
-After saving, you can instantiate your quantized model by calling the from_pretrained interface.
+You can overview and try out `api_server` APIs online by swagger UI at `http://0.0.0.0:23333`, or you can also read the API specification from [here](../llm/api_server.md).
diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md
index dd8ceb4ffa..cb9805bb0b 100644
--- a/docs/en/supported_models/supported_models.md
+++ b/docs/en/supported_models/supported_models.md
@@ -4,104 +4,107 @@ The following tables detail the models supported by LMDeploy's TurboMind engine
## TurboMind on CUDA Platform
-| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 |
-| :-------------------: | :------------: | :--: | :-------: | :-----: | :-----: | :---: |
-| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes |
-| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3.2 | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes |
-| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
-| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
-| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes |
-| InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes |
-| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes |
-| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes |
-| Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes |
-| Qwen2 | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes |
-| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes |
-| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes |
-| Mistral | 7B | LLM | Yes | Yes | Yes | No |
-| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes |
-| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No |
-| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No |
-| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
-| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
-| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes |
-| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes |
-| Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No |
-| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes |
-| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes |
-| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes |
-| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes |
-| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes |
-| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes |
-| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes |
-| MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes |
-| GLM4 | 9B | LLM | Yes | Yes | Yes | Yes |
-| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - |
-| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No |
+| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 |
+| :------------------------------: | :--------------: | :--: | :-------: | :-----: | :-----: | :---: |
+| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes |
+| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes |
+| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
+| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
+| Llama3.2\[2\] | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes |
+| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
+| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
+| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes |
+| InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes |
+| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes |
+| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes |
+| Qwen1.5\[1\] | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes |
+| Qwen2\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes |
+| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes |
+| Qwen2.5\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes |
+| Mistral\[1\] | 7B | LLM | Yes | Yes | Yes | No |
+| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes |
+| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No |
+| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No |
+| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
+| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
+| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes |
+| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes |
+| Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No |
+| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes |
+| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes |
+| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes |
+| InternVL2\[2\] | 1 - 2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes |
+| InternVL2.5(MPO)\[2\] | 1 - 78B | MLLM | Yes | Yes\* | Yes\* | Yes |
+| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes |
+| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes |
+| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes |
+| MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes |
+| GLM4 | 9B | LLM | Yes | Yes | Yes | Yes |
+| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - |
+| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No |
"-" means not verified yet.
```{note}
-* The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference.
-* When the head_dim of a model is not 128, such as llama3.2-1B, qwen2-0.5B and internvl2-1B, turbomind doesn't support its kv cache 4/8 bit quantization and inference
+* [1] The TurboMind engine doesn't support window attention. Therefore, for models that have applied window attention and have the corresponding switch "use_sliding_window" enabled, such as Mistral, Qwen1.5 and etc., please choose the PyTorch engine for inference.
+* [2] When the head_dim of a model is not 128, such as llama3.2-1B, qwen2-0.5B and internvl2-1B, turbomind doesn't support its kv cache 4/8 bit quantization and inference
```
## PyTorchEngine on CUDA Platform
-| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 |
-| :------------: | :---------: | :--: | :-------: | :-----: | :-----: | :--: | :---: |
-| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | - | - |
-| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
-| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
-| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No |
-| Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No |
-| ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No |
-| Falcon | 7B - 180B | LLM | Yes | Yes | Yes | No | No |
-| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No |
-| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes |
-| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes |
-| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No |
-| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
-| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
-| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | No |
-| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
-| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
-| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No |
-| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No |
-| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes |
-| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No |
-| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No |
-| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No |
-| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - |
-| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - |
-| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - |
-| LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | - | - |
-| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes |
-| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | - | - |
-| Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | - | - |
-| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - |
-| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - |
-| GLM4 | 9B | LLM | Yes | Yes | Yes | No | No |
-| GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | No |
-| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - |
-| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - |
-| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - |
-| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - |
+| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 |
+| :----------------------------: | :---------: | :--: | :-------: | :-----: | :-----: | :--: | :---: |
+| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | - | - |
+| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
+| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
+| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No |
+| Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No |
+| ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No |
+| Falcon | 7B - 180B | LLM | Yes | Yes | Yes | No | No |
+| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No |
+| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes |
+| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes |
+| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No |
+| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
+| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
+| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
+| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
+| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
+| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No |
+| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No |
+| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes |
+| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No |
+| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No |
+| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No |
+| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - |
+| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - |
+| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - |
+| LLaVA(1.5,1.6)\[2\] | 7B-34B | MLLM | No | No | No | No | No |
+| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes |
+| InternVL2 | 1B-76B | MLLM | Yes | Yes | Yes | - | - |
+| InternVL2.5(MPO) | 1B-78B | MLLM | Yes | Yes | Yes | - | - |
+| Mono-InternVL\[1\] | 2B | MLLM | Yes | Yes | Yes | - | - |
+| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - |
+| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - |
+| GLM4 | 9B | LLM | Yes | Yes | Yes | No | No |
+| GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | Yes |
+| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - |
+| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - |
+| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - |
+| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - |
```{note}
-* Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead.
+* [1] Currently Mono-InternVL does not support FP16 due to numerical instability. Please use BF16 instead.
+* [2] PyTorch engine removes the support of original llava models after v0.6.4. Please use their corresponding transformers models instead, which can be found in https://huggingface.co/llava-hf
```
## PyTorchEngine on Huawei Ascend Platform
diff --git a/docs/zh_cn/advance/pytorch_multithread.md b/docs/zh_cn/advance/pytorch_multithread.md
new file mode 100644
index 0000000000..ebd68f503e
--- /dev/null
+++ b/docs/zh_cn/advance/pytorch_multithread.md
@@ -0,0 +1,78 @@
+# PyTorchEngine 多线程推理
+
+自 [PR2907](https://github.com/InternLM/lmdeploy/pull/2907) 起,我们废除了 PytorchEngine 的 thread_safe 模式以保证引擎能够更高效的运行。我们鼓励用户尽可能使用**服务接口**或**协程**来实现高并发,比如:
+
+```python
+import asyncio
+from lmdeploy import pipeline, PytorchEngineConfig
+
+event_loop = asyncio.new_event_loop()
+asyncio.set_event_loop(event_loop)
+
+model_path = 'Llama-3.2-1B-Instruct'
+pipe = pipeline(model_path, backend_config=PytorchEngineConfig())
+
+async def _gather_output():
+ tasks = [
+ pipe.async_batch_infer('Hakuna Matata'),
+ pipe.async_batch_infer('giraffes are heartless creatures'),
+ ]
+ return await asyncio.gather(*tasks)
+
+output = asyncio.run(_gather_output())
+print(output[0].text)
+print(output[1].text)
+```
+
+如果你确实有多线程推理的需求,那么可以进行简单的封装,来实现类似的效果。
+
+```python
+import threading
+from queue import Queue
+import asyncio
+from lmdeploy import pipeline, PytorchEngineConfig
+
+model_path = 'Llama-3.2-1B-Instruct'
+
+
+async def _batch_infer(inque: Queue, outque: Queue, pipe):
+ while True:
+ if inque.empty():
+ await asyncio.sleep(0)
+ continue
+
+ input = inque.get_nowait()
+ output = await pipe.async_batch_infer(input)
+ outque.put(output)
+
+
+def server(inques, outques):
+ event_loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(event_loop)
+ pipe = pipeline(model_path, backend_config=PytorchEngineConfig())
+ for inque, outque in zip(inques, outques):
+ event_loop.create_task(_batch_infer(inque, outque, pipe))
+ event_loop.run_forever()
+
+def client(inque, outque, message):
+ inque.put(message)
+ print(outque.get().text)
+
+
+inques = [Queue(), Queue()]
+outques = [Queue(), Queue()]
+
+t_server = threading.Thread(target=server, args=(inques, outques))
+t_client0 = threading.Thread(target=client, args=(inques[0], outques[0], 'Hakuna Matata'))
+t_client1 = threading.Thread(target=client, args=(inques[1], outques[1], 'giraffes are heartless creatures'))
+
+t_server.start()
+t_client0.start()
+t_client1.start()
+
+t_client0.join()
+t_client1.join()
+```
+
+> \[!WARNING\]
+> 我们不鼓励这样实现,多线程会带来额外的开销,使得推理性能不稳定
diff --git a/docs/zh_cn/get_started/ascend/get_started.md b/docs/zh_cn/get_started/ascend/get_started.md
index 9f0a7b1f90..e4790253cd 100644
--- a/docs/zh_cn/get_started/ascend/get_started.md
+++ b/docs/zh_cn/get_started/ascend/get_started.md
@@ -17,7 +17,7 @@ cd lmdeploy
### 环境准备
-Docker 版本应不低于 18.03。并且需按照[官方指南](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/clusterscheduling/clusterschedulingig/clusterschedulingig/dlug_installation_012.html)安装 Ascend Docker Runtime。
+Docker 版本应不低于 18.09。并且需按照[官方指南](https://www.hiascend.com/document/detail/zh/mindx-dl/60rc2/clusterscheduling/clusterschedulingig/clusterschedulingig/dlug_installation_012.html)安装 Ascend Docker Runtime。
> \[!CAUTION\]
> 如果在后续容器内出现`libascend_hal.so: cannot open shared object file`错误,说明Ascend Docker Runtime没有被正确安装。
diff --git a/docs/zh_cn/get_started/installation.md b/docs/zh_cn/get_started/installation.md
index 0213fa6d15..501f8a13e8 100644
--- a/docs/zh_cn/get_started/installation.md
+++ b/docs/zh_cn/get_started/installation.md
@@ -23,7 +23,7 @@ pip install lmdeploy
默认的预构建包是在 **CUDA 12** 上编译的。如果需要 CUDA 11+ (>=11.3),你可以使用以下命令安装 lmdeploy:
```shell
-export LMDEPLOY_VERSION=0.6.4
+export LMDEPLOY_VERSION=0.6.5
export PYTHON_VERSION=38
pip install https://github.com/InternLM/lmdeploy/releases/download/v${LMDEPLOY_VERSION}/lmdeploy-${LMDEPLOY_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux2014_x86_64.whl --extra-index-url https://download.pytorch.org/whl/cu118
```
diff --git a/docs/zh_cn/index.rst b/docs/zh_cn/index.rst
index 018a00487f..197e800d58 100644
--- a/docs/zh_cn/index.rst
+++ b/docs/zh_cn/index.rst
@@ -104,6 +104,7 @@ LMDeploy 工具箱提供以下核心功能:
advance/chat_template.md
advance/debug_turbomind.md
advance/structed_output.md
+ advance/pytorch_multithread.md
.. toctree::
:maxdepth: 1
diff --git a/docs/zh_cn/llm/pipeline.md b/docs/zh_cn/llm/pipeline.md
index 40406c85a4..012e6b3abb 100644
--- a/docs/zh_cn/llm/pipeline.md
+++ b/docs/zh_cn/llm/pipeline.md
@@ -6,7 +6,7 @@ pipeline API 详细的接口说明,请阅读[此处](https://lmdeploy.readthed
## 使用方法
-- **使用默认参数的例子:**
+### "Hello, world" 示例
```python
from lmdeploy import pipeline
@@ -40,7 +40,7 @@ LMDeploy 在研发过程中,k/v cache 比例的设定策略有变更,以下
分配策略改为从**空闲显存**中按比例为 k/v cache 开辟空间。默认比例值调整为 0.8。如果遇到 OOM,类似上面的方法,请酌情减少比例值,降低 k/v cache 的内存占用量
-- **如何设置 tp:**
+### 设置多卡并行
```python
from lmdeploy import pipeline, TurbomindEngineConfig
@@ -52,7 +52,7 @@ response = pipe(['Hi, pls intro yourself', 'Shanghai is'])
print(response)
```
-- **如何设置 sampling 参数:**
+### 设置随机采样参数
```python
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
@@ -69,7 +69,7 @@ response = pipe(['Hi, pls intro yourself', 'Shanghai is'],
print(response)
```
-- **如何设置 OpenAI 格式输入:**
+### 使用 OpenAI 格式的 prompt
```python
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
@@ -93,7 +93,7 @@ response = pipe(prompts,
print(response)
```
-- **流式返回处理结果:**
+### 流式输出
```python
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
@@ -116,31 +116,64 @@ for item in pipe.stream_infer(prompts, gen_config=gen_config):
print(item)
```
-- **计算 logits & ppl:**
+### 获取生成 token 的 logits
+
+```python
+from lmdeploy import pipeline, GenerationConfig
+
+pipe = pipeline('internlm/internlm2_5-7b-chat')
+
+gen_config=GenerationConfig(output_logits='generation'
+ max_new_tokens=10)
+response = pipe(['Hi, pls intro yourself', 'Shanghai is'],
+ gen_config=gen_config)
+logits = [x.logits for x in response]
+```
+
+### 获取生成 token 最后一层的 hidden_states
+
+```python
+from lmdeploy import pipeline, GenerationConfig
+
+pipe = pipeline('internlm/internlm2_5-7b-chat')
+
+gen_config=GenerationConfig(output_last_hidden_state='generation',
+ max_new_tokens=10)
+response = pipe(['Hi, pls intro yourself', 'Shanghai is'],
+ gen_config=gen_config)
+hidden_states = [x.last_hidden_state for x in response]
+```
+
+### 计算 ppl
```python
from transformers import AutoTokenizer
from lmdeploy import pipeline
-model_repoid_or_path='internlm/internlm2_5-7b-chat'
+
+
+model_repoid_or_path = 'internlm/internlm2_5-7b-chat'
pipe = pipeline(model_repoid_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)
-
-# logits
messages = [
{"role": "user", "content": "Hello, how are you?"},
]
input_ids = tokenizer.apply_chat_template(messages)
+
+# logits is a list of tensor
logits = pipe.get_logits(input_ids)
+print(logits)
-# ppl
+# ppl is a list of float numbers
ppl = pipe.get_ppl(input_ids)
+print(ppl)
```
```{note}
+当 input_ids 过长时,可能会出现 OOM 错误,请小心应用
get_ppl 返回的是 cross entropy loss,没有在之后加 exp 操作
```
-- **使用 pytorch 后端**
+### 使用 PyTorchEngine
需要先安装 triton
@@ -169,7 +202,7 @@ response = pipe(prompts, gen_config=gen_config)
print(response)
```
-- **一个 lora 的例子**
+### LoRA 模型推理
```python
from lmdeploy import pipeline, GenerationConfig, PytorchEngineConfig
@@ -190,7 +223,7 @@ response = pipe(prompts, gen_config=gen_config, adapter_name='lora_name_1')
print(response)
```
-## FAQs
+## 常见问题
- **RuntimeError: An attempt has been made to start a new process before the current process has finished its bootstrapping phase**.
diff --git a/docs/zh_cn/multi_modal/vl_pipeline.md b/docs/zh_cn/multi_modal/vl_pipeline.md
index 35f647e36c..bac920fb5a 100644
--- a/docs/zh_cn/multi_modal/vl_pipeline.md
+++ b/docs/zh_cn/multi_modal/vl_pipeline.md
@@ -4,7 +4,7 @@ LMDeploy 把视觉-语言模型(VLM)复杂的推理过程,抽象为简单
在[这个列表中](../supported_models/supported_models.md),你可以查阅每个推理引擎支持的 VLM 模型。我们诚挚邀请社区在 LMDeploy 中添加更多 VLM 模型。
-本文将以 [liuhaotian/llava-v1.6-vicuna-7b](https://huggingface.co/liuhaotian/llava-v1.6-vicuna-7b) 模型为例,展示 VLM pipeline 的用法。你将了解它的最基础用法,以及如何通过调整引擎参数和生成条件来逐步解锁更多高级特性,如张量并行,上下文窗口大小调整,随机采样,以及对话模板的定制。
+本文将以 [OpenGVLab/InternVL2_5-8B](https://huggingface.co/OpenGVLab/InternVL2_5-8B) 模型为例,展示 VLM pipeline 的用法。你将了解它的最基础用法,以及如何通过调整引擎参数和生成条件来逐步解锁更多高级特性,如张量并行,上下文窗口大小调整,随机采样,以及对话模板的定制。
此外,我们还提供针对多图、批量提示词等场景的实际推理示例。
@@ -16,7 +16,7 @@ LMDeploy 把视觉-语言模型(VLM)复杂的推理过程,抽象为简单
from lmdeploy import pipeline
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b')
+pipe = pipeline('OpenGVLab/InternVL2_5-8B')
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
response = pipe(('describe this image', image))
@@ -30,7 +30,7 @@ print(response)
```python
from lmdeploy import pipeline
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b')
+pipe = pipeline('OpenGVLab/InternVL2_5-8B')
prompts = [
{
@@ -53,7 +53,7 @@ print(response)
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
+pipe = pipeline('OpenGVLab/InternVL2_5-8B',
backend_config=TurbomindEngineConfig(tp=2))
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
@@ -69,7 +69,7 @@ print(response)
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
+pipe = pipeline('OpenGVLab/InternVL2_5-8B',
backend_config=TurbomindEngineConfig(session_len=8192))
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
@@ -85,7 +85,7 @@ print(response)
from lmdeploy import pipeline, GenerationConfig, TurbomindEngineConfig
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
+pipe = pipeline('OpenGVLab/InternVL2_5-8B',
backend_config=TurbomindEngineConfig(tp=2, session_len=8192))
gen_config = GenerationConfig(top_k=40, top_p=0.8, temperature=0.6)
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
@@ -139,22 +139,19 @@ response = pipe(('describe this image', image))
print(response)
```
-### 计算 logits
-
-LMDeploy 支持用户自定义输入,用户可以调用`prepare_inputs`,了解多模态的输入是如何组织的。
+### 获取生成 token 的 logits
```python
-from lmdeploy import pipeline, TurbomindEngineConfig
+from lmdeploy import pipeline, GenerationConfig
from lmdeploy.vl import load_image
-pipe = pipeline('internlm/internlm-xcomposer2-7b', backend_config=TurbomindEngineConfig(cache_max_entry_count=0.5))
+pipe = pipeline('OpenGVLab/InternVL2_5-8B')
-# logits
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/tests/data/tiger.jpeg')
-inputs = pipe.prepare_inputs(('describe this image', image))
-input_ids = inputs['input_ids']
-embeddings = inputs['input_embeddings']
-embedding_ranges = inputs['input_embedding_ranges']
-logits = pipe.get_logits(input_ids, embeddings, embedding_ranges)
+
+response = pipe(('describe this image', image),
+ gen_config=GenerationConfig(output_logits='generation'))
+logits = response.logits
+print(logits)
```
## 多图推理
@@ -165,7 +162,7 @@ logits = pipe.get_logits(input_ids, embeddings, embedding_ranges)
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
+pipe = pipeline('OpenGVLab/InternVL2_5-8B',
backend_config=TurbomindEngineConfig(session_len=8192))
image_urls=[
@@ -186,7 +183,7 @@ print(response)
from lmdeploy import pipeline, TurbomindEngineConfig
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
+pipe = pipeline('OpenGVLab/InternVL2_5-8B',
backend_config=TurbomindEngineConfig(session_len=8192))
image_urls=[
@@ -206,7 +203,7 @@ pipeline 进行多轮对话有两种方式,一种是按照 openai 的格式来
from lmdeploy import pipeline, TurbomindEngineConfig, GenerationConfig
from lmdeploy.vl import load_image
-pipe = pipeline('liuhaotian/llava-v1.6-vicuna-7b',
+pipe = pipeline('OpenGVLab/InternVL2_5-8B',
backend_config=TurbomindEngineConfig(session_len=8192))
image = load_image('https://raw.githubusercontent.com/open-mmlab/mmdeploy/main/demo/resources/human-pose.jpg')
diff --git a/docs/zh_cn/quantization/w4a16.md b/docs/zh_cn/quantization/w4a16.md
index 3cea164dd9..1d1467fa77 100644
--- a/docs/zh_cn/quantization/w4a16.md
+++ b/docs/zh_cn/quantization/w4a16.md
@@ -45,7 +45,7 @@ lmdeploy lite auto_awq \
绝大多数情况下,在执行上述命令时,可选参数可不用填写,使用默认的即可。比如量化 [internlm/internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) 模型,命令可以简化为:
```shell
-lmdeploy lite auto_awq internlm/ianternlm2-chat-7b --work-dir internlm2_5-7b-chat-4bit
+lmdeploy lite auto_awq internlm/internlm2_5-7b-chat --work-dir internlm2_5-7b-chat-4bit
```
**Note:**
diff --git a/docs/zh_cn/quantization/w8a8.md b/docs/zh_cn/quantization/w8a8.md
index 302dd538fd..3a63c82f8c 100644
--- a/docs/zh_cn/quantization/w8a8.md
+++ b/docs/zh_cn/quantization/w8a8.md
@@ -1,56 +1,76 @@
# W8A8 LLM 模型部署
-LMDeploy 提供了使用 8 bit 整数对神经网络模型进行量化和推理的功能。
+LMDeploy 提供了使用 8-bit 整数(INT8)和浮点数(FP8)对神经网络模型进行量化和推理的功能。
-在开始推理前,需要确保已经正确安装了 lmdeploy 和 openai/triton。可以通过以下命令进行安装:
+可用于 INT8 和 FP8 推理的 NVIDIA GPU 分别为:
+
+- INT8
+ - V100(sm70): V100
+ - Turing(sm75): 20 series, T4
+ - Ampere(sm80,sm86): 30 series, A10, A16, A30, A100
+ - Ada Lovelace(sm89): 40 series
+ - Hopper(sm90): H100
+- FP8
+ - Ada Lovelace(sm89): 40 series
+ - Hopper(sm90): H100
+
+首先,执行如下命令安装lmdeploy:
```shell
-pip install lmdeploy
-pip install triton>=2.1.0
+pip install lmdeploy[all]
```
-## 8bit 权重模型推理
+## 8-bit 权重量化
-如果你需要进行 8 bit 权重模型推理,可以直接从 LMDeploy 的 [model zoo](https://huggingface.co/lmdeploy) 下载已经量化好的 8bit 权重模型。以8bit 的 Internlm-chat-7B 模型为例,可以从 model zoo 直接下载:
+进行 8-bit 权重量化需要经历以下三步:
-```shell
-git-lfs install
-git clone https://huggingface.co/lmdeploy/internlm-chat-7b-w8 (coming soon)
-```
+1. **权重平滑**:首先对语言模型的权重进行平滑处理,以便更好地进行量化。
+2. **模块替换**:使用 `QRMSNorm` 和 `QLinear` 模块替换原模型 `DecoderLayer` 中的 `RMSNorm` 模块和 `nn.Linear` 模块。`lmdeploy/pytorch/models/q_modules.py` 文件中定义了这些量化模块。
+3. **保存量化模型**:完成上述必要的替换后,我们即可保存新的量化模型。
-你也可以参考["8bit 权重量化"](#8bit-权重量化)章节的内容手动将原 16bit 权重量化为 8bit,并保存至 `internlm-chat-7b-w8` 目录下,操作命令如下:
+lmdeploy 提供了命令行工具 `lmdeploy lite smooth_quant` 实现了以上三个步骤。并且其中命令行参数 `--quant-dtype` 可以用来控制是进行8-bit整数还是浮点数类型的量化。更多命令行工具使用方式,请执行 `lmdeploy lite smooth_quant --help` 查看。
-```shell
-lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8
-```
+以下示例演示了进行 int8 或 fp8 的量化命令。
-然后,执行以下命令,即可在终端与模型对话:
+- int8
-```shell
-lmdeploy chat ./internlm-chat-7b-w8 --backend pytorch
-```
+ ```shell
+ lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-int8 --quant-dtype int8
+ ```
-## 启动 gradio 服务
+- fp8
-Coming soon...
+ ```shell
+ lmdeploy lite smooth_quant internlm/internlm2_5-7b-chat --work-dir ./internlm2_5-7b-chat-fp8 --quant-dtype fp8
+ ```
-## 推理速度
+## 模型推理
-Coming soon...
+量化后的模型,通过以下几行简单的代码,可以实现离线推理:
-## 8bit 权重量化
+```python
+from lmdeploy import pipeline, PytorchEngineConfig
-进行 8bit 权重量化需要经历以下三步:
+engine_config = PytorchEngineConfig(tp=1)
+pipe = pipeline("internlm2_5-7b-chat-int8", backend_config=engine_config)
+response = pipe(["Hi, pls intro yourself", "Shanghai is"])
+print(response)
+```
-1. **权重平滑**:首先对语言模型的权重进行平滑处理,以便更好地进行量化。
-2. **模块替换**:使用 `QRSMNorm` 和 `QLinear` 模块替换原模型 `DecoderLayer` 中的 `RSMNorm` 模块和 `nn.Linear` 模块。`lmdeploy/pytorch/models/q_modules.py` 文件中定义了这些量化模块。
-3. **保存量化模型**:完成上述必要的替换后,我们即可保存新的量化模型。
+关于 pipeline 的详细介绍,请参考[这里](../llm/pipeline.md)
-我们在`lmdeploy/lite/api/smooth_quantity.py`脚本中已经实现了以上三个步骤。例如,可以通过以下命令得到量化后的 Internlm-chat-7B 模型的模型权重:
+## 推理服务
+
+LMDeploy `api_server` 支持把模型一键封装为服务,对外提供的 RESTful API 兼容 openai 的接口。以下为服务启动的示例:
```shell
+lmdeploy serve api_server ./internlm2_5-7b-chat-int8 --backend pytorch
+```
-lmdeploy lite smooth_quant internlm/internlm-chat-7b --work-dir ./internlm-chat-7b-w8
+服务默认端口是23333。在 server 启动后,你可以在终端通过`api_client`与server进行对话:
+
+```shell
+lmdeploy serve api_client http://0.0.0.0:23333
```
-保存之后,你就可以通过调用from_pretrained接口来实例化你的量化模型。
+还可以通过 Swagger UI `http://0.0.0.0:23333` 在线阅读和试用 `api_server` 的各接口,也可直接查阅[文档](../llm/api_server.md),了解各接口的定义和使用方法。
diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md
index 3ec3688e1b..83b7a9ca6f 100644
--- a/docs/zh_cn/supported_models/supported_models.md
+++ b/docs/zh_cn/supported_models/supported_models.md
@@ -4,104 +4,107 @@
## TurboMind CUDA 平台
-| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 |
-| :-------------------: | :------------: | :--: | :-------: | :-----: | :-----: | :---: |
-| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes |
-| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
-| Llama3.2 | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes |
-| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
-| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
-| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes |
-| InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes |
-| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes |
-| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes |
-| Qwen1.5 | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes |
-| Qwen2 | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes |
-| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes |
-| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | Yes | Yes |
-| Mistral | 7B | LLM | Yes | Yes | Yes | No |
-| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes |
-| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No |
-| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No |
-| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
-| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
-| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes |
-| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes |
-| Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No |
-| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes |
-| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes |
-| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes |
-| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes |
-| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes |
-| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes |
-| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes |
-| MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes |
-| GLM4 | 9B | LLM | Yes | Yes | Yes | Yes |
-| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - |
-| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No |
+| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W4A16 |
+| :------------------------------: | :------------: | :--: | :-------: | :-----: | :-----: | :---: |
+| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes |
+| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes |
+| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
+| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes |
+| Llama3.2\[2\] | 1B, 3B | LLM | Yes | Yes\* | Yes\* | Yes |
+| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
+| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes |
+| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes |
+| InternLM-XComposer2 | 7B, 4khd-7B | MLLM | Yes | Yes | Yes | Yes |
+| InternLM-XComposer2.5 | 7B | MLLM | Yes | Yes | Yes | Yes |
+| Qwen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes |
+| Qwen1.5\[1\] | 1.8B - 110B | LLM | Yes | Yes | Yes | Yes |
+| Qwen2\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes |
+| Qwen2-MoE | 57BA14B | LLM | Yes | Yes | Yes | Yes |
+| Qwen2.5\[2\] | 0.5B - 72B | LLM | Yes | Yes\* | Yes\* | Yes |
+| Mistral\[1\] | 7B | LLM | Yes | Yes | Yes | No |
+| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | Yes |
+| DeepSeek-V2 | 16B, 236B | LLM | Yes | Yes | Yes | No |
+| DeepSeek-V2.5 | 236B | LLM | Yes | Yes | Yes | No |
+| Qwen-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
+| DeepSeek-VL | 7B | MLLM | Yes | Yes | Yes | Yes |
+| Baichuan | 7B | LLM | Yes | Yes | Yes | Yes |
+| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes |
+| Code Llama | 7B - 34B | LLM | Yes | Yes | Yes | No |
+| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes |
+| LLaVA(1.5,1.6) | 7B - 34B | MLLM | Yes | Yes | Yes | Yes |
+| InternVL | v1.1 - v1.5 | MLLM | Yes | Yes | Yes | Yes |
+| InternVL2 | 1-2B, 8B - 76B | MLLM | Yes | Yes\* | Yes\* | Yes |
+| InternVL2.5(MPO)\[2\] | 1 - 78B | MLLM | Yes | Yes\* | Yes\* | Yes |
+| ChemVLM | 8B - 26B | MLLM | Yes | Yes | Yes | Yes |
+| MiniCPM-Llama3-V-2_5 | - | MLLM | Yes | Yes | Yes | Yes |
+| MiniCPM-V-2_6 | - | MLLM | Yes | Yes | Yes | Yes |
+| MiniGeminiLlama | 7B | MLLM | Yes | - | - | Yes |
+| GLM4 | 9B | LLM | Yes | Yes | Yes | Yes |
+| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - |
+| Molmo | 7B-D,72B | MLLM | Yes | Yes | Yes | No |
“-” 表示还没有验证。
```{note}
-* turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine
-* 当模型的 head_dim 非 128 时,turbomind 不支持它的 kv cache 4/8 bit 量化和推理。比如,llama3.2-1B,qwen2-0.5B,internvl2-1B 等等
+* [1] turbomind 引擎不支持 window attention。所以,对于应用了 window attention,并开启了对应的开关"use_sliding_window"的模型,比如 Mistral、Qwen1.5 等,在推理时,请选择 pytorch engine
+* [2] 当模型的 head_dim 非 128 时,turbomind 不支持它的 kv cache 4/8 bit 量化和推理。比如,llama3.2-1B,qwen2-0.5B,internvl2-1B 等等
```
## PyTorchEngine CUDA 平台
-| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 |
-| :------------: | :---------: | :--: | :-------: | :-----: | :-----: | :--: | :---: |
-| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | - | - |
-| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
-| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
-| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No |
-| Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No |
-| ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No |
-| Falcon | 7B - 180B | LLM | Yes | Yes | Yes | No | No |
-| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No |
-| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes |
-| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes |
-| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No |
-| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
-| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
-| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | No |
-| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
-| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
-| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No |
-| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No |
-| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes |
-| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No |
-| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No |
-| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No |
-| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes |
-| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - |
-| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - |
-| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - |
-| LLaVA(1.5,1.6) | 7B-34B | MLLM | Yes | Yes | Yes | - | - |
-| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes |
-| InternVL2 | 1B-40B | MLLM | Yes | Yes | Yes | - | - |
-| Mono-InternVL | 2B | MLLM | Yes\* | Yes | Yes | - | - |
-| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - |
-| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - |
-| GLM4 | 9B | LLM | Yes | Yes | Yes | No | No |
-| GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | No |
-| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - |
-| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - |
-| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - |
-| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - |
+| Model | Size | Type | FP16/BF16 | KV INT8 | KV INT4 | W8A8 | W4A16 |
+| :----------------------------: | :---------: | :--: | :-------: | :-----: | :-----: | :--: | :---: |
+| Llama | 7B - 65B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama2 | 7B - 70B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3.1 | 8B, 70B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3.2 | 1B, 3B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Llama3.2-VL | 11B, 90B | MLLM | Yes | Yes | Yes | - | - |
+| InternLM | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
+| InternLM2 | 7B - 20B | LLM | Yes | Yes | Yes | Yes | Yes |
+| InternLM2.5 | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Baichuan2 | 7B | LLM | Yes | Yes | Yes | Yes | No |
+| Baichuan2 | 13B | LLM | Yes | Yes | Yes | No | No |
+| ChatGLM2 | 6B | LLM | Yes | Yes | Yes | No | No |
+| Falcon | 7B - 180B | LLM | Yes | Yes | Yes | No | No |
+| YI | 6B - 34B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Mistral | 7B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Mixtral | 8x7B, 8x22B | LLM | Yes | Yes | Yes | No | No |
+| QWen | 1.8B - 72B | LLM | Yes | Yes | Yes | Yes | Yes |
+| QWen1.5 | 0.5B - 110B | LLM | Yes | Yes | Yes | Yes | Yes |
+| QWen1.5-MoE | A2.7B | LLM | Yes | Yes | Yes | No | No |
+| QWen2 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
+| Qwen2.5 | 0.5B - 72B | LLM | Yes | Yes | No | Yes | Yes |
+| QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes |
+| DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No |
+| DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No |
+| DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No |
+| MiniCPM3 | 4B | LLM | Yes | Yes | Yes | No | No |
+| MiniCPM-V-2_6 | 8B | LLM | Yes | No | No | No | Yes |
+| Gemma | 2B-7B | LLM | Yes | Yes | Yes | No | No |
+| Dbrx | 132B | LLM | Yes | Yes | Yes | No | No |
+| StarCoder2 | 3B-15B | LLM | Yes | Yes | Yes | No | No |
+| Phi-3-mini | 3.8B | LLM | Yes | Yes | Yes | Yes | Yes |
+| Phi-3-vision | 4.2B | MLLM | Yes | Yes | Yes | - | - |
+| CogVLM-Chat | 17B | MLLM | Yes | Yes | Yes | - | - |
+| CogVLM2-Chat | 19B | MLLM | Yes | Yes | Yes | - | - |
+| LLaVA(1.5,1.6)\[2\] | 7B-34B | MLLM | No | No | No | No | No |
+| InternVL(v1.5) | 2B-26B | MLLM | Yes | Yes | Yes | No | Yes |
+| InternVL2 | 1B-76B | MLLM | Yes | Yes | Yes | - | - |
+| InternVL2.5(MPO) | 1B-78B | MLLM | Yes | Yes | Yes | - | - |
+| Mono-InternVL\[1\] | 2B | MLLM | Yes\* | Yes | Yes | - | - |
+| ChemVLM | 8B-26B | MLLM | Yes | Yes | No | - | - |
+| Gemma2 | 9B-27B | LLM | Yes | Yes | Yes | - | - |
+| GLM4 | 9B | LLM | Yes | Yes | Yes | No | No |
+| GLM-4V | 9B | MLLM | Yes | Yes | Yes | No | Yes |
+| CodeGeeX4 | 9B | LLM | Yes | Yes | Yes | - | - |
+| Phi-3.5-mini | 3.8B | LLM | Yes | Yes | No | - | - |
+| Phi-3.5-MoE | 16x3.8B | LLM | Yes | Yes | No | - | - |
+| Phi-3.5-vision | 4.2B | MLLM | Yes | Yes | No | - | - |
```{note}
-* 目前,Mono-InternVL不支持FP16,因为数值不稳定。请改用BF16。
+* [1] 目前,Mono-InternVL不支持FP16,因为数值不稳定。请改用BF16
+* [2] 自 0.6.4 之后,PyTorch 引擎移除了对 llava 模型原始格式的支持。我们建议使用它们对应的 transformers 格式的模型。这些模型可以在 https://huggingface.co/llava-hf 中找到
```
## PyTorchEngine 华为昇腾平台
diff --git a/lmdeploy/cli/lite.py b/lmdeploy/cli/lite.py
index 236e022b34..499bace485 100644
--- a/lmdeploy/cli/lite.py
+++ b/lmdeploy/cli/lite.py
@@ -126,6 +126,7 @@ def add_parser_smooth_quant():
ArgumentHelper.calib_batchsize(parser)
ArgumentHelper.calib_search_scale(parser)
ArgumentHelper.dtype(parser)
+ ArgumentHelper.quant_dtype(parser)
@staticmethod
def auto_awq(args):
diff --git a/lmdeploy/cli/utils.py b/lmdeploy/cli/utils.py
index cf7b6526ec..25fcdd4620 100644
--- a/lmdeploy/cli/utils.py
+++ b/lmdeploy/cli/utils.py
@@ -113,6 +113,16 @@ def dtype(parser, default: str = 'auto'):
'for BF16 models. This option will be ignored if '
'the model is a quantized model')
+ @staticmethod
+ def quant_dtype(parser, default: str = 'int8'):
+ return parser.add_argument(
+ '--quant-dtype',
+ type=str,
+ default=default,
+ choices=['int8', 'float8_e4m3fn', 'float8_e5m2', 'fp8'],
+ help='data type for the quantized model weights and activations.'
+ 'Note "fp8" is the short version of "float8_e4m3fn"')
+
@staticmethod
def model_format(parser, default: str = None):
return parser.add_argument(
@@ -367,7 +377,7 @@ def calib_search_scale(parser):
@staticmethod
def device(parser,
default: str = 'cuda',
- choices: List[str] = ['cuda', 'ascend', 'maca']):
+ choices: List[str] = ['cuda', 'ascend', 'maca', 'camb']):
"""Add argument device to parser."""
return parser.add_argument('--device',
diff --git a/lmdeploy/lite/apis/smooth_quant.py b/lmdeploy/lite/apis/smooth_quant.py
index 188eedbd0e..8d67535bcc 100644
--- a/lmdeploy/lite/apis/smooth_quant.py
+++ b/lmdeploy/lite/apis/smooth_quant.py
@@ -24,7 +24,19 @@ def smooth_quant(model: str,
batch_size: int = 1,
w_bits: int = 8,
dtype: Literal['float16', 'bfloat16', 'auto'] = 'auto',
- device: str = 'cuda'):
+ device: str = 'cuda',
+ quant_dtype: Literal['int8', 'fp8', 'float8_e4m3fn',
+ 'float8_e5m2'] = 'int8'):
+ if quant_dtype == 'fp8':
+ quant_dtype = 'float8_e4m3fn'
+
+ quant_dtype = getattr(torch, quant_dtype, torch.int8)
+ if quant_dtype.is_floating_point:
+ q_dtype_info = torch.finfo(quant_dtype)
+ else:
+ q_dtype_info = torch.iinfo(quant_dtype)
+
+ assert q_dtype_info.bits == w_bits
model_path = model
vl_model, model, tokenizer, work_dir = calibrate(model,
calib_dataset,
@@ -84,7 +96,7 @@ def smooth_quant(model: str,
if skipped_module(name):
continue
linear.to(device)
- q_linear = QLinear.from_float(linear)
+ q_linear = QLinear.from_float(linear, quant_dtype=quant_dtype)
parent_name, _, child_name = name.rpartition('.')
parent = model.get_submodule(parent_name)
setattr(parent, child_name, q_linear)
@@ -94,7 +106,7 @@ def smooth_quant(model: str,
if skipped_module(name):
continue
norm.to(device)
- q_norm = QRMSNorm.from_float(norm)
+ q_norm = QRMSNorm.from_float(norm, quant_dtype=quant_dtype)
parent_name, _, child_name = name.rpartition('.')
parent = model.get_submodule(parent_name)
setattr(parent, child_name, q_norm)
@@ -104,8 +116,10 @@ def smooth_quant(model: str,
from .auto_awq import save_vl_model
save_vl_model(vl_model, model_path, work_dir)
else:
+ quant_dtype_s = str(quant_dtype).split('.')[1]
model.config.update(
- dict(quantization_config=dict(quant_method='smooth_quant')))
+ dict(quantization_config=dict(quant_method='smooth_quant',
+ quant_dtype=f'{quant_dtype_s}')))
model.save_pretrained(work_dir,
max_shard_size='2GB',
safe_serialization=False)
diff --git a/lmdeploy/messages.py b/lmdeploy/messages.py
index 2336d10752..cfc146f86d 100644
--- a/lmdeploy/messages.py
+++ b/lmdeploy/messages.py
@@ -7,6 +7,9 @@
from pydantic.dataclasses import dataclass as pydantic_dataclass
from .tokenizer import Tokenizer
+from .utils import get_logger
+
+logger = get_logger('lmdeploy')
LogitsProcessor = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
"""LogitsProcessor is a function that takes a tensor of input_ids, the logits
@@ -52,6 +55,9 @@ class GenerationConfig:
ignoring the number of tokens in the prompt.
skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
+ spaces_between_special_tokens (bool): Whether or not to add spaces
+ around special tokens. The behavior of Fast tokenizers is to have
+ this to False. This is setup to True in slow tokenizers.
logprobs (int): Number of log probabilities to return per output token.
response_format (Dict): Only pytorch backend support formatting
response. Examples:
@@ -94,9 +100,12 @@ class GenerationConfig:
bad_token_ids: List[int] = None
min_new_tokens: int = None
skip_special_tokens: bool = True
+ spaces_between_special_tokens: bool = True
logprobs: int = None
response_format: Optional[Dict] = None
logits_processors: Optional[List[LogitsProcessor]] = None
+ output_logits: Literal['all', 'generation'] = None
+ output_last_hidden_state: Literal['all', 'generation'] = None
def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer):
"""convert stop_words/bad_sords to ids and append the ids to
@@ -124,7 +133,7 @@ def __post_init__(self):
"""Check input validation."""
assert type(
self.n) == int and self.n > 0, 'n is not a positive integer'
- assert self.top_p > 0 and self.top_p <= 1 # (0, 1]
+ assert self.top_p >= 0 and self.top_p <= 1 # [0, 1]
assert self.top_k >= 0, 'top_k can not be a negative integer'
assert self.temperature >= 0 and self.temperature <= 2 # [0,2]
assert 0 <= self.min_p <= 1, \
@@ -291,13 +300,18 @@ def __post_init__(self):
assert self.num_gpu_blocks >= 0, 'invalid num_gpu_blocks'
assert self.quant_policy in (0, 4, 8), 'invalid quant_policy'
assert self.device_type in [
- 'cuda', 'ascend', 'maca'
+ 'cuda', 'ascend', 'maca', 'camb'
], (f'invalid device_type: {self.device_type}')
if self.quant_policy > 0 and self.device_type not in [
'cuda', 'ascend'
]:
assert False, \
'kv cache quantization only works for CUDA and ASCEND.'
+ if self.device_type == 'camb' and self.block_size != 16:
+ self.block_size = 16
+ logger.warning(
+ 'Currently, camb device requires block size to be 16, \
+ setting block size to 16')
class ResponseType(enum.Enum):
@@ -338,10 +352,11 @@ class Response:
text: str
generate_token_len: int
input_token_len: int
- session_id: int
finish_reason: Optional[Literal['stop', 'length']] = None
token_ids: List[int] = field(default_factory=list)
logprobs: List[Dict[int, float]] = None
+ logits: torch.Tensor = None
+ last_hidden_state: torch.Tensor = None
index: int = 0
@@ -361,6 +376,8 @@ class EngineOutput:
token_ids: List[int]
num_token: int
logprobs: List[Dict[int, float]] = None
+ logits: torch.Tensor = None
+ last_hidden_state: torch.Tensor = None
@dataclass
diff --git a/lmdeploy/model.py b/lmdeploy/model.py
index a0b0c8e09b..f7b80ed102 100644
--- a/lmdeploy/model.py
+++ b/lmdeploy/model.py
@@ -46,6 +46,8 @@ class ChatTemplateConfig:
eoh (str | None): end of the user prompt
assistant (str | None): begin of the assistant prompt
eoa (str | None): end of the assistant prompt
+ tool (str | None): begin of the tool prompt
+ eotool (str | None): end of the tool prompt
capability: ('completion' | 'infilling' | 'chat' | 'python') = None
""" # noqa: E501
@@ -57,6 +59,8 @@ class ChatTemplateConfig:
eoh: Optional[str] = None
assistant: Optional[str] = None
eoa: Optional[str] = None
+ tool: Optional[str] = None
+ eotool: Optional[str] = None
separator: Optional[str] = None
capability: Optional[Literal['completion', 'infilling', 'chat',
'python']] = None
@@ -173,6 +177,8 @@ def __init__(self,
assistant='',
eoa='',
separator='',
+ tool='',
+ eotool='',
**kwargs):
super().__init__(**kwargs)
self.system = system
@@ -183,6 +189,8 @@ def __init__(self,
self.separator = separator
self.eosys = eosys
self.assistant = assistant
+ self.tool = tool
+ self.eotool = eotool
def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
@@ -223,10 +231,12 @@ def messages2prompt(self, messages, sequence_start=True, **kwargs):
return self.get_prompt(messages, sequence_start)
box_map = dict(user=self.user,
assistant=self.assistant,
- system=self.system)
+ system=self.system,
+ tool=self.tool)
eox_map = dict(user=self.eoh,
assistant=self.eoa + self.separator,
- system=self.eosys)
+ system=self.eosys,
+ tool=self.eotool)
ret = ''
if self.meta_instruction is not None and sequence_start:
if len(messages) and messages[0]['role'] != 'system':
@@ -819,7 +829,7 @@ class Llama3_1(Llama3):
def __init__(
self,
- tools="""# Tool Instructions
+ tool="""# Tool Instructions
- Always execute python code in messages that you share.
- When looking for real time information use relevant functions if available else fallback to brave_search
@@ -828,7 +838,7 @@ def __init__(
You have access to the following functions:
""", # noqa
- eotools="""
+ eotool="""
If a you choose to call a function ONLY reply in the following format:
<{start_tag}={function_name}>{parameters}{end_tag}
@@ -858,8 +868,8 @@ def __init__(
**kwargs)
self.ipython = ipython
self.eoi = eoi
- self.tools = tools
- self.eotools = eotools
+ self.tool = tool
+ self.eotool = eotool
self.knowledge = knowledge
def messages2prompt(self,
@@ -899,7 +909,7 @@ def messages2prompt(self,
if tools is None:
ret += f'{self.system}{self.knowledge}{self.meta_instruction}{self.eosys}'
else:
- ret += f'{self.system}{self.knowledge}{self.tools}{tool_prompt}{self.eotools}{self.meta_instruction}{self.eosys}'
+ ret += f'{self.system}{self.knowledge}{self.tool}{tool_prompt}{self.eotool}{self.meta_instruction}{self.eosys}'
for message in messages:
role = message['role']
content = get_text(message['content'])
@@ -907,7 +917,7 @@ def messages2prompt(self,
or '' in content):
ret += f'{box_map[role]}{content}<|eom_id|>'
elif role == 'system' and tools is not None:
- ret += f'{box_map[role]}{self.tools}{tool_prompt}{self.eotools}{content}{eox_map[role]}'
+ ret += f'{box_map[role]}{self.tool}{tool_prompt}{self.eotool}{content}{eox_map[role]}'
else:
ret += f'{box_map[role]}{content}{eox_map[role]}'
if sequence_start and not isinstance(messages, str):
diff --git a/lmdeploy/profiler.py b/lmdeploy/profiler.py
new file mode 100644
index 0000000000..c1bf6b3875
--- /dev/null
+++ b/lmdeploy/profiler.py
@@ -0,0 +1,170 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import csv
+import time
+from typing import List
+
+import numpy as np
+
+
+class Session:
+
+ UNKNOWN = 0
+ SUCCESS = 1
+ FAIL = 2
+
+ def __init__(self, input_len, req_output_len):
+ self.ts = []
+ self.ns = []
+ self.input_len = input_len
+ self.req_output_len = req_output_len
+ self.status = Session.UNKNOWN
+
+ def tick(self, n_token):
+ self.ts.append(time.perf_counter())
+ self.ns.append(n_token)
+
+ def finish(self, status):
+ self.status = status
+
+
+class Profiler:
+
+ def __init__(self, stream_output: bool, percentages: List[int]):
+ self.sessions: List[Session] = []
+ self.stream_output = stream_output
+ self.percentages = percentages
+
+ def new_session(self, *args, **kwargs):
+ sess = Session(*args, **kwargs)
+ self.sessions.append(sess)
+ return sess
+
+ def start(self):
+ self.t_start = time.perf_counter()
+
+ def finish(self):
+ self.elapsed_time = time.perf_counter() - self.t_start
+
+ def compute_metrics(self):
+ self.ttfts: List[float] = []
+ self.tpots: List[float] = []
+ self.e2es: List[float] = []
+ self.itls: List[float] = []
+ self.tpts: List[int] = []
+ self.total_output = 0
+ self.total_input = 0
+ self.success = 0
+
+ for sess in self.sessions:
+ if sess.status != Session.SUCCESS:
+ continue
+ ns = sess.ns
+ ts = sess.ts
+ if ns[-1] < sess.req_output_len:
+ continue
+ self.success += 1
+ self.total_output += ns[-1]
+ self.total_input += sess.input_len
+ self.e2es.append(ts[-1] - ts[0])
+ self.ttfts.append(ts[1] - ts[0])
+ if ns[-1] > ns[1]:
+ self.tpots.append((ts[-1] - ts[1]) / (ns[-1] - ns[1]))
+ else: # no-stream-output
+ self.tpots.append((ts[-1] - ts[0]) / (ns[-1] - ns[0]))
+ t_dif = np.subtract(ts[1:], ts[:-1])
+ n_dif = np.subtract(ns[1:], ns[:-1])
+ self.itls.extend(t_dif[1:])
+ self.tpts.extend(n_dif)
+
+ self.output_throughput = self.total_output / self.elapsed_time
+ self.input_throughput = self.total_input / self.elapsed_time
+
+ qs = self.percentages
+
+ self.e2es = self.e2es or [float('inf')]
+ self.tpots = self.tpots or [float('inf')]
+ self.ttfts = self.ttfts or [float('inf')]
+ self.itls = self.itls or [float('inf')]
+ self.tpts = self.tpts or [0]
+
+ self.tpot_mean = np.mean(self.tpots)
+ self.tpot_stat = tuple(np.percentile(self.tpots, qs))
+ self.e2e_mean = np.mean(self.e2es)
+ self.e2e_stat = tuple(np.percentile(self.e2es, qs))
+
+ if self.stream_output:
+ self.ttft_mean = np.mean(self.ttfts)
+ self.ttft_stat = tuple(np.percentile(self.ttfts, qs))
+ self.itls_mean = np.mean(self.itls)
+ self.itls_stat = tuple(np.percentile(self.itls, qs))
+ self.tpts_mean = np.mean(self.tpts)
+ self.tpts_stat = tuple(np.percentile(self.tpts, qs).astype(int))
+
+ self.rps = self.success / self.elapsed_time
+
+ def summarize(self,
+ title: str,
+ hyperparams: List = None,
+ header=40,
+ digits=10):
+
+ width = header + digits * (1 + len(self.percentages))
+
+ def tab_row(name, *items):
+
+ def fmt(x):
+ return '{:>{d}.3f}'.format(x, d=digits) if isinstance(
+ x, float) else '{:>{d}}'.format(x, d=digits)
+
+ print('{:<{p}}{}'.format(name,
+ ''.join([fmt(x) for x in items]),
+ p=header))
+
+ print('\n{s:{c}^{n}}'.format(s=f' {title} ', n=width, c='='))
+ tab_row('Benchmark duration', self.elapsed_time)
+ tab_row('Total requests', len(self.sessions))
+ tab_row('Successful requests', self.success)
+ if hyperparams:
+ for k, v in hyperparams:
+ tab_row(k, v)
+ tab_row('Total input tokens', self.total_input)
+ tab_row('Total generated tokens', self.total_output)
+ tab_row('Input throughput (tok/s)', self.input_throughput)
+ tab_row('Output throughput (tok/s)', self.output_throughput)
+ tab_row('Request throughput (req/s)', self.rps)
+ print('-' * width)
+ tab_row('', 'mean', *(f'P{q}' for q in self.percentages))
+ tab_row('End-to-end Latency', self.e2e_mean, *self.e2e_stat)
+ if self.stream_output:
+ tab_row('Time to First Token (TTFT)', self.ttft_mean,
+ *self.ttft_stat)
+ tab_row('Time per Output Token (TPOT)', self.tpot_mean,
+ *self.tpot_stat)
+ if self.stream_output:
+ tab_row('Inter-token Latency (ITL)', self.itls_mean,
+ *self.itls_stat)
+ tab_row('Tokens per Tick', self.tpts_mean, *self.tpts_stat)
+ print('=' * width)
+
+ def save_csv(self, csv_file: str, hyperparams):
+ """Export legacy metrics to CSV."""
+ with open(csv_file, 'w') as csvfile:
+ writer = csv.writer(csvfile)
+ keys, vals = zip(*hyperparams)
+ writer.writerow([
+ *keys,
+ 'RPS',
+ 'RPM',
+ 'FTL(ave)(s)',
+ 'throughput(out tok/s)',
+ 'throughput(total tok/s)',
+ ])
+ ttft_mean = f'{self.ttft_mean:.3f}' if self.stream_output else '-'
+ writer.writerow([
+ *vals,
+ f'{self.rps:.3f}',
+ f'{(self.rps * 60):.3f}',
+ ttft_mean,
+ f'{self.output_throughput:.3f}',
+ f'{(self.input_throughput + self.output_throughput):.3f}',
+ ])
diff --git a/lmdeploy/pytorch/backends/base.py b/lmdeploy/pytorch/backends/base.py
index c8623666dc..263b419f1a 100644
--- a/lmdeploy/pytorch/backends/base.py
+++ b/lmdeploy/pytorch/backends/base.py
@@ -28,6 +28,9 @@ class OpType(Enum):
LinearW4A16 = auto()
SoftmaxTopK = auto()
FusedMoE = auto()
+ FusedMoEW8A8 = auto()
+ LinearBlockedF8 = auto()
+ FusedMoEBlockedF8 = auto()
class OpsBackend(ABC):
diff --git a/lmdeploy/pytorch/backends/blockedf8_modules.py b/lmdeploy/pytorch/backends/blockedf8_modules.py
new file mode 100644
index 0000000000..d79b41330c
--- /dev/null
+++ b/lmdeploy/pytorch/backends/blockedf8_modules.py
@@ -0,0 +1,39 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from abc import ABC, abstractmethod
+from typing import Optional
+
+import torch
+
+
+class LinearBlockedF8Impl(ABC):
+ """linear BlockedF8 implementation api."""
+
+ def update_weights(self,
+ weight: torch.Tensor,
+ scale: torch.Tensor,
+ bias: Optional[torch.Tensor] = None):
+ """update weights."""
+ return weight, scale, bias
+
+ @abstractmethod
+ def forward(self,
+ x,
+ weight: torch.Tensor,
+ scale: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ all_reduce: bool = False):
+ """forward."""
+ raise NotImplementedError
+
+
+class LinearBlockedF8Builder(ABC):
+ """linear BlockedF8 implementation builder."""
+
+ @staticmethod
+ @abstractmethod
+ def build(in_features: int,
+ out_features: int,
+ bias: bool = True,
+ dtype: torch.dtype = None):
+ """build."""
+ raise NotImplementedError
diff --git a/lmdeploy/pytorch/backends/cuda/awq_modules.py b/lmdeploy/pytorch/backends/cuda/awq_modules.py
index 8159bbf554..18b0150493 100644
--- a/lmdeploy/pytorch/backends/cuda/awq_modules.py
+++ b/lmdeploy/pytorch/backends/cuda/awq_modules.py
@@ -18,23 +18,14 @@ def wq_gemm_forward(
out_features=0,
):
"""wq gemm forward."""
- from awq.modules.linear.gemm import awq_ext
-
from lmdeploy.pytorch.kernels.cuda.awq_kernels import awq_linear
out_shape = x.shape[:-1] + (out_features, )
input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()
- FP16_MATMUL_HEURISTIC_CONDITION = x.size(0) * x.size(1) >= 64
-
x = x.flatten(0, -2)
- if FP16_MATMUL_HEURISTIC_CONDITION:
- out = awq_linear(x, qweight, scales, qzeros)
- else:
- if not x.is_contiguous():
- x = x.contiguous()
- out = awq_ext.gemm_forward_cuda(x, qweight, scales, qzeros, 8)
+ out = awq_linear(x, qweight, scales, qzeros)
out = out + bias if bias is not None else out
out = out.reshape(out_shape)
diff --git a/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py
new file mode 100644
index 0000000000..8299ac2dfd
--- /dev/null
+++ b/lmdeploy/pytorch/backends/cuda/blockedf8_modules.py
@@ -0,0 +1,65 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Optional
+
+import torch
+import torch.distributed as dist
+
+from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import (blocked_gemm_fp8,
+ quant_fp8)
+
+from ..blockedf8_modules import LinearBlockedF8Builder, LinearBlockedF8Impl
+
+
+class TritonLinearBlockedF8Impl(LinearBlockedF8Impl):
+ """triton linear blocked f8 implementation."""
+
+ def __init__(self,
+ in_features: int,
+ out_features: int,
+ block_size: int,
+ out_dtype: torch.dtype = torch.float16):
+ self.in_features = in_features
+ self.out_features = out_features
+ self.out_dtype = out_dtype
+ self.block_size = block_size
+
+ def forward(self,
+ x,
+ weight: torch.Tensor,
+ scale: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ all_reduce: bool = False):
+ """forward."""
+ x_shape = x.shape
+ x = x.flatten(0, -2)
+ input_quant, input_scale = quant_fp8(x,
+ self.block_size,
+ dtype=weight.dtype)
+
+ out = blocked_gemm_fp8(input_quant,
+ input_scale,
+ weight.t(),
+ scale.t(),
+ out_dtype=x.dtype)
+ if bias is not None:
+ out += bias
+
+ if all_reduce:
+ dist.all_reduce(out)
+
+ out = out.unflatten(0, x_shape[:-1])
+ return out
+
+
+class TritonLinearBlockedF8Builder(LinearBlockedF8Builder):
+ """triton linear blocked f8 implementation builder."""
+
+ @staticmethod
+ def build(in_features: int,
+ out_features: int,
+ block_size: int = 128,
+ bias: bool = True,
+ dtype: torch.dtype = None):
+ """build."""
+ return TritonLinearBlockedF8Impl(in_features, out_features, block_size,
+ dtype)
diff --git a/lmdeploy/pytorch/backends/cuda/lora.py b/lmdeploy/pytorch/backends/cuda/lora.py
index 798d985715..b65a01df14 100644
--- a/lmdeploy/pytorch/backends/cuda/lora.py
+++ b/lmdeploy/pytorch/backends/cuda/lora.py
@@ -50,6 +50,15 @@ def forward(self,
"""forward."""
lora_input = self._make_packed_lora_input(x, ctx_mgr)
+ base_slice = adapter_info.base_slice
+ sliced_base = base_output[..., base_slice]
+
+ if base_output.is_contiguous():
+ kernel_output = sliced_base.flatten(0, -2)
+ cum = True
+ else:
+ kernel_output = None
+ cum = False
lora_out = fused_lora(
lora_input.x,
lora_A,
@@ -62,14 +71,14 @@ def forward(self,
adapter_ids=lora_input.adapter_ids,
max_rank=adapter_info.max_rank,
max_seqlen=lora_input.max_seq_len,
+ output=kernel_output,
+ cum=cum,
)
- base_slice = adapter_info.base_slice
- sliced_base = base_output[..., base_slice]
- lora_out = lora_out.reshape(sliced_base.shape)
- sliced_base.add_(lora_out)
- output = base_output
- return output
+ if not base_output.is_contiguous():
+ lora_out = lora_out.reshape(sliced_base.shape)
+ sliced_base.add_(lora_out)
+ return base_output
class TritonLoRABuilder(LoRABuilder):
diff --git a/lmdeploy/pytorch/backends/cuda/moe.py b/lmdeploy/pytorch/backends/cuda/moe.py
index eb38401211..24143c7061 100644
--- a/lmdeploy/pytorch/backends/cuda/moe.py
+++ b/lmdeploy/pytorch/backends/cuda/moe.py
@@ -4,9 +4,17 @@
import torch
-from lmdeploy.pytorch.kernels.cuda import fused_moe
+from lmdeploy.pytorch.kernels.cuda import fused_moe, fused_moe_w8a8
+from lmdeploy.pytorch.kernels.cuda.blocked_fp8_fused_moe import \
+ fused_moe_blocked_fp8
+from lmdeploy.pytorch.kernels.cuda.blocked_gemm_fp8 import quant_fp8
+from lmdeploy.pytorch.kernels.cuda.w8a8_triton_kernels import \
+ per_token_quant_int8
+from lmdeploy.pytorch.models.q_modules import QTensor
-from ..moe import FusedMoEBuilder, FusedMoEImpl
+from ..moe import (FusedMoEBlockedF8Builder, FusedMoEBlockedF8Impl,
+ FusedMoEBuilder, FusedMoEImpl, FusedMoEW8A8Builder,
+ FusedMoEW8A8Impl)
class TritonFusedMoEImpl(FusedMoEImpl):
@@ -74,3 +82,180 @@ def build(top_k: int, num_experts: int, renormalize: bool = False):
return TritonFusedMoEImpl(top_k=top_k,
num_experts=num_experts,
renormalize=renormalize)
+
+
+class TritonFusedMoEW8A8Impl(FusedMoEW8A8Impl):
+ """triton fused moe w8a8 implementation."""
+
+ def __init__(
+ self,
+ top_k: int,
+ num_experts: int,
+ renormalize: bool = False,
+ out_dtype: torch.dtype = torch.float16,
+ quant_dtype: torch.dtype = torch.int8,
+ ):
+ self.num_experts = num_experts
+ self.top_k = top_k
+ self.renormalize = renormalize
+ self.out_dtype = out_dtype
+ self.quant_dtype = quant_dtype
+
+ def update_weights(self, gate_up_weights: torch.Tensor,
+ down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
+ down_scale: torch.Tensor):
+ # do not transpose weight for int8/fp8
+ return gate_up_weights, down_weights, gate_up_scale, down_scale
+
+ def support_ep(self):
+ """support expert parallelism."""
+ return True
+
+ def ep_expert_list(self, world_size: int, rank: int):
+ """experts list of current rank."""
+ num_experts = self.num_experts
+ expert_per_rank = (num_experts + world_size - 1) // world_size
+ first_expert = rank * expert_per_rank
+ last_expert = min(first_expert + expert_per_rank, num_experts)
+ return list(range(first_expert, last_expert))
+
+ def forward(self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.LongTensor,
+ gate_up_weights: torch.Tensor,
+ gate_up_scale: torch.Tensor,
+ down_weights: torch.Tensor,
+ down_scale: torch.Tensor,
+ expert_list: List[int] = None):
+ """forward."""
+
+ if isinstance(hidden_states, torch.Tensor):
+ hidden_states = hidden_states.contiguous()
+ input_quant, input_scale = per_token_quant_int8(
+ hidden_states, 1e-7, quant_dtype=self.quant_dtype)
+ else:
+ assert isinstance(hidden_states, QTensor)
+ input_quant, input_scale = (hidden_states.tensor,
+ hidden_states.scale)
+
+ expert_offset = 0
+ num_experts = None
+ if expert_list is not None and len(expert_list) != self.num_experts:
+ expert_offset = expert_list[0]
+ num_experts = self.num_experts
+ return fused_moe_w8a8(input_quant,
+ input_scale,
+ gate_up_weights,
+ gate_up_scale,
+ down_weights,
+ down_scale,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ topk=self.top_k,
+ out_dtype=self.out_dtype,
+ quant_dtype=self.quant_dtype,
+ expert_offset=expert_offset,
+ num_experts=num_experts,
+ renormalize=self.renormalize)
+
+
+class TritonFusedMoEW8A8Builder(FusedMoEW8A8Builder):
+ """triton fused moe w8a8 builder."""
+
+ @staticmethod
+ def build(
+ top_k: int,
+ num_experts: int,
+ renormalize: bool = False,
+ out_dtype: torch.dtype = torch.float16,
+ quant_dtype: torch.dtype = torch.int8,
+ ):
+ """build from mlp."""
+ return TritonFusedMoEW8A8Impl(top_k=top_k,
+ num_experts=num_experts,
+ renormalize=renormalize,
+ out_dtype=out_dtype,
+ quant_dtype=quant_dtype)
+
+
+class TritonFusedMoEBlockedF8Impl(FusedMoEBlockedF8Impl):
+ """triton fused moe blocked f8 implementation."""
+
+ def __init__(self,
+ top_k: int,
+ num_experts: int,
+ renormalize: bool = False,
+ block_size: int = 128,
+ out_dtype: torch.dtype = torch.float16):
+ self.num_experts = num_experts
+ self.top_k = top_k
+ self.renormalize = renormalize
+ self.block_size = block_size
+ self.out_dtype = out_dtype
+
+ def support_ep(self):
+ """support expert parallelism."""
+ return True
+
+ def ep_expert_list(self, world_size: int, rank: int):
+ """experts list of current rank."""
+ num_experts = self.num_experts
+ expert_per_rank = (num_experts + world_size - 1) // world_size
+ first_expert = rank * expert_per_rank
+ last_expert = min(first_expert + expert_per_rank, num_experts)
+ return list(range(first_expert, last_expert))
+
+ def forward(self,
+ hidden_states: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.LongTensor,
+ gate_up_weights: torch.Tensor,
+ gate_up_scale: torch.Tensor,
+ down_weights: torch.Tensor,
+ down_scale: torch.Tensor,
+ expert_list: List[int] = None):
+ """forward."""
+ input_size = hidden_states.shape
+ hidden_states = hidden_states.flatten(0, -2)
+ input_quant, input_scale = quant_fp8(hidden_states,
+ self.block_size,
+ dtype=gate_up_weights.dtype)
+
+ expert_offset = 0
+ num_experts = None
+ if expert_list is not None and len(expert_list) != self.num_experts:
+ expert_offset = expert_list[0]
+ num_experts = self.num_experts
+ output = fused_moe_blocked_fp8(input_quant,
+ input_scale,
+ gate_up_weights,
+ gate_up_scale,
+ down_weights,
+ down_scale,
+ topk_weights=topk_weights,
+ topk_ids=topk_ids,
+ topk=self.top_k,
+ out_dtype=hidden_states.dtype,
+ expert_offset=expert_offset,
+ num_experts=num_experts,
+ renormalize=self.renormalize)
+ output = output.unflatten(0, input_size[:-1])
+ return output
+
+
+class TritonFusedMoEBlockedF8Builder(FusedMoEBlockedF8Builder):
+ """triton fused moe blocked f8 builder."""
+
+ @staticmethod
+ def build(top_k: int,
+ num_experts: int,
+ renormalize: bool = False,
+ block_size: int = 128,
+ out_dtype: torch.dtype = torch.float16):
+ """build from mlp."""
+ return TritonFusedMoEBlockedF8Impl(top_k=top_k,
+ num_experts=num_experts,
+ renormalize=renormalize,
+ block_size=block_size,
+ out_dtype=out_dtype)
diff --git a/lmdeploy/pytorch/backends/cuda/op_backend.py b/lmdeploy/pytorch/backends/cuda/op_backend.py
index bfe89dc63d..7b2134aeef 100644
--- a/lmdeploy/pytorch/backends/cuda/op_backend.py
+++ b/lmdeploy/pytorch/backends/cuda/op_backend.py
@@ -51,21 +51,20 @@ def get_layer_impl_builder(cls, layer_type: OpType):
from .activation import TritonSiluAndMulBuilder
return TritonSiluAndMulBuilder
elif layer_type == OpType.LinearW4A16:
- try:
- from awq.modules.linear.gemm import awq_ext # noqa: F401
- AWQ_INSTALLED = True
- except Exception:
- AWQ_INSTALLED = False
- if AWQ_INSTALLED:
- from .awq_modules import AwqLinearW4A16Builder
- return AwqLinearW4A16Builder
- else:
- logger.debug(
- f'Op {layer_type} fallback to default implementation.')
- return super().get_layer_impl_builder(layer_type)
+ from .awq_modules import AwqLinearW4A16Builder
+ return AwqLinearW4A16Builder
elif layer_type == OpType.FusedMoE:
from .moe import TritonFusedMoEBuilder
return TritonFusedMoEBuilder
+ elif layer_type == OpType.FusedMoEW8A8:
+ from .moe import TritonFusedMoEW8A8Builder
+ return TritonFusedMoEW8A8Builder
+ elif layer_type == OpType.FusedMoEBlockedF8:
+ from .moe import TritonFusedMoEBlockedF8Builder
+ return TritonFusedMoEBlockedF8Builder
+ elif layer_type == OpType.LinearBlockedF8:
+ from .blockedf8_modules import TritonLinearBlockedF8Builder
+ return TritonLinearBlockedF8Builder
else:
logger.debug(
f'Op {layer_type} fallback to default implementation.')
diff --git a/lmdeploy/pytorch/backends/cuda/qmodules.py b/lmdeploy/pytorch/backends/cuda/qmodules.py
index 30f729a63f..13d9a47ddf 100644
--- a/lmdeploy/pytorch/backends/cuda/qmodules.py
+++ b/lmdeploy/pytorch/backends/cuda/qmodules.py
@@ -15,42 +15,62 @@
class TritonRMSNormW8A8Impl(RMSNormW8A8Impl):
"""triton RMS norm w8a8 implementation api."""
- def __init__(self, hidden_size: int, eps: float = 1e-6):
+ def __init__(self,
+ hidden_size: int,
+ eps: float = 1e-6,
+ quant_dtype: torch.dtype = torch.int8):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
+ self.quant_dtype = quant_dtype
def forward(self,
x: torch.Tensor,
weight: torch.Tensor,
residual: torch.Tensor = None):
"""forward."""
- if residual is not None:
- x = x + residual
- residual = x
- hidden_states_quant, rms_scale = rms_norm_dynamic_quant(
- x, weight, self.eps)
- x = QTensor(hidden_states_quant, rms_scale)
if residual is None:
+ (x,
+ rms_scale) = rms_norm_dynamic_quant(x,
+ weight,
+ self.eps,
+ quant_dtype=self.quant_dtype)
+ x = QTensor(x, rms_scale)
return x
- return x, residual
+ else:
+ (x, rms_scale,
+ residual) = rms_norm_dynamic_quant(x,
+ weight,
+ self.eps,
+ residual=residual,
+ quant_dtype=self.quant_dtype)
+ x = QTensor(x, rms_scale)
+ return x, residual
class TritonRMSNormBuilder(RMSNormW8A8Builder):
"""triton RMS norm w8a8 implementation builder."""
@staticmethod
- def build(hidden_size: int, eps: float = 1e-6):
+ def build(hidden_size: int,
+ eps: float = 1e-6,
+ quant_dtype: torch.dtype = torch.int8):
"""build."""
- return TritonRMSNormW8A8Impl(hidden_size, eps)
+ return TritonRMSNormW8A8Impl(hidden_size, eps, quant_dtype)
class TritonLinearW8A8Impl(LinearW8A8Impl):
"""triton linear w8a8 implementation."""
- def __init__(self, in_features: int, out_features: int):
+ def __init__(self,
+ in_features: int,
+ out_features: int,
+ out_dtype: torch.dtype = torch.float16,
+ quant_dtype: torch.dtype = torch.int8):
self.in_features = in_features
self.out_features = out_features
+ self.out_dtype = out_dtype
+ self.quant_dtype = quant_dtype
def forward(self,
x,
@@ -60,8 +80,8 @@ def forward(self,
all_reduce: bool = False):
"""forward."""
if isinstance(x, torch.Tensor):
- x = x.contiguous()
- input_quant, input_scale = per_token_quant_int8(x, 1e-7)
+ input_quant, input_scale = per_token_quant_int8(
+ x, 1e-7, quant_dtype=self.quant_dtype)
else:
assert isinstance(x, QTensor)
input_quant, input_scale = x.tensor, x.scale
@@ -70,7 +90,7 @@ def forward(self,
weight,
input_scale,
scale,
- output_dtype=torch.float16,
+ output_dtype=self.out_dtype,
bias=bias)
if all_reduce:
@@ -85,6 +105,10 @@ class TritonLinearW8A8Builder(LinearW8A8Builder):
def build(in_features: int,
out_features: int,
bias: bool = True,
- dtype: torch.dtype = None):
+ dtype: torch.dtype = None,
+ quant_dtype: torch.dtype = torch.int8):
"""build."""
- return TritonLinearW8A8Impl(in_features, out_features)
+ return TritonLinearW8A8Impl(in_features,
+ out_features,
+ dtype,
+ quant_dtype=quant_dtype)
diff --git a/lmdeploy/pytorch/backends/dlinfer/__init__.py b/lmdeploy/pytorch/backends/dlinfer/__init__.py
index af3ccff085..1cf6eea440 100644
--- a/lmdeploy/pytorch/backends/dlinfer/__init__.py
+++ b/lmdeploy/pytorch/backends/dlinfer/__init__.py
@@ -1,3 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .ascend import AscendOpsBackend # noqa: F401
+from .camb import CambOpsBackend # noqa: F401
from .maca import MacaOpsBackend # noqa: F401
diff --git a/lmdeploy/pytorch/backends/dlinfer/attention.py b/lmdeploy/pytorch/backends/dlinfer/attention.py
index 6b03403c84..4782ee11ff 100644
--- a/lmdeploy/pytorch/backends/dlinfer/attention.py
+++ b/lmdeploy/pytorch/backends/dlinfer/attention.py
@@ -16,6 +16,7 @@ class DlinferAttentionMetadata(AttentionMetadata):
max_q_seq_len: int = 1
max_kv_seq_len: int = 1
quant_meta: Dict = None
+ cu_seq_lens_kv: Optional[Tensor] = None
class DlinferAttentionImpl(AttentionImpl[DlinferAttentionMetadata]):
@@ -79,6 +80,8 @@ def forward(
max_q_seq_len = attn_metadata.max_q_seq_len
max_kv_seq_len = attn_metadata.max_kv_seq_len
quant_bits = attn_metadata.quant_policy
+ cu_seq_lens_kv = attn_metadata.cu_seq_lens_kv
+
if attn_metadata.quant_meta is not None:
k_scales_zeros = [
next(attn_metadata.quant_meta['k_scales']),
@@ -128,6 +131,7 @@ def forward(
q_start_loc=q_start_loc,
q_seqlens=q_seqlens,
kv_seqlens=kv_seqlens,
+ cu_seq_lens_kv=cu_seq_lens_kv,
max_q_seq_len=max_q_seq_len,
max_kv_seq_len=max_kv_seq_len,
is_decoding=is_decoding,
diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/__init__.py b/lmdeploy/pytorch/backends/dlinfer/camb/__init__.py
new file mode 100644
index 0000000000..897495c209
--- /dev/null
+++ b/lmdeploy/pytorch/backends/dlinfer/camb/__init__.py
@@ -0,0 +1,2 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .op_backend import CambOpsBackend # noqa: F401
diff --git a/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py
new file mode 100644
index 0000000000..89c71f46fb
--- /dev/null
+++ b/lmdeploy/pytorch/backends/dlinfer/camb/op_backend.py
@@ -0,0 +1,132 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple
+
+import torch
+
+from lmdeploy.pytorch.config import BackendConfig, CacheConfig, ModelConfig
+from lmdeploy.utils import get_logger
+
+from ..op_backend import DlinferOpsBackend
+
+logger = get_logger('lmdeploy')
+
+
+class CambOpsBackend(DlinferOpsBackend):
+ """camb layer backend."""
+ total_slots = None
+
+ @staticmethod
+ def get_name() -> str:
+ """backend name."""
+ return 'camb'
+
+ @staticmethod
+ def get_k_block_shape(
+ block_size: int,
+ num_heads: int,
+ head_size: int,
+ dtype: torch.dtype,
+ ) -> Tuple[int, ...]:
+ return (
+ num_heads,
+ block_size,
+ head_size,
+ )
+
+ @staticmethod
+ def get_v_block_shape(
+ block_size: int,
+ num_heads: int,
+ head_size: int,
+ dtype: torch.dtype,
+ ) -> Tuple[int, ...]:
+ return (
+ num_heads,
+ block_size,
+ head_size,
+ )
+
+ @classmethod
+ def update_step_context(cls, step_context):
+ """update step context."""
+
+ def get_total_slots():
+ if cls.total_slots is None:
+ cls.total_slots = torch.arange(
+ block_num * block_size,
+ dtype=torch.int32,
+ device=step_context.block_offsets.device)
+ cls.total_slots = cls.total_slots.view(block_num, block_size)
+ return cls.total_slots
+
+ kv_start_indices = []
+ block_num, _, block_size, _ = step_context.kv_caches[0][0].shape
+
+ is_unpaged_prefill = False
+ q_start_loc = step_context.q_start_loc
+ q_seqlens = step_context.q_seqlens
+ kv_seqlens = step_context.kv_seqlens.to(torch.int32)
+ block_offsets = step_context.block_offsets.to(torch.int32)
+ max_q_seq_len = torch.max(q_seqlens).cpu().item()
+ max_kv_seq_len = torch.max(kv_seqlens).cpu().item()
+
+ cu_seqlens = torch.cat(
+ (q_start_loc, q_seqlens.sum().unsqueeze(0))).int()
+ cu_seq_lens_kv = None
+
+ q_seqlens_list = step_context.q_seqlens.tolist()
+ kv_seqlens_list = step_context.kv_seqlens.tolist()
+ if not step_context.is_decoding:
+ is_unpaged_prefill = q_seqlens_list == kv_seqlens_list
+ # get kv_indices
+ for i in range(q_start_loc.size(0)):
+ q_seq_len = q_seqlens_list[i]
+ kv_seq_len = kv_seqlens_list[i]
+ # collect kv start indices.
+ history_length = kv_seq_len - q_seq_len
+ total_slots = get_total_slots()
+ slot_tables = total_slots[block_offsets[i]].view(-1)
+ slots = slot_tables[history_length:kv_seq_len]
+ kv_start_indices.append(slots)
+ kv_start_indices = torch.cat(kv_start_indices)
+ if not is_unpaged_prefill:
+ cu_seq_lens_kv = torch.cat(
+ (torch.tensor([0], device=kv_seqlens.device),
+ kv_seqlens.cumsum(0))).int()
+ else:
+ # collect kv_start_indices without using a for-loop,
+ # (fill kv-cache for just ONE token during the decoding phase)
+ idx = (step_context.kv_seqlens - 1) % block_size
+ block_num = (step_context.kv_seqlens - 1) // block_size
+ last_block = block_offsets.gather( # dtype of gather must be int64
+ 1, block_num.view(-1, 1)).view(-1)
+ kv_start_indices = (last_block * block_size + idx).to(torch.int32)
+
+ attn_meta_cls = cls.get_attention_metadata_cls()
+ attn_metadata = attn_meta_cls(
+ step_context.is_decoding,
+ block_offsets,
+ q_start_loc=cu_seqlens,
+ cu_seq_lens_kv=cu_seq_lens_kv,
+ q_seqlens=q_seqlens,
+ kv_seqlens=kv_seqlens,
+ kv_start_indices=kv_start_indices,
+ block_size=block_size,
+ attention_mask=None,
+ is_unpaged_prefill=is_unpaged_prefill,
+ max_q_seq_len=max_q_seq_len,
+ max_kv_seq_len=max_kv_seq_len,
+ )
+
+ step_context.attn_metadata = attn_metadata
+ return step_context
+
+ @staticmethod
+ def build_graph_runner(model: torch.nn.Module, model_config: ModelConfig,
+ cache_config: CacheConfig,
+ backend_config: BackendConfig,
+ device: torch.device):
+ """build graph runner."""
+ from lmdeploy.pytorch.backends.cuda.graph_runner import CUDAGraphRunner
+ return CUDAGraphRunner(model, model_config, cache_config,
+ backend_config, device)
diff --git a/lmdeploy/pytorch/backends/moe.py b/lmdeploy/pytorch/backends/moe.py
index 8e7977625e..0437a0a2b0 100644
--- a/lmdeploy/pytorch/backends/moe.py
+++ b/lmdeploy/pytorch/backends/moe.py
@@ -60,3 +60,94 @@ class FusedMoEBuilder(ABC):
def build(top_k: int, num_experts: int, renormalize: bool = False):
"""build from mlp."""
raise NotImplementedError
+
+
+class FusedMoEW8A8Impl(ABC):
+ """fused moe w8a8 implementation."""
+
+ def update_weights(self, gate_up_weights: torch.Tensor,
+ down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
+ down_scale: torch.Tensor):
+ """update weights."""
+ return gate_up_weights, down_weights, gate_up_scale, down_scale
+
+ def support_ep(self):
+ """support expert parallelism."""
+ return False
+
+ def ep_expert_list(self, world_size: int, rank: int):
+ """experts list of current rank."""
+ raise NotImplementedError('Not Implemented.')
+
+ @abstractmethod
+ def forward(self,
+ hidden_states: torch.Tensor,
+ input_scale: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.LongTensor,
+ gate_up_weights: torch.Tensor,
+ gate_up_scale: torch.Tensor,
+ down_weights: torch.Tensor,
+ down_scale: torch.Tensor,
+ expert_list: List[int] = None):
+ """forward."""
+ raise NotImplementedError
+
+
+class FusedMoEW8A8Builder(ABC):
+ """fused moe w8a8 builder."""
+
+ @staticmethod
+ @abstractmethod
+ def build(top_k: int,
+ num_experts: int,
+ renormalize: bool = False,
+ out_dtype: torch.dtype = torch.float16,
+ quant_dtype: torch.dtype = torch.int8):
+ """build from mlp."""
+ raise NotImplementedError
+
+
+class FusedMoEBlockedF8Impl(ABC):
+ """fused moe blocked f8 implementation."""
+
+ def update_weights(self, gate_up_weights: torch.Tensor,
+ down_weights: torch.Tensor, gate_up_scale: torch.Tensor,
+ down_scale: torch.Tensor):
+ """update weights."""
+ return gate_up_weights, down_weights, gate_up_scale, down_scale
+
+ def support_ep(self):
+ """support expert parallelism."""
+ return False
+
+ def ep_expert_list(self, world_size: int, rank: int):
+ """experts list of current rank."""
+ raise NotImplementedError('Not Implemented.')
+
+ @abstractmethod
+ def forward(self,
+ hidden_states: torch.Tensor,
+ input_scale: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.LongTensor,
+ gate_up_weights: torch.Tensor,
+ gate_up_scale: torch.Tensor,
+ down_weights: torch.Tensor,
+ down_scale: torch.Tensor,
+ expert_list: List[int] = None):
+ """forward."""
+ raise NotImplementedError
+
+
+class FusedMoEBlockedF8Builder(ABC):
+ """fused moe blocked f8 builder."""
+
+ @staticmethod
+ @abstractmethod
+ def build(top_k: int,
+ num_experts: int,
+ renormalize: bool = False,
+ out_dtype: torch.dtype = torch.float16):
+ """build from mlp."""
+ raise NotImplementedError
diff --git a/lmdeploy/pytorch/backends/qmodules.py b/lmdeploy/pytorch/backends/qmodules.py
index a61941b37d..e877a4ca6b 100644
--- a/lmdeploy/pytorch/backends/qmodules.py
+++ b/lmdeploy/pytorch/backends/qmodules.py
@@ -37,7 +37,9 @@ class RMSNormW8A8Builder(ABC):
@staticmethod
@abstractmethod
- def build(hidden_size: int, eps: float = 1e-6):
+ def build(hidden_size: int,
+ eps: float = 1e-6,
+ quant_dtype: torch.dtype = torch.int8):
"""build."""
raise NotImplementedError
@@ -71,6 +73,7 @@ class LinearW8A8Builder(ABC):
def build(in_features: int,
out_features: int,
bias: bool = True,
- dtype: torch.dtype = None):
+ dtype: torch.dtype = None,
+ quant_dtype: torch.dtype = torch.int8):
"""build."""
raise NotImplementedError
diff --git a/lmdeploy/pytorch/backends/selector.py b/lmdeploy/pytorch/backends/selector.py
index 987730a981..4db73fa370 100644
--- a/lmdeploy/pytorch/backends/selector.py
+++ b/lmdeploy/pytorch/backends/selector.py
@@ -18,5 +18,8 @@ def get_backend():
if device_type == 'maca':
from .dlinfer import MacaOpsBackend
return MacaOpsBackend
+ if device_type == 'camb':
+ from .dlinfer import CambOpsBackend
+ return CambOpsBackend
else:
raise RuntimeError(f'Unsupported device type: {device_type}')
diff --git a/lmdeploy/pytorch/check_env/deeplink.py b/lmdeploy/pytorch/check_env/deeplink.py
index 74ab5a7b87..00bcfdf77c 100644
--- a/lmdeploy/pytorch/check_env/deeplink.py
+++ b/lmdeploy/pytorch/check_env/deeplink.py
@@ -5,6 +5,7 @@
'ascend',
'npu',
'maca',
+ 'camb',
]
diff --git a/lmdeploy/pytorch/check_env/model.py b/lmdeploy/pytorch/check_env/model.py
index 4b721e50e2..79d8d26e3c 100644
--- a/lmdeploy/pytorch/check_env/model.py
+++ b/lmdeploy/pytorch/check_env/model.py
@@ -72,33 +72,6 @@ def check_dtype(self, config):
'Please send issue to LMDeploy with error logs.')
self.log_and_exit(e, 'Model', message=message)
- def check_awq(self, config):
- """check awq."""
- logger = self.get_logger()
- device_type = self.device_type
- if device_type != 'cuda':
- return
-
- quantization_config = getattr(config, 'quantization_config', dict())
- quant_method = quantization_config.get('quant_method', None)
- if quant_method != 'awq':
- return
- try:
- import awq # noqa
- except Exception as e:
- self.log_and_exit(e, 'autoawq', logger)
-
- try:
- import awq_ext # noqa
- except Exception as e:
- logger.debug('Exception:', exc_info=1)
- self.log_and_exit(
- e,
- 'awq_ext',
- message='Failed to import `awq_ext`. '
- 'Try reinstall it from source: '
- 'https://github.com/casper-hansen/AutoAWQ_kernels')
-
def check(self):
"""check."""
import transformers
@@ -112,6 +85,3 @@ def check(self):
# dtype check
self.check_dtype(config)
-
- # awq
- self.check_awq(config)
diff --git a/lmdeploy/pytorch/configurations/deepseek_v2.py b/lmdeploy/pytorch/configurations/deepseek_v2.py
index d1f0844ad5..bf06ff0c33 100644
--- a/lmdeploy/pytorch/configurations/deepseek_v2.py
+++ b/lmdeploy/pytorch/configurations/deepseek_v2.py
@@ -9,7 +9,7 @@ class DeepseekV2ModelConfigBuilder(AutoModelConfigBuilder):
@classmethod
def condition(cls, hf_config):
"""config."""
- return hf_config.model_type == 'deepseek_v2'
+ return hf_config.model_type in ['deepseek_v3', 'deepseek_v2']
@classmethod
def build(cls, hf_config, model_path: str = None, **kwargs):
diff --git a/lmdeploy/pytorch/engine/engine.py b/lmdeploy/pytorch/engine/engine.py
index e06e0cf80a..a674d609af 100644
--- a/lmdeploy/pytorch/engine/engine.py
+++ b/lmdeploy/pytorch/engine/engine.py
@@ -30,24 +30,13 @@
_EMPTY_TOKEN = np.empty((0, ), dtype=np.int64)
-def _raise_exception_on_finish(task: asyncio.Task) -> None:
- """raise exception on finish."""
- try:
- task.result()
- except asyncio.CancelledError:
- return
- except Exception as e:
- raise e
-
-
@dataclass
class InferOutput:
"""The output of the model inference."""
session_id: int
+ resp: Response
token_ids: List[int]
- sender_id: int
- req_id: int
meta: Any = None
finish: bool = False
logits: torch.Tensor = None
@@ -250,7 +239,7 @@ def _build_adapter_manager(self, adapters):
def _bind_request_manager(self):
"""bind request manager."""
- req_manager = RequestManager(self.engine_config.thread_safe)
+ req_manager = RequestManager()
req_manager.bind_func(RequestType.ADD_SESSION, self._on_add_session)
req_manager.bind_func(RequestType.STOP_SESSION, self._on_stop_session)
req_manager.bind_func(RequestType.END_SESSION, self._on_end_session)
@@ -262,18 +251,15 @@ def _start_loop(self):
return self.req_manager.start_loop(self.async_loop)
def _response(self,
+ resp: Response,
resp_type: ResponseType,
- sender_id: int,
- req_id: int,
data: Any = None,
err_msg: str = ''):
"""response."""
- self.req_manager.response(
- Response(type=resp_type,
- sender_id=sender_id,
- req_id=req_id,
- data=data,
- err_msg=err_msg))
+ resp.type = resp_type
+ resp.data = data
+ resp.err_msg = err_msg
+ self.req_manager.response(resp)
def _get_max_session_len(self):
"""get max session len."""
@@ -299,7 +285,7 @@ def _on_add_session(self, reqs: Request, **kwargs):
self.scheduler.add_session(session_id)
resp_type = ResponseType.SUCCESS
if resp:
- self._response(resp_type, req.sender_id, req.req_id)
+ self._response(req.resp, resp_type)
def _on_stop_session(self, reqs: Request, **kwargs):
"""on stop session callback."""
@@ -311,7 +297,7 @@ def _on_stop_session(self, reqs: Request, **kwargs):
self.scheduler.stop_session(session_id)
resp_type = ResponseType.SUCCESS
if resp:
- self._response(resp_type, req.sender_id, req.req_id)
+ self._response(req.resp, resp_type)
def _on_end_session(self, reqs: Request, **kwargs):
"""on end session callback."""
@@ -323,14 +309,35 @@ def _on_end_session(self, reqs: Request, **kwargs):
self.scheduler.end_session(session_id)
resp_type = ResponseType.SUCCESS
if resp:
- self._response(resp_type, req.sender_id, req.req_id)
+ self._response(req.resp, resp_type)
def _on_add_message(self, reqs: Request, **kwargs):
"""on add message callback."""
+ for req in reqs:
+ req_data = req.data
+ if req_data.get('input_multimodals', None) is None:
+ continue
+ elif self.input_processor is None:
+ logger.warning('Do not support Multimodal inputs.')
+ continue
+ input_ids = req_data['token_ids']
+ input_multimodals = req_data['input_multimodals']
+ if len(input_multimodals) == 0:
+ req_data['input_multimodals'] = None
+ continue
+ result = self.input_processor.preprocess_input(
+ input_ids, input_multimodals)
+
+ input_ids = result.input_ids
+ input_multimodals = result.input_multimodals
+
+ req_data['token_ids'] = input_ids
+ req_data['input_multimodals'] = input_multimodals
- self._msg_preprocess_inque.put_nowait(reqs)
+ if len(reqs) > 0:
+ self._add_message(reqs)
- def _add_message(self, que):
+ def _add_message(self, reqs):
def __update_bad_words(msg):
"""update bad words."""
@@ -353,16 +360,10 @@ def __update_max_new_tokens(msg):
sampling_param.max_new_tokens,
max_session_len - msg.num_all_tokens())
- if que.qsize() == 0:
- return
-
- reqs = que.get_nowait()
-
for req in reqs:
session_id = req.data['session_id']
if session_id not in self.scheduler.sessions:
- self._response(ResponseType.SESSION_NOT_EXIST, req.sender_id,
- req.req_id)
+ self._response(req.resp, ResponseType.SESSION_NOT_EXIST)
continue
session_id = req.data['session_id']
sess = self.scheduler.sessions[session_id]
@@ -396,8 +397,7 @@ def __update_max_new_tokens(msg):
__update_bad_words(msg)
__update_max_new_tokens(msg)
- msg.sender_id = req.sender_id
- msg.req_id = req.req_id
+ msg.resp = req.resp
@property
def model_config(self):
@@ -553,11 +553,12 @@ def _batch_stopping_criteria(self, token_ids: torch.Tensor,
return stopped, num_appendable_ids
@logging_timer('SamplingLogits', logger)
- def async_sampling_logits(self, logits: torch.Tensor,
- all_ids: torch.Tensor,
- guided_input_ids: torch.Tensor,
- sampling_inputs: SamplingInputs,
- inputs: ModelInputs, ignore_eos: torch.Tensor):
+ async def async_sampling_logits(self, logits: torch.Tensor,
+ all_ids: torch.Tensor,
+ guided_input_ids: torch.Tensor,
+ sampling_inputs: SamplingInputs,
+ inputs: ModelInputs,
+ ignore_eos: torch.Tensor):
"""sampling logits."""
def __get_last_logits():
@@ -572,7 +573,8 @@ def __get_last_logits():
split_logits = __get_last_logits()
logits_processor = FusedLogitsProcessor(sampling_inputs, ignore_eos,
self.tokenizer.model.model)
- logits = logits_processor(all_ids, guided_input_ids, split_logits)
+ logits = await logits_processor(all_ids, guided_input_ids,
+ split_logits)
next_token_ids = logits_processor.sampling(logits)
return next_token_ids
@@ -585,13 +587,11 @@ def update_running(self, running: SeqList, next_token_ids: torch.Tensor,
if model_metas is None:
model_metas = [None] * len(running)
next_token_ids = next_token_ids.numpy()
- eos_token_id = self.model_config.eos_token_id
for token, msg, stop, model_meta in zip(next_token_ids, running,
stopped, model_metas):
if msg.status != MessageStatus.RUNNING:
continue
update_token = token
- stop = stop or token in eos_token_id
if stop:
update_token = _EMPTY_TOKEN
else:
@@ -697,15 +697,6 @@ async def _make_infer_outputs(self, next_token_ids: torch.LongTensor,
event: torch.cuda.Event):
"""make infer output."""
- def __get_out_token_ids(token: torch.Tensor, msg: SchedulerSequence,
- stopped: bool):
- """check if output is necessary."""
- if stopped:
- return []
- if token in msg.sampling_param.stop_words:
- return []
- return [token]
-
def __get_q_start_loc():
inputs = self._inputs
seq_length = inputs.seq_length
@@ -733,16 +724,15 @@ def __get_q_start_loc():
for idx, msg in enumerate(running):
if not is_run[idx]:
continue
- token_ids = __get_out_token_ids(next_token_ids[idx], msg,
- stopped[idx])
+ token_ids = msg.all_ids[-msg.num_new_tokens:]
finish = msg.status == MessageStatus.STOPPED
if not finish and len(token_ids) == 0:
continue
session_id = msg.session_id
+ resp = msg.resp
out = InferOutput(
session_id=session_id,
- sender_id=msg.sender_id,
- req_id=msg.req_id,
+ resp=resp,
finish=finish,
token_ids=token_ids,
)
@@ -803,7 +793,7 @@ def __update_inputs(next_token_ids):
logits = logits[0] # [bs, seq, prob] -> [seq, prob]
# sampling
- next_token_ids = self.async_sampling_logits(
+ next_token_ids = await self.async_sampling_logits(
logits, all_ids, guided_input_ids, sampling_inputs, inputs,
num_ignore_eos > 0)
num_ignore_eos = num_ignore_eos - 1
@@ -832,39 +822,28 @@ def __update_inputs(next_token_ids):
swap_out_map = dict()
__update_inputs(next_token_ids)
+ def _set_has_runable_event(self, has_runable_event: asyncio.Event):
+ """set has runable event."""
+ if self.scheduler.has_unfinished():
+ has_runable_event.set()
+ else:
+ has_runable_event.clear()
+
@torch.inference_mode()
- async def _async_loop_preprocess_message(self, inque, outque):
+ async def _async_loop_preprocess_message(self,
+ forward_event: asyncio.Event,
+ has_runable_event: asyncio.Event):
"""preprocess msg."""
while True:
- reqs = await inque.get()
-
- for req in reqs:
- req_data = req.data
- if req_data.get('input_multimodals', None) is None:
- continue
- elif self.input_processor is None:
- logger.warning('Do not support Multimodal inputs.')
- continue
- input_ids = req_data['token_ids']
- input_multimodals = req_data['input_multimodals']
- if len(input_multimodals) == 0:
- req_data['input_multimodals'] = None
- continue
- result = self.input_processor.preprocess_input(
- input_ids, input_multimodals)
-
- input_ids = result.input_ids
- input_multimodals = result.input_multimodals
-
- req_data['token_ids'] = input_ids
- req_data['input_multimodals'] = input_multimodals
-
- if len(reqs) > 0:
- outque.put_nowait(reqs)
+ if self.scheduler.has_unfinished():
+ await forward_event.wait()
+ await self.req_manager.step()
+ self._set_has_runable_event(has_runable_event)
@torch.inference_mode()
async def _async_loop_background(self, in_que: asyncio.Queue,
- out_que: asyncio.Queue):
+ out_que: asyncio.Queue,
+ forward_event: asyncio.Event):
"""async loop background."""
def __gather_all_ids(seqs: SeqList, sampling_inputs: SamplingInputs):
@@ -925,76 +904,52 @@ def __need_logits(seqs: SeqList):
while True:
is_prefill, scheduler_output = await in_que.get()
- try:
- running = scheduler_output.running
- swap_in_map = scheduler_output.swap_in_map
- swap_out_map = scheduler_output.swap_out_map
- prefill_interval = self.scheduler_config.prefill_interval
- loop_count = 1 if is_prefill else (prefill_interval - 1)
- assert len(running) > 0
-
- # create inputs
- inputs = self.create_model_inputs(running, is_prefill)
- sampling_inputs = SamplingInputs.from_sampling_params(running)
- all_ids = __gather_all_ids(running, sampling_inputs)
- guided_input_ids = __gather_guided_input_ids(
- running, sampling_inputs)
- num_appendable_ids = __get_num_appendable_ids(running)
- num_ignore_eos = __get_num_ignore_eos(running)
- return_logits = __need_logits(running)
-
- self._running = running
- self._inputs = inputs
-
- await self._async_step_background(
- inputs=inputs,
- swap_in_map=swap_in_map,
- swap_out_map=swap_out_map,
- all_ids=all_ids,
- guided_input_ids=guided_input_ids,
- sampling_inputs=sampling_inputs,
- num_appendable_ids=num_appendable_ids,
- num_ignore_eos=num_ignore_eos,
- loop_count=loop_count,
- return_logits=return_logits,
- output_que=out_que,
- )
- except Exception as e:
- out_que.put_nowait((True, e))
- finally:
- in_que.task_done()
-
- @torch.inference_mode()
- async def _async_loop(self):
- """Main loop of the engine.
-
- Each engine instance would communicate with the engine by queue.
- """
-
- self._msg_preprocess_inque = asyncio.Queue()
- self._msg_preprocess_outque = asyncio.Queue()
+ running = scheduler_output.running
+ swap_in_map = scheduler_output.swap_in_map
+ swap_out_map = scheduler_output.swap_out_map
+ prefill_interval = self.scheduler_config.prefill_interval
+ loop_count = 1 if is_prefill else (prefill_interval - 1)
+ assert len(running) > 0
+
+ # create inputs
+ inputs = self.create_model_inputs(running, is_prefill)
+ sampling_inputs = SamplingInputs.from_sampling_params(running)
+ all_ids = __gather_all_ids(running, sampling_inputs)
+ guided_input_ids = __gather_guided_input_ids(
+ running, sampling_inputs)
+ num_appendable_ids = __get_num_appendable_ids(running)
+ num_ignore_eos = __get_num_ignore_eos(running)
+ return_logits = __need_logits(running)
+
+ self._running = running
+ self._inputs = inputs
+
+ forward_event.clear()
+ await self._async_step_background(
+ inputs=inputs,
+ swap_in_map=swap_in_map,
+ swap_out_map=swap_out_map,
+ all_ids=all_ids,
+ guided_input_ids=guided_input_ids,
+ sampling_inputs=sampling_inputs,
+ num_appendable_ids=num_appendable_ids,
+ num_ignore_eos=num_ignore_eos,
+ loop_count=loop_count,
+ return_logits=return_logits,
+ output_que=out_que,
+ )
+ forward_event.set()
- prefill_interval = self.scheduler_config.prefill_interval
- in_que = asyncio.Queue()
- out_que = asyncio.Queue()
- loop_background = asyncio.get_event_loop().create_task(
- self._async_loop_background(in_que, out_que),
- name='MainLoopBackground')
- loop_background.add_done_callback(_raise_exception_on_finish)
-
- loop_msg_proc = asyncio.get_event_loop().create_task(
- self._async_loop_preprocess_message(self._msg_preprocess_inque,
- self._msg_preprocess_outque),
- name='MainLoopPreprocessMessage')
- loop_msg_proc.add_done_callback(_raise_exception_on_finish)
+ async def _async_send_responses(self, que: asyncio.Queue,
+ forward_event: asyncio.Event):
+ """send responses."""
def __send_resp(out: InferOutput):
"""send response."""
resp_type = (ResponseType.FINISH
if out.finish else ResponseType.SUCCESS)
- self._response(resp_type,
- sender_id=out.sender_id,
- req_id=out.req_id,
+ self._response(out.resp,
+ resp_type,
data=dict(token_ids=out.token_ids,
logits=out.logits))
@@ -1003,9 +958,89 @@ def __send_resps(step_outputs: Dict[int, InferOutput]):
for out in step_outputs.values():
__send_resp(out)
+ while True:
+ resps = await que.get()
+ if self.scheduler.has_unfinished():
+ await forward_event.wait()
+ __send_resps(resps)
+
+ @staticmethod
+ def _add_loop_tasks_done_callback(tasks: List[asyncio.Task]):
+ """add loop tasks done callback."""
+
+ def __task_callback(task: asyncio.Task) -> None:
+ """raise exception on finish."""
+ task_name = task.get_name()
+ try:
+ task.result()
+ except asyncio.CancelledError:
+ logger.debug(f'Task <{task_name}> cancelled.')
+ return
+ except Exception:
+ logger.exception(f'Task <{task_name}> failed')
+ for task in tasks:
+ if not task.cancelled():
+ task.cancel()
+
+ for task in tasks:
+ task.add_done_callback(__task_callback)
+
+ @torch.inference_mode()
+ async def _async_loop(self):
+ """Main loop of the engine.
+
+ Each engine instance would communicate with the engine by queue.
+ """
+ event_loop = asyncio.get_event_loop()
+ prefill_interval = self.scheduler_config.prefill_interval
+
+ # forward task
+ in_que = asyncio.Queue()
+ out_que = asyncio.Queue()
+ forward_event = asyncio.Event()
+ forward_event.set()
+ loop_background = event_loop.create_task(self._async_loop_background(
+ in_que, out_que, forward_event),
+ name='MainLoopBackground')
+
+ # preprocess task
+ has_runable_event = asyncio.Event()
+ loop_msg_proc = event_loop.create_task(
+ self._async_loop_preprocess_message(forward_event,
+ has_runable_event),
+ name='MainLoopPreprocessMessage')
+
+ # response task
+ resp_que = asyncio.Queue()
+ loop_send_resp = event_loop.create_task(self._async_send_responses(
+ resp_que, forward_event),
+ name='MainLoopResponse')
+
+ loop_main = asyncio.current_task()
+ loop_tasks: List[asyncio.Task] = [
+ loop_main, loop_background, loop_msg_proc, loop_send_resp
+ ]
+ self._add_loop_tasks_done_callback(loop_tasks)
+
+ def __do_prefill():
+ # decoding if no waiting
+ if not self.scheduler.has_waiting():
+ return False
+ num_running = self.scheduler.num_running()
+ num_waiting = self.scheduler.num_waiting()
+ max_batches = self.scheduler_config.max_batches
+ # prefill if too much waiting
+ if num_waiting >= 4:
+ return True
+ # prefill if no enough running
+ if num_running < max_batches * 0.5:
+ return True
+ # decoding
+ return False
+
async def __step():
"""step decoding."""
- prefill = self.scheduler.has_waiting()
+ prefill = __do_prefill()
schedule_output = self.scheduler.schedule(
is_prefill=prefill, prealloc_size=prefill_interval)
# schedule decoding if no valid prefill reqs.
@@ -1017,31 +1052,13 @@ async def __step():
in_que.put_nowait((prefill, schedule_output))
finish = False
while not finish:
- if self.req_manager.has_requests():
- self.req_manager.step()
- self._add_message(self._msg_preprocess_outque)
finish, out = await out_que.get()
- try:
- if isinstance(out, Exception):
- raise out
- (next_token_ids, logits, stopped, model_metas, event) = out
- step_outputs = await self._make_infer_outputs(
- next_token_ids, logits, stopped, model_metas, event)
- __send_resps(step_outputs)
- except Exception as e:
- raise e
- finally:
- out_que.task_done()
+ step_outputs = await self._make_infer_outputs(*out)
+ self._set_has_runable_event(has_runable_event)
+ resp_que.put_nowait(step_outputs)
while True:
- if self.req_manager.has_requests():
- self.req_manager.step()
- self._add_message(self._msg_preprocess_outque)
-
- if not self.scheduler.has_unfinished():
- await asyncio.sleep(0.01)
- continue
-
+ await has_runable_event.wait()
await __step()
async def async_loop(self):
diff --git a/lmdeploy/pytorch/engine/engine_checker.py b/lmdeploy/pytorch/engine/engine_checker.py
index 7276a51fbc..5b0cc9865c 100644
--- a/lmdeploy/pytorch/engine/engine_checker.py
+++ b/lmdeploy/pytorch/engine/engine_checker.py
@@ -64,11 +64,13 @@ def __init__(self,
def check(self):
"""check."""
engine_config = self.engine_config
- logger = self.get_logger()
if engine_config.thread_safe:
- logger.warning('thread safe mode has been deprecated and'
- ' it would be removed in the future.')
+ self.log_and_exit(
+ mod_name='Engine',
+ message='thread safe mode is no longer supported.\n'
+ 'Read https://github.com/InternLM/lmdeploy/blob/main/docs/en/advance/pytorch_multithread.md for more details.', # noqa: E501
+ )
if engine_config.max_batch_size <= 0:
self.log_and_exit(
diff --git a/lmdeploy/pytorch/engine/engine_instance.py b/lmdeploy/pytorch/engine/engine_instance.py
index dff9667eb4..5cf1366783 100644
--- a/lmdeploy/pytorch/engine/engine_instance.py
+++ b/lmdeploy/pytorch/engine/engine_instance.py
@@ -43,8 +43,8 @@ async def async_try_add_session(req_sender: RequestSender, session_id: int):
async def async_end(req_sender: RequestSender, session_id: int):
"""End the given session."""
- await req_sender.async_send_async(
- RequestType.END_SESSION, dict(session_id=session_id, response=False))
+ req_sender.send_async(RequestType.END_SESSION,
+ dict(session_id=session_id, response=False))
async def async_cancel(req_sender: RequestSender, session_id: int):
@@ -140,9 +140,8 @@ async def async_stream_infer(self,
return
gen_config = gen_config or GenerationConfig()
sampling_param = SamplingParam.from_gen_config(gen_config=gen_config)
- await self.req_sender.async_send_async(
- RequestType.ADD_SESSION, dict(session_id=session_id,
- response=False))
+ self.req_sender.send_async(RequestType.ADD_SESSION,
+ dict(session_id=session_id, response=False))
msg = dict(
token_ids=input_ids,
session_id=session_id,
@@ -150,20 +149,16 @@ async def async_stream_infer(self,
adapter_name=adapter_name,
input_multimodals=multimodal,
)
- req_id = await self.req_sender.async_send_async(
- RequestType.ADD_MESSAGE, msg)
+ resp = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg)
- token_ids = []
while True:
- resp = await self.req_sender.async_recv(req_id)
+ resp = await self.req_sender.async_recv(resp)
- if resp.req_id != req_id:
- continue
if resp.type == ResponseType.SUCCESS:
- token_ids += resp.data['token_ids']
+ token_ids = resp.data['token_ids'].tolist()
yield EngineOutput(resp.type, token_ids, len(token_ids))
elif resp.type == ResponseType.FINISH:
- token_ids += resp.data['token_ids']
+ token_ids = resp.data['token_ids'].tolist()
yield EngineOutput(resp.type, token_ids, len(token_ids))
break
else:
@@ -240,39 +235,7 @@ def __call_async():
except StopAsyncIteration:
break
- if not self.req_sender.is_thread_safe():
- yield from __call_async()
- return
-
- gen_config = gen_config or GenerationConfig()
- sampling_param = SamplingParam.from_gen_config(gen_config=gen_config)
- self.req_sender.send_async(RequestType.ADD_SESSION,
- dict(session_id=session_id, response=False))
- msg = dict(
- token_ids=input_ids,
- session_id=session_id,
- sampling_param=sampling_param,
- adapter_name=adapter_name,
- input_multimodals=multimodal,
- )
- req_id = self.req_sender.send_async(RequestType.ADD_MESSAGE, msg)
-
- token_ids = []
- while True:
- resp = self.req_sender.recv(req_id)
-
- if resp.req_id != req_id:
- continue
- if resp.type == ResponseType.SUCCESS:
- token_ids += resp.data['token_ids']
- yield EngineOutput(resp.type, token_ids, len(token_ids))
- elif resp.type == ResponseType.FINISH:
- token_ids += resp.data['token_ids']
- yield EngineOutput(resp.type, token_ids, len(token_ids))
- break
- else:
- yield EngineOutput(resp.type, [], 0)
- break
+ yield from __call_async()
def infer(self,
session_id: int,
@@ -365,9 +328,9 @@ def __add_messages(session_ids, input_ids, adapter_names,
return_logits=True)
add_msgs.append(msg)
req_types = [RequestType.ADD_MESSAGE] * batch_size
- req_ids = self.req_sender.batched_send_async(req_types,
- data=add_msgs)
- return req_ids
+ resps = self.req_sender.batched_send_async(req_types,
+ data=add_msgs)
+ return resps
if steps is not None:
assert batch_size == len(steps)
@@ -384,21 +347,14 @@ def __add_messages(session_ids, input_ids, adapter_names,
dict(session_id=sid))
self._try_add_session(sid)
- req_ids = __add_messages(session_ids, input_ids, adapter_names,
- multimodal)
- req_idx_map = dict(zip(req_ids, range(len(req_ids))))
-
- finish_count = batch_size
- ret = [None] * batch_size
- while finish_count > 0:
- resp = self.req_sender.recv_any()
- if resp.req_id not in req_ids:
- continue
+ resps = __add_messages(session_ids, input_ids, adapter_names,
+ multimodal)
+ ret = []
+ for resp in resps:
+ resp = self.req_sender.recv(resp)
assert resp.type == ResponseType.FINISH
- idx = req_idx_map[resp.req_id]
- ret[idx] = resp.data['logits']
- finish_count -= 1
+ ret.append(resp.data['logits'])
ret = pad_sequence(ret, True)
diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py
index 24cb336d71..f7ca9c5116 100644
--- a/lmdeploy/pytorch/engine/logits_process.py
+++ b/lmdeploy/pytorch/engine/logits_process.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import asyncio
import json
from dataclasses import asdict, dataclass
from typing import Dict, List, Optional, Tuple
@@ -298,9 +299,15 @@ def __init__(self,
self.ignore_eos = ignore_eos
self.tokenizer = tokenizer
- def __call__(self, all_ids: torch.LongTensor,
- guided_input_ids: torch.LongTensor,
- scores: torch.FloatTensor) -> torch.FloatTensor:
+ async def _wait_stream_once(self):
+ """wait stream once."""
+ stream = torch.cuda.current_stream()
+ if not stream.query():
+ await asyncio.sleep(0)
+
+ async def __call__(self, all_ids: torch.LongTensor,
+ guided_input_ids: torch.LongTensor,
+ scores: torch.FloatTensor) -> torch.FloatTensor:
r"""
Args:
all_ids (torch.LongTensor): All the token ids.
@@ -320,6 +327,7 @@ def __call__(self, all_ids: torch.LongTensor,
custom_logits_processors = self.sampling_inputs.logits_processors
if any(custom_logits_processors):
+ await self._wait_stream_once()
scores = _apply_custom_logits_processors(custom_logits_processors,
all_ids, scores)
@@ -343,8 +351,10 @@ def __call__(self, all_ids: torch.LongTensor,
stop_mask = torch.where(self.ignore_eos[:, None], stop_mask, False)
scores = _process_bad_words_(scores, stop_words, stop_mask)
- scores = _guided_sampling(sampling_inputs.response_formats, scores,
- guided_input_ids, self.tokenizer)
+ if guided_input_ids is not None:
+ await self._wait_stream_once()
+ scores = _guided_sampling(sampling_inputs.response_formats, scores,
+ guided_input_ids, self.tokenizer)
return scores
@torch.inference_mode()
diff --git a/lmdeploy/pytorch/engine/request.py b/lmdeploy/pytorch/engine/request.py
index 18bd2193d4..0d20deb907 100644
--- a/lmdeploy/pytorch/engine/request.py
+++ b/lmdeploy/pytorch/engine/request.py
@@ -2,8 +2,6 @@
import asyncio
import enum
from dataclasses import dataclass, field
-from queue import Queue
-from threading import Lock, Thread
from typing import Any, Awaitable, Callable, Dict, List
from lmdeploy.messages import ResponseType
@@ -12,25 +10,6 @@
logger = get_logger('lmdeploy')
-def _raise_exception_on_finish(task: asyncio.Task) -> None:
- try:
- task.result()
- except asyncio.CancelledError:
- return
- except Exception as e:
- logger.exception(f'Engine loop failed with error: {e}')
-
-
-def _ignore_exception_on_finish(task: asyncio.Task) -> None:
- try:
- task.result()
- except asyncio.CancelledError:
- return
- except Exception as exc:
- logger.debug(f'task: {task.get_name()} ended.')
- logger.debug(f'task: {task.get_name()} exception: {exc}')
-
-
class RequestType(enum.Enum):
"""Request type."""
@@ -43,24 +22,24 @@ class RequestType(enum.Enum):
@dataclass
-class Request:
- """Request."""
+class Response:
+ """Response."""
- type: RequestType
+ type: ResponseType
sender_id: int
- req_id: int
+ event: asyncio.Event
data: Any = None
+ err_msg: str = ''
@dataclass
-class Response:
- """Response."""
+class Request:
+ """Request."""
- type: ResponseType
+ type: RequestType
sender_id: int
- req_id: int
data: Any = None
- err_msg: str = ''
+ resp: Response = None
ReqList = List[Request]
@@ -85,28 +64,20 @@ class RequestSender:
Args:
sender_id (int): The id of the sender
"""
-
sender_id: int
manager: 'RequestManager'
resp_dict: Dict[int, List[Response]] = field(default_factory=dict)
- _next_req_id: int = 0
_resp_que: asyncio.Queue = None
- _resp_thread_que: Queue = None
- _thread_safe: bool = False
@classmethod
def new(cls, sender_id: int, manager: 'RequestManager'):
"""new."""
obj = cls(sender_id=sender_id, manager=manager)
- obj._thread_safe = manager.is_thread_safe()
return obj
@property
def resp_que(self):
"""response queue."""
- thread_safe = self.is_thread_safe()
- if thread_safe:
- return self.manager.responses
if self._resp_que is not None:
return self._resp_que
if self.manager._loop_task is None:
@@ -119,27 +90,11 @@ def req_que(self):
"""request queue."""
return self.manager.requests
- @property
- def resp_thread_que(self):
- """response threadsafe queue."""
- if self._resp_thread_que is None:
- self._resp_thread_que = Queue()
- return self._resp_thread_que
-
- @property
- def req_thread_que(self):
- """request threadsafe queue."""
- return self.manager.thread_requests
-
@property
def event_loop(self):
"""get event loop."""
return self.manager.event_loop
- def is_thread_safe(self):
- """is thread safe."""
- return self._thread_safe
-
def is_loop_alive(self):
"""is loop alive."""
return self.manager.is_loop_alive()
@@ -148,203 +103,72 @@ def run_until_complete(self, future: Awaitable):
"""run untile complete."""
return self.manager.run_until_complete(future)
- def _resp_get(self):
- """resp_que.get."""
- timeout = 1.0
- que = self.resp_thread_que
- not_empty = que.not_empty
- with not_empty:
- while not que._qsize():
- not_empty.wait(timeout)
- return que.get_nowait()
-
- async def _async_resp_get(self):
- """get resp.
-
- Different behavior in threadsafe mode.
- """
- timeout = 1
-
- async def __no_threadsafe_get():
- while True:
- try:
- return await asyncio.wait_for(self.resp_que.get(), timeout)
- except asyncio.TimeoutError:
- if not self.manager.is_loop_alive():
- logger.debug('Engine loop is not alive.')
- exit(1)
- continue
- except Exception as e:
- logger.exception(
- f'sender[{self.sender_id}] get response failed: {e}')
- raise e
-
- if self.is_thread_safe():
- ret = self._resp_get()
- await asyncio.sleep(0)
- return ret
- else:
- return await __no_threadsafe_get()
-
def _req_put(self, reqs: Any):
- """req put."""
- self.req_thread_que.put(reqs)
-
- async def _async_req_put(self, reqs: Any):
- """async rq_que put.
-
- Different behavior in threadsafe mode.
- """
- if self.is_thread_safe():
- self._req_put(reqs)
- await asyncio.sleep(0)
- else:
- await self.req_que.put(reqs)
-
- def _prefetch_resps(self):
- """prefetch from resp que.
-
- Different behavior in threadsafe mode.
- """
- if self.is_thread_safe():
- resp_que = self.resp_thread_que
- else:
- resp_que = self.resp_que
- num_resps = resp_que.qsize()
- for _ in range(num_resps):
- resp: Response = resp_que.get_nowait()
- req_id = resp.req_id
- self._push_resp(req_id, resp)
-
- def _push_resp(self, req_id: int, resp: Response):
- """push response."""
- self.resp_dict.setdefault(req_id, [])
- self.resp_dict[req_id].append(resp)
-
- def _pop_resp(self, req_id: int, default: Any = None):
- """pop response."""
- if req_id not in self.resp_dict:
- return default
- resps = self.resp_dict[req_id]
- ret = resps.pop(0)
- if len(resps) == 0:
- self.resp_dict.pop(req_id)
- return ret
+ """async rq_que put."""
+ self.req_que.put_nowait(reqs)
def _gather_request(self, req_types: List[RequestType], data: List[Any]):
"""gather requests."""
- if self.manager._loop_task is None and not self.is_thread_safe():
+ if self.manager._loop_task is None:
self.manager.create_loop_task()
assert len(req_types) == len(data)
- batch_size = len(req_types)
-
- req_ids = list(range(self._next_req_id,
- self._next_req_id + batch_size))
- self._next_req_id += batch_size
-
- reqs = [
- Request(type=rtype,
- sender_id=self.sender_id,
- req_id=req_id,
- data=rdata)
- for req_id, rtype, rdata in zip(req_ids, req_types, data)
- ]
- return req_ids, reqs
- async def async_batched_send_async(self, req_types: List[RequestType],
- data: List[Any]):
- """Batched send request asynchronize."""
- req_ids, reqs = self._gather_request(req_types, data)
- await self._async_req_put(reqs)
- return req_ids
-
- async def async_send_async(self, req_type: RequestType, data: Any):
- """send request asynchronize."""
- return (await self.async_batched_send_async(req_types=[req_type],
- data=[data]))[0]
+ reqs = []
+ resps = []
+ for rtype, rdata in zip(req_types, data):
+ event = asyncio.Event()
+ resp = Response(type=ResponseType.HANDLER_NOT_EXIST,
+ sender_id=self.sender_id,
+ event=event,
+ data=None,
+ err_msg=None)
+ req = Request(type=rtype,
+ sender_id=self.sender_id,
+ data=rdata,
+ resp=resp)
+ resps.append(resp)
+ reqs.append(req)
+ return resps, reqs
def batched_send_async(self, req_types: List[RequestType],
- data: List[Any]) -> List[int]:
- """Batched send request asynchronize.
-
- Different behavior in threadsafe mode.
- """
- if not self.is_thread_safe():
- coro = self.async_batched_send_async(req_types, data)
- return self.run_until_complete(coro)
-
- req_ids, reqs = self._gather_request(req_types, data)
+ data: List[Any]):
+ """Batched send request asynchronize."""
+ resps, reqs = self._gather_request(req_types, data)
self._req_put(reqs)
- return req_ids
+ return resps
- def send_async(self, req_type: RequestType, data: Any) -> int:
+ def send_async(self, req_type: RequestType, data: Any):
"""send request asynchronize."""
return self.batched_send_async(req_types=[req_type], data=[data])[0]
- async def async_recv_any(self) -> Response:
- """receive any response."""
- self._prefetch_resps()
- for req_id in self.resp_dict:
- ret = self._pop_resp(req_id, default=None)
- if ret is not None:
- return ret
- return await self._async_resp_get()
-
- def recv_any(self) -> Response:
- """receive any response."""
- coro = self.async_recv_any()
- return self.run_until_complete(coro)
-
- def recv_all(self, req_id: int, block: bool = True):
- """revceive all response with req_id."""
- self._prefetch_resps()
- resps = self.resp_dict.pop(req_id, [])
- return resps
-
- async def async_recv(self, req_id: int) -> Response:
+ async def async_recv(self, resp: Response) -> Response:
"""receive response of given request id async."""
- ret = self._pop_resp(req_id, default=None)
- if ret is not None:
- return ret
-
- # check resp que
- while True:
- resp: Response = await self._async_resp_get()
- if resp.req_id != req_id:
- self._push_resp(req_id, resp)
- else:
- return resp
-
- def recv(self, req_id: int) -> Response:
- """receive response of given request id.
-
- Different behavior in threadsafe mode.
- """
- if not self.is_thread_safe():
- coro = self.async_recv(req_id)
- return self.run_until_complete(coro)
-
- ret = self._pop_resp(req_id, default=None)
- if ret is not None:
- return ret
-
- # check resp que
- while True:
- resp: Response = self._resp_get()
- if resp.req_id != req_id:
- self._push_resp(req_id, resp)
- else:
- return resp
+ event = resp.event
+ while not event.is_set():
+ try:
+ await asyncio.wait_for(event.wait(), 1)
+ except asyncio.TimeoutError:
+ if self.is_loop_alive():
+ continue
+ logger.debug('Engine main loop failed.')
+ break
+ event.clear()
+ return resp
+
+ def recv(self, resp: Response) -> Response:
+ """receive response of given request id."""
+ coro = self.async_recv(resp)
+ return self.run_until_complete(coro)
async def async_send(self, req_type: RequestType, data: Any):
"""send and receive synchronize."""
- req_id = await self.async_send_async(req_type, data)
- return await self.async_recv(req_id)
+ resp = self.send_async(req_type, data)
+ return await self.async_recv(resp)
def send(self, req_type: RequestType, data: Any) -> Response:
"""send and receive synchronize."""
- req_id = self.send_async(req_type, data)
- return self.recv(req_id)
+ resp = self.send_async(req_type, data)
+ return self.recv(resp)
def response_callback(self, resp: Response):
"""response callback."""
@@ -354,7 +178,7 @@ def response_callback(self, resp: Response):
class RequestManager:
"""Request manager."""
- def __init__(self, thread_safe: bool = False):
+ def __init__(self):
self.senders: Dict[int, RequestSender] = dict()
self.callbacks: Dict[RequestType, Callable] = dict()
self.request_priority: List[RequestType] = [
@@ -365,17 +189,7 @@ def __init__(self, thread_safe: bool = False):
self.requests: asyncio.Queue = None
self._loop_task: asyncio.Future = None
self._loop_coro: Callable = None
- self._thread_safe = thread_safe
self._next_sender_id = 0
- self._mutex = Lock()
- self._loop_thread: Thread = None
-
- self.thread_requests: Queue = None
- # every sender has it's own responses, this responses is
- # only used in thread safe mode.
- self.responses: asyncio.Queue = None
- if thread_safe:
- self.thread_requests = Queue()
def create_loop_task(self):
"""create coro task."""
@@ -385,7 +199,6 @@ def create_loop_task(self):
'Please set loop task with manager.start_loop')
loop_unshielded = event_loop.create_task(self._loop_coro(),
name='EngineMainLoop')
- loop_unshielded.add_done_callback(_raise_exception_on_finish)
self._loop_task = asyncio.shield(loop_unshielded)
self.requests = asyncio.Queue()
return self._loop_task
@@ -398,105 +211,17 @@ def event_loop(self):
else:
return self._loop_task.get_loop()
- def is_thread_safe(self):
- """is thread safe."""
- return self._thread_safe
-
def start_loop(self, loop: asyncio.Task):
"""start main loop."""
self._loop_coro = loop
- def __get_thread_reqs():
- """get thread reqs."""
- num_reqs = self.thread_requests.qsize()
- reqs = []
- for _ in range(num_reqs):
- tmp_reqs = self.thread_requests.get_nowait()
- if isinstance(tmp_reqs, Request):
- tmp_reqs = [tmp_reqs]
- reqs += tmp_reqs
- return reqs
-
- async def __async_get_req(event_loop):
- """async get request."""
- que = self.thread_requests
- not_empty = que.not_empty
- with not_empty:
- while not que._qsize():
- await event_loop.run_in_executor(None, not_empty.wait, 1.0)
- reqs = que.get_nowait()
- if isinstance(reqs, Request):
- reqs = [reqs]
- return reqs
-
- async def __req_loop():
- """req loop."""
- event_loop = asyncio.get_event_loop()
- while True:
- # get reqs
- reqs = __get_thread_reqs()
- if len(reqs) == 0:
- reqs = await __async_get_req(event_loop)
- self.requests.put_nowait(reqs)
-
- def __put_thread_resps(resps: List[Response]):
- """put thread resps."""
- for resp in resps:
- sender = self.senders.get(resp.sender_id, None)
- if sender is None:
- continue
- sender.resp_thread_que.put_nowait(resp)
-
- async def __resp_loop():
- """resp loop."""
- while True:
- num_resps = self.responses.qsize()
-
- if num_resps == 0:
- resps = [await self.responses.get()]
- else:
- resps = []
- for _ in range(num_resps):
- resps.append(self.responses.get_nowait())
- __put_thread_resps(resps)
- await asyncio.sleep(0)
-
- def __run_forever(event_loop: asyncio.BaseEventLoop):
- """run forever."""
- logger.debug('start thread run forever.')
- asyncio.set_event_loop(event_loop)
- self.responses = asyncio.Queue()
- self.create_loop_task()
- req_loop = event_loop.create_task(__req_loop(),
- name='RunForeverReqLoop')
- req_loop.add_done_callback(_ignore_exception_on_finish)
- resp_loop = event_loop.create_task(__resp_loop(),
- name='RunForeverRespLoop')
- resp_loop.add_done_callback(_ignore_exception_on_finish)
- self.event_loop.run_forever()
-
- if self.is_thread_safe():
- event_loop = asyncio.new_event_loop()
- self._loop_thread = Thread(target=__run_forever,
- args=(event_loop, ),
- daemon=True)
- self._loop_thread.start()
+ def stop_loop(self):
+ if self.is_loop_alive():
+ self._loop_task.cancel()
def is_loop_alive(self):
"""check if main loop is alive."""
- def __check_threadsafe():
- if self._loop_thread is None:
- return False
- if not self._loop_thread.is_alive():
- return False
- if self._loop_task is None:
- return False
- return not self._loop_task.done()
-
- if self.is_thread_safe():
- return __check_threadsafe()
-
if self._loop_task is None:
logger.debug('loop task has not been created.')
return False
@@ -508,12 +233,11 @@ def __check_threadsafe():
def build_sender(self):
"""create a new sender."""
- with self._mutex:
- sender_id = self._next_sender_id
- self._next_sender_id += 1
- new_sender = RequestSender.new(sender_id, self)
- self.senders[sender_id] = new_sender
- return new_sender
+ sender_id = self._next_sender_id
+ self._next_sender_id += 1
+ new_sender = RequestSender.new(sender_id, self)
+ self.senders[sender_id] = new_sender
+ return new_sender
def has_requests(self):
"""has unprocessed request."""
@@ -521,16 +245,27 @@ def has_requests(self):
return False
return not self.requests.empty()
- def get_all_requests(self) -> Dict[RequestType, Request]:
+ async def get_all_requests(self) -> Dict[RequestType, Request]:
"""get all requests in current queue."""
num_reqs = self.requests.qsize()
reqs: ReqList = []
- for _ in range(num_reqs):
- elem = self.requests.get_nowait()
+
+ def __proc_reqs(elem):
+ """proc reqs."""
+ nonlocal reqs
if isinstance(elem, Request):
elem = [elem]
reqs += elem
+ if num_reqs == 0:
+ elem = await self.requests.get()
+ __proc_reqs(elem)
+ num_reqs = self.requests.qsize()
+
+ for _ in range(num_reqs):
+ elem = self.requests.get_nowait()
+ __proc_reqs(elem)
+
# gather requests
reqs_by_type: Dict[RequestType, Request] = dict(
(t, []) for t in RequestType)
@@ -548,11 +283,7 @@ def set_request_priority(self, priority: List[RequestType]):
def response(self, resp: Response):
"""send response."""
- if resp.sender_id not in self.senders:
- logger.warning(f'sender {resp.sender_id} not exist. '
- f'Send {resp} failed.')
- return
- self.senders[resp.sender_id].response_callback(resp)
+ resp.event.set()
def process_request(self, req_type: RequestType, reqs: ReqList, **kwargs):
"""process reqs with given req type."""
@@ -563,19 +294,18 @@ def process_request(self, req_type: RequestType, reqs: ReqList, **kwargs):
else:
# TODO: send error message
for req in reqs:
- resp = Response(ResponseType.HANDLER_NOT_EXIST,
- sender_id=req.sender_id,
- req_id=req.req_id,
- err_msg=(f'callback for {req_type}'
- ' not exists.'))
+ resp = req.resp
+ resp.type = ResponseType.HANDLER_NOT_EXIST
+ resp.err_msg = (f'callback for {req_type}'
+ ' not exists.')
self.response(resp)
- def step(self, **kwargs):
+ async def step(self, **kwargs):
"""handle requests.
Should only be called in loop task.
"""
- reqs_by_type = self.get_all_requests()
+ reqs_by_type = await self.get_all_requests()
# handle requests
for req_type in self.request_priority:
diff --git a/lmdeploy/pytorch/kernels/cuda/__init__.py b/lmdeploy/pytorch/kernels/cuda/__init__.py
index 3790cf0f66..b62ddef80a 100644
--- a/lmdeploy/pytorch/kernels/cuda/__init__.py
+++ b/lmdeploy/pytorch/kernels/cuda/__init__.py
@@ -9,6 +9,7 @@
from .multinomial_sampling import multinomial_sampling
from .pagedattention import paged_attention_fwd
from .rms_norm import rms_norm
+from .w8a8_fused_moe import fused_moe_w8a8
from .w8a8_triton_kernels import (matmul_kernel_dynamic_quant,
per_channel_quant, per_token_quant_int8,
rms_norm_dynamic_quant)
@@ -28,4 +29,5 @@
'rms_norm_dynamic_quant',
'flash_attention_fwd',
'flatten_kv_cache',
+ 'fused_moe_w8a8',
]
diff --git a/lmdeploy/pytorch/kernels/cuda/awq_kernels.py b/lmdeploy/pytorch/kernels/cuda/awq_kernels.py
index 13b9841e9b..395b0c427e 100644
--- a/lmdeploy/pytorch/kernels/cuda/awq_kernels.py
+++ b/lmdeploy/pytorch/kernels/cuda/awq_kernels.py
@@ -2,210 +2,95 @@
import triton
from triton import language as tl
-from .triton_utils import get_kernel_meta, wrap_jit_func
-
def get_cuda_autotune_config():
return [
- # most used
- triton.Config(
- {
- 'BLOCK_SIZE_M': 128,
- 'BLOCK_SIZE_N': 64,
- 'BLOCK_SIZE_K': 32,
- 'GROUP_SIZE_M': 8
- },
- num_stages=4,
- num_warps=4),
- triton.Config(
- {
- 'BLOCK_SIZE_M': 64,
- 'BLOCK_SIZE_N': 64,
- 'BLOCK_SIZE_K': 64,
- 'GROUP_SIZE_M': 8
- },
- num_stages=4,
- num_warps=4),
- # # other
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 256,
- # 'BLOCK_SIZE_K': 64,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=3,
- # num_warps=8),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 64,
- # 'BLOCK_SIZE_N': 256,
- # 'BLOCK_SIZE_K': 32,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 128,
- # 'BLOCK_SIZE_K': 32,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 64,
- # 'BLOCK_SIZE_N': 128,
- # 'BLOCK_SIZE_K': 32,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 32,
- # 'BLOCK_SIZE_K': 32,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 64,
- # 'BLOCK_SIZE_N': 32,
- # 'BLOCK_SIZE_K': 32,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=5,
- # num_warps=2),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 32,
- # 'BLOCK_SIZE_N': 64,
- # 'BLOCK_SIZE_K': 32,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=5,
- # num_warps=2),
- # # Good config for fp8 inputs.
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 256,
- # 'BLOCK_SIZE_K': 128,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=3,
- # num_warps=8),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 256,
- # 'BLOCK_SIZE_N': 128,
- # 'BLOCK_SIZE_K': 128,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=3,
- # num_warps=8),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 256,
- # 'BLOCK_SIZE_N': 64,
- # 'BLOCK_SIZE_K': 128,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 64,
- # 'BLOCK_SIZE_N': 256,
- # 'BLOCK_SIZE_K': 128,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 128,
- # 'BLOCK_SIZE_K': 128,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 64,
- # 'BLOCK_SIZE_K': 64,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
- # triton.Config(
- # {
- # 'BLOCK_SIZE_M': 128,
- # 'BLOCK_SIZE_N': 32,
- # 'BLOCK_SIZE_K': 64,
- # 'GROUP_SIZE_M': 8
- # },
- # num_stages=4,
- # num_warps=4),
+ triton.Config({
+ 'BLOCK_SIZE_N': 64,
+ 'GROUP_SIZE_M': 8,
+ },
+ num_stages=3,
+ num_warps=4),
]
@triton.jit
-def _get_unpacked_order(offs_n, elem_per_int: tl.constexpr):
- """get unpacked order."""
- origin_order = offs_n % elem_per_int
- unpacked_order = (origin_order & 1) * 4 + origin_order // 2
- return unpacked_order
+def _dequant_s4_to_f16x2(weight, shift: tl.constexpr, is_top: tl.constexpr):
+
+ immLut: tl.constexpr = (0xf0 & 0xcc) | 0xaa
+ BOTTOM_MASK: tl.constexpr = 0x000f000f
+ TOP_MASK: tl.constexpr = 0x00f000f0
+ I4s_TO_F16s_MAGIC_NUM: tl.constexpr = 0x64006400
+ FP16_TOP_MAGIC_NUM: tl.constexpr = 0x64006400
+ ONE_SIXTEENTH: tl.constexpr = 0x2c002c00
+ NEG_64: tl.constexpr = 0xd400d400
+
+ if shift:
+ weight = weight >> 8
+
+ if is_top:
+ return tl.inline_asm_elementwise("""{
+ .reg .b32 tmp;
+ lop3.b32 tmp, $2, $3, $4, $5;
+ fma.rn.f16x2 tmp, tmp, $6, $7;
+ mov.b32 {$0, $1}, tmp;
+ }""",
+ '=h,=h,r,n,n,n,r,r',
+ args=[
+ weight, TOP_MASK,
+ I4s_TO_F16s_MAGIC_NUM, immLut,
+ ONE_SIXTEENTH, NEG_64
+ ],
+ dtype=(tl.float16, tl.float16),
+ is_pure=True,
+ pack=1)
+ else:
+ return tl.inline_asm_elementwise("""{
+ .reg .b32 tmp;
+ lop3.b32 tmp, $2, $3, $4, $5;
+ sub.f16x2 tmp, tmp, $6;
+ mov.b32 {$0, $1}, tmp;
+ }""",
+ '=h,=h,r,n,n,n,r',
+ args=[
+ weight, BOTTOM_MASK,
+ I4s_TO_F16s_MAGIC_NUM, immLut,
+ FP16_TOP_MAGIC_NUM
+ ],
+ dtype=(tl.float16, tl.float16),
+ is_pure=True,
+ pack=1)
@triton.jit
-def _broadcast_pack(weight, width: tl.constexpr):
- """broadcast pack."""
- broadcast_tmp = tl.arange(0, width)
+def _unpack_weight(weight):
+ """unpack weight."""
+ # broadcast and shift
+ width: tl.constexpr = 8
BLOCK_SIZE_K: tl.constexpr = weight.shape[0]
BLOCK_SIZE_QN: tl.constexpr = weight.shape[1]
BLOCK_SIZE_N: tl.constexpr = BLOCK_SIZE_QN * width
- weight = tl.broadcast(weight[:, :, None], broadcast_tmp[None, None, :])
- weight = tl.reshape(weight, (BLOCK_SIZE_K, BLOCK_SIZE_N))
- return weight
+ w0, w1 = _dequant_s4_to_f16x2(weight, False, False)
+ w2, w3 = _dequant_s4_to_f16x2(weight, False, True)
+ w4, w5 = _dequant_s4_to_f16x2(weight, True, False)
+ w6, w7 = _dequant_s4_to_f16x2(weight, True, True)
-@triton.jit
-def _unpack_weight(weight, order):
- """unpack weight."""
- weight = _broadcast_pack(weight, 8)
- weight = weight >> (order * 4)
- # cast to float16
- immLut = (0xf0 & 0xcc) | 0xaa
- BOTTOM_MASK = 0xf
- I4s_TO_F16s_MAGIC_NUM = 0x6400
- FP16_TOP_MAGIC_NUM = 0x6400
- weight = tl.inline_asm_elementwise(
- """lop3.b32 $1, $1, $2, $3, $4;
- sub.f16x2 $1, $1, $5;
- mov.b32 {$0, _}, $1;""",
- '=h, r, n, n, n, r', [
- weight, BOTTOM_MASK, I4s_TO_F16s_MAGIC_NUM, immLut,
- FP16_TOP_MAGIC_NUM
- ],
- dtype=tl.float16,
- is_pure=False,
- pack=1)
- return weight
+ w04 = tl.join(w0, w4)
+ w15 = tl.join(w1, w5)
+ w26 = tl.join(w2, w6)
+ w37 = tl.join(w3, w7)
+ w0246 = tl.join(w04, w26)
+ w1357 = tl.join(w15, w37)
+ weight = tl.join(w0246, w1357)
+
+ return weight.reshape(BLOCK_SIZE_K, BLOCK_SIZE_N)
@triton.autotune(
configs=get_cuda_autotune_config(),
- key=['M_NEXT_P2', 'N', 'K'],
+ key=['N', 'K'],
)
-@wrap_jit_func
@triton.jit
def awq_linear_kernel(
a_ptr,
@@ -225,12 +110,9 @@ def awq_linear_kernel(
stride_zk: tl.constexpr,
stride_zn: tl.constexpr, #
stride_cm,
- stride_ck: tl.constexpr,
stride_cn: tl.constexpr,
# Meta-parameters
- M_NEXT_P2: tl.constexpr,
- Q_GROUP_SIZE: tl.constexpr,
- SPLIT_K_ITERS: tl.constexpr,
+ SPLIT_K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr, #
@@ -239,19 +121,13 @@ def awq_linear_kernel(
"""Kernel for computing the matmul C = A x B.
A has shape (M, K), B has shape (K, N) and C has shape (M, N)
"""
- ELEM_PER_INT = 8
- if Q_GROUP_SIZE > BLOCK_SIZE_K:
- GROUP_SIZE_K: tl.constexpr = BLOCK_SIZE_K
- else:
- GROUP_SIZE_K: tl.constexpr = Q_GROUP_SIZE
- K_PER_GROUP: tl.constexpr = Q_GROUP_SIZE // GROUP_SIZE_K
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
# See above `L2 Cache Optimizations` section for details.
+ kid = tl.program_id(axis=1)
pid = tl.program_id(axis=0)
- split_kid = tl.program_id(axis=1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
@@ -267,8 +143,7 @@ def awq_linear_kernel(
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
BLOCK_SIZE_QN: tl.constexpr = BLOCK_SIZE_N // 8
offs_wn = pid_n * BLOCK_SIZE_QN + tl.arange(0, BLOCK_SIZE_QN)
- offs_k = tl.arange(0, GROUP_SIZE_K)
- unpacked_order = _get_unpacked_order(offs_bn, ELEM_PER_INT)
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am +
offs_k[None, :] * stride_ak)
qw_ptrs = qw_ptr + (offs_k[:, None] * stride_wk +
@@ -276,49 +151,52 @@ def awq_linear_kernel(
s_ptrs = s_ptr + offs_bn * stride_sn
qz_ptrs = qz_ptr + offs_wn * stride_zn
- # split k
- NUM_K_BLOCKS = K // GROUP_SIZE_K
- K_PER_SPLIT = tl.cdiv(NUM_K_BLOCKS, SPLIT_K_ITERS)
- k_start = split_kid * K_PER_SPLIT
- k_last = min(k_start + K_PER_SPLIT, NUM_K_BLOCKS)
- a_ptrs += k_start * GROUP_SIZE_K * stride_ak
- qw_ptrs += k_start * GROUP_SIZE_K * stride_wk
- qg_id = k_start // K_PER_GROUP
-
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
- s = tl.zeros((1, BLOCK_SIZE_N), dtype=s_ptrs.dtype.element_ty)
- zs = tl.zeros((1, BLOCK_SIZE_N), dtype=s_ptrs.dtype.element_ty)
+
+ k_start = kid
+ k_last = K // BLOCK_SIZE_K
# prefetch
- next_qw = tl.load(qw_ptrs)
- qw_ptrs += GROUP_SIZE_K * stride_wk
+ a_ptrs += k_start * BLOCK_SIZE_K * stride_ak
+ qw_ptrs += k_start * BLOCK_SIZE_K * stride_wk
+ s_ptrs += k_start * stride_sk
+ qz_ptrs += k_start * stride_zk
+ qw = tl.load(qw_ptrs)
+ qz = tl.load(qz_ptrs)[None, :]
+ s = tl.load(s_ptrs)[None, :]
+ qw_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_wk
+ s_ptrs += SPLIT_K * stride_sk
+ qz_ptrs += SPLIT_K * stride_zk
+
+ for k in tl.range(k_start, k_last, SPLIT_K, num_stages=3):
- for k in range(k_start, k_last):
+ # load a
a = tl.load(a_ptrs)
- qw = next_qw
- if k + 1 < k_last:
- next_qw = tl.load(qw_ptrs)
- w = _unpack_weight(qw, unpacked_order)
-
- if k == k_start or k % K_PER_GROUP == 0:
- s = tl.load(s_ptrs + qg_id * stride_sk)[None, :]
- qz = tl.load(qz_ptrs + qg_id * stride_zk)[None, :]
- qg_id += 1
- z = _unpack_weight(qz, unpacked_order)
- zs = -z * s
- b = w * s + zs
+
+ # unpack b
+ z = _unpack_weight(qz)
+ w = _unpack_weight(qw)
+ b = (w - z) * s
+
+ # load next q
+ mask = k + SPLIT_K < k_last
+ qz = tl.load(qz_ptrs, mask=mask)[None, :]
+ s = tl.load(s_ptrs, mask=mask)[None, :]
+ qw = tl.load(qw_ptrs, mask=mask)
# We accumulate along the K dimension.
- accumulator += tl.dot(a, b)
+ accumulator = tl.dot(a, b, acc=accumulator)
# Advance the ptrs to the next K block.
- a_ptrs += GROUP_SIZE_K * stride_ak
- qw_ptrs += GROUP_SIZE_K * stride_wk
+ a_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_ak
+ qw_ptrs += SPLIT_K * BLOCK_SIZE_K * stride_wk
+ s_ptrs += SPLIT_K * stride_sk
+ qz_ptrs += SPLIT_K * stride_zk
c = accumulator.to(tl.float16)
@@ -329,11 +207,11 @@ def awq_linear_kernel(
c_ptrs = c_ptr + stride_cm * offs_cm[:,
None] + stride_cn * offs_cn[None, :]
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
- if stride_ck > 0:
- c_ptrs += split_kid * stride_ck
- tl.store(c_ptrs, c, mask=c_mask)
+
+ if SPLIT_K > 1:
+ tl.atomic_add(c_ptrs, c, mask=c_mask, sem='relaxed', scope='gpu')
else:
- tl.atomic_add(c_ptrs, c, mask=c_mask)
+ tl.store(c_ptrs, c, mask=c_mask)
def awq_linear(x, qweight, scales, qzeros):
@@ -341,18 +219,24 @@ def awq_linear(x, qweight, scales, qzeros):
M = x.size(0)
K = qweight.size(0)
N = scales.size(1)
- SPLIT_K_ITERS = 4
group_size = K // scales.size(0)
+ SPLIT_K = max(1, K // 4096)
def grid(META):
"""grid."""
- return (triton.cdiv(M, META['BLOCK_SIZE_M']) *
- triton.cdiv(N, META['BLOCK_SIZE_N']), SPLIT_K_ITERS)
+ return (
+ triton.cdiv(M, META['BLOCK_SIZE_M']) *
+ triton.cdiv(N, META['BLOCK_SIZE_N']),
+ SPLIT_K,
+ )
- out = scales.new_empty(M, SPLIT_K_ITERS, N)
- M_NEXT_P2 = triton.next_power_of_2(M)
+ if SPLIT_K > 1:
+ out = scales.new_zeros(M, N)
+ else:
+ out = scales.new_empty(M, N)
- kernel_meta = get_kernel_meta(x)
+ BLOCK_SIZE_M = triton.next_power_of_2(M)
+ BLOCK_SIZE_M = max(16, min(128, BLOCK_SIZE_M))
awq_linear_kernel[grid](
# Pointers to matrices
x,
@@ -373,12 +257,11 @@ def grid(META):
stride_zk=qzeros.stride(0),
stride_zn=qzeros.stride(1), #
stride_cm=out.stride(0),
- stride_ck=out.stride(1),
- stride_cn=out.stride(2),
+ stride_cn=out.stride(1),
# Meta-parameters
- M_NEXT_P2=M_NEXT_P2,
- Q_GROUP_SIZE=group_size,
- SPLIT_K_ITERS=SPLIT_K_ITERS,
- **kernel_meta)
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
+ BLOCK_SIZE_K=group_size,
+ SPLIT_K=SPLIT_K,
+ )
- return out.sum(1)
+ return out
diff --git a/lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py b/lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py
new file mode 100644
index 0000000000..6a87830c6c
--- /dev/null
+++ b/lmdeploy/pytorch/kernels/cuda/blocked_fp8_fused_moe.py
@@ -0,0 +1,343 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# modify from: https://github.com/vllm-project/vllm
+import torch
+import triton
+import triton.language as tl
+
+from .activation import silu_and_mul
+from .blocked_gemm_fp8 import quant_fp8
+from .fused_moe import _get_sorted_idx, _make_intermediate, _renormalize
+
+
+def get_cuda_autotune_config():
+ return [
+ triton.Config({
+ 'BLOCK_SIZE_M': 128,
+ 'BLOCK_SIZE_N': 64,
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config({
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 128,
+ },
+ num_stages=4,
+ num_warps=4),
+ ]
+
+
+@triton.autotune(
+ configs=get_cuda_autotune_config(),
+ key=['N', 'K', 'M_NP2'],
+ warmup=10,
+ rep=25,
+)
+@triton.jit
+def fused_moe_blocked_f8_kernel(
+ A,
+ A_scale,
+ B,
+ B_scale,
+ C,
+ SortedIdx,
+ ExpStart,
+ ExpEnd,
+ Weights,
+ N: tl.constexpr,
+ K: tl.constexpr,
+ group_ak: tl.constexpr,
+ group_bk: tl.constexpr,
+ group_bn: tl.constexpr,
+ stride_am: tl.constexpr,
+ stride_ak: tl.constexpr,
+ stride_asm,
+ stride_ask: tl.constexpr,
+ stride_be: tl.constexpr,
+ stride_bn: tl.constexpr,
+ stride_bk: tl.constexpr,
+ stride_bse: tl.constexpr,
+ stride_bsk: tl.constexpr,
+ stride_bsn: tl.constexpr,
+ stride_cm,
+ stride_cn: tl.constexpr,
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+ M_NP2: tl.constexpr,
+ ENABLE_WEIGHTS: tl.constexpr,
+ top_k: tl.constexpr,
+ expert_offset: tl.constexpr,
+ reindex_a: tl.constexpr,
+ reindex_c: tl.constexpr,
+):
+ """fused moe kernel."""
+ exp_id = tl.program_id(1)
+ pid = tl.program_id(0)
+
+ exp_start = tl.load(ExpStart + exp_id + expert_offset)
+ exp_end = tl.load(ExpEnd + exp_id + expert_offset)
+ M = exp_end - exp_start
+ if M <= 0:
+ return
+
+ num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+
+ if GROUP_SIZE_M == 1:
+ pid_m = pid % num_pid_m
+ pid_n = pid // num_pid_m
+ else:
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N:
+ return
+
+ offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ mask_sid = offs_sid < exp_end
+ sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0)
+
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ if reindex_a:
+ offs_am = sid // top_k
+ else:
+ offs_am = offs_sid
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+ offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N),
+ BLOCK_SIZE_N)
+
+ # deepseek has 160 experts, exp index would overflow int32
+ exp_id = exp_id.to(tl.int64)
+ exp_off = stride_be * exp_id
+ b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk +
+ offs_bn[None, :] * stride_bn)
+
+ offs_bsn = pid_n * BLOCK_SIZE_N // group_bn
+ as_ptrs = A_scale + (offs_am % M) * stride_asm
+ bs_ptrs = B_scale + stride_bse * exp_id + offs_bsn * stride_bsn
+
+ acc_scale = tl.load(as_ptrs) * tl.load(bs_ptrs)
+ acc_ratio = 1 / acc_scale
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+ # load scales
+ k_start = (k + 1) * BLOCK_SIZE_K
+ offs_ksa = k_start // group_ak
+ offs_ksb = k_start // group_bk
+ a_scale = tl.load(as_ptrs + offs_ksa * stride_ask,
+ mask=k_start < K,
+ other=1.0)
+ b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk,
+ mask=k_start < K,
+ other=1.0)
+
+ # load ab
+ a = tl.load(a_ptrs,
+ mask=mask_sid[:, None] &
+ (offs_k[None, :] < K - k * BLOCK_SIZE_K),
+ other=0.0)
+ b = tl.load(b_ptrs,
+ mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
+ other=0.0)
+
+ # mma
+ accumulator = tl.dot(a, b, acc=accumulator * acc_ratio[:, None])
+
+ # update scales and ratio
+ new_acc_scale = a_scale * b_scale
+ acc_ratio = acc_scale / new_acc_scale
+ acc_scale = new_acc_scale
+
+ a_ptrs += BLOCK_SIZE_K * stride_ak
+ b_ptrs += BLOCK_SIZE_K * stride_bk
+
+ c = accumulator * (acc_ratio * acc_scale)[:, None]
+
+ if ENABLE_WEIGHTS:
+ weight = tl.load(Weights + sid, mask=mask_sid)
+ c = c * weight[:, None].to(c.dtype)
+
+ c = c.to(C.dtype.element_ty)
+
+ if reindex_c:
+ offs_cm = sid
+ else:
+ offs_cm = offs_sid
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :]
+ tl.store(c_ptrs, c, mask=mask_sid[:, None])
+
+
+def fused_moe_blocked_fp8_kernel_launcher(
+ A: torch.Tensor,
+ A_scale: torch.Tensor,
+ B: torch.Tensor,
+ B_scale: torch.Tensor,
+ C: torch.Tensor,
+ sorted_idx: torch.Tensor,
+ exp_start: torch.Tensor,
+ exp_end: torch.Tensor,
+ weights: torch.Tensor,
+ enable_weights: bool = False,
+ top_k: int = 1,
+ num_tokens: int = None,
+ expert_offset: int = 0,
+ reindex_a: bool = True,
+ reindex_c: bool = True,
+):
+ """fused moe kernel launcher."""
+
+ if num_tokens is None:
+ num_tokens = A.size(0)
+ M_NP2 = triton.next_power_of_2(num_tokens)
+ M_NP2 = max(64, M_NP2)
+ E, N, K = B.shape
+
+ assert A.dim() == 2
+ assert A_scale.dim() == 2
+ assert B.dim() == 3
+ assert B_scale.dim() == 3
+
+ assert K % A_scale.size(1) == 0
+ assert K % B_scale.size(2) == 0
+ assert N % B_scale.size(1) == 0
+
+ group_ak = K // A_scale.size(1)
+ group_bk = K // B_scale.size(2)
+ group_bn = N // B_scale.size(1)
+
+ def _grid_fn(META):
+ grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) *
+ triton.cdiv(N, META['BLOCK_SIZE_N']), E)
+ return grid
+
+ A = A.flatten(0, -2)
+ C = C.flatten(0, -2)
+
+ BLOCK_SIZE_K = group_bk
+ GROUP_SIZE_M = 8
+ grid = _grid_fn
+ fused_moe_blocked_f8_kernel[grid](
+ A,
+ A_scale,
+ B,
+ B_scale,
+ C,
+ sorted_idx,
+ exp_start,
+ exp_end,
+ weights,
+ N=N,
+ K=K,
+ group_ak=group_ak,
+ group_bk=group_bk,
+ group_bn=group_bn,
+ stride_am=A.stride(0),
+ stride_ak=A.stride(1),
+ stride_asm=A_scale.stride(0),
+ stride_ask=A_scale.stride(1),
+ stride_be=B.stride(0),
+ stride_bn=B.stride(1),
+ stride_bk=B.stride(2),
+ stride_bse=B_scale.stride(0),
+ stride_bsn=B_scale.stride(1),
+ stride_bsk=B_scale.stride(2),
+ stride_cm=C.stride(0),
+ stride_cn=C.stride(1),
+ ENABLE_WEIGHTS=enable_weights,
+ top_k=top_k,
+ expert_offset=expert_offset,
+ reindex_a=reindex_a,
+ reindex_c=reindex_c,
+ M_NP2=M_NP2,
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
+ GROUP_SIZE_M=GROUP_SIZE_M,
+ )
+
+
+def fused_moe_blocked_fp8(input: torch.Tensor,
+ input_scale: torch.Tensor,
+ w1: torch.Tensor,
+ w1_scale: torch.Tensor,
+ w2: torch.Tensor,
+ w2_scale: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ topk: int,
+ out_dtype: torch.dtype = torch.float16,
+ expert_offset: int = 0,
+ num_experts: int = None,
+ renormalize: bool = False) -> torch.Tensor:
+ """fused moe."""
+ device = input.device
+ M = input.size(0)
+ E, N, _ = w1.shape
+ if num_experts is None:
+ num_experts = E
+ full_exp = num_experts == E
+ group_size = input.size(-1) // input_scale.size(-1)
+
+ topk_weights = _renormalize(topk_weights, renormalize)
+ sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts)
+
+ intermediate_cache1 = _make_intermediate((M, topk, N),
+ dtype=out_dtype,
+ device=device,
+ zeros=not full_exp)
+ # gate and up
+ fused_moe_blocked_fp8_kernel_launcher(
+ input,
+ input_scale,
+ w1,
+ w1_scale,
+ intermediate_cache1,
+ sorted_idx=sorted_idx,
+ exp_start=exp_start,
+ exp_end=exp_end,
+ weights=topk_weights,
+ enable_weights=False,
+ top_k=topk,
+ num_tokens=M,
+ expert_offset=expert_offset,
+ reindex_a=True,
+ reindex_c=False,
+ )
+
+ # activate
+ intermediate_cache1 = intermediate_cache1.flatten(0, -2)
+ gate_cache = silu_and_mul(intermediate_cache1)
+ del intermediate_cache1
+ gate_cache, gate_scale = quant_fp8(gate_cache,
+ group_size,
+ dtype=input.dtype)
+
+ intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]),
+ dtype=out_dtype,
+ device=device,
+ zeros=not full_exp)
+ # down
+ fused_moe_blocked_fp8_kernel_launcher(
+ gate_cache,
+ gate_scale,
+ w2,
+ w2_scale,
+ intermediate_cache2,
+ sorted_idx=sorted_idx,
+ exp_start=exp_start,
+ exp_end=exp_end,
+ weights=topk_weights,
+ enable_weights=True,
+ top_k=1,
+ num_tokens=M,
+ expert_offset=expert_offset,
+ reindex_a=False,
+ reindex_c=True,
+ )
+
+ ret = intermediate_cache2.sum(dim=1)
+ return ret
diff --git a/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py b/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py
new file mode 100644
index 0000000000..9f992bcfef
--- /dev/null
+++ b/lmdeploy/pytorch/kernels/cuda/blocked_gemm_fp8.py
@@ -0,0 +1,237 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+import triton
+import triton.language as tl
+from torch import Tensor
+
+
+@triton.jit
+def _quant_fp8_kernel(
+ a_ptr,
+ out_ptr,
+ scale_ptr,
+ fp8_min: tl.constexpr,
+ fp8_max: tl.constexpr,
+ stride_am,
+ stride_ak: tl.constexpr,
+ stride_om,
+ stride_ok: tl.constexpr,
+ stride_sm,
+ stride_sg: tl.constexpr,
+ GROUP_SIZE: tl.constexpr,
+):
+ """quant fp8 kernel."""
+ group_id = tl.program_id(0)
+ m_id = tl.program_id(1)
+
+ g_offs = group_id * GROUP_SIZE + tl.arange(0, GROUP_SIZE)
+
+ a_ptrs = a_ptr + m_id * stride_am + g_offs * stride_ak
+ o_ptrs = out_ptr + m_id * stride_om + g_offs * stride_ok
+ s_ptr = scale_ptr + m_id * stride_sm + group_id * stride_sg
+
+ rfp8_max = 1 / fp8_max
+
+ a = tl.load(a_ptrs).to(tl.float32)
+ scale = tl.max(tl.abs(a)) * rfp8_max
+ out = a / scale
+
+ out = tl.clamp(out, fp8_min, fp8_max)
+ out = out.to(out_ptr.dtype.element_ty)
+
+ tl.store(o_ptrs, out)
+ tl.store(s_ptr, scale)
+
+
+def quant_fp8(A: Tensor,
+ group_size: int,
+ dtype: torch.dtype = torch.float8_e4m3fn):
+ """quant online."""
+ assert A.dim() == 2
+ M, K = A.shape
+ assert K % group_size == 0
+ num_groups = K // group_size
+
+ finfo = torch.finfo(dtype)
+ fmin = finfo.min
+ fmax = finfo.max
+
+ out = torch.empty_like(A, dtype=dtype)
+ scales = A.new_empty(M, num_groups, dtype=torch.float32)
+ grid = (num_groups, M)
+ num_warps = 4
+ num_stages = 1
+ _quant_fp8_kernel[grid](
+ A,
+ out,
+ scales,
+ fp8_min=fmin,
+ fp8_max=fmax,
+ stride_am=A.stride(0),
+ stride_ak=A.stride(1),
+ stride_om=out.stride(0),
+ stride_ok=out.stride(1),
+ stride_sm=scales.stride(0),
+ stride_sg=scales.stride(1),
+ GROUP_SIZE=group_size,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+
+ return out, scales
+
+
+@triton.autotune(configs=[
+ triton.Config({
+ 'BLOCK_M': 64,
+ 'BLOCK_N': 128,
+ }, num_stages=3, num_warps=4),
+ triton.Config({
+ 'BLOCK_M': 128,
+ 'BLOCK_N': 64,
+ }, num_stages=3, num_warps=4)
+],
+ key=['N', 'K'],
+ warmup=5,
+ rep=10)
+@triton.jit
+def _gemm_fp8_kernel(
+ A,
+ a_scale_ptr,
+ B,
+ b_scale_ptr,
+ C,
+ M,
+ N: tl.constexpr,
+ K: tl.constexpr,
+ group_ak: tl.constexpr,
+ group_bk: tl.constexpr,
+ group_bn: tl.constexpr,
+ stride_am,
+ stride_ak: tl.constexpr,
+ stride_asm,
+ stride_ask: tl.constexpr,
+ stride_bk: tl.constexpr,
+ stride_bn: tl.constexpr,
+ stride_bsk: tl.constexpr,
+ stride_bsn: tl.constexpr,
+ stride_cm,
+ stride_cn: tl.constexpr,
+ BLOCK_M: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr,
+ GROUP_M: tl.constexpr,
+):
+ """gemm fp8 kernel."""
+ pid = tl.program_id(axis=0)
+ num_pid_m = tl.cdiv(M, BLOCK_M)
+ num_pid_n = tl.cdiv(N, BLOCK_N)
+ num_pid_in_group = GROUP_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ offs_am = (pid_m * BLOCK_M + tl.arange(0, BLOCK_M)) % M
+ offs_bn = (pid_n * BLOCK_N + tl.arange(0, BLOCK_N)) % N
+ offs_k = tl.arange(0, BLOCK_K)
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
+
+ offs_bsn = pid_n * BLOCK_N // group_bn
+ as_ptrs = a_scale_ptr + offs_am * stride_asm
+ bs_ptrs = b_scale_ptr + offs_bsn * stride_bsn
+
+ acc_scale = tl.load(as_ptrs) * tl.load(bs_ptrs)
+ acc_ratio = 1 / acc_scale
+ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_K)):
+ # load scales
+ k_start = (k + 1) * BLOCK_K
+ offs_ksa = k_start // group_ak
+ offs_ksb = k_start // group_bk
+ a_scale = tl.load(as_ptrs + offs_ksa * stride_ask,
+ mask=k_start < K,
+ other=1.0)
+ b_scale = tl.load(bs_ptrs + offs_ksb * stride_bsk,
+ mask=k_start < K,
+ other=1.0)
+
+ # load ab
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0.0)
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0.0)
+
+ # mma
+ accumulator = tl.dot(a, b, acc=accumulator * acc_ratio[:, None])
+
+ # update scales and ratio
+ new_acc_scale = a_scale * b_scale
+ acc_ratio = acc_scale / new_acc_scale
+ acc_scale = new_acc_scale
+
+ a_ptrs += BLOCK_K * stride_ak
+ b_ptrs += BLOCK_K * stride_bk
+ c = accumulator * (acc_ratio * acc_scale)[:, None]
+
+ offs_cm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
+ offs_cn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
+ tl.store(c_ptrs, c, mask=c_mask)
+
+
+def blocked_gemm_fp8(A: Tensor,
+ A_scale: Tensor,
+ B: Tensor,
+ B_scale: torch.Tensor,
+ out_dtype: torch.dtype = torch.float16):
+ """gemm fp8."""
+
+ def grid(META):
+ return (triton.cdiv(M, META['BLOCK_M']) *
+ triton.cdiv(N, META['BLOCK_N']), )
+
+ assert A.dim() == 2
+ assert A_scale.dim() == 2
+ assert B.dim() == 2
+ assert B_scale.dim() == 2
+
+ M, K = A.shape
+ _, N = B.shape
+
+ group_ak = triton.cdiv(K, A_scale.size(1))
+ group_bk = triton.cdiv(K, B_scale.size(0))
+ group_bn = triton.cdiv(N, B_scale.size(1))
+
+ C = A.new_empty(M, N, dtype=out_dtype)
+
+ BLOCK_K = max(group_ak, group_bk)
+
+ _gemm_fp8_kernel[grid](
+ A,
+ A_scale,
+ B,
+ B_scale,
+ C,
+ M=M,
+ N=N,
+ K=K,
+ group_ak=group_ak,
+ group_bk=group_bk,
+ group_bn=group_bn,
+ stride_am=A.stride(0),
+ stride_ak=A.stride(1),
+ stride_asm=A_scale.stride(0),
+ stride_ask=A_scale.stride(1),
+ stride_bk=B.stride(0),
+ stride_bn=B.stride(1),
+ stride_bsk=B_scale.stride(0),
+ stride_bsn=B_scale.stride(1),
+ stride_cm=C.stride(0),
+ stride_cn=C.stride(1),
+ BLOCK_K=BLOCK_K,
+ GROUP_M=8,
+ )
+
+ return C
diff --git a/lmdeploy/pytorch/kernels/cuda/fused_lora.py b/lmdeploy/pytorch/kernels/cuda/fused_lora.py
index d7fbb34588..3dc7e3a10b 100644
--- a/lmdeploy/pytorch/kernels/cuda/fused_lora.py
+++ b/lmdeploy/pytorch/kernels/cuda/fused_lora.py
@@ -9,8 +9,8 @@ def get_autotune_config():
return [
triton.Config(
{
- 'BLOCK_SIZE_M': 64,
- 'BLOCK_SIZE_N': 256,
+ 'BLOCK_SIZE_M': 32,
+ 'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 128
},
num_stages=4,
@@ -26,9 +26,26 @@ def get_autotune_config():
]
+@triton.jit
+def _atomic_store(ptrs, val, mask):
+ """atomic store values."""
+ dtype = ptrs.dtype.element_ty
+ if (dtype == torch.float16) | (dtype == torch.float32):
+ tl.atomic_add(ptrs, val, mask=mask, sem='relaxed')
+ else:
+ # bfloat16 does not support atomic add
+ origin = tl.load(ptrs, mask=mask)
+ val = val.to(origin.dtype)
+ val += origin
+ tl.store(ptrs, val, mask=mask)
+
+
@triton.autotune(
configs=get_autotune_config(),
key=['N', 'K'],
+ restore_value=['c_ptr'],
+ warmup=5,
+ rep=20,
)
@triton.jit
def _fused_lora_kernel(
@@ -44,18 +61,19 @@ def _fused_lora_kernel(
adapter_ids_ptr,
N: tl.constexpr,
K: tl.constexpr,
- stride_am: tl.constexpr,
+ stride_am,
stride_ak: tl.constexpr,
stride_lar: tl.constexpr,
stride_lak: tl.constexpr,
stride_lbr: tl.constexpr,
stride_lbn: tl.constexpr,
- stride_cm: tl.constexpr,
+ stride_cm,
stride_cn: tl.constexpr,
BLOCK_SIZE_R: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
+ CUM: tl.constexpr,
):
"""fused lora kernel."""
pid = tl.program_id(axis=0)
@@ -70,87 +88,91 @@ def _fused_lora_kernel(
rank_start = tl.load(rank_start_ptr + adapter_id)
rank = tl.load(ranks_ptr + adapter_id)
- num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
- num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
- GROUP_SIZE_M: tl.constexpr = 1
- num_pid_in_group = GROUP_SIZE_M * num_pid_n
- group_id = pid // num_pid_in_group
- first_pid_m = group_id * GROUP_SIZE_M
- group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
- pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
- pid_n = (pid % num_pid_in_group) // group_size_m
+ pid_m = pid
if pid_m * BLOCK_SIZE_M >= M:
return
offs_m = (seq_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M))
- offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N))
+ offs_n = tl.arange(0, BLOCK_SIZE_N)
mask_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) < M
- if rank == 0:
- offs_cm = offs_m
- offs_cn = offs_n
- c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[
- None, :]
- c_mask = mask_cm[:, None] & (offs_cn[None, :] < N)
- tl.store(c_ptrs, 0, mask=c_mask)
- return
-
- offs_am = (seq_start +
- (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M)
- offs_r = rank_start + tl.arange(0, BLOCK_SIZE_R) % rank
- offs_k = tl.arange(0, BLOCK_SIZE_K)
- a_ptrs = a_ptr + (offs_am[:, None] * stride_am +
- offs_k[None, :] * stride_ak)
- la_ptrs = lora_a_ptr + (offs_k[:, None] * stride_lak +
- offs_r[None, :] * stride_lar)
-
- accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_R), dtype=tl.float32)
- for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
- # Load the next block of A and B
- # If it is out of bounds, set it to 0.
- a = tl.load(a_ptrs,
- mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
- other=0.0)
- la = tl.load(la_ptrs,
- mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
- other=0.0)
- # We accumulate along the K dimension.
- accumulator += tl.dot(a, la)
- # Advance the ptrs to the next K block.
- a_ptrs += BLOCK_SIZE_K * stride_ak
- la_ptrs += BLOCK_SIZE_K * stride_lak
- ar = accumulator.to(lora_b_ptr.dtype.element_ty)
-
- offs_lbn = offs_n % N
- lb_ptrs = lora_b_ptr + (offs_r[:, None] * stride_lbr +
- offs_lbn * stride_lbn)
- lb = tl.load(lb_ptrs, mask=tl.arange(0, BLOCK_SIZE_R)[:, None] < rank)
-
- c = tl.dot(ar, lb)
-
- scaling = tl.load(scaling_ptr + adapter_id)
- c *= scaling
-
- c = c.to(c_ptr.dtype.element_ty)
offs_cm = offs_m
- offs_cn = offs_n
- c_ptrs = c_ptr + stride_cm * offs_cm[:,
- None] + stride_cn * offs_cn[None, :]
- c_mask = mask_cm[:, None] & (offs_cn[None, :] < N)
- tl.store(c_ptrs, c, mask=c_mask)
-
-
-def fused_lora(input: torch.Tensor, lora_a: torch.Tensor, lora_b: torch.Tensor,
- scaling: torch.LongTensor, rank_start: torch.LongTensor,
- ranks: torch.LongTensor, seq_start: torch.LongTensor,
- seq_lens: torch.LongTensor, adapter_ids: torch.LongTensor,
- max_rank: int, max_seqlen: int):
+ c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_n[None, :]
+
+ if rank == 0:
+ if not CUM:
+ for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
+ mask_cn = (offs_n < N - n * BLOCK_SIZE_N)
+ c_mask = mask_cm[:, None] * mask_cn[None, :]
+ tl.store(c_ptrs, 0.0, mask=c_mask)
+ c_ptrs += stride_cn * BLOCK_SIZE_N
+ else:
+
+ offs_am = (seq_start +
+ (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M)
+ offs_r = rank_start + tl.arange(0, BLOCK_SIZE_R) % rank
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ a_ptrs = a_ptr + (offs_am[:, None] * stride_am +
+ offs_k[None, :] * stride_ak)
+ la_ptrs = lora_a_ptr + (offs_k[:, None] * stride_lak +
+ offs_r[None, :] * stride_lar)
+
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_R), dtype=tl.float32)
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+ # Load the next block of A and B
+ # If it is out of bounds, set it to 0.
+ a = tl.load(a_ptrs,
+ mask=offs_k[None, :] < K - k * BLOCK_SIZE_K,
+ other=0.0)
+ la = tl.load(la_ptrs,
+ mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
+ other=0.0)
+ # We accumulate along the K dimension.
+ accumulator = tl.dot(a, la, acc=accumulator)
+ # Advance the ptrs to the next K block.
+ a_ptrs += BLOCK_SIZE_K * stride_ak
+ la_ptrs += BLOCK_SIZE_K * stride_lak
+ ar = accumulator.to(lora_b_ptr.dtype.element_ty)
+
+ scaling = tl.load(scaling_ptr + adapter_id).to(ar.dtype)
+ ar *= scaling
+ ar = tl.where(
+ tl.arange(0, BLOCK_SIZE_R)[None, :] < rank, ar, tl.zeros_like(ar))
+ lb_ptrs = lora_b_ptr + (offs_r[:, None] * stride_lbr +
+ offs_n[None, :] * stride_lbn)
+
+ for n in range(0, tl.cdiv(N, BLOCK_SIZE_N)):
+ lb = tl.load(lb_ptrs, mask=offs_n[None, :] < N - n * BLOCK_SIZE_N)
+ c = tl.dot(ar, lb)
+
+ mask_cn = (offs_n < N - n * BLOCK_SIZE_N)
+ c_mask = mask_cm[:, None] * mask_cn[None, :]
+ if CUM:
+ _atomic_store(c_ptrs, c, mask=c_mask)
+ else:
+ tl.store(c_ptrs, c, mask=c_mask)
+ c_ptrs += stride_cn * BLOCK_SIZE_N
+ lb_ptrs += stride_lbn * BLOCK_SIZE_N
+
+
+def fused_lora(input: torch.Tensor,
+ lora_a: torch.Tensor,
+ lora_b: torch.Tensor,
+ scaling: torch.LongTensor,
+ rank_start: torch.LongTensor,
+ ranks: torch.LongTensor,
+ seq_start: torch.LongTensor,
+ seq_lens: torch.LongTensor,
+ adapter_ids: torch.LongTensor,
+ max_rank: int,
+ max_seqlen: int,
+ output: torch.Tensor = None,
+ cum: bool = False):
"""fused lora."""
def grid(META):
- ret = ((triton.cdiv(max_seqlen, META['BLOCK_SIZE_M']) *
- triton.cdiv(N, META['BLOCK_SIZE_N'])), batch_size)
+ ret = ((triton.cdiv(max_seqlen, META['BLOCK_SIZE_M'])), batch_size)
return ret
assert input.dim() == 2
@@ -158,7 +180,12 @@ def grid(META):
M, K = input.shape
N = lora_b.size(1)
- output = input.new_empty((M, N))
+ if output is None:
+ output = input.new_empty((M, N))
+ cum = False
+ else:
+ assert output.size(0) == M
+ assert output.size(1) == N
BLOCK_SIZE_R = max(16, max_rank)
_fused_lora_kernel[grid](
@@ -183,6 +210,7 @@ def grid(META):
stride_cm=output.stride(0),
stride_cn=output.stride(1),
BLOCK_SIZE_R=BLOCK_SIZE_R,
+ CUM=cum,
)
return output
diff --git a/lmdeploy/pytorch/kernels/cuda/fused_moe.py b/lmdeploy/pytorch/kernels/cuda/fused_moe.py
index 9f9771368e..9d73208c53 100644
--- a/lmdeploy/pytorch/kernels/cuda/fused_moe.py
+++ b/lmdeploy/pytorch/kernels/cuda/fused_moe.py
@@ -91,8 +91,6 @@ def fused_moe_kernel(
if GROUP_SIZE_M == 1:
pid_m = pid % num_pid_m
pid_n = pid // num_pid_m
- # pid_m = pid // num_pid_n
- # pid_n = pid % num_pid_n
else:
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
@@ -133,7 +131,7 @@ def fused_moe_kernel(
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
- accumulator += tl.dot(a, b)
+ accumulator = tl.dot(a, b, acc=accumulator)
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -271,6 +269,33 @@ def get_start_end(topk_idx: torch.Tensor, sorted_idx: torch.Tensor,
return exp_start, exp_end
+def _get_sorted_idx(topk_ids: torch.Tensor, num_experts: int):
+ """get sorted idx."""
+ flatten_topk_ids = topk_ids.flatten()
+ sorted_idx = flatten_topk_ids.argsort()
+
+ exp_start, exp_end = get_start_end(flatten_topk_ids, sorted_idx,
+ num_experts)
+ return sorted_idx, exp_start, exp_end
+
+
+def _renormalize(topk_weights: torch.Tensor, renormalize: bool):
+ if renormalize:
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
+ if not topk_weights.is_contiguous():
+ topk_weights = topk_weights.contiguous()
+ return topk_weights
+
+
+def _make_intermediate(shape: tuple, dtype: torch.dtype, device: torch.device,
+ zeros: bool):
+ """make intermediate."""
+ if zeros:
+ return torch.zeros(shape, dtype=dtype, device=device)
+ else:
+ return torch.empty(shape, dtype=dtype, device=device)
+
+
def fused_moe(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
@@ -283,31 +308,17 @@ def fused_moe(hidden_states: torch.Tensor,
"""fused moe."""
M = hidden_states.size(0)
E, N, _ = w1.shape
- full_exp = False
if num_experts is None:
num_experts = E
- elif num_experts == E:
- full_exp = True
-
- def __get_sorted_idx(topk_ids: torch.Tensor):
- flatten_topk_ids = topk_ids.flatten()
- sorted_idx = flatten_topk_ids.argsort()
-
- exp_start, exp_end = get_start_end(flatten_topk_ids, sorted_idx,
- num_experts)
- return sorted_idx, exp_start, exp_end
-
- if renormalize:
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
- if not topk_weights.is_contiguous():
- topk_weights = topk_weights.contiguous()
+ full_exp = num_experts == E
- sorted_idx, exp_start, exp_end = __get_sorted_idx(topk_ids)
+ topk_weights = _renormalize(topk_weights, renormalize)
+ sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts)
- if full_exp:
- intermediate_cache1 = hidden_states.new_empty((M, topk, N))
- else:
- intermediate_cache1 = hidden_states.new_zeros((M, topk, N))
+ intermediate_cache1 = _make_intermediate((M, topk, N),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device,
+ zeros=not full_exp)
# gate and up
fused_moe_kernel_launcher(
hidden_states,
@@ -331,10 +342,10 @@ def __get_sorted_idx(topk_ids: torch.Tensor):
gate_cache = silu_and_mul(intermediate_cache1)
gate_cache = gate_cache.unflatten(0, unflat_size)
- if full_exp:
- intermediate_cache2 = hidden_states.new_empty((M, topk, w2.shape[1]))
- else:
- intermediate_cache2 = hidden_states.new_zeros((M, topk, w2.shape[1]))
+ intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]),
+ dtype=hidden_states.dtype,
+ device=hidden_states.device,
+ zeros=not full_exp)
# down
fused_moe_kernel_launcher(
gate_cache,
diff --git a/lmdeploy/pytorch/kernels/cuda/rms_norm.py b/lmdeploy/pytorch/kernels/cuda/rms_norm.py
index bc994012fc..045b55e1ba 100644
--- a/lmdeploy/pytorch/kernels/cuda/rms_norm.py
+++ b/lmdeploy/pytorch/kernels/cuda/rms_norm.py
@@ -4,8 +4,6 @@
import triton.language as tl
from torch import Tensor
-from .triton_utils import get_kernel_meta, wrap_jit_func
-
@triton.jit
def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):
@@ -18,15 +16,6 @@ def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):
return out
-@wrap_jit_func(type_hint=dict(
- input=Tensor,
- weight=Tensor,
- output=Tensor,
- input_row_stride=int,
- eps=float,
- N_COLS=torch.int32,
- BLOCK_N=torch.int32,
-))
@triton.jit
def rms_norm_kernel(input, weight, output, input_row_stride: tl.constexpr,
eps: tl.constexpr, N_COLS: tl.constexpr,
@@ -45,18 +34,6 @@ def rms_norm_kernel(input, weight, output, input_row_stride: tl.constexpr,
tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)
-@wrap_jit_func(type_hint=dict(
- input=Tensor,
- weight=Tensor,
- residual=Tensor,
- output=Tensor,
- out_residual=Tensor,
- input_row_stride=int,
- residual_row_stride=int,
- eps=float,
- N_COLS=torch.int32,
- BLOCK_N=torch.int32,
-))
@triton.jit
def add_rms_norm_kernel(input, weight, residual, output, out_residual,
input_row_stride: tl.constexpr,
@@ -95,6 +72,7 @@ def rms_norm(hidden_states: Tensor,
hidden_states = hidden_states.contiguous()
feat_size = weight.shape[0]
+ assert hidden_states.size(-1) == feat_size
seq_len = hidden_states.numel() // hidden_states.size(-1)
input_stride = hidden_states.stride(-2)
@@ -103,39 +81,40 @@ def rms_norm(hidden_states: Tensor,
if out is None:
out = torch.empty_like(hidden_states)
- kernel_meta = get_kernel_meta(hidden_states)
grid = (seq_len, )
if residual is None:
- rms_norm_kernel[grid](hidden_states,
- weight,
- out,
- input_row_stride=input_stride,
- eps=eps,
- N_COLS=feat_size,
- BLOCK_N=BLOCK_N,
- num_warps=4,
- num_stages=2,
- **kernel_meta)
+ rms_norm_kernel[grid](
+ hidden_states,
+ weight,
+ out,
+ input_row_stride=input_stride,
+ eps=eps,
+ N_COLS=feat_size,
+ BLOCK_N=BLOCK_N,
+ num_warps=4,
+ num_stages=2,
+ )
return out
else:
if out_residual is None:
out_residual = torch.empty_like(hidden_states)
res_stride = residual.stride(-2)
- add_rms_norm_kernel[grid](hidden_states,
- weight,
- residual,
- out,
- out_residual,
- input_row_stride=input_stride,
- residual_row_stride=res_stride,
- eps=eps,
- N_COLS=feat_size,
- BLOCK_N=BLOCK_N,
- num_warps=4,
- num_stages=2,
- **kernel_meta)
+ add_rms_norm_kernel[grid](
+ hidden_states,
+ weight,
+ residual,
+ out,
+ out_residual,
+ input_row_stride=input_stride,
+ residual_row_stride=res_stride,
+ eps=eps,
+ N_COLS=feat_size,
+ BLOCK_N=BLOCK_N,
+ num_warps=4,
+ num_stages=2,
+ )
return out, out_residual
diff --git a/lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py b/lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py
new file mode 100644
index 0000000000..569a3c5964
--- /dev/null
+++ b/lmdeploy/pytorch/kernels/cuda/w8a8_fused_moe.py
@@ -0,0 +1,331 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+# modify from: https://github.com/vllm-project/vllm
+import torch
+import triton
+import triton.language as tl
+
+from .activation import silu_and_mul
+from .fused_moe import _get_sorted_idx, _make_intermediate, _renormalize
+from .triton_utils import get_kernel_meta
+from .w8a8_triton_kernels import per_token_quant_int8
+
+
+def get_cuda_autotune_config():
+ return [
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 128,
+ 'BLOCK_SIZE_N': 128,
+ 'BLOCK_SIZE_K': 32,
+ 'GROUP_SIZE_M': 1,
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 256,
+ 'BLOCK_SIZE_K': 32,
+ 'GROUP_SIZE_M': 1,
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 64,
+ 'BLOCK_SIZE_N': 128,
+ 'BLOCK_SIZE_K': 64,
+ 'GROUP_SIZE_M': 1,
+ },
+ num_stages=4,
+ num_warps=4),
+ triton.Config(
+ {
+ 'BLOCK_SIZE_M': 128,
+ 'BLOCK_SIZE_N': 128,
+ 'BLOCK_SIZE_K': 128,
+ 'GROUP_SIZE_M': 1,
+ },
+ num_stages=3,
+ num_warps=8),
+ ]
+
+
+@triton.autotune(
+ configs=get_cuda_autotune_config(),
+ key=['N', 'K', 'M_NP2'],
+ warmup=10,
+ rep=25,
+)
+@triton.jit
+def fused_moe_w8a8_kernel(
+ A,
+ A_scale,
+ B,
+ B_scale,
+ C,
+ SortedIdx,
+ ExpStart,
+ ExpEnd,
+ Weights,
+ N: tl.constexpr,
+ K: tl.constexpr,
+ stride_am: tl.constexpr,
+ stride_ak: tl.constexpr,
+ stride_be: tl.constexpr,
+ stride_bn: tl.constexpr,
+ stride_bk: tl.constexpr,
+ stride_bse: tl.constexpr,
+ stride_cm: tl.constexpr,
+ stride_cn: tl.constexpr,
+ BLOCK_SIZE_M: tl.constexpr,
+ BLOCK_SIZE_N: tl.constexpr,
+ BLOCK_SIZE_K: tl.constexpr,
+ GROUP_SIZE_M: tl.constexpr,
+ M_NP2: tl.constexpr,
+ ENABLE_WEIGHTS: tl.constexpr,
+ top_k: tl.constexpr,
+ expert_offset: tl.constexpr,
+ reindex_a: tl.constexpr,
+ reindex_c: tl.constexpr,
+ ACCUMULATOR_DTYPE: tl.constexpr,
+):
+ """fused moe kernel."""
+ exp_id = tl.program_id(1)
+ pid = tl.program_id(0)
+
+ exp_start = tl.load(ExpStart + exp_id + expert_offset)
+ exp_end = tl.load(ExpEnd + exp_id + expert_offset)
+ M = exp_end - exp_start
+ if M <= 0:
+ return
+
+ num_pid_m = tl.cdiv(M_NP2, BLOCK_SIZE_M)
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
+
+ if GROUP_SIZE_M == 1:
+ pid_m = pid % num_pid_m
+ pid_n = pid // num_pid_m
+ else:
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
+ group_id = pid // num_pid_in_group
+ first_pid_m = group_id * GROUP_SIZE_M
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
+ pid_m = first_pid_m + (pid % group_size_m)
+ pid_n = (pid % num_pid_in_group) // group_size_m
+
+ if pid_m * BLOCK_SIZE_M >= M or pid_n * BLOCK_SIZE_N >= N:
+ return
+
+ offs_sid = exp_start + pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
+ mask_sid = offs_sid < exp_end
+ sid = tl.load(SortedIdx + offs_sid, mask=mask_sid, other=0)
+
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
+ if reindex_a:
+ offs_am = sid // top_k
+ else:
+ offs_am = offs_sid
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
+ as_ptrs = A_scale + offs_am
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
+ offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N),
+ BLOCK_SIZE_N)
+
+ # deepseek has 160 experts, exp index would overflow int32
+ exp_id = exp_id.to(tl.int64)
+ exp_off = stride_be * exp_id
+ b_ptrs = B + exp_off + (offs_k[:, None] * stride_bk +
+ offs_bn[None, :] * stride_bn)
+ bs_ptrs = B_scale + exp_id * stride_bse + offs_bn
+
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N),
+ dtype=ACCUMULATOR_DTYPE)
+
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
+ a = tl.load(a_ptrs,
+ mask=mask_sid[:, None] &
+ (offs_k[None, :] < K - k * BLOCK_SIZE_K),
+ other=0.0)
+ b = tl.load(b_ptrs,
+ mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
+ other=0.0)
+ accumulator = tl.dot(a,
+ b,
+ acc=accumulator,
+ out_dtype=ACCUMULATOR_DTYPE)
+ a_ptrs += BLOCK_SIZE_K * stride_ak
+ b_ptrs += BLOCK_SIZE_K * stride_bk
+
+ ascale = tl.load(as_ptrs, mask=mask_sid)
+ bscale = tl.load(bs_ptrs)
+ c = accumulator.to(ascale.dtype)
+ c = c * ascale[:, None] * bscale[None, :]
+
+ if ENABLE_WEIGHTS:
+ weight = tl.load(Weights + sid, mask=mask_sid)
+ c = c * weight[:, None].to(c.dtype)
+
+ c = c.to(C.dtype.element_ty)
+
+ if reindex_c:
+ offs_cm = sid
+ else:
+ offs_cm = offs_sid
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_bn[None, :]
+ tl.store(c_ptrs, c, mask=mask_sid[:, None])
+
+
+def fused_moe_w8a8_kernel_launcher(
+ A: torch.Tensor,
+ A_scale: torch.Tensor,
+ B: torch.Tensor,
+ B_scale: torch.Tensor,
+ C: torch.Tensor,
+ sorted_idx: torch.Tensor,
+ exp_start: torch.Tensor,
+ exp_end: torch.Tensor,
+ weights: torch.Tensor,
+ enable_weights: bool = False,
+ top_k: int = 1,
+ num_tokens: int = None,
+ expert_offset: int = 0,
+ reindex_a: bool = True,
+ reindex_c: bool = True,
+):
+ """fused moe kernel launcher."""
+
+ if num_tokens is None:
+ num_tokens = A.size(0)
+ M_NP2 = triton.next_power_of_2(num_tokens)
+ M_NP2 = max(64, M_NP2)
+ E, N, K = B.shape
+
+ assert A_scale.is_contiguous()
+ assert B_scale.is_contiguous()
+ accumulator_dtype = tl.float32 if A.is_floating_point() else tl.int32
+
+ def _grid_fn(META):
+ grid = (triton.cdiv(M_NP2, META['BLOCK_SIZE_M']) *
+ triton.cdiv(N, META['BLOCK_SIZE_N']), E)
+ return grid
+
+ A = A.flatten(0, -2)
+ C = C.flatten(0, -2)
+
+ grid = _grid_fn
+ kernel_meta = get_kernel_meta(A)
+ fused_moe_w8a8_kernel[grid](
+ A,
+ A_scale,
+ B,
+ B_scale,
+ C,
+ sorted_idx,
+ exp_start,
+ exp_end,
+ weights,
+ N=N,
+ K=K,
+ stride_am=A.stride(0),
+ stride_ak=A.stride(1),
+ stride_be=B.stride(0),
+ stride_bn=B.stride(1),
+ stride_bk=B.stride(2),
+ stride_bse=B_scale.stride(0),
+ stride_cm=C.stride(0),
+ stride_cn=C.stride(1),
+ ENABLE_WEIGHTS=enable_weights,
+ top_k=top_k,
+ expert_offset=expert_offset,
+ reindex_a=reindex_a,
+ reindex_c=reindex_c,
+ M_NP2=M_NP2,
+ ACCUMULATOR_DTYPE=accumulator_dtype,
+ **kernel_meta,
+ )
+
+
+def fused_moe_w8a8(input: torch.Tensor,
+ input_scale: torch.Tensor,
+ w1: torch.Tensor,
+ w1_scale: torch.Tensor,
+ w2: torch.Tensor,
+ w2_scale: torch.Tensor,
+ topk_weights: torch.Tensor,
+ topk_ids: torch.Tensor,
+ topk: int,
+ out_dtype: torch.dtype = torch.float16,
+ quant_dtype: torch.dtype = torch.int8,
+ expert_offset: int = 0,
+ num_experts: int = None,
+ renormalize: bool = False) -> torch.Tensor:
+ """fused moe."""
+ device = input.device
+ M = input.size(0)
+ E, N, _ = w1.shape
+ if num_experts is None:
+ num_experts = E
+ full_exp = num_experts == E
+
+ topk_weights = _renormalize(topk_weights, renormalize)
+ sorted_idx, exp_start, exp_end = _get_sorted_idx(topk_ids, num_experts)
+
+ intermediate_cache1 = _make_intermediate((M, topk, N),
+ dtype=out_dtype,
+ device=device,
+ zeros=not full_exp)
+ # gate and up
+ fused_moe_w8a8_kernel_launcher(
+ input,
+ input_scale,
+ w1,
+ w1_scale,
+ intermediate_cache1,
+ sorted_idx=sorted_idx,
+ exp_start=exp_start,
+ exp_end=exp_end,
+ weights=topk_weights,
+ enable_weights=False,
+ top_k=topk,
+ num_tokens=M,
+ expert_offset=expert_offset,
+ reindex_a=True,
+ reindex_c=False,
+ )
+
+ # activate
+ unflat_size = intermediate_cache1.shape[:-1]
+ intermediate_cache1 = intermediate_cache1.flatten(0, -2)
+ gate_cache = silu_and_mul(intermediate_cache1)
+ del intermediate_cache1
+ gate_cache = gate_cache.unflatten(0, unflat_size)
+ gate_cache, gate_scale = per_token_quant_int8(gate_cache,
+ 1e-7,
+ quant_dtype=quant_dtype)
+
+ intermediate_cache2 = _make_intermediate((M, topk, w2.shape[1]),
+ dtype=out_dtype,
+ device=device,
+ zeros=not full_exp)
+ # down
+ fused_moe_w8a8_kernel_launcher(
+ gate_cache,
+ gate_scale,
+ w2,
+ w2_scale,
+ intermediate_cache2,
+ sorted_idx=sorted_idx,
+ exp_start=exp_start,
+ exp_end=exp_end,
+ weights=topk_weights,
+ enable_weights=True,
+ top_k=1,
+ num_tokens=M,
+ expert_offset=expert_offset,
+ reindex_a=False,
+ reindex_c=True,
+ )
+
+ ret = intermediate_cache2.sum(dim=1)
+ return ret
diff --git a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py
index 0d0e10ec83..a8eeb63a5f 100644
--- a/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py
+++ b/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py
@@ -14,14 +14,13 @@
tl_round = tl.math.round
-def per_channel_quant(x, n_bits, dtype):
+def per_channel_quant(x: torch.Tensor, dtype: torch.dtype):
"""Quantize the input tensor 'x' channel-wise using the given number of
bits.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be a
2-dimensional tensor.
- n_bits (int): The number of bits to use for quantization.
dtype (torch.dtype): The data type to which the quantized tensor should
be converted.
@@ -32,31 +31,40 @@ def per_channel_quant(x, n_bits, dtype):
assert x.ndim == 2
x = x.to(torch.float32)
x_absmax = x.view(x.shape[0], -1).abs().max(dim=1, keepdim=True)[0]
- q_max = 2**(n_bits - 1) - 1
- q_min = -2**(n_bits - 1)
- scale = x_absmax / (2**(n_bits - 1) - 1)
- x_q = torch.round(x / scale).clamp(q_min, q_max).to(dtype)
+ qtype_info = torch.finfo(
+ dtype) if dtype.is_floating_point else torch.iinfo(dtype)
+ q_max = qtype_info.max
+ q_min = qtype_info.min
+ scale = x_absmax / q_max
+ x_q = x / scale
+ if not dtype.is_floating_point:
+ x_q = torch.round(x_q)
+ x_q = x_q.clamp(q_min, q_max).to(dtype)
return x_q, scale
@triton.autotune(
configs=[
triton.Config({
- 'BLOCK_N': 64,
+ 'BLOCK_M': 128,
+ 'BLOCK_N': 256,
'BLOCK_K': 128,
},
- num_stages=4,
- num_warps=4),
+ num_stages=3,
+ num_warps=8),
triton.Config({
+ 'BLOCK_M': 256,
'BLOCK_N': 128,
'BLOCK_K': 128,
},
- num_stages=4,
- num_warps=4)
+ num_stages=3,
+ num_warps=8)
],
key=['N', 'K'],
+ warmup=5,
+ rep=20,
)
-@triton.jit
+@triton.jit(do_not_specialize=['M'])
def _linear(
A,
B,
@@ -76,6 +84,7 @@ def _linear(
GROUP_SIZE_M: tl.constexpr,
rms_scale_ptr,
linear_scale_ptr,
+ ACCUMULATOR_DTYPE: tl.constexpr,
):
"""Triton-accelerated function used to perform linear operations (dot
product) on input tensors `A` and `B`, and store the result in output
@@ -100,12 +109,11 @@ def _linear(
offs_k = tl.arange(0, BLOCK_K)
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
-
- accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
+ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACCUMULATOR_DTYPE)
for k in range(0, tl.cdiv(K, BLOCK_K)):
- a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0)
- b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0)
- accumulator += tl.dot(a, b)
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=None)
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=None)
+ accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = accumulator.to(tl.float32)
@@ -124,42 +132,31 @@ def _linear(
@triton.autotune(
configs=[
triton.Config({
- 'BLOCK_N': 64,
+ 'BLOCK_M': 128,
+ 'BLOCK_N': 256,
'BLOCK_K': 128,
},
- num_stages=4,
- num_warps=4),
+ num_stages=3,
+ num_warps=8),
triton.Config({
+ 'BLOCK_M': 256,
'BLOCK_N': 128,
'BLOCK_K': 128,
},
- num_stages=4,
- num_warps=4)
+ num_stages=3,
+ num_warps=8)
],
key=['N', 'K'],
+ warmup=5,
+ rep=20,
)
-@triton.jit
-def _linear_add(
- A,
- B,
- C,
- residual_ptr,
- M,
- N,
- K,
- stride_am,
- stride_ak,
- stride_bk,
- stride_bn,
- stride_cm,
- stride_cn,
- BLOCK_M: tl.constexpr,
- BLOCK_N: tl.constexpr,
- BLOCK_K: tl.constexpr,
- GROUP_SIZE_M: tl.constexpr,
- rms_scale_ptr,
- linear_scale_ptr,
-):
+@triton.jit(do_not_specialize=['M'])
+def _linear_add(A, B, C, residual_ptr, M, N, K, stride_am, stride_ak,
+ stride_bk, stride_bn, stride_cm, stride_cn,
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
+ BLOCK_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr,
+ rms_scale_ptr, linear_scale_ptr,
+ ACCUMULATOR_DTYPE: tl.constexpr):
"""Triton-accelerated function used to perform a linear operation (dot
product) on input tensors `A` and `B`, with addition of residual.
@@ -183,11 +180,11 @@ def _linear_add(
a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
- accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)
+ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACCUMULATOR_DTYPE)
for k in range(0, tl.cdiv(K, BLOCK_K)):
- a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=0)
- b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=0)
- accumulator += tl.dot(a, b)
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_K, other=None)
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_K, other=None)
+ accumulator = tl.dot(a, b, accumulator, out_dtype=ACCUMULATOR_DTYPE)
a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk
c = accumulator.to(tl.float32)
@@ -231,14 +228,11 @@ def matmul_kernel_dynamic_quant(a,
assert residual.shape == c_shape
assert residual.is_contiguous()
c = a.new_empty(c_shape, dtype=output_dtype)
-
- BLOCK_M = 128
- if M < BLOCK_M:
- BLOCK_M = triton.next_power_of_2(M)
- BLOCK_M = max(BLOCK_M, 16)
+ accumulator_dtype = tl.float32 if a.is_floating_point() else tl.int32
def grid(META):
- return (triton.cdiv(M, BLOCK_M) * triton.cdiv(N, META['BLOCK_N']), )
+ return (triton.cdiv(M, META['BLOCK_M']) *
+ triton.cdiv(N, META['BLOCK_N']), )
kernel_meta = get_kernel_meta(a)
if residual is not None:
@@ -255,10 +249,10 @@ def grid(META):
b.stride(0),
c.stride(-2),
c.stride(-1),
- BLOCK_M=BLOCK_M,
GROUP_SIZE_M=8,
rms_scale_ptr=rms_scale,
linear_scale_ptr=linear_scale,
+ ACCUMULATOR_DTYPE=accumulator_dtype,
**kernel_meta)
else:
_linear[grid](a,
@@ -273,10 +267,10 @@ def grid(META):
b.stride(0),
c.stride(-2),
c.stride(-1),
- BLOCK_M=BLOCK_M,
GROUP_SIZE_M=8,
rms_scale_ptr=rms_scale,
linear_scale_ptr=linear_scale,
+ ACCUMULATOR_DTYPE=accumulator_dtype,
**kernel_meta)
if bias is not None:
c += bias
@@ -286,13 +280,16 @@ def grid(META):
@triton.jit
def _per_token_quant_int8(
- y_ptr,
- y_q_ptr,
- y_s_ptr,
- y_stride,
- N, # number of columns in X
- eps, # epsilon to avoid division by zero
- BLOCK: tl.constexpr,
+ y_ptr,
+ y_q_ptr,
+ y_s_ptr,
+ y_stride: tl.constexpr,
+ yq_stride: tl.constexpr,
+ N, # number of columns in X
+ eps: tl.constexpr, # epsilon to avoid division by zero
+ BLOCK: tl.constexpr,
+ Q_MAX: tl.constexpr,
+ IS_FLOATING_POINT: tl.constexpr, # True for floating point dtype
):
"""A Triton-accelerated function to perform per-token quantization on a
tensor.
@@ -302,7 +299,7 @@ def _per_token_quant_int8(
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
y_ptr += row * y_stride
- y_q_ptr += row * y_stride
+ y_q_ptr += row * yq_stride
y_s_ptr += row
cols = tl.arange(0, BLOCK) # N <= BLOCK
@@ -311,21 +308,26 @@ def _per_token_quant_int8(
y = tl.load(y_ptr + cols, mask=mask, other=0.).to(tl.float32)
# Quant
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
- y_s = _absmax / 127
- y_q = tl_round(y / y_s).to(tl.int8)
+ y_s = _absmax / Q_MAX
+ y_q = y / y_s
+ if not IS_FLOATING_POINT:
+ y_q = tl_round(y_q).to(tl.int8)
tl.store(y_q_ptr + cols, y_q, mask=mask)
tl.store(y_s_ptr, y_s)
-def per_token_quant_int8(x, eps):
+def per_token_quant_int8(x, eps, quant_dtype=torch.int8):
"""Function to perform per-token quantization on an input tensor `x`.
It converts the tensor values into signed 8-bit integers and returns the
quantized tensor along with the scaling factor used for quantization.
"""
-
- x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
+ qdtype_info = torch.finfo(
+ quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(
+ quant_dtype)
+ q_max = qdtype_info.max
+ x_q = torch.empty_like(x, device=x.device, dtype=quant_dtype)
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_s = torch.empty(x.shape[:-1] + (1, ),
@@ -334,94 +336,184 @@ def per_token_quant_int8(x, eps):
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
+
+ if x.dim() > 2:
+ x = x.flatten(0, -2)
+ assert x.stride(-1) == 1
# enqueue kernel
kernel_meta = get_kernel_meta(x)
- _per_token_quant_int8[(M, )](x,
- x_q,
- x_s,
- x.stride(-2),
- N,
- eps,
- BLOCK=BLOCK,
- num_warps=num_warps,
- **kernel_meta)
+ _per_token_quant_int8[(M, )](
+ x,
+ x_q,
+ x_s,
+ y_stride=x.stride(-2),
+ yq_stride=x_q.stride(-2),
+ N=N,
+ eps=eps,
+ BLOCK=BLOCK,
+ Q_MAX=q_max,
+ IS_FLOATING_POINT=quant_dtype.is_floating_point,
+ num_warps=num_warps,
+ **kernel_meta)
return x_q, x_s
@triton.jit
-def _rms_norm_fwd_fused_dynamic_symmetric(
- X, # pointer to the input
- Y, # pointer to the output
- W, # pointer to the weights
- Scale, # pointer to the scales of the output activation
- stride, # how much to increase the pointer when moving by 1 row
- N, # number of columns in X
- eps, # epsilon to avoid division by zero
- BLOCK_SIZE: tl.constexpr,
+def _compute_rms_norm(x, w, eps: tl.constexpr, N_COLS: tl.constexpr):
+ """compute rms norm."""
+ xf = x.to(tl.float32)
+
+ var = tl.sum(xf * xf, 0) * float(1.0 / N_COLS)
+ out = xf * tl.math.rsqrt(var + eps)
+ out = (w * out).to(x.dtype)
+ return out
+
+
+@triton.jit
+def rms_norm_quant_kernel(
+ input,
+ weight,
+ output,
+ out_scale,
+ input_row_stride: tl.constexpr,
+ eps: tl.constexpr,
+ N_COLS: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ Q_MIN: tl.constexpr,
+ Q_MAX: tl.constexpr,
+ IS_FLOATING_POINT: tl.constexpr,
):
- """A Triton kernel that calculates Root Mean Square (RMS) normalization
- with fused dynamic symmetric quantization."""
- row = tl.program_id(0)
- Y += row * stride
- X += row * stride
+ """rms norm kernel."""
+ prog_id = tl.program_id(0)
+ offsets = tl.arange(0, BLOCK_N)
- cols = tl.arange(0, BLOCK_SIZE)
- mask = cols < N
- x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)
- _var = x * x
- var = tl.sum(_var, axis=0) / N
- rstd = tl.math.rsqrt(var + eps)
+ w = tl.load(weight + offsets, mask=offsets < N_COLS)
+
+ x_ptr = input + prog_id * input_row_stride
+ x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)
+ out = _compute_rms_norm(x, w, eps, N_COLS)
+
+ scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX
+ out_s_ptr = out_scale + prog_id
+ tl.store(out_s_ptr, scale)
+ out = out / scale
+ if not IS_FLOATING_POINT:
+ out = tl_round(out)
+ out = tl.clamp(out, Q_MIN, Q_MAX)
+ out_ptr = output + prog_id * input_row_stride
+ tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)
- w = tl.load(W + cols, mask=mask)
- x_hat = x * rstd
- y = x_hat * w
- scale = tl.max(tl.abs(y)).to(tl.float32) / 127
- tl.store(Scale + row, scale)
+@triton.jit
+def add_rms_norm_quant_kernel(
+ input,
+ weight,
+ residual,
+ output,
+ out_scale,
+ out_residual,
+ input_row_stride: tl.constexpr,
+ residual_row_stride: tl.constexpr,
+ eps: tl.constexpr,
+ N_COLS: tl.constexpr,
+ BLOCK_N: tl.constexpr,
+ Q_MIN: tl.constexpr,
+ Q_MAX: tl.constexpr,
+ IS_FLOATING_POINT: tl.constexpr,
+):
+ """rms norm kernel."""
+ prog_id = tl.program_id(0)
+ offsets = tl.arange(0, BLOCK_N)
+
+ w = tl.load(weight + offsets, mask=offsets < N_COLS)
- y = tl_round(y / scale)
- y = tl.minimum(y, 127)
- y = tl.maximum(y, -128)
- tl.store(Y + cols, y, mask=mask)
+ x_ptr = input + prog_id * input_row_stride
+ x = tl.load(x_ptr + offsets, mask=offsets < N_COLS)
+ res_ptr = residual + prog_id * residual_row_stride
+ res = tl.load(res_ptr + offsets, mask=offsets < N_COLS)
-def rms_norm_dynamic_quant(x, w, eps):
+ new_x = x + res
+ out_res_ptr = out_residual + prog_id * residual_row_stride
+ tl.store(out_res_ptr + offsets, new_x, mask=offsets < N_COLS)
+
+ out = _compute_rms_norm(new_x, w, eps, N_COLS)
+
+ scale = tl.max(tl.abs(out)).to(tl.float32) / Q_MAX
+ out_s_ptr = out_scale + prog_id
+ tl.store(out_s_ptr, scale)
+ out = out / scale
+ if not IS_FLOATING_POINT:
+ out = tl_round(out)
+ out = tl.clamp(out, Q_MIN, Q_MAX)
+ out_ptr = output + prog_id * input_row_stride
+ tl.store(out_ptr + offsets, out, mask=offsets < N_COLS)
+
+
+def rms_norm_dynamic_quant(x, w, eps, residual=None, quant_dtype=torch.int8):
"""Performs RMS normalization with dynamic quantization.
The function reshapes the input tensor `x`, creates an empty tensor `y`
with the same shape as `x`, and calculates RMS normalization on the
- reshaped `x` using a Triton kernel `_rms_norm_fwd_fused_dynamic_symmetric`.
+ reshaped `x` using a Triton kernel `rms_norm_quant_kernel`.
"""
-
- x_arg = x.flatten(0, -2)
- y = torch.empty_like(x, dtype=torch.int8)
- M, K = x_arg.shape
- MAX_FUSED_SIZE = 65536 // x.element_size()
- BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(K))
- if K > BLOCK_SIZE:
- raise RuntimeError(
- "This rms norm doesn't support feature dim >= 64KB.")
- num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
+ qdtype_info = torch.finfo(
+ quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(
+ quant_dtype)
+ y = torch.empty_like(x, dtype=quant_dtype)
scale = x.new_empty(x.shape[:-1] + (1, ), dtype=torch.float32)
- kernel_meta = get_kernel_meta(x_arg)
- _rms_norm_fwd_fused_dynamic_symmetric[(M, )](x_arg,
- y,
- w,
- scale,
- x_arg.stride(0),
- K,
- eps,
- BLOCK_SIZE=BLOCK_SIZE,
- num_warps=num_warps,
- **kernel_meta)
- return y, scale
+
+ feat_size = w.shape[0]
+ seq_len = x.numel() // x.size(-1)
+ input_stride = x.stride(-2)
+ BLOCK_N = triton.next_power_of_2(feat_size)
+ grid = (seq_len, )
+
+ if residual is None:
+ rms_norm_quant_kernel[grid](
+ x,
+ w,
+ y,
+ scale,
+ input_row_stride=input_stride,
+ eps=eps,
+ N_COLS=feat_size,
+ BLOCK_N=BLOCK_N,
+ Q_MIN=qdtype_info.min,
+ Q_MAX=qdtype_info.max,
+ IS_FLOATING_POINT=quant_dtype.is_floating_point,
+ num_warps=4,
+ num_stages=2)
+ return y, scale
+ else:
+ out_residual = torch.empty_like(x)
+ res_stride = residual.stride(-2)
+ add_rms_norm_quant_kernel[grid](
+ x,
+ w,
+ residual,
+ y,
+ scale,
+ out_residual,
+ input_row_stride=input_stride,
+ residual_row_stride=res_stride,
+ eps=eps,
+ N_COLS=feat_size,
+ BLOCK_N=BLOCK_N,
+ Q_MIN=qdtype_info.min,
+ Q_MAX=qdtype_info.max,
+ IS_FLOATING_POINT=quant_dtype.is_floating_point,
+ num_warps=4,
+ num_stages=2)
+ return y, scale, out_residual
def test_rms_and_linear(x,
rms_weight,
linear_weight,
- dtype=torch.float16,
+ output_dtype=torch.float16,
+ quant_dtype=torch.int8,
eps=1e-5):
"""Test quantized rms norm and quantized linear layer."""
@@ -434,15 +526,18 @@ def linear_torch(x, b):
return F.linear(x, b)
linear_weight_quant, linear_scale = per_channel_quant(
- linear_weight, 8, torch.int8)
+ linear_weight, quant_dtype)
- rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps)
+ rms_out, rms_scale = rms_norm_dynamic_quant(x,
+ rms_weight,
+ eps,
+ quant_dtype=quant_dtype)
assert rms_out.shape == x.shape and rms_scale.shape[:-1] == x.shape[:-1]
linear_out = matmul_kernel_dynamic_quant(rms_out,
linear_weight_quant,
rms_scale,
linear_scale,
- output_dtype=dtype)
+ output_dtype=output_dtype)
rms_out_torch = rms_norm_torch(x, rms_weight, eps).half()
linear_out_torch = linear_torch(rms_out_torch, linear_weight)
@@ -456,17 +551,26 @@ def linear_torch(x, b):
linear_out_torch.flatten().to(torch.float32)))
-def test_per_token_quant(x, eps):
+def test_per_token_quant(x, eps, quant_dtype=torch.int8):
"""Test per-token quantization."""
- def per_token_quant_int8_torch(x, eps):
+ def per_token_quant_int8_torch(x, eps, quant_dtype):
+ qdtype_info = torch.finfo(
+ quant_dtype) if quant_dtype.is_floating_point else torch.iinfo(
+ quant_dtype)
+
_absmax = torch.clamp(x.abs().max(dim=-1, keepdim=True)[0], min=eps)
- x_s = _absmax / 127
- x_q = torch.clamp((x / x_s).round(), min=-128, max=127)
+ x_s = _absmax / qdtype_info.max
+ x_q = x / x_s
+ if not quant_dtype.is_floating_point:
+ x_q = x_q.round()
+ x_q = torch.clamp(x_q, min=qdtype_info.min, max=qdtype_info.max)
return x_q, x_s
- x_q, x_s = per_token_quant_int8(x, eps)
- x_q_torch, x_s_torch = per_token_quant_int8_torch(x, eps)
+ x_q, x_s = per_token_quant_int8(x, eps, quant_dtype=quant_dtype)
+ x_q_torch, x_s_torch = per_token_quant_int8_torch(x,
+ eps,
+ quant_dtype=quant_dtype)
assert x_q.shape == x_q_torch.shape and x_s.shape == x_s_torch.shape
cos = torch.nn.CosineSimilarity(0)
print(
@@ -479,21 +583,11 @@ def per_token_quant_int8_torch(x, eps):
x_s_torch.flatten().to(torch.float32)))
-@triton.testing.perf_report(
- triton.testing.Benchmark(
- x_names=['M'],
- x_vals=[1, 16, 32, 64, 128, 256] + [512 * i * 2 for i in range(1, 17)],
- line_arg='provider',
- line_vals=['int8_dynamic_triton_op', 'float_torch'],
- line_names=['int8_dynamic_triton_op', 'float_torch'],
- styles=[('blue', '-'), ('green', '-'), ('orange', '-'),
- ('yellow', '-'), ('yellow', '-')],
- ylabel='GB/s',
- plot_name='forward',
- args={
- 'dtype': torch.float16,
- }))
-def bench_rms_and_linear(M, dtype, provider, eps=1e-5, device='cuda'):
+def bench_rms_and_linear(M: int,
+ provider: str,
+ dtype: torch.dtype = torch.float16,
+ eps: float = 1e-5):
+ """benchmark rms and linear."""
def rms_norm_torch(x, w, eps):
variance = x.to(torch.float32).pow(2).mean(-1, keepdim=True)
@@ -505,6 +599,7 @@ def linear_torch(x, b):
N = 4096
K = 4096
+
x_shape = (M, K)
rms_w_shape = (x_shape[-1], )
rms_weight = torch.randn(rms_w_shape,
@@ -516,14 +611,33 @@ def linear_torch(x, b):
dtype=dtype,
device='cuda',
requires_grad=True)
- linear_weight_quant, linear_scale = per_channel_quant(
- linear_weight, 8, torch.int8)
- alpha = max(x.max().abs(), x.min().abs())
- rms_scale = alpha / 127
+ if provider == 'torch_fp16':
+ rms_out_torch = rms_norm_torch(x, rms_weight, eps).half()
- if provider == 'int8_dynamic_triton_op':
- rms_out, rms_scale = rms_norm_dynamic_quant(x, rms_weight, eps)
+ def y_fwd():
+ linear_torch(rms_out_torch, linear_weight)
+ else:
+ if provider == 'triton_int8':
+ quant_dtype = torch.int8
+ elif provider == 'triton_fp8_e4m3':
+ quant_dtype = torch.float8_e4m3fn
+ elif provider == 'triton_fp8_e5m2':
+ quant_dtype = torch.float8_e5m2
+
+ linear_weight_quant, linear_scale = per_channel_quant(
+ linear_weight, quant_dtype)
+
+ alpha = max(x.max().abs(), x.min().abs())
+ if quant_dtype.is_floating_point:
+ qdtype_info = torch.finfo(quant_dtype)
+ else:
+ qdtype_info = torch.iinfo(quant_dtype)
+ rms_scale = alpha / qdtype_info.max
+ rms_out, rms_scale = rms_norm_dynamic_quant(x,
+ rms_weight,
+ eps,
+ quant_dtype=quant_dtype)
def y_fwd():
@@ -532,21 +646,22 @@ def y_fwd():
rms_scale,
linear_scale,
output_dtype=dtype)
- elif provider == 'float_torch':
- rms_out_torch = rms_norm_torch(x, rms_weight, eps).half()
-
- def y_fwd():
- linear_torch(rms_out_torch, linear_weight)
quantiles = [0.5, 0.2, 0.8]
ms, min_ms, max_ms = triton.testing.do_bench(y_fwd,
quantiles=quantiles,
rep=500)
- return ms, max_ms, min_ms
+
+ def perf(ms):
+ return 2 * M * N * K * 1e-12 / (ms * 1e-3)
+
+ return perf(ms), perf(max_ms), perf(min_ms)
if __name__ == '__main__':
torch.manual_seed(0)
+ device_map = torch.cuda.get_device_capability()
+ is_fp8_supported = device_map[0] >= 9
dtype = torch.float16
# test (bs, seq_len, dim) x (dim, out_dim)
x = torch.randn((2, 2048, 4096), dtype=dtype, device='cuda')
@@ -559,7 +674,16 @@ def y_fwd():
dtype=dtype,
device='cuda',
requires_grad=True)
- test_rms_and_linear(x, rms_weight, linear_weight)
+ test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.int8)
+ if is_fp8_supported:
+ test_rms_and_linear(x,
+ rms_weight,
+ linear_weight,
+ quant_dtype=torch.float8_e4m3fn)
+ test_rms_and_linear(x,
+ rms_weight,
+ linear_weight,
+ quant_dtype=torch.float8_e5m2)
# test (M, K) x (K, N)
x = torch.randn((4, 4096), dtype=dtype, device='cuda')
@@ -572,11 +696,45 @@ def y_fwd():
dtype=dtype,
device='cuda',
requires_grad=True)
- test_rms_and_linear(x, rms_weight, linear_weight)
+ test_rms_and_linear(x, rms_weight, linear_weight, quant_dtype=torch.int8)
+ if is_fp8_supported:
+ test_rms_and_linear(x,
+ rms_weight,
+ linear_weight,
+ quant_dtype=torch.float8_e4m3fn)
+ test_rms_and_linear(x,
+ rms_weight,
+ linear_weight,
+ quant_dtype=torch.float8_e5m2)
# test per-token quant
x = torch.randn((4, 2048, 4096), dtype=dtype, device='cuda')
eps = 1e-7
- test_per_token_quant(x, eps)
-
- bench_rms_and_linear.run(print_data=True)
+ test_per_token_quant(x, eps, quant_dtype=torch.int8)
+ if is_fp8_supported:
+ test_per_token_quant(x, eps, quant_dtype=torch.float8_e4m3fn)
+ test_per_token_quant(x, eps, quant_dtype=torch.float8_e5m2)
+
+ # benchmark triton kernels
+ line_vals = ['triton_int8', 'torch_fp16']
+ line_names = ['triton_int8', 'torch_fp16']
+
+ if is_fp8_supported:
+ line_vals += ['triton_fp8_e4m3', 'triton_fp8_e5m2']
+ line_names += ['triton_fp8_e4m3', 'triton_fp8_e5m2']
+ config = triton.testing.Benchmark(x_names=['M'],
+ x_vals=[1, 16, 32, 64, 128, 256] +
+ [512 * i * 2 for i in range(1, 5)],
+ line_arg='provider',
+ line_vals=line_vals,
+ line_names=line_names,
+ styles=[('blue', '-'), ('green', '-'),
+ ('orange', '-'), ('black', '-'),
+ ('yellow', '-')],
+ ylabel='TFLOPS',
+ plot_name='bench-triton',
+ args={
+ 'dtype': torch.float16,
+ })
+ bench_funch = (triton.testing.perf_report(config))(bench_rms_and_linear)
+ bench_funch.run(print_data=True)
diff --git a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py
index 0f13f3f38c..0fd07cf10c 100644
--- a/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py
+++ b/lmdeploy/pytorch/kernels/dlinfer/apply_rotary_pos_emb.py
@@ -15,15 +15,12 @@ def apply_rotary_pos_emb(
) -> Tuple[Tensor, Tensor]:
query_states = query_states.contiguous()
key_states = key_states.contiguous()
- bs = query_states.shape[0]
query_states_reshaped = query_states.unsqueeze(0)
key_states_reshaped = key_states.unsqueeze(0)
- cos_reshaped = cos.reshape(1, bs, 1, -1)
- sin_reshaped = sin.reshape(1, bs, 1, -1)
query_states_reshaped, key_states_reshaped = \
ext_ops.apply_rotary_pos_emb(query_states_reshaped,
key_states_reshaped,
- cos_reshaped, sin_reshaped,
+ cos, sin,
None, None)
if q_embed is None:
q_embed = query_states_reshaped.view(query_states.shape)
diff --git a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
index ded85d476d..584945bb2a 100644
--- a/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
+++ b/lmdeploy/pytorch/kernels/dlinfer/pagedattention.py
@@ -15,7 +15,9 @@ def prefill_attention(
q_start_loc: Tensor,
q_seq_len: Tensor,
kv_seq_len: Tensor,
+ cu_seq_lens_kv: Tensor,
max_q_seq_len: int,
+ max_kv_seq_len: int,
block_size: int,
attn_mask: Sequence[Optional[Tensor]],
is_unpaged_prefill: Optional[bool],
@@ -51,7 +53,9 @@ def prefill_attention(
q_start_loc,
q_seq_len,
kv_seq_len,
+ cu_seq_lens_kv,
max_q_seq_len,
+ max_kv_seq_len,
num_q_heads,
num_kv_heads,
attn_mask,
@@ -105,6 +109,7 @@ def paged_attention_fwd(
q_start_loc: Tensor,
q_seqlens: Tensor,
kv_seqlens: Tensor,
+ cu_seq_lens_kv: Tensor,
max_q_seq_len: int,
max_kv_seq_len: int,
is_decoding: bool,
@@ -127,7 +132,9 @@ def paged_attention_fwd(
q_start_loc,
q_seqlens,
kv_seqlens,
+ cu_seq_lens_kv,
max_q_seq_len,
+ max_kv_seq_len,
block_size,
attn_mask,
is_unpaged_prefill,
diff --git a/lmdeploy/pytorch/messages.py b/lmdeploy/pytorch/messages.py
index 0aaba98c94..968b71fee1 100644
--- a/lmdeploy/pytorch/messages.py
+++ b/lmdeploy/pytorch/messages.py
@@ -433,8 +433,6 @@ class SchedulerSequence:
sampling_param: SamplingParam = field(default_factory=SamplingParam)
logical_blocks: LogicalTokenBlocks = field(
default_factory=LogicalTokenBlocks)
- sender_id: int = -1
- req_id: int = -1
adapter_name: str = None
arrive_time: float = 0.0
meta: Any = None
diff --git a/lmdeploy/pytorch/models/baichuan.py b/lmdeploy/pytorch/models/baichuan.py
index 583cd19fe9..38d794f1be 100644
--- a/lmdeploy/pytorch/models/baichuan.py
+++ b/lmdeploy/pytorch/models/baichuan.py
@@ -228,7 +228,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -245,7 +244,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/chatglm2.py b/lmdeploy/pytorch/models/chatglm2.py
index 73f64d277c..5a83154167 100644
--- a/lmdeploy/pytorch/models/chatglm2.py
+++ b/lmdeploy/pytorch/models/chatglm2.py
@@ -265,7 +265,6 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
- quantization_config = getattr(config, 'quantization_config', None)
self.num_layers = config.num_layers
self.post_layer_norm = config.post_layer_norm
@@ -280,7 +279,6 @@ def build_layer(layer_number):
assert config.rmsnorm
self.final_layernorm = RMSNorm(config.hidden_size,
config.layernorm_epsilon,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py
index c460b8e44f..8010e5cead 100644
--- a/lmdeploy/pytorch/models/cogvlm.py
+++ b/lmdeploy/pytorch/models/cogvlm.py
@@ -617,7 +617,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -634,7 +633,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/dbrx.py b/lmdeploy/pytorch/models/dbrx.py
index e71ff17fe9..7e61fd317d 100644
--- a/lmdeploy/pytorch/models/dbrx.py
+++ b/lmdeploy/pytorch/models/dbrx.py
@@ -9,7 +9,7 @@
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, LayerNorm,
RopeType, build_rotary_embedding)
from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear
-from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK
+from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMixin
@@ -165,7 +165,7 @@ def __init__(self,
act_fn_name = ffn_act_fn.get('name', None)
assert act_fn_name == 'silu'
- self.mlp = FusedMoE(
+ self.mlp = build_fused_moe(
hidden_size,
ffn_hidden_size,
moe_num_experts,
@@ -522,7 +522,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if '.experts' in name:
loaded_weight = loaded_weight.unflatten(0, (num_experts, -1))
if '.w1' in name:
- name = name.replace('.w1', '.gate_up_weights')
+ name = name.replace('.w1', '.gate_up.weight')
param = params_dict[name]
for exp_id in range(num_experts):
weight = loaded_weight[exp_id]
@@ -531,7 +531,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_id=exp_id,
shard_id='gate')
elif '.v1' in name:
- name = name.replace('.v1', '.gate_up_weights')
+ name = name.replace('.v1', '.gate_up.weight')
param = params_dict[name]
for exp_id in range(num_experts):
weight = loaded_weight[exp_id]
@@ -540,7 +540,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_id=exp_id,
shard_id='up')
elif '.w2' in name:
- name = name.replace('.w2', '.down_weights')
+ name = name.replace('.w2', '.down.weight')
param = params_dict[name]
for exp_id in range(num_experts):
weight = loaded_weight[exp_id].t()
diff --git a/lmdeploy/pytorch/models/deepseek.py b/lmdeploy/pytorch/models/deepseek.py
index 5742baeee5..09c0b74fcc 100644
--- a/lmdeploy/pytorch/models/deepseek.py
+++ b/lmdeploy/pytorch/models/deepseek.py
@@ -12,7 +12,7 @@
SiluAndMul, build_rotary_embedding)
from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear,
build_qkv_proj, build_rowwise_linear)
-from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK
+from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMixin
@@ -135,7 +135,7 @@ def __init__(self,
self.softmax_topk = SoftmaxTopK(self.top_k)
- self.experts = FusedMoE(
+ self.experts = build_fused_moe(
self.hidden_dim,
self.ffn_dim,
self.num_experts,
@@ -265,12 +265,10 @@ def __init__(self,
device=device)
# build attention layer norm
- self.post_attention_layernorm = RMSNorm(
- config.hidden_size,
- config.rms_norm_eps,
- quant_config=quantization_config,
- dtype=dtype,
- device=device)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ dtype=dtype,
+ device=device)
def forward(
self,
@@ -315,7 +313,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -332,7 +329,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
@@ -528,14 +524,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_experts = self.config.n_routed_experts
expert_params_mapping = []
for exp_id in range(num_experts):
- gate_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.gate_proj.weight', exp_id,
- 'gate')
- up_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.up_proj.weight', exp_id, 'up')
- down_param = ('.experts.down_weights',
- f'.experts.{exp_id}.down_proj.weight', exp_id,
- 'down')
+ gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj',
+ exp_id, 'gate')
+ up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj',
+ exp_id, 'up')
+ down_param = ('.experts.down', f'.experts.{exp_id}.down_proj',
+ exp_id, 'down')
expert_params_mapping += [gate_param, up_param, down_param]
params_dict = dict(self.named_parameters())
diff --git a/lmdeploy/pytorch/models/deepseek_v2.py b/lmdeploy/pytorch/models/deepseek_v2.py
index 66f68d90e5..b69ae6650d 100644
--- a/lmdeploy/pytorch/models/deepseek_v2.py
+++ b/lmdeploy/pytorch/models/deepseek_v2.py
@@ -4,6 +4,7 @@
import torch
import torch.distributed as dist
+import torch.nn.functional as F
from torch import nn
from lmdeploy.pytorch.distributed import get_world_rank
@@ -13,7 +14,7 @@
from lmdeploy.pytorch.nn.linear import (build_colwise_linear,
build_merged_colwise_linear,
build_rowwise_linear)
-from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK
+from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.nn.rotary_embedding import YarnParameters
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
@@ -81,7 +82,7 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
- quantization_config = None
+ quantization_config = getattr(config, 'quantization_config', None)
self.q_lora_rank = config.q_lora_rank
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
@@ -102,6 +103,7 @@ def __init__(self,
dtype=dtype,
device=device,
is_tp=True,
+ quant_config=quantization_config,
)
else:
self.q_a_proj = build_colwise_linear(
@@ -111,6 +113,7 @@ def __init__(self,
dtype=dtype,
device=device,
is_tp=False,
+ quant_config=quantization_config,
)
self.q_a_layernorm = RMSNorm(config.q_lora_rank,
1e-6,
@@ -124,6 +127,7 @@ def __init__(self,
dtype=dtype,
device=device,
is_tp=True,
+ quant_config=quantization_config,
)
self.kv_a_proj_with_mqa = build_colwise_linear(
@@ -133,6 +137,7 @@ def __init__(self,
dtype=dtype,
device=device,
is_tp=False,
+ quant_config=quantization_config,
)
self.kv_a_layernorm = RMSNorm(config.kv_lora_rank,
1e-6,
@@ -176,6 +181,7 @@ def __init__(self,
dtype=dtype,
device=device,
is_tp=True,
+ quant_config=quantization_config,
)
def _q_proj(self, hidden_states, num_heads: int, nope_size: int,
@@ -272,6 +278,104 @@ def forward(
return attn_output
+class MoEGate(nn.Module):
+ """Deepseek Gate."""
+
+ def __init__(self,
+ config: Any,
+ dtype: torch.dtype = None,
+ device: torch.device = None):
+ super().__init__()
+ self.config = config
+ self.top_k = config.num_experts_per_tok
+ self.n_routed_experts = config.n_routed_experts
+ self.routed_scaling_factor = config.routed_scaling_factor
+ self.scoring_func = config.scoring_func
+ self.alpha = config.aux_loss_alpha
+ self.seq_aux = config.seq_aux
+ self.topk_method = config.topk_method
+ self.n_group = config.n_group
+ self.topk_group = config.topk_group
+ self.norm_topk_prob = config.norm_topk_prob
+ self.renormalize = self.top_k > 1 and self.norm_topk_prob
+
+ # topk selection algorithm
+ self.norm_topk_prob = config.norm_topk_prob
+ self.gating_dim = config.hidden_size
+ self.weight = nn.Parameter(
+ torch.empty((self.n_routed_experts, self.gating_dim),
+ dtype=dtype,
+ device=device))
+ if self.topk_method == 'noaux_tc':
+ self.e_score_correction_bias = nn.Parameter(
+ torch.empty((self.n_routed_experts, ),
+ dtype=dtype,
+ device=device))
+ self.softmax_topk = SoftmaxTopK(self.top_k)
+
+ def _compute_scores(self, logits: torch.Tensor):
+ """compute scores."""
+ if self.scoring_func == 'softmax':
+ scores = logits.softmax(dim=-1, dtype=torch.float32)
+ elif self.scoring_func == 'sigmoid':
+ scores = logits.sigmoid()
+ else:
+ raise NotImplementedError('insupportable scoring function '
+ f'for MoE gating: {self.scoring_func}')
+ return scores
+
+ def forward(self, hidden_states: torch.Tensor):
+ """forward."""
+ sequence_length, hidden_dim = hidden_states.shape
+ router_logits = F.linear(hidden_states, self.weight)
+
+ if self.topk_method == 'greedy':
+ topk_weight, topk_idx = self.softmax_topk(router_logits)
+ elif self.topk_method == 'group_limited_greedy':
+ scores = self._compute_scores(router_logits)
+ grouped_logits = scores.unflatten(-1, (self.n_group, -1))
+ group_scores = (grouped_logits.max(-1).values)
+ group_idx = torch.topk(group_scores,
+ k=self.topk_group,
+ dim=-1,
+ sorted=False)[1] # [n, top_k_group]
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
+ group_mask = ~group_mask.bool()[..., None]
+ grouped_logits = grouped_logits.masked_fill(group_mask, 0.0)
+ scores = grouped_logits.flatten(1, 2)
+ topk_weight, topk_idx = self.softmax_topk(scores)
+ elif self.topk_method == 'noaux_tc':
+ scores = self._compute_scores(router_logits)
+ scores_for_choice = scores.view(
+ sequence_length, -1) + self.e_score_correction_bias[None]
+ group_scores = (scores_for_choice.view(
+ sequence_length, self.n_group,
+ -1).topk(2, dim=-1)[0].sum(dim=-1)) # [n, n_group]
+ group_idx = torch.topk(group_scores,
+ k=self.topk_group,
+ dim=-1,
+ sorted=False)[1] # [n, top_k_group]
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
+ score_mask = (group_mask.unsqueeze(-1).expand(
+ sequence_length, self.n_group,
+ self.n_routed_experts // self.n_group).reshape(
+ sequence_length, -1)) # [n, e]
+ tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(),
+ 0.0) # [n, e]
+ _, topk_idx = torch.topk(tmp_scores,
+ k=self.top_k,
+ dim=-1,
+ sorted=False)
+ topk_weight = scores.gather(1, topk_idx)
+ else:
+ raise RuntimeError(f'Unsupported topk_method: {self.topk_method}')
+ if not self.renormalize:
+ topk_weight = topk_weight * self.routed_scaling_factor
+ return topk_weight, topk_idx
+
+
class DeepseekV2MoE(nn.Module):
"""Deepseek v2 MoE."""
@@ -280,6 +384,7 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
self.hidden_dim = config.hidden_size
self.ffn_dim = config.moe_intermediate_size
self.num_experts = config.n_routed_experts
@@ -291,18 +396,9 @@ def __init__(self,
self.n_group = config.n_group
self.topk_group = config.topk_group
- self.gate = build_rowwise_linear(
- self.hidden_dim,
- self.num_experts,
- bias=False,
- dtype=dtype,
- device=device,
- is_tp=False,
- )
+ self.gate = MoEGate(config, dtype=dtype, device=device)
- self.softmax_topk = SoftmaxTopK(self.top_k)
-
- self.experts = FusedMoE(
+ self.experts = build_fused_moe(
self.hidden_dim,
self.ffn_dim,
self.num_experts,
@@ -311,6 +407,7 @@ def __init__(self,
dtype=dtype,
device=device,
all_reduce=False,
+ quant_config=quantization_config,
)
self.shared_experts = None
@@ -335,27 +432,8 @@ def forward(self, hidden_states: torch.Tensor):
"""forward."""
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
- router_logits = self.gate(hidden_states)
+ topk_weights, topk_ids = self.gate(hidden_states)
- if self.topk_method == 'greedy':
- topk_weights, topk_ids = self.softmax_topk(router_logits)
- elif self.topk_method == 'group_limited_greedy':
- grouped_logits = router_logits.unflatten(-1, (self.n_group, -1))
- group_scores = (grouped_logits.max(-1).values)
- group_idx = torch.topk(group_scores,
- k=self.topk_group,
- dim=-1,
- sorted=False)[1] # [n, top_k_group]
- group_mask = torch.zeros_like(group_scores) # [n, n_group]
- group_mask.scatter_(1, group_idx, 1) # [n, n_group]
- group_mask = ~group_mask.bool()[..., None]
- grouped_logits = grouped_logits.masked_fill(group_mask, 0.0)
- router_logits = grouped_logits.flatten(1, 2)
- topk_weights, topk_ids = self.softmax_topk(router_logits)
- else:
- raise RuntimeError(f'Unsupported topk_method: {self.topk_method}')
- if not self.renormalize:
- topk_weights = topk_weights * self.routed_scaling_factor
out_states = self.experts(
hidden_states,
topk_weights,
@@ -452,12 +530,10 @@ def __init__(self,
device=device)
# build attention layer norm
- self.post_attention_layernorm = RMSNorm(
- config.hidden_size,
- config.rms_norm_eps,
- quant_config=quantization_config,
- dtype=dtype,
- device=device)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ dtype=dtype,
+ device=device)
def forward(
self,
@@ -574,7 +650,6 @@ def forward(
cos, sin = cos[0], sin[0]
rotary_pos_emb = (cos, sin)
for idx, decoder_layer in enumerate(self.layers):
-
past_key_value = past_key_values[idx]
hidden_states, residual = decoder_layer(
hidden_states,
@@ -603,6 +678,8 @@ def __init__(self,
device: torch.device = None):
super().__init__()
self.config = config
+ self.quantization_config = getattr(config, 'quantization_config', None)
+ self.dtype = dtype
self.ctx_mgr = ctx_mgr
self.model = DeepseekV2Model(config, dtype=dtype, device=device)
# build lm_head
@@ -611,6 +688,7 @@ def __init__(self,
bias=False,
dtype=dtype,
device=device)
+ self._load_buffers = dict()
def forward(
self,
@@ -694,40 +772,99 @@ def __update_pe(weight, head_dim: int, pe_dim_offset: int):
weight = weight.flatten(0, 1)
return weight
+ def __load_kcvc(name: str, weight: torch.Tensor):
+ """load kc and vc from weight."""
+ config = self.config
+ v_head_dim = config.v_head_dim
+ qk_nope_head_dim = config.qk_nope_head_dim
+ w_kc, w_vc = weight.unflatten(
+ 0, (-1, qk_nope_head_dim + v_head_dim)).split(
+ [qk_nope_head_dim, v_head_dim], dim=1)
+ w_vc = w_vc.transpose(1, 2).contiguous()
+ kc_param_name = name.replace('.kv_b_proj', '.kc')
+ param_kc = params_dict[kc_param_name]
+ load_weight(param_kc, w_kc)
+ vc_param_name = name.replace('.kv_b_proj', '.vc')
+ param_vc = params_dict[vc_param_name]
+ load_weight(param_vc, w_vc)
+
+ def __dequant_weight(weight: torch.Tensor, scale: torch.Tensor,
+ dtype: torch.dtype):
+ """dequant weight."""
+ dim_w0, dim_w1 = weight.shape
+ dim_s0, dim_s1 = scale.shape
+ assert dim_w0 % dim_s0 == 0
+ assert dim_w1 % dim_s1 == 0
+ group0 = dim_w0 // dim_s0
+ group1 = dim_w1 // dim_s1
+ weight = weight.reshape(dim_s0, group0, dim_s1, group1)
+ scale = scale.reshape(dim_s0, 1, dim_s1, 1)
+ weight = weight.to(scale.dtype) * scale
+ weight = weight.to(dtype)
+ weight = weight.reshape(dim_w0, dim_w1)
+ return weight
+
+ def __load_kcvc_blocked_fp8(name: str, loaded_weight: torch.Tensor):
+ """dequant weight."""
+ if name.endswith('.weight'):
+ weight_name = name
+ scale_name = name.replace('.weight', '.scale')
+ elif name.endswith('.scale'):
+ weight_name = name.replace('.scale', '.weight')
+ scale_name = name
+ self._load_buffers[name] = loaded_weight
+ if (weight_name in self._load_buffers
+ and scale_name in self._load_buffers):
+ weight = self._load_buffers.pop(weight_name)
+ scale = self._load_buffers.pop(scale_name)
+ kc_param_name = weight_name.replace('.kv_b_proj', '.kc')
+ dtype = params_dict[kc_param_name].dtype
+ weight = __dequant_weight(weight, scale, dtype)
+ __load_kcvc(weight_name, weight)
+
for (mod_name, head_dim, pe_dim_offset) in update_pe_mapping:
if mod_name not in name:
continue
- weight = __update_pe(loaded_weight, head_dim, pe_dim_offset)
+ if name.endswith('.scale'):
+ weight = loaded_weight
+ else:
+ weight = __update_pe(loaded_weight, head_dim, pe_dim_offset)
param = params_dict[name]
load_weight(param, weight)
break
else:
if '.kv_b_proj' in name:
- config = self.config
- v_head_dim = config.v_head_dim
- qk_nope_head_dim = config.qk_nope_head_dim
- w_kc, w_vc = loaded_weight.unflatten(
- 0, (-1, qk_nope_head_dim + v_head_dim)).split(
- [qk_nope_head_dim, v_head_dim], dim=1)
- w_vc = w_vc.transpose(1, 2).contiguous()
- kc_param_name = name.replace('.kv_b_proj', '.kc')
- param_kc = params_dict[kc_param_name]
- load_weight(param_kc, w_kc)
- vc_param_name = name.replace('.kv_b_proj', '.vc')
- param_vc = params_dict[vc_param_name]
- load_weight(param_vc, w_vc)
+ quantization_config = self.quantization_config
+ quant_method = None
+ if quantization_config is not None:
+ quant_method = quantization_config.get('quant_method')
+
+ if quant_method == 'fp8':
+ # update blocked fp8 weight
+ __load_kcvc_blocked_fp8(name, loaded_weight)
+ else:
+ __load_kcvc(name, loaded_weight)
else:
param = params_dict[name]
load_weight(param, loaded_weight)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
"""load weights."""
+
+ def __skip_nextn(name, nextn_keys):
+ for nextn_key in nextn_keys:
+ if nextn_key in name:
+ return True
+ return False
+
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
('.gate_up_proj', '.gate_proj', 0),
('.gate_up_proj', '.up_proj', 1),
]
+ scale_suffix = '.weight_scale_inv'
+
config = self.config
qk_rope_head_dim = config.qk_rope_head_dim
kv_lora_rank = config.kv_lora_rank
@@ -741,16 +878,23 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_experts = self.config.n_routed_experts
expert_params_mapping = []
for exp_id in range(num_experts):
- gate_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.gate_proj.weight', exp_id,
- 'gate')
- up_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.up_proj.weight', exp_id, 'up')
- down_param = ('.experts.down_weights',
- f'.experts.{exp_id}.down_proj.weight', exp_id,
- 'down')
+ gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj',
+ exp_id, 'gate')
+ up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj',
+ exp_id, 'up')
+ down_param = ('.experts.down', f'.experts.{exp_id}.down_proj',
+ exp_id, 'down')
expert_params_mapping += [gate_param, up_param, down_param]
+ num_hidden_layers = self.config.num_hidden_layers
+
+ num_nextn_predict_layers = getattr(self.config,
+ 'num_nextn_predict_layers', 1)
+ nextn_keys = [
+ f'.layers.{num_hidden_layers+i}'
+ for i in range(num_nextn_predict_layers)
+ ]
+
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
if 'rotary_emb.inv_freq' in name:
@@ -758,8 +902,14 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
if ('rotary_emb.cos_cached' in name
or 'rotary_emb.sin_cached' in name):
continue
+ if '.layers' in name:
+ # skip nextn
+ if __skip_nextn(name, nextn_keys):
+ continue
if self.config.tie_word_embeddings and 'lm_head.weight' in name:
continue
+ if name.endswith(scale_suffix):
+ name = name[:-len(scale_suffix)] + '.scale'
if '.experts' in name:
self._load_weight_experts(
name,
diff --git a/lmdeploy/pytorch/models/gemma.py b/lmdeploy/pytorch/models/gemma.py
index 86be85669e..1f24206b16 100644
--- a/lmdeploy/pytorch/models/gemma.py
+++ b/lmdeploy/pytorch/models/gemma.py
@@ -263,7 +263,6 @@ def __init__(self,
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -280,7 +279,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/internlm2.py b/lmdeploy/pytorch/models/internlm2.py
index db246331a1..52f51a3ad1 100644
--- a/lmdeploy/pytorch/models/internlm2.py
+++ b/lmdeploy/pytorch/models/internlm2.py
@@ -221,7 +221,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.tok_embeddings = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -241,7 +240,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/internlm2_ve.py b/lmdeploy/pytorch/models/internlm2_ve.py
index b1a2329597..c10faa5f5d 100644
--- a/lmdeploy/pytorch/models/internlm2_ve.py
+++ b/lmdeploy/pytorch/models/internlm2_ve.py
@@ -105,7 +105,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.tok_embeddings = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -125,7 +124,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/internvl.py b/lmdeploy/pytorch/models/internvl.py
index 1059569a09..5fccd627e5 100644
--- a/lmdeploy/pytorch/models/internvl.py
+++ b/lmdeploy/pytorch/models/internvl.py
@@ -124,12 +124,16 @@ def __init__(self,
eps=config.layer_norm_eps,
dtype=dtype,
device=device,
+ tp=True,
+ align=self.head_dim,
)
self.k_norm = RMSNorm(
self.embed_dim,
eps=config.layer_norm_eps,
dtype=dtype,
device=device,
+ tp=True,
+ align=self.head_dim,
)
self.scale = self.head_dim**-0.5
diff --git a/lmdeploy/pytorch/models/minicpmv26.py b/lmdeploy/pytorch/models/minicpmv26.py
index e551dda841..9e47c56437 100644
--- a/lmdeploy/pytorch/models/minicpmv26.py
+++ b/lmdeploy/pytorch/models/minicpmv26.py
@@ -227,7 +227,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -247,7 +246,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/mistral.py b/lmdeploy/pytorch/models/mistral.py
index ad27963093..962cdb3d2b 100644
--- a/lmdeploy/pytorch/models/mistral.py
+++ b/lmdeploy/pytorch/models/mistral.py
@@ -223,7 +223,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -240,7 +239,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/mixtral.py b/lmdeploy/pytorch/models/mixtral.py
index d98efee712..be414d7bff 100644
--- a/lmdeploy/pytorch/models/mixtral.py
+++ b/lmdeploy/pytorch/models/mixtral.py
@@ -8,7 +8,7 @@
from lmdeploy.pytorch.nn import (ApplyRotaryEmb, Attention, RMSNorm, RopeType,
build_rotary_embedding)
from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear
-from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK
+from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMixin
@@ -22,7 +22,7 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
- quantization_config = None
+ quantization_config = getattr(config, 'quantization_config', None)
num_heads = config.num_attention_heads
num_key_value_heads = config.num_key_value_heads
@@ -112,6 +112,7 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
+ quantization_config = getattr(config, 'quantization_config', None)
self.hidden_dim = config.hidden_size
self.ffn_dim = config.intermediate_size
self.num_experts = config.num_local_experts
@@ -124,11 +125,12 @@ def __init__(self,
dtype=dtype,
device=device,
is_tp=False,
+ quant_config=None,
)
self.softmax_topk = SoftmaxTopK(self.top_k)
- self.experts = FusedMoE(
+ self.experts = build_fused_moe(
self.hidden_dim,
self.ffn_dim,
self.num_experts,
@@ -137,6 +139,7 @@ def __init__(self,
dtype=dtype,
device=device,
all_reduce=True,
+ quant_config=quantization_config,
)
def forward(self, hidden_states: torch.Tensor):
@@ -166,7 +169,7 @@ def __init__(self,
device: torch.device = None):
super().__init__()
self.layer_idx = layer_idx
- quantization_config = None
+ quantization_config = getattr(config, 'quantization_config', None)
# build attention layer
self.self_attn = MixtralAttention(config, dtype=dtype, device=device)
@@ -182,12 +185,10 @@ def __init__(self,
device=device)
# build attention layer norm
- self.post_attention_layernorm = RMSNorm(
- config.hidden_size,
- config.rms_norm_eps,
- quant_config=quantization_config,
- dtype=dtype,
- device=device)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ dtype=dtype,
+ device=device)
def forward(
self,
@@ -376,12 +377,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_experts = self.config.num_local_experts
expert_params_mapping = []
for exp_id in range(num_experts):
- gate_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.w1.weight', exp_id, 'gate')
- up_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.w3.weight', exp_id, 'up')
- down_param = ('.experts.down_weights',
- f'.experts.{exp_id}.w2.weight', exp_id, 'down')
+ gate_param = ('.experts.gate_up', f'.experts.{exp_id}.w1', exp_id,
+ 'gate')
+ up_param = ('.experts.gate_up', f'.experts.{exp_id}.w3', exp_id,
+ 'up')
+ down_param = ('.experts.down', f'.experts.{exp_id}.w2', exp_id,
+ 'down')
expert_params_mapping += [gate_param, up_param, down_param]
params_dict = dict(self.named_parameters())
diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py
index e7b460026a..c01a166b94 100644
--- a/lmdeploy/pytorch/models/module_map.py
+++ b/lmdeploy/pytorch/models/module_map.py
@@ -6,9 +6,11 @@
MODULE_MAP = dict()
ASCEND_MODULE_MAP = dict()
MACA_MODULE_MAP = dict()
+CAMB_MODULE_MAP = dict()
DEVICE_SPECIAL_MODULE_MAP = dict(ascend=ASCEND_MODULE_MAP,
- maca=MACA_MODULE_MAP)
+ maca=MACA_MODULE_MAP,
+ camb=CAMB_MODULE_MAP)
# llama
MODULE_MAP.update({
@@ -82,6 +84,12 @@
f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM'
})
+# deepseek-v3
+MODULE_MAP.update({
+ 'DeepseekV3ForCausalLM':
+ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.deepseek_v2.DeepseekV2ForCausalLM'
+})
+
# llava
MODULE_MAP.update(
{
diff --git a/lmdeploy/pytorch/models/phi3.py b/lmdeploy/pytorch/models/phi3.py
index 288fdf3b19..988fee11e5 100644
--- a/lmdeploy/pytorch/models/phi3.py
+++ b/lmdeploy/pytorch/models/phi3.py
@@ -226,7 +226,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -243,7 +242,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/phi3_moe.py b/lmdeploy/pytorch/models/phi3_moe.py
index 080f5e996c..7d0572513a 100644
--- a/lmdeploy/pytorch/models/phi3_moe.py
+++ b/lmdeploy/pytorch/models/phi3_moe.py
@@ -7,7 +7,7 @@
from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager
from lmdeploy.pytorch.nn import ApplyRotaryEmb, Attention, LayerNorm, RopeType
from lmdeploy.pytorch.nn.linear import build_qkv_proj, build_rowwise_linear
-from lmdeploy.pytorch.nn.moe import FusedMoE
+from lmdeploy.pytorch.nn.moe import build_fused_moe
from lmdeploy.pytorch.nn.rotary_embedding import (LongRoPEScalingParameters,
build_rotary_embedding)
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
@@ -180,7 +180,7 @@ def __init__(self,
is_tp=False,
)
- self.experts = FusedMoE(
+ self.experts = build_fused_moe(
self.hidden_dim,
self.ffn_dim,
self.num_experts,
@@ -448,12 +448,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_experts = self.config.num_local_experts
expert_params_mapping = []
for exp_id in range(num_experts):
- gate_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.w1.weight', exp_id, 'gate')
- up_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.w3.weight', exp_id, 'up')
- down_param = ('.experts.down_weights',
- f'.experts.{exp_id}.w2.weight', exp_id, 'down')
+ gate_param = ('.experts.gate_up', f'.experts.{exp_id}.w1', exp_id,
+ 'gate')
+ up_param = ('.experts.gate_up', f'.experts.{exp_id}.w3', exp_id,
+ 'up')
+ down_param = ('.experts.down', f'.experts.{exp_id}.w2', exp_id,
+ 'down')
expert_params_mapping += [gate_param, up_param, down_param]
params_dict = dict(self.named_parameters())
diff --git a/lmdeploy/pytorch/models/q_modules.py b/lmdeploy/pytorch/models/q_modules.py
index 001fab7a60..8379bb18c9 100644
--- a/lmdeploy/pytorch/models/q_modules.py
+++ b/lmdeploy/pytorch/models/q_modules.py
@@ -34,13 +34,17 @@ class QRMSNorm(nn.Module):
"""It performs traditional RMS normalization and then quantizes the output
to 8-bit integers."""
- def __init__(self, hidden_size, eps=1e-6):
+ def __init__(self, hidden_size, eps=1e-6, quant_dtype=torch.int8):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
+ self.quant_dtype = quant_dtype
@classmethod
- def from_float(cls, mod: nn.Module, initialization: bool = True):
+ def from_float(cls,
+ mod: nn.Module,
+ initialization: bool = True,
+ quant_dtype=torch.int8):
"""Class method to create a QRMSNorm instance from a floating-point
module.
@@ -49,7 +53,7 @@ def from_float(cls, mod: nn.Module, initialization: bool = True):
"""
hidden_size = mod.weight.shape[0]
eps = mod.variance_epsilon
- q_mod = cls(hidden_size, eps)
+ q_mod = cls(hidden_size, eps, quant_dtype=quant_dtype)
if initialization:
q_mod.weight = nn.Parameter(mod.weight.detach())
return q_mod
@@ -62,7 +66,10 @@ def forward(self, hidden_states):
with its scale factor.
"""
hidden_states_quant, rms_scale = rms_norm_dynamic_quant(
- hidden_states, self.weight, self.variance_epsilon)
+ hidden_states,
+ self.weight,
+ self.variance_epsilon,
+ quant_dtype=self.quant_dtype)
return QTensor(hidden_states_quant, rms_scale)
@@ -83,16 +90,18 @@ def __init__(self,
out_features: int,
bias: bool = True,
device=None,
- dtype=None) -> None:
+ dtype=None,
+ quant_dtype=torch.int8) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.in_features = in_features
self.out_features = out_features
+ self.quant_dtype = quant_dtype
self.register_buffer(
'weight',
torch.empty((out_features, in_features),
device=device,
- dtype=torch.int8))
+ dtype=quant_dtype))
self.register_buffer(
'scale',
torch.empty((out_features, 1), device=device, dtype=torch.float32))
@@ -103,7 +112,10 @@ def __init__(self,
self.register_parameter('bias', None)
@classmethod
- def from_float(cls, mod: nn.Module, initialization: bool = True):
+ def from_float(cls,
+ mod: nn.Module,
+ initialization: bool = True,
+ quant_dtype=torch.int8):
"""Class method to create a QLinear instance from a floating-point
module.
@@ -114,11 +126,12 @@ def from_float(cls, mod: nn.Module, initialization: bool = True):
mod.out_features,
mod.bias is not None,
device=mod.weight.device,
- dtype=mod.weight.dtype)
+ dtype=mod.weight.dtype,
+ quant_dtype=quant_dtype)
if initialization:
- weight_quant, scale = per_channel_quant(mod.weight.detach(), 8,
- torch.int8)
+ weight_quant, scale = per_channel_quant(mod.weight.detach(),
+ quant_dtype)
q_mod.weight.data = weight_quant
q_mod.scale = scale
@@ -137,7 +150,8 @@ def forward(self, input):
"""
if isinstance(input, torch.Tensor):
- input_quant, input_scale = per_token_quant_int8(input, 1e-7)
+ input_quant, input_scale = per_token_quant_int8(
+ input, 1e-7, quant_dtype=self.quant_dtype)
else:
assert isinstance(input, QTensor)
input_quant, input_scale = input.tensor, input.scale
diff --git a/lmdeploy/pytorch/models/qwen.py b/lmdeploy/pytorch/models/qwen.py
index bf856461a3..20e184bdf8 100644
--- a/lmdeploy/pytorch/models/qwen.py
+++ b/lmdeploy/pytorch/models/qwen.py
@@ -229,7 +229,6 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None):
super().__init__()
- quantization_config = getattr(config, 'quantization_config', None)
self.vocab_size = config.vocab_size
self.embed_dim = config.hidden_size
self.wte = nn.Embedding(self.vocab_size,
@@ -263,7 +262,6 @@ def __init__(self,
self.ln_f = RMSNorm(self.embed_dim,
eps=config.layer_norm_epsilon,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/qwen2.py b/lmdeploy/pytorch/models/qwen2.py
index 38773c21e1..a26aa22d5a 100644
--- a/lmdeploy/pytorch/models/qwen2.py
+++ b/lmdeploy/pytorch/models/qwen2.py
@@ -225,7 +225,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -242,7 +241,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/models/qwen2_moe.py b/lmdeploy/pytorch/models/qwen2_moe.py
index 1aff14483a..de990592d5 100644
--- a/lmdeploy/pytorch/models/qwen2_moe.py
+++ b/lmdeploy/pytorch/models/qwen2_moe.py
@@ -13,7 +13,7 @@
SiluAndMul, build_rotary_embedding)
from lmdeploy.pytorch.nn.linear import (build_merged_colwise_linear,
build_qkv_proj, build_rowwise_linear)
-from lmdeploy.pytorch.nn.moe import FusedMoE, SoftmaxTopK
+from lmdeploy.pytorch.nn.moe import SoftmaxTopK, build_fused_moe
from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight
from .utils.cudagraph import CudaGraphMixin
@@ -185,7 +185,7 @@ def __init__(self,
self.softmax_topk = SoftmaxTopK(self.top_k)
- self.experts = FusedMoE(
+ self.experts = build_fused_moe(
self.hidden_dim,
self.ffn_dim,
self.num_experts,
@@ -280,12 +280,10 @@ def __init__(self,
device=device)
# build attention layer norm
- self.post_attention_layernorm = RMSNorm(
- config.hidden_size,
- config.rms_norm_eps,
- quant_config=quantization_config,
- dtype=dtype,
- device=device)
+ self.post_attention_layernorm = RMSNorm(config.hidden_size,
+ config.rms_norm_eps,
+ dtype=dtype,
+ device=device)
def forward(
self,
@@ -330,7 +328,6 @@ def __init__(self,
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -347,7 +344,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
@@ -531,14 +527,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
num_experts = self.config.num_experts
expert_params_mapping = []
for exp_id in range(num_experts):
- gate_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.gate_proj.weight', exp_id,
- 'gate')
- up_param = ('.experts.gate_up_weights',
- f'.experts.{exp_id}.up_proj.weight', exp_id, 'up')
- down_param = ('.experts.down_weights',
- f'.experts.{exp_id}.down_proj.weight', exp_id,
- 'down')
+ gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj',
+ exp_id, 'gate')
+ up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj',
+ exp_id, 'up')
+ down_param = ('.experts.down', f'.experts.{exp_id}.down_proj',
+ exp_id, 'down')
expert_params_mapping += [gate_param, up_param, down_param]
params_dict = dict(self.named_parameters())
diff --git a/lmdeploy/pytorch/models/qwen2_vl.py b/lmdeploy/pytorch/models/qwen2_vl.py
index 4e2b1017b5..bfd6e352f1 100644
--- a/lmdeploy/pytorch/models/qwen2_vl.py
+++ b/lmdeploy/pytorch/models/qwen2_vl.py
@@ -260,7 +260,6 @@ def __init__(self,
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.mrope_section = config.rope_scaling['mrope_section']
- quantization_config = getattr(config, 'quantization_config', None)
self.embed_tokens = nn.Embedding(config.vocab_size,
config.hidden_size,
@@ -277,7 +276,6 @@ def __init__(self,
# build norm
self.norm = RMSNorm(config.hidden_size,
config.rms_norm_eps,
- quant_config=quantization_config,
dtype=dtype,
device=device)
diff --git a/lmdeploy/pytorch/nn/linear.py b/lmdeploy/pytorch/nn/linear.py
index 486c684a3c..73d0ef918d 100644
--- a/lmdeploy/pytorch/nn/linear.py
+++ b/lmdeploy/pytorch/nn/linear.py
@@ -12,7 +12,7 @@
from ..backends import OpType, get_backend
from ..backends.lora import AdapterInfo
-from .utils import get_distribute_size
+from .utils import chunk_aligned, div_up, get_distribute_size
logger = get_logger('lmdeploy')
@@ -25,20 +25,7 @@ def _check_qkv_split_layout(layout: str):
f'but get: {layout}')
-def _chunk_align(weight: torch.Tensor, chunks: int, dim: int, align: int):
- """chunk aligned."""
- if align == 1:
- return weight.chunk(chunks, dim=dim)
- size = weight.size(dim)
- assert size % align == 0
- aligned_size = size // align
-
- # try best to evenly split chunks
- align_per_chunk = aligned_size // chunks
- remain = aligned_size % chunks
- sections = [align_per_chunk + int(c < remain) for c in range(chunks)]
- sections = [sec * align for sec in sections]
- return weight.split(sections, dim=dim)
+_chunk_align = chunk_aligned
class QKVMixin:
@@ -165,6 +152,239 @@ def weight_loader_B(self, param: nn.Parameter, loaded_weight: torch.Tensor,
param_r.copy_(loaded_weight.t())
+class BlockedF8Linear(nn.Module):
+ """blocked f8 linear."""
+
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ bias: bool,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ fp8_dtype: torch.dtype = torch.float8_e4m3fn,
+ colwise: bool = True,
+ is_tp: bool = False,
+ all_reduce: bool = True,
+ ):
+ super().__init__()
+ if device is None:
+ device = torch.device('cpu')
+ if dtype is None:
+ dtype = torch.float16
+ if is_tp:
+ in_features, out_features = self._get_io_features(
+ in_features, out_features, colwise)
+ impl_builder = get_backend().get_layer_impl_builder(
+ OpType.LinearBlockedF8)
+ self.impl = impl_builder.build(in_features,
+ out_features,
+ block_size=128,
+ bias=bias is not None,
+ dtype=dtype)
+ self.block_size = 128
+ self.fp8_dtype = fp8_dtype
+ weight, scale, bias = self.create_weights(in_features, out_features,
+ bias, dtype, device)
+ weight = torch.nn.Parameter(weight, requires_grad=False)
+ weight.weight_loader = self.weight_loader
+ scale = torch.nn.Parameter(scale, requires_grad=False)
+ scale.weight_loader = self.weight_loader
+ if bias is not None:
+ bias = torch.nn.Parameter(bias, requires_grad=False)
+ bias.weight_loader = self.weight_loader
+ self.register_parameter('weight', weight)
+ self.register_parameter('scale', scale)
+ self.register_parameter('bias', bias)
+
+ self.in_features = in_features
+ self.out_features = out_features
+ self.lora_adapters = nn.ModuleDict()
+ self.is_tp = is_tp
+ self.colwise = colwise
+ self.all_reduce = all_reduce
+
+ def _get_io_features(self, in_features: int, out_features: int,
+ colwise: bool):
+ """get io features."""
+ world_size, rank = get_world_rank()
+ if colwise:
+ out_features = get_distribute_size(out_features, world_size, rank)
+ else:
+ in_features = get_distribute_size(in_features, world_size, rank)
+ return in_features, out_features
+
+ def _weight_loader_tp_colwise(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, rank: int,
+ world_size: int):
+ """weight loader for colwise linear."""
+ weight = loaded_weight.chunk(world_size, 0)[rank]
+ return default_weight_loader(param, weight)
+
+ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, rank: int,
+ world_size: int):
+ """weight loader for rowwise linear."""
+ if loaded_weight.dim() == 2:
+ weight = loaded_weight.chunk(world_size, 1)[rank]
+ return default_weight_loader(param, weight)
+ else:
+ # bias
+ if rank != 0:
+ loaded_weight = torch.zeros_like(loaded_weight)
+ return default_weight_loader(param, loaded_weight)
+
+ def weight_loader(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor):
+ """weight loader."""
+ if not self.is_tp:
+ return default_weight_loader(param, loaded_weight)
+
+ world_size, rank = get_world_rank()
+ if self.colwise:
+ return self._weight_loader_tp_colwise(param, loaded_weight, rank,
+ world_size)
+ else:
+ return self._weight_loader_tp_rowwise(param, loaded_weight, rank,
+ world_size)
+
+ def create_weights(self, in_features: int, out_features: int, bias: bool,
+ dtype: torch.dtype, device: torch.device):
+ """create weights."""
+ weight = torch.empty((out_features, in_features),
+ dtype=self.fp8_dtype,
+ device=device)
+ scale = torch.empty(
+ (div_up(out_features,
+ self.block_size), div_up(in_features, self.block_size)),
+ dtype=torch.float32,
+ device=device)
+ if bias:
+ bias = torch.empty((out_features, ), dtype=dtype, device=device)
+ else:
+ bias = None
+ return weight, scale, bias
+
+ def update_weights(self):
+ """update weights."""
+ weight, scale, bias = self.impl.update_weights(self.weight, self.scale,
+ self.bias)
+ weight = torch.nn.Parameter(weight, requires_grad=False)
+ self.weight.weight_loader = self.weight_loader
+ scale = torch.nn.Parameter(scale, requires_grad=False)
+ self.scale.weight_loader = self.weight_loader
+ if bias is not None:
+ bias = torch.nn.Parameter(bias, requires_grad=False)
+ self.bias.weight_loader = self.weight_loader
+ self.register_parameter('weight', weight)
+ self.register_parameter('scale', scale)
+ self.register_parameter('bias', bias)
+
+ def forward(self, x):
+ """forward of blocked fp8 linear."""
+ all_reduce = False if self.colwise else self.is_tp
+ all_reduce = all_reduce and self.all_reduce
+ if len(self.lora_adapters) == 0:
+ return self.impl.forward(x, self.weight, self.scale, self.bias,
+ all_reduce)
+
+ out = self.impl.forward(x, self.weight, self.scale, self.bias, False)
+ for lora_adapter in self.lora_adapters.values():
+ out = lora_adapter(x, out)
+ if all_reduce:
+ dist.all_reduce(out)
+ return out
+
+
+class MergedBlockedF8Linear(BlockedF8Linear):
+ """merged blocked fp8 linear."""
+
+ def __init__(self,
+ in_features: int,
+ all_out_features: List[int],
+ bias: bool,
+ fp8_dtype: torch.dtype = torch.float8_e4m3fn,
+ replicate: Optional[List[bool]] = None,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ is_tp: bool = True,
+ out_names: Optional[List[int]] = None):
+ if replicate is None:
+ replicate = tuple(False for _ in all_out_features)
+ self.block_size = 128
+ self.split_section = all_out_features
+ self.scale_split_section = [
+ section // self.block_size for section in self.split_section
+ ]
+ all_out_features = self._update_all_out_features(
+ all_out_features, replicate)
+ self.all_out_features = all_out_features
+ self.replicate = replicate
+ if out_names is None:
+ out_names = torch.arange(len(self.all_out_features)).tolist()
+ assert len(out_names) == len(self.all_out_features)
+ self.out_names_map = dict(
+ (name, idx) for idx, name in enumerate(out_names))
+ out_features = sum(all_out_features)
+ super().__init__(in_features,
+ out_features,
+ bias,
+ dtype,
+ device,
+ fp8_dtype=fp8_dtype,
+ colwise=True,
+ is_tp=is_tp)
+ self.weight.weight_loader = self.weight_loader
+ self.scale.weight_loader = self.weight_loader
+ self.weight.weight_spliter = self.weight_spliter
+ self.scale.weight_spliter = self.weight_spliter
+ if self.bias is not None:
+ self.bias.weight_loader = self.weight_loader
+ self.bias.weight_spliter = self.weight_spliter
+
+ def _get_io_features(self, in_features: int, out_features: int,
+ colwise: bool):
+ """get io features."""
+ return in_features, out_features
+
+ def _update_all_out_features(self, all_out_features: List[int],
+ replicate: Optional[List[bool]]):
+ """update all out features."""
+ world_size, rank = get_world_rank()
+ new_all_out_features = []
+ for out_feat, rep in zip(all_out_features, replicate):
+ if rep:
+ new_all_out_features.append(out_feat)
+ new_out_feat = get_distribute_size(out_feat, world_size, rank)
+ new_all_out_features.append(new_out_feat)
+ return new_all_out_features
+
+ def weight_loader(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, shard_id: Any):
+ """weight loader."""
+ world_size, rank = get_world_rank()
+ shard_idx = self.out_names_map[shard_id]
+ if loaded_weight.dim() == 2 and loaded_weight.dtype == torch.float32:
+ all_out_features = [
+ feats // self.block_size for feats in self.all_out_features
+ ]
+ param_w = param.data.split(all_out_features, 0)[shard_idx]
+ else:
+ param_w = param.data.split(self.all_out_features, 0)[shard_idx]
+ if not self.replicate[shard_idx]:
+ loaded_weight = loaded_weight.chunk(world_size, 0)[rank]
+ param_w.copy_(loaded_weight)
+
+ def weight_spliter(self, loaded_weight: torch.Tensor):
+ """weight spliter."""
+ if loaded_weight.dim() == 2 and loaded_weight.dtype == torch.float32:
+ return loaded_weight.split(self.scale_split_section, dim=0)
+ return loaded_weight.split(self.split_section, dim=0)
+
+ def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):
+ return loaded_weight.split(self.split_section, dim=0)
+
+
class AwqLinear(nn.Module):
"""w4a16 linear."""
@@ -598,17 +818,16 @@ def weight_spliter_lora_b(self, loaded_weight: torch.Tensor):
class W8A8Linear(nn.Module):
"""w8a8 linear."""
- def __init__(
- self,
- in_features: int,
- out_features: int,
- bias: bool,
- dtype: Optional[torch.dtype] = None,
- device: Optional[torch.device] = None,
- colwise: bool = True,
- is_tp: bool = False,
- all_reduce: bool = True,
- ):
+ def __init__(self,
+ in_features: int,
+ out_features: int,
+ bias: bool,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ colwise: bool = True,
+ is_tp: bool = False,
+ all_reduce: bool = True,
+ quant_dtype: Optional[torch.dtype] = torch.int8):
super().__init__()
if device is None:
device = torch.device('cpu')
@@ -618,10 +837,12 @@ def __init__(
in_features, out_features = self._get_io_features(
in_features, out_features, colwise)
impl_builder = get_backend().get_layer_impl_builder(OpType.LinearW8A8)
+ self.quant_dtype = quant_dtype
self.impl = impl_builder.build(in_features,
out_features,
bias is not None,
- dtype=dtype)
+ dtype=dtype,
+ quant_dtype=quant_dtype)
weight, scale, bias = self.create_weights(in_features, out_features,
bias, dtype, device)
weight = torch.nn.Parameter(weight, requires_grad=False)
@@ -663,7 +884,9 @@ def _weight_loader_tp_rowwise(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, rank: int,
world_size: int):
"""weight loader for rowwise linear."""
- if loaded_weight.dim() == 2 and param.dtype == torch.int8:
+ if loaded_weight.dim() == 2 and param.dtype in (torch.int8,
+ torch.float8_e4m3fn,
+ torch.float8_e5m2):
weight = loaded_weight.chunk(world_size, 1)[rank]
return default_weight_loader(param, weight)
elif loaded_weight.dim() == 2 and loaded_weight.size(1) == 1:
@@ -693,7 +916,7 @@ def create_weights(self, in_features: int, out_features: int, bias: bool,
dtype: torch.dtype, device: torch.device):
"""create weights."""
weight = torch.empty((out_features, in_features),
- dtype=torch.int8,
+ dtype=self.quant_dtype,
device=device)
scale = torch.empty((out_features, 1),
dtype=torch.float32,
@@ -745,7 +968,8 @@ def __init__(self,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
is_tp: bool = True,
- out_names: Optional[List[int]] = None):
+ out_names: Optional[List[int]] = None,
+ quant_dtype: torch.dtype = torch.int8):
self.split_section = all_out_features
all_out_features = self._update_all_out_features(all_out_features)
self.all_out_features = all_out_features
@@ -761,7 +985,8 @@ def __init__(self,
dtype,
device,
colwise=True,
- is_tp=is_tp)
+ is_tp=is_tp,
+ quant_dtype=quant_dtype)
self.weight.weight_loader = self.weight_loader
self.scale.weight_loader = self.weight_loader
self.weight.weight_spliter = self.weight_spliter
@@ -814,7 +1039,9 @@ def __init__(self,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
is_tp: bool = True,
- num_replicate_kv_heads: int = 1):
+ num_replicate_kv_heads: int = 1,
+ quant_dtype: torch.dtype = torch.int8):
+
self.qkv_split_section = self._get_qkv_out_features(
num_q_heads, num_kv_heads, head_size, head_size_v,
num_replicate_kv_heads)
@@ -835,7 +1062,8 @@ def __init__(self,
dtype=dtype,
device=device,
is_tp=is_tp,
- out_names=out_names)
+ out_names=out_names,
+ quant_dtype=quant_dtype)
def _update_all_out_features(self, all_out_features: List[int]):
"""update all out features."""
@@ -1200,6 +1428,10 @@ def build_linear(in_features: int,
)
quant_method = quant_config['quant_method']
+ quant_dtype = torch.int8
+ if 'quant_dtype' in quant_config:
+ quant_dtype = eval('torch.' + quant_config['quant_dtype'])
+
if quant_method == 'awq':
w_bit = quant_config.get('bits', 4)
group_size = quant_config.get('group_size', 128)
@@ -1215,10 +1447,28 @@ def build_linear(in_features: int,
all_reduce=all_reduce,
)
if quant_method == 'smooth_quant':
- return W8A8Linear(
+ return W8A8Linear(in_features,
+ out_features,
+ bias=bias,
+ dtype=dtype,
+ device=device,
+ colwise=colwise,
+ is_tp=is_tp,
+ all_reduce=all_reduce,
+ quant_dtype=quant_dtype)
+ elif quant_method == 'fp8':
+ fmt = quant_config.get('fmt', 'e4m3')
+ if fmt == 'e4m3':
+ fp8_dtype = torch.float8_e4m3fn
+ elif fmt == 'e5m2':
+ fp8_dtype = torch.float8_e5m2
+ else:
+ raise TypeError(f'Unsupported fp8 fmt: {fmt}')
+ return BlockedF8Linear(
in_features,
out_features,
bias=bias,
+ fp8_dtype=fp8_dtype,
dtype=dtype,
device=device,
colwise=colwise,
@@ -1299,6 +1549,10 @@ def build_merged_colwise_linear(
)
quant_method = quant_config['quant_method']
+ quant_dtype = torch.int8
+ if 'quant_dtype' in quant_config:
+ quant_dtype = eval('torch.' + quant_config['quant_dtype'])
+
if quant_method == 'awq':
w_bit = quant_config.get('bits', 4)
group_size = quant_config.get('group_size', 128)
@@ -1312,10 +1566,27 @@ def build_merged_colwise_linear(
is_tp=is_tp,
)
if quant_method == 'smooth_quant':
- return MergedW8A8Linear(
+ return MergedW8A8Linear(in_features=in_features,
+ all_out_features=all_out_features,
+ bias=bias,
+ dtype=dtype,
+ device=device,
+ is_tp=is_tp,
+ out_names=out_names,
+ quant_dtype=quant_dtype)
+ elif quant_method == 'fp8':
+ fmt = quant_config.get('fmt', 'e4m3')
+ if fmt == 'e4m3':
+ fp8_dtype = torch.float8_e4m3fn
+ elif fmt == 'e5m2':
+ fp8_dtype = torch.float8_e5m2
+ else:
+ raise TypeError(f'Unsupported fp8 fmt: {fmt}')
+ return MergedBlockedF8Linear(
in_features=in_features,
all_out_features=all_out_features,
bias=bias,
+ fp8_dtype=fp8_dtype,
dtype=dtype,
device=device,
is_tp=is_tp,
@@ -1357,6 +1628,10 @@ def build_qkv_proj(in_features: int,
num_replicate_kv_heads=num_replicate_kv_heads)
quant_method = quant_config['quant_method']
+ quant_dtype = torch.int8
+ if 'quant_dtype' in quant_config:
+ quant_dtype = eval('torch.' + quant_config['quant_dtype'])
+
if quant_method == 'awq':
w_bit = quant_config.get('bits', 4)
group_size = quant_config.get('group_size', 128)
@@ -1381,6 +1656,7 @@ def build_qkv_proj(in_features: int,
dtype=dtype,
device=device,
is_tp=is_tp,
- num_replicate_kv_heads=num_replicate_kv_heads)
+ num_replicate_kv_heads=num_replicate_kv_heads,
+ quant_dtype=quant_dtype)
else:
raise RuntimeError(f'Unsupported quant method: {quant_method}')
diff --git a/lmdeploy/pytorch/nn/moe.py b/lmdeploy/pytorch/nn/moe.py
index 47176335c4..1218a2f581 100644
--- a/lmdeploy/pytorch/nn/moe.py
+++ b/lmdeploy/pytorch/nn/moe.py
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
-from typing import Optional
+from typing import Any, List, Optional
import torch
import torch.distributed as dist
@@ -8,6 +8,7 @@
from lmdeploy.pytorch.distributed import get_world_rank
from ..backends import OpType, get_backend
+from .utils import div_up
class SoftmaxTopK(nn.Module):
@@ -24,6 +25,102 @@ def forward(self, x: torch.Tensor):
return self.impl.forward(x)
+def create_mlp_weights(hidden_dim: int, ffn_dim: int, num_experts: int,
+ dtype: torch.dtype, device: torch.device):
+ """create weights."""
+ gate_up_weights = torch.empty((num_experts, ffn_dim * 2, hidden_dim),
+ dtype=dtype,
+ device=device)
+ down_weights = torch.empty((num_experts, hidden_dim, ffn_dim),
+ dtype=dtype,
+ device=device)
+ return gate_up_weights, down_weights
+
+
+def _update_args(hidden_dim: int, ffn_dim: int):
+ """update args."""
+ world_size, _ = get_world_rank()
+ assert ffn_dim % world_size == 0
+ ffn_dim = ffn_dim // world_size
+ return hidden_dim, ffn_dim
+
+
+class LinearWeights(nn.Module):
+ """fused moe linear weights."""
+
+ def __init__(self,
+ num_experts: int,
+ in_features: int,
+ out_features: int,
+ weight_type: str,
+ dtype: torch.dtype,
+ device: torch.device,
+ expert_list: List[int] = None,
+ ep: bool = False):
+ super().__init__()
+ weight = torch.empty((num_experts, out_features, in_features),
+ dtype=dtype,
+ device=device)
+ weight = torch.nn.Parameter(weight, requires_grad=False)
+ self.register_parameter('weight', weight)
+ self.ep = ep
+ self.expert_list = expert_list
+ self.weight_type = weight_type
+ self.half_out = out_features // 2
+
+ if self.ep:
+ self.expert_map = dict(
+ (eid, idx) for idx, eid in enumerate(expert_list))
+ self.weight.weight_loader = self.weight_loader_ep
+ else:
+ self.weight.weight_loader = self.weight_loader_tp
+
+ def update_weight(self, weight: torch.Tensor):
+ """update weight."""
+ weight_loader = self.weight.weight_loader
+ weight = torch.nn.Parameter(weight, requires_grad=False)
+ weight.weight_loader = weight_loader
+ self.register_parameter('weight', weight)
+
+ def weight_loader_tp(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, expert_id: int,
+ shard_id: str):
+ """weight loader."""
+ world_size, rank = get_world_rank()
+ if shard_id == 'gate':
+ param_data = param.data[expert_id, :self.half_out]
+ weight = loaded_weight.chunk(world_size, dim=0)[rank]
+ elif shard_id == 'up':
+ param_data = param.data[expert_id, self.half_out:]
+ weight = loaded_weight.chunk(world_size, dim=0)[rank]
+ elif shard_id == 'down':
+ param_data = param.data[expert_id]
+ weight = loaded_weight.chunk(world_size, dim=1)[rank]
+ else:
+ raise RuntimeError(f'Unknown shard_id: {shard_id}')
+ param_data.copy_(weight)
+
+ def weight_loader_ep(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, expert_id: int,
+ shard_id: str):
+ """weight loader."""
+ expert_list = self.expert_list
+ if expert_id not in expert_list:
+ return
+
+ expert_map = self.expert_map
+ param_id = expert_map[expert_id]
+ if shard_id == 'gate':
+ param_data = param.data[param_id, :self.half_out]
+ elif shard_id == 'up':
+ param_data = param.data[param_id, self.half_out:]
+ elif shard_id == 'down':
+ param_data = param.data[param_id]
+ else:
+ raise RuntimeError(f'Unknown shard_id: {shard_id}')
+ param_data.copy_(loaded_weight)
+
+
class FusedMoE(nn.Module):
"""fused moe."""
@@ -46,42 +143,33 @@ def __init__(self,
impl_builder = get_backend().get_layer_impl_builder(OpType.FusedMoE)
self.impl = impl_builder.build(top_k, num_experts, renormalize)
- self.expert_list = None
- self.expert_map = None
enable_ep = enable_ep and self.impl.support_ep()
if enable_ep:
world_size, rank = get_world_rank()
expert_list = self.impl.ep_expert_list(world_size, rank)
- self.expert_list = expert_list
- self.expert_map = dict(
- (eid, idx) for idx, eid in enumerate(expert_list))
num_experts = len(expert_list)
- gate_up_weights, down_weights = self.create_weights(hidden_dim,
- ffn_dim,
- num_experts,
- dtype=dtype,
- device=device)
- else:
- hidden_dim, ffn_dim = self._update_args(hidden_dim, ffn_dim)
- gate_up_weights, down_weights = self.create_weights(hidden_dim,
- ffn_dim,
- num_experts,
- dtype=dtype,
- device=device)
- gate_up_weights = torch.nn.Parameter(gate_up_weights,
- requires_grad=False)
- down_weights = torch.nn.Parameter(down_weights, requires_grad=False)
- gate_up_weights._weight_type = 'gate_up_weights'
- down_weights._weight_type = 'down_weights'
- self.register_parameter('gate_up_weights', gate_up_weights)
- self.register_parameter('down_weights', down_weights)
-
- if enable_ep:
- gate_up_weights.weight_loader = self.weight_loader_ep
- down_weights.weight_loader = self.weight_loader_ep
else:
- gate_up_weights.weight_loader = self.weight_loader_tp
- down_weights.weight_loader = self.weight_loader_tp
+ hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim)
+ expert_list = None
+ self.expert_list = expert_list
+ self.gate_up = LinearWeights(num_experts,
+ hidden_dim,
+ ffn_dim * 2,
+ weight_type='gate_up',
+ dtype=dtype,
+ device=device,
+ expert_list=expert_list,
+ ep=enable_ep)
+ self.down = LinearWeights(
+ num_experts,
+ ffn_dim,
+ hidden_dim,
+ weight_type='down',
+ dtype=dtype,
+ device=device,
+ expert_list=expert_list,
+ ep=enable_ep,
+ )
self.hidden_dim = hidden_dim
self.ffn_dim = ffn_dim
@@ -93,83 +181,384 @@ def __init__(self,
all_reduce = False
self.all_reduce = all_reduce
- def _update_args(self, hidden_dim: int, ffn_dim: int):
- """update args."""
- world_size, _ = get_world_rank()
- assert ffn_dim % world_size == 0
- ffn_dim = ffn_dim // world_size
- return hidden_dim, ffn_dim
-
- def create_weights(self, hidden_dim: int, ffn_dim: int, num_experts: int,
- dtype: torch.dtype, device: torch.device):
- """create weights."""
- gate_up_weights = torch.empty((num_experts, ffn_dim * 2, hidden_dim),
- dtype=dtype,
- device=device)
- down_weights = torch.empty((num_experts, hidden_dim, ffn_dim),
- dtype=dtype,
- device=device)
- return gate_up_weights, down_weights
-
def update_weights(self):
"""update weights."""
- gateup_loader = self.gate_up_weights.weight_loader
- down_loader = self.down_weights.weight_loader
gate_up_weights, down_weights = self.impl.update_weights(
- self.gate_up_weights, self.down_weights)
- gate_up_weights = torch.nn.Parameter(gate_up_weights,
- requires_grad=False)
- down_weights = torch.nn.Parameter(down_weights, requires_grad=False)
- gate_up_weights.weight_loader = gateup_loader
- down_weights.weight_loader = down_loader
- gate_up_weights._weight_type = 'gate_up_weights'
- down_weights._weight_type = 'down_weights'
- self.register_parameter('gate_up_weights', gate_up_weights)
- self.register_parameter('down_weights', down_weights)
+ self.gate_up.weight, self.down.weight)
+ self.gate_up.update_weight(gate_up_weights)
+ self.down.update_weight(down_weights)
- def weight_loader_tp(self, param: torch.nn.Parameter,
- loaded_weight: torch.Tensor, expert_id: int,
- shard_id: str):
- """weight loader."""
+ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
+ topk_ids: torch.LongTensor):
+ ret = self.impl.forward(hidden_states, topk_weights, topk_ids,
+ self.gate_up.weight, self.down.weight,
+ self.expert_list)
+ if self.all_reduce:
+ dist.all_reduce(ret)
+ return ret
+
+
+class LinearWeightsW8A8(LinearWeights):
+ """fused moe linear w8a8 weights."""
+
+ def __init__(self,
+ num_experts: int,
+ in_features: int,
+ out_features: int,
+ weight_type: str,
+ device: torch.device,
+ expert_list: List[int] = None,
+ ep: bool = False,
+ quant_dtype: torch.dtype = torch.int8):
+ super().__init__(
+ num_experts=num_experts,
+ in_features=in_features,
+ out_features=out_features,
+ weight_type=weight_type,
+ dtype=quant_dtype,
+ device=device,
+ expert_list=expert_list,
+ ep=ep,
+ )
+ scale = torch.empty((num_experts, out_features, 1),
+ dtype=torch.float32,
+ device=device)
+ scale = torch.nn.Parameter(scale, requires_grad=False)
+ self.register_parameter('scale', scale)
+
+ if self.ep:
+ self.scale.weight_loader = self.weight_loader_ep
+ else:
+ self.scale.weight_loader = self.weight_loader_scale_tp
+
+ def update_weight(self, weight: torch.Tensor, scale: torch.Tensor):
+ """update weight."""
+ super().update_weight(weight=weight)
+ weight_loader = self.scale.weight_loader
+ scale = torch.nn.Parameter(scale, requires_grad=False)
+ scale.weight_loader = weight_loader
+ self.register_parameter('scale', scale)
+
+ def weight_loader_scale_tp(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, expert_id: int,
+ shard_id: str):
+ """weight loader scale tp."""
world_size, rank = get_world_rank()
if shard_id == 'gate':
- param_data = param.data[expert_id, :self.ffn_dim]
+ param_data = param.data[expert_id, :self.half_out]
weight = loaded_weight.chunk(world_size, dim=0)[rank]
elif shard_id == 'up':
- param_data = param.data[expert_id, self.ffn_dim:]
+ param_data = param.data[expert_id, self.half_out:]
weight = loaded_weight.chunk(world_size, dim=0)[rank]
elif shard_id == 'down':
param_data = param.data[expert_id]
- weight = loaded_weight.chunk(world_size, dim=1)[rank]
+ weight = loaded_weight
else:
raise RuntimeError(f'Unknown shard_id: {shard_id}')
param_data.copy_(weight)
- def weight_loader_ep(self, param: torch.nn.Parameter,
- loaded_weight: torch.Tensor, expert_id: int,
- shard_id: str):
- """weight loader."""
- expert_list = self.expert_list
- if expert_id not in expert_list:
- return
- expert_map = self.expert_map
- param_id = expert_map[expert_id]
+class FusedMoEW8A8(nn.Module):
+ """fused moe w8a8."""
+
+ def __init__(self,
+ hidden_dim: int,
+ ffn_dim: int,
+ num_experts: int,
+ top_k: int,
+ renormalize: bool = False,
+ dtype: Optional[torch.dtype] = None,
+ quant_dtype: Optional[torch.dtype] = torch.int8,
+ device: Optional[torch.device] = None,
+ all_reduce: bool = True,
+ enable_ep: bool = False):
+ super().__init__()
+
+ if device is None:
+ device = torch.device('cpu')
+ dtype = torch.float16 if dtype is None else dtype
+
+ impl_builder = get_backend().get_layer_impl_builder(
+ OpType.FusedMoEW8A8)
+ self.impl = impl_builder.build(top_k,
+ num_experts,
+ renormalize,
+ dtype,
+ quant_dtype=quant_dtype)
+
+ enable_ep = enable_ep and self.impl.support_ep()
+ if enable_ep:
+ world_size, rank = get_world_rank()
+ expert_list = self.impl.ep_expert_list(world_size, rank)
+ num_experts = len(expert_list)
+ else:
+ hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim)
+ expert_list = None
+ self.expert_list = expert_list
+
+ self.gate_up = LinearWeightsW8A8(num_experts,
+ hidden_dim,
+ ffn_dim * 2,
+ weight_type='gate_up',
+ device=device,
+ expert_list=expert_list,
+ ep=enable_ep,
+ quant_dtype=quant_dtype)
+ self.down = LinearWeightsW8A8(num_experts,
+ ffn_dim,
+ hidden_dim,
+ weight_type='down',
+ device=device,
+ expert_list=expert_list,
+ ep=enable_ep,
+ quant_dtype=quant_dtype)
+
+ self.hidden_dim = hidden_dim
+ self.ffn_dim = ffn_dim
+ self.num_experts = num_experts
+ self.dtype = dtype
+ self.device = device
+ world_size, _ = get_world_rank()
+ if world_size == 1:
+ all_reduce = False
+ self.all_reduce = all_reduce
+
+ def update_weights(self):
+ """update weights."""
+ (gate_up_weights, down_weights, gate_up_scale,
+ down_scale) = self.impl.update_weights(self.gate_up.weight,
+ self.down.weight,
+ self.gate_up.scale,
+ self.down.scale)
+ self.gate_up.update_weight(gate_up_weights, gate_up_scale)
+ self.down.update_weight(down_weights, down_scale)
+
+ def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
+ topk_ids: torch.LongTensor):
+ ret = self.impl.forward(hidden_states, topk_weights, topk_ids,
+ self.gate_up.weight, self.gate_up.scale,
+ self.down.weight, self.down.scale,
+ self.expert_list)
+ if self.all_reduce:
+ dist.all_reduce(ret)
+ return ret
+
+
+class LinearWeightsBlockedF8(LinearWeights):
+ """fused moe linear blocked fp8 weights."""
+
+ def __init__(self,
+ num_experts: int,
+ in_features: int,
+ out_features: int,
+ weight_type: str,
+ block_size: int,
+ dtype: torch.dtype,
+ device: torch.device,
+ expert_list: List[int] = None,
+ ep: bool = False):
+ super().__init__(
+ num_experts=num_experts,
+ in_features=in_features,
+ out_features=out_features,
+ weight_type=weight_type,
+ dtype=dtype,
+ device=device,
+ expert_list=expert_list,
+ ep=ep,
+ )
+ self.block_size = block_size
+ scale = torch.empty((num_experts, div_up(
+ out_features, block_size), div_up(in_features, block_size)),
+ dtype=torch.float32,
+ device=device)
+ scale = torch.nn.Parameter(scale, requires_grad=False)
+ self.register_parameter('scale', scale)
+
+ if self.ep:
+ self.scale.weight_loader = self.weight_loader_ep
+ else:
+ self.scale.weight_loader = self.weight_loader_scale_tp
+
+ def update_weight(self, weight: torch.Tensor, scale: torch.Tensor):
+ """update weight."""
+ super().update_weight(weight=weight)
+ weight_loader = self.scale.weight_loader
+ scale = torch.nn.Parameter(scale, requires_grad=False)
+ scale.weight_loader = weight_loader
+ self.register_parameter('scale', scale)
+
+ def weight_loader_scale_tp(self, param: torch.nn.Parameter,
+ loaded_weight: torch.Tensor, expert_id: int,
+ shard_id: str):
+ """weight loader scale tp."""
+ world_size, rank = get_world_rank()
+ block_size = self.block_size
+ half_out = self.half_out // block_size
if shard_id == 'gate':
- param_data = param.data[param_id, :self.ffn_dim]
+ param_data = param.data[expert_id, :half_out]
+ weight = loaded_weight.chunk(world_size, dim=0)[rank]
elif shard_id == 'up':
- param_data = param.data[param_id, self.ffn_dim:]
+ param_data = param.data[expert_id, half_out:]
+ weight = loaded_weight.chunk(world_size, dim=0)[rank]
elif shard_id == 'down':
- param_data = param.data[param_id]
+ param_data = param.data[expert_id]
+ weight = loaded_weight.chunk(world_size, dim=1)[rank]
else:
raise RuntimeError(f'Unknown shard_id: {shard_id}')
- param_data.copy_(loaded_weight)
+ param_data.copy_(weight)
+
+
+class FusedMoEBlockedF8(nn.Module):
+ """fused moe blocked f8."""
+
+ def __init__(self,
+ hidden_dim: int,
+ ffn_dim: int,
+ num_experts: int,
+ top_k: int,
+ renormalize: bool = False,
+ fp8_dtype: torch.dtype = torch.float8_e4m3fn,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ all_reduce: bool = True,
+ enable_ep: bool = False):
+ super().__init__()
+ if device is None:
+ device = torch.device('cpu')
+ dtype = torch.float16 if dtype is None else dtype
+ self.block_size = 128
+ impl_builder = get_backend().get_layer_impl_builder(
+ OpType.FusedMoEBlockedF8)
+ self.impl = impl_builder.build(top_k,
+ num_experts,
+ renormalize,
+ block_size=self.block_size,
+ out_dtype=dtype)
+
+ enable_ep = enable_ep and self.impl.support_ep()
+ if enable_ep:
+ world_size, rank = get_world_rank()
+ expert_list = self.impl.ep_expert_list(world_size, rank)
+ num_experts = len(expert_list)
+ else:
+ hidden_dim, ffn_dim = _update_args(hidden_dim, ffn_dim)
+ expert_list = None
+ self.expert_list = expert_list
+
+ self.gate_up = LinearWeightsBlockedF8(num_experts,
+ hidden_dim,
+ ffn_dim * 2,
+ weight_type='gate_up',
+ block_size=self.block_size,
+ dtype=fp8_dtype,
+ device=device,
+ expert_list=expert_list,
+ ep=enable_ep)
+ self.down = LinearWeightsBlockedF8(
+ num_experts,
+ ffn_dim,
+ hidden_dim,
+ weight_type='down',
+ block_size=self.block_size,
+ dtype=fp8_dtype,
+ device=device,
+ expert_list=expert_list,
+ ep=enable_ep,
+ )
+
+ self.hidden_dim = hidden_dim
+ self.ffn_dim = ffn_dim
+ self.num_experts = num_experts
+ self.dtype = dtype
+ self.device = device
+ world_size, _ = get_world_rank()
+ if world_size == 1:
+ all_reduce = False
+ self.all_reduce = all_reduce
+
+ def update_weights(self):
+ """update weights."""
+ (gate_up_weights, down_weights, gate_up_scale,
+ down_scale) = self.impl.update_weights(self.gate_up.weight,
+ self.down.weight,
+ self.gate_up.scale,
+ self.down.scale)
+ self.gate_up.update_weight(gate_up_weights, gate_up_scale)
+ self.down.update_weight(down_weights, down_scale)
def forward(self, hidden_states: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.LongTensor):
ret = self.impl.forward(hidden_states, topk_weights, topk_ids,
- self.gate_up_weights, self.down_weights,
+ self.gate_up.weight, self.gate_up.scale,
+ self.down.weight, self.down.scale,
self.expert_list)
if self.all_reduce:
dist.all_reduce(ret)
return ret
+
+
+def build_fused_moe(
+ hidden_dim: int,
+ ffn_dim: int,
+ num_experts: int,
+ top_k: int,
+ renormalize: bool = False,
+ dtype: Optional[torch.dtype] = None,
+ device: Optional[torch.device] = None,
+ all_reduce: bool = True,
+ enable_ep: bool = False,
+ quant_config: Any = None,
+):
+ """fused moe builder."""
+
+ if quant_config is None:
+ return FusedMoE(
+ hidden_dim=hidden_dim,
+ ffn_dim=ffn_dim,
+ num_experts=num_experts,
+ top_k=top_k,
+ renormalize=renormalize,
+ dtype=dtype,
+ device=device,
+ all_reduce=all_reduce,
+ enable_ep=enable_ep,
+ )
+
+ quant_method = quant_config['quant_method']
+ if quant_method == 'smooth_quant':
+ quant_dtype = eval('torch.' + quant_config.get('quant_dtype', 'int8'))
+ return FusedMoEW8A8(
+ hidden_dim=hidden_dim,
+ ffn_dim=ffn_dim,
+ num_experts=num_experts,
+ top_k=top_k,
+ renormalize=renormalize,
+ dtype=dtype,
+ quant_dtype=quant_dtype,
+ device=device,
+ all_reduce=all_reduce,
+ enable_ep=enable_ep,
+ )
+ elif quant_method == 'fp8':
+ fmt = quant_config.get('fmt', 'e4m3')
+ if fmt == 'e4m3':
+ fp8_dtype = torch.float8_e4m3fn
+ elif fmt == 'e5m2':
+ fp8_dtype = torch.float8_e5m2
+ else:
+ raise TypeError(f'Unsupported fp8 fmt: {fmt}')
+ return FusedMoEBlockedF8(
+ hidden_dim=hidden_dim,
+ ffn_dim=ffn_dim,
+ num_experts=num_experts,
+ top_k=top_k,
+ renormalize=renormalize,
+ fp8_dtype=fp8_dtype,
+ dtype=dtype,
+ device=device,
+ all_reduce=all_reduce,
+ enable_ep=enable_ep,
+ )
+ else:
+ raise RuntimeError(f'Unsupported quant method: {quant_method}')
diff --git a/lmdeploy/pytorch/nn/norm.py b/lmdeploy/pytorch/nn/norm.py
index ef244ff73f..7e2c820399 100644
--- a/lmdeploy/pytorch/nn/norm.py
+++ b/lmdeploy/pytorch/nn/norm.py
@@ -4,19 +4,23 @@
import torch
from torch import nn
+from lmdeploy.pytorch.distributed import get_world_rank
+
from ..backends import OpType, get_backend
+from .utils import chunk_aligned, get_distribute_size
def _is_w8a8(quant_config: Any):
"""is w8a8."""
- if quant_config is None:
- return False
- else:
+ quant_dtype = None
+ w8a8_flag = False
+ if quant_config is not None:
quant_method = quant_config['quant_method']
- if quant_method == 'w8a8':
- return True
- else:
- return False
+ if quant_method == 'smooth_quant':
+ w8a8_flag = True
+ quant_dtype = quant_config.get('quant_dtype', 'int8')
+ quant_dtype = eval(f'torch.{quant_dtype}')
+ return w8a8_flag, quant_dtype
class RMSNorm(nn.Module):
@@ -27,16 +31,44 @@ def __init__(self,
eps: float = 1e-6,
dtype: torch.dtype = None,
device: torch.device = None,
- quant_config: Any = None):
+ quant_config: Any = None,
+ tp: bool = False,
+ align: int = 1):
super().__init__()
backend = get_backend()
- if _is_w8a8(quant_config):
+
+ w8a8_flag, quant_dtype = _is_w8a8(quant_config)
+ if w8a8_flag:
builder = backend.get_layer_impl_builder(OpType.RMSNormW8A8)
else:
builder = backend.get_layer_impl_builder(OpType.RMSNorm)
+
+ if tp:
+ world_size, rank = get_world_rank()
+ hidden_size = get_distribute_size(hidden_size,
+ world_size,
+ rank,
+ align=align)
+
self.register_parameter('weight',
self.create_weight(hidden_size, dtype, device))
- self.impl = builder.build(hidden_size, eps)
+ if w8a8_flag:
+ self.impl = builder.build(hidden_size,
+ eps,
+ quant_dtype=quant_dtype)
+ else:
+ self.impl = builder.build(hidden_size, eps)
+
+ if tp:
+ self.weight.weight_loader = self.weight_loader
+ self.align = align
+
+ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor):
+ """weight loader."""
+ world_size, rank = get_world_rank()
+ loaded_weight = chunk_aligned(loaded_weight, world_size, 0,
+ self.align)[rank]
+ param.copy_(loaded_weight)
@staticmethod
def create_weight(hidden_size: int,
diff --git a/lmdeploy/pytorch/nn/utils.py b/lmdeploy/pytorch/nn/utils.py
index 3b60ca21de..085b12c3e9 100644
--- a/lmdeploy/pytorch/nn/utils.py
+++ b/lmdeploy/pytorch/nn/utils.py
@@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import torch
+
+
def div_up(a: int, b: int):
"""div up."""
return (a + b - 1) // b
@@ -18,3 +21,19 @@ def get_distribute_size(feature_size: int,
if rank < aligned_size % world_size:
updated_aligned_size += 1
return updated_aligned_size * align
+
+
+def chunk_aligned(weight: torch.Tensor, chunks: int, dim: int, align: int):
+ """chunk aligned."""
+ if align == 1:
+ return weight.chunk(chunks, dim=dim)
+ size = weight.size(dim)
+ assert size % align == 0
+ aligned_size = size // align
+
+ # try best to evenly split chunks
+ align_per_chunk = aligned_size // chunks
+ remain = aligned_size % chunks
+ sections = [align_per_chunk + int(c < remain) for c in range(chunks)]
+ sections = [sec * align for sec in sections]
+ return weight.split(sections, dim=dim)
diff --git a/lmdeploy/pytorch/paging/scheduler.py b/lmdeploy/pytorch/paging/scheduler.py
index 8879863092..722329a906 100644
--- a/lmdeploy/pytorch/paging/scheduler.py
+++ b/lmdeploy/pytorch/paging/scheduler.py
@@ -273,11 +273,19 @@ def has_unfinished(self):
return self.has_running() or self.has_waiting()
def has_running(self):
- return self.seq_manager.num_sequences(MessageStatus.RUNNING) > 0
+ return self.num_running() > 0
def has_waiting(self):
- return self.seq_manager.num_sequences(MessageStatus.WAITING) > 0
+ return self.num_waiting() > 0
def get_block_tables(self, seqs: SeqList):
"""get block table of the sequences."""
return [self.block_manager.get_block_table(seq) for seq in seqs]
+
+ def num_running(self):
+ """num running."""
+ return self.seq_manager.num_sequences(MessageStatus.RUNNING)
+
+ def num_waiting(self):
+ """num waiting."""
+ return self.seq_manager.num_sequences(MessageStatus.WAITING)
diff --git a/lmdeploy/serve/async_engine.py b/lmdeploy/serve/async_engine.py
index dfcf01a69d..797397e660 100644
--- a/lmdeploy/serve/async_engine.py
+++ b/lmdeploy/serve/async_engine.py
@@ -1,22 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
+import atexit
+import concurrent.futures
import dataclasses
import json
import os
import random
import re
-from contextlib import asynccontextmanager
+from contextlib import asynccontextmanager, closing
from copy import deepcopy
+from functools import partial
from itertools import count
-from queue import Empty, Queue
+from queue import Queue
from threading import Thread
-from typing import Any, Dict, List, Literal, Optional, Tuple, Union
+from typing import (Any, AsyncIterator, Dict, Iterator, List, Literal,
+ Optional, Tuple, Union)
+
+import tqdm
from lmdeploy.logger import RequestLogger
from lmdeploy.messages import (GenerationConfig, PytorchEngineConfig, Response,
ResponseType, TurbomindEngineConfig)
from lmdeploy.model import MODELS, ChatTemplateConfig, best_match_model
-from lmdeploy.serve.utils import LogitsMixin, _get_event_loop
+from lmdeploy.serve.utils import LogitsMixin
from lmdeploy.tokenizer import DetokenizeState
from lmdeploy.utils import _get_and_verify_max_len, _stop_words, get_logger
@@ -50,6 +56,37 @@ class GenOut:
finish_reason: Optional[Literal['stop', 'length', 'error']] = None
token_ids: List[int] = None
logprobs: List[Dict[int, float]] = None
+ logits: Any = None
+ last_hidden_state: Any = None
+
+
+def _gen_out_to_response(out: GenOut, index) -> Response:
+ return Response(text=out.response,
+ generate_token_len=out.generate_token_len,
+ input_token_len=out.input_token_len,
+ finish_reason=out.finish_reason,
+ token_ids=out.token_ids,
+ logprobs=out.logprobs,
+ last_hidden_state=out.last_hidden_state,
+ logits=out.logits,
+ index=index)
+
+
+def _append_response(dst: Response, src: Response):
+ """dst += src."""
+ if not dst:
+ return src
+ dst.text += src.text
+ dst.generate_token_len = src.generate_token_len
+ dst.input_token_len = src.input_token_len
+ dst.finish_reason = src.finish_reason
+ dst.index = src.index
+ if src.token_ids:
+ dst.token_ids += src.token_ids
+ if src.logprobs:
+ dst.logprobs = dst.logprobs or []
+ dst.logprobs += src.logprobs
+ return dst
class Session:
@@ -63,14 +100,17 @@ class Session:
_engine (Any): engine for internal use.
history (List[Any, str]): chat history.
"""
- _ids = count(0)
- def __init__(self):
- self._id: int = next(self._ids)
+ def __init__(self,
+ session_id: int,
+ engine: Any,
+ gen_config: GenerationConfig = None):
+ self._id: int = session_id
+ self._engine = engine
self._step: int = 0
self._prompt: Any = None
self._response: Response = None
- self._engine: Any = None
+ self._gen_config = gen_config
self.history: List[Tuple[Any, str]] = []
def _merge_response(self, resp: Response, step: Union[Response, GenOut]):
@@ -89,8 +129,8 @@ def response(self) -> Response:
def close(self):
"""release engine storage for this session."""
if self._engine:
- inst = self._engine.create_instance()
- inst.end(self._id)
+ self._engine._run(coro=self._engine.end_session(self._id)).result()
+ self._engine = None
def __repr__(self) -> str:
res = ''
@@ -100,6 +140,89 @@ def __repr__(self) -> str:
res += f'USER:\n{user}\nASSISTANT:\n{assistant}\n'
return res
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback):
+ self.close()
+
+ def __call__(
+ self,
+ prompt: str,
+ gen_config: Optional[GenerationConfig] = None,
+ stream_response: bool = True,
+ do_preprocess: bool = True) -> Union[Response, Iterator[Response]]:
+ self._engine.chat(prompt=prompt,
+ gen_config=gen_config or self._gen_config,
+ stream_response=stream_response,
+ do_preprocess=do_preprocess,
+ session=self)
+ if stream_response:
+ return self.generator
+ else:
+ return self.response
+
+
+class _EventLoopThread:
+
+ def __init__(self, daemon=False):
+ fut = concurrent.futures.Future()
+ self.thread = Thread(target=partial(self._thread_entry, fut),
+ daemon=daemon)
+ self.thread.start()
+ self.loop: asyncio.AbstractEventLoop = fut.result()
+ self.closed = False
+ if daemon:
+ atexit.register(self.close)
+
+ def _thread_entry(self, fut):
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+ fut.set_result(loop)
+ try:
+ loop.run_forever()
+ except BaseException as e:
+ logger.error(f'[internal_thread] {type(e).__name__} {e}')
+ finally:
+ try:
+ self._cancel_all_tasks()
+ loop.run_until_complete(loop.shutdown_asyncgens())
+ finally:
+ asyncio.set_event_loop(None)
+ loop.close()
+
+ def _cancel_all_tasks(self):
+ """Modified from asyncio/runners.py."""
+ to_cancel = asyncio.all_tasks(self.loop)
+ if not to_cancel:
+ return
+
+ for task in to_cancel:
+ task.cancel()
+
+ async def _gather():
+ await asyncio.gather(*to_cancel, return_exceptions=True)
+
+ self.loop.run_until_complete(_gather())
+
+ for task in to_cancel:
+ if task.cancelled():
+ continue
+ if task.exception() is not None:
+ self.loop.call_exception_handler({
+ 'message':
+ 'unhandled exception during worker thread shutdown',
+ 'exception': task.exception(),
+ 'task': task,
+ })
+
+ def close(self):
+ if self.closed:
+ return
+ self.closed = True
+ self.loop.call_soon_threadsafe(self.loop.stop)
+ self.thread.join()
+
class AsyncEngine(LogitsMixin):
"""Async inference engine. Maintaining a bunch of tm_model instances.
@@ -179,13 +302,26 @@ def __init__(self,
self.instance_num = self.backend_config.max_batch_size
self.tokenizer = self.engine.tokenizer
self.id2step = {}
- self.id2generator = {}
- self.running_session_ids = set()
- self.gens_set = set()
- for i in range(self.instance_num):
- self.gens_set.add(self.engine.create_instance())
+ self.id2inst = {}
+ self.free_insts: asyncio.Queue = None
+ self.instances = [
+ self.engine.create_instance() for _ in range(self.instance_num)
+ ]
self._session_id = count(0)
self.request_logger = RequestLogger(max_log_len)
+ self.internal_thread = _EventLoopThread(daemon=True)
+ self.limiter: asyncio.Semaphore = None
+
+ def close(self):
+ self.internal_thread.close()
+
+ def _get_free_insts(self):
+ if self.free_insts is None:
+ # `asyncio.Queue` must be created in an async context
+ self.free_insts = asyncio.Queue()
+ for inst in self.instances:
+ self.free_insts.put_nowait(inst)
+ return self.free_insts
def _build_turbomind(
self,
@@ -246,45 +382,117 @@ def __call__(self,
async def stop_session(self, session_id: int):
"""Stop a session by a session_id."""
- if str(session_id) in self.id2generator:
- await self.id2generator[str(session_id)].async_cancel(session_id)
- self.gens_set.add(self.id2generator[str(session_id)])
-
- self.running_session_ids.discard(session_id)
+ generator = self.id2inst.get(session_id)
+ if generator:
+ await generator.async_cancel(session_id)
+ # else it's not running at all
async def end_session(self, session_id: int):
- """Clear a session by a session_id."""
- if str(session_id) in self.id2generator:
- await self.id2generator[str(session_id)].async_end(session_id)
- self.id2step[str(session_id)] = 0
- self.gens_set.add(self.id2generator[str(session_id)])
-
- self.running_session_ids.discard(session_id)
-
- @asynccontextmanager
- async def safe_run(self, session_id: Optional[int] = None):
- """A context manager to make sure server's safe running."""
+ """For ending a session that is not running."""
+ inst = self.id2inst.get(session_id)
+ if inst:
+ await inst._active.wait()
+ assert session_id not in self.id2inst
+ inst = await self._get_free_insts().get()
try:
- yield
+ await inst.async_end(session_id)
+ self.id2step[session_id] = 0
except (Exception, asyncio.CancelledError, GeneratorExit) as e: # noqa
- # TODO: find out why await would block the coroutine here
- _get_event_loop().create_task(self.stop_session(session_id))
- raise e
- if str(session_id) in self.id2generator:
- self.gens_set.add(self.id2generator[str(session_id)])
- self.running_session_ids.discard(session_id)
-
- async def get_generator(self, stop: bool, session_id: int):
- """Only return the model instance if it is available."""
- if stop:
- return self.engine.create_instance()
- # waiting no generator is available or the same session_id is running
- while self.gens_set == set() or session_id in self.running_session_ids:
- await asyncio.sleep(0.1)
- generator = self.gens_set.pop()
- self.id2generator[str(session_id)] = generator
- self.running_session_ids.add(session_id)
- return generator
+ logger.error(f'[end_session] exception caught: {e}')
+ finally:
+ self._get_free_insts().put_nowait(inst)
+
+ def _get_limiter(self):
+ if not self.limiter:
+ self.limiter = asyncio.Semaphore(self.instance_num)
+ return self.limiter
+
+ async def _async_infer(self, requests: AsyncIterator[Dict],
+ **kwargs) -> AsyncIterator[AsyncIterator[Response]]:
+ async for req in requests:
+ gen = self.generate(**req, **kwargs)
+ yield gen
+
+ def _infer(self,
+ requests: Iterator[Dict],
+ multiplex: bool,
+ pbar=None,
+ loop=None) -> Iterator[Iterator[Response]]:
+
+ async def _sync_resp(g, que: Queue, idx: int, sem: asyncio.Semaphore):
+ async for out in g:
+ que.put(_gen_out_to_response(out, idx))
+ sem.release()
+ if not multiplex:
+ que.put(None) # sentinel of inner generator
+ if pbar:
+ pbar.update(1)
+
+ que = Queue()
+
+ async def _infer():
+ sem = self._get_limiter()
+ tasks = []
+ for idx, req in enumerate(requests):
+ await sem.acquire()
+ gen = self.generate(**req)
+ dst = que if multiplex else Queue()
+ if not multiplex:
+ que.put(iter(dst.get, None))
+ # create a task to send the responses
+ task = asyncio.create_task(_sync_resp(gen, dst, idx, sem))
+ tasks.append(task)
+ if not multiplex: # sentinel of outer generator
+ que.put(None)
+ await asyncio.gather(*tasks)
+ if multiplex:
+ que.put(None) # sentinel of inner generator
+
+ loop = loop or self.internal_thread.loop
+ # submit the coroutine to async world
+ asyncio.run_coroutine_threadsafe(
+ _infer(), loop).add_done_callback(lambda x: x.result())
+
+ return iter(que.get, None)
+
+ @staticmethod
+ def _is_single(prompts):
+ return isinstance(prompts, str) or isinstance(prompts[0], Dict)
+
+ def infer(self,
+ prompts: Union[List[str], str, List[Dict], List[List[Dict]]],
+ gen_config: Optional[Union[GenerationConfig,
+ List[GenerationConfig]]] = None,
+ do_preprocess: bool = True,
+ adapter_name: Optional[str] = None,
+ stream_response: bool = False,
+ multiplex: bool = False,
+ pbar: Optional[tqdm.tqdm] = None,
+ **kwargs):
+
+ prompts = [prompts] if AsyncEngine._is_single(prompts) else prompts
+ assert isinstance(prompts, List), 'prompts should be a list'
+ gen_config = gen_config or GenerationConfig()
+ if not isinstance(gen_config, List):
+ gen_config = [gen_config] * len(prompts)
+ assert len(prompts) == len(gen_config), \
+ 'input gen_confg length differs from the length of prompts' # noqa
+
+ def requests():
+ for prompt, gen_cfg in zip(prompts, gen_config):
+ r = dict(messages=prompt,
+ gen_config=gen_cfg,
+ do_preprocess=do_preprocess,
+ adapter_name=adapter_name,
+ stream_response=stream_response,
+ **kwargs)
+ r.setdefault('sequence_start', True)
+ r.setdefault('sequence_end', True)
+ if 'session_id' not in r:
+ r['session_id'] = next(self._session_id)
+ yield r
+
+ return self._infer(requests(), multiplex, pbar)
def batch_infer(self,
prompts: Union[List[str], str, List[Dict],
@@ -310,59 +518,26 @@ def batch_infer(self,
Pick one from adapters. Default to None, using the base model.
use_tqdm (bool): Whether use the progress bar. Default to False
"""
- need_list_wrap = isinstance(prompts, str) or isinstance(
- prompts[0], Dict)
- prompts = [prompts] if need_list_wrap else prompts
- assert isinstance(prompts, List), 'prompts should be a list'
- if gen_config is None:
- gen_config = GenerationConfig()
- if not isinstance(gen_config, List):
- gen_config = [gen_config] * len(prompts)
- assert len(prompts) == len(gen_config), \
- 'input gen_confg length differs from the length of prompts' # noqa
- prompt_num = len(prompts)
- session_ids = [next(self._session_id) for _ in range(prompt_num)]
- outputs = [
- Response('', 0, 0, session_ids[i], index=i)
- for i in range(prompt_num)
- ]
- generators = []
- if use_tqdm:
- import tqdm
- pbar = tqdm.tqdm(total=len(prompts))
- for i, prompt in enumerate(prompts):
- generators.append(
- self.generate(prompt,
- session_ids[i],
- gen_config=gen_config[i],
- stream_response=True,
- sequence_start=True,
- sequence_end=True,
- do_preprocess=do_preprocess,
- adapter_name=adapter_name,
- **kwargs))
-
- async def _inner_call(i, generator):
- async for out in generator:
- outputs[i].text += out.response
- outputs[i].generate_token_len = out.generate_token_len
- outputs[i].input_token_len = out.input_token_len
- outputs[i].finish_reason = out.finish_reason
- if out.token_ids:
- outputs[i].token_ids.extend(out.token_ids)
- if out.logprobs:
- if outputs[i].logprobs is None:
- outputs[i].logprobs = []
- outputs[i].logprobs.extend(out.logprobs)
- if use_tqdm and out.finish_reason is not None:
- pbar.update(1)
-
- async def gather():
- await asyncio.gather(
- *[_inner_call(i, generators[i]) for i in range(len(prompts))])
-
- _get_event_loop().run_until_complete(gather())
- outputs = outputs[0] if need_list_wrap else outputs
+ is_single = AsyncEngine._is_single(prompts)
+ outputs = []
+ pbar = tqdm.tqdm(
+ total=1 if is_single else len(prompts)) if use_tqdm else None
+ try:
+ for g in self.infer(prompts,
+ gen_config,
+ do_preprocess,
+ adapter_name,
+ stream_response=False,
+ pbar=pbar,
+ **kwargs):
+ res = None
+ for out in g:
+ res = _append_response(res, out)
+ outputs.append(res)
+ finally:
+ if pbar: pbar.close() # noqa
+ if is_single:
+ return outputs[0]
return outputs
def stream_infer(
@@ -372,6 +547,7 @@ def stream_infer(
List[GenerationConfig]]] = None,
do_preprocess: bool = True,
adapter_name: Optional[str] = None,
+ stream_response: bool = True,
**kwargs):
"""Inference a batch of prompts with stream mode.
@@ -387,62 +563,13 @@ def stream_infer(
adapter_name (str): the adapter name of slora for pytorch backend.
Pick one from adapters. Default to None, using the base model.
"""
- need_list_wrap = isinstance(prompts, str) or isinstance(
- prompts[0], Dict)
- prompts = [prompts] if need_list_wrap else prompts
- assert isinstance(prompts, List), 'prompts should be a list'
- if gen_config is None:
- gen_config = GenerationConfig()
- if not isinstance(gen_config, List):
- gen_config = [gen_config] * len(prompts)
- assert len(prompts) == len(gen_config), \
- 'input gen_confg length differs from the length of prompts' # noqa
- session_ids = [next(self._session_id) for _ in range(len(prompts))]
- outputs = Queue()
- generators = []
- for i, prompt in enumerate(prompts):
- generators.append(
- self.generate(prompt,
- session_ids[i],
- gen_config=gen_config[i],
- stream_response=True,
- sequence_start=True,
- sequence_end=True,
- do_preprocess=do_preprocess,
- adapter_name=adapter_name,
- **kwargs))
-
- async def _inner_call(i, generator):
- async for out in generator:
- outputs.put(
- Response(out.response,
- out.generate_token_len,
- out.input_token_len,
- session_ids[i],
- out.finish_reason,
- out.token_ids,
- out.logprobs,
- index=i))
-
- async def gather():
- await asyncio.gather(
- *[_inner_call(i, generators[i]) for i in range(len(prompts))])
- outputs.put(None)
-
- loop = _get_event_loop()
- proc = Thread(target=lambda: loop.run_until_complete(gather()))
- proc.start()
-
- while True:
- try:
- out = outputs.get(timeout=0.001)
- if out is None:
- break
- yield out
- except Empty:
- pass
-
- proc.join()
+ return self.infer(prompts,
+ gen_config,
+ do_preprocess,
+ adapter_name,
+ stream_response,
+ multiplex=True,
+ **kwargs)
async def _get_prompt_input(self,
prompt: str,
@@ -466,6 +593,34 @@ async def _get_prompt_input(self,
input_ids = self.tokenizer.encode(prompt, add_bos=sequence_start)
return {'prompt': prompt, 'input_ids': input_ids}
+ @asynccontextmanager
+ async def model_inst(self, session_id: int):
+ """A context manager to make sure server's safe running."""
+ assert session_id not in self.id2inst
+ free_insts = self._get_free_insts()
+ inst = await free_insts.get()
+ inst._active = asyncio.Event()
+ self.id2inst[session_id] = inst
+ try:
+ yield inst
+ finally:
+ self.id2inst.pop(session_id)
+ inst._active.set()
+ free_insts.put_nowait(inst)
+
+ @asynccontextmanager
+ async def safe_run(self, inst, session_id, **kwargs):
+ generator = inst.async_stream_infer(session_id, **kwargs)
+ try:
+ yield generator
+ except (Exception, asyncio.CancelledError, GeneratorExit) as e: # noqa
+ logger.error(
+ f'[safe_run] exception caught: {type(e).__name__} {e}')
+ # TODO: remove session_id from async cancel
+ await inst.async_cancel(session_id)
+ finally:
+ await generator.aclose()
+
async def generate(
self,
messages,
@@ -478,6 +633,9 @@ async def generate(
step: int = 0,
do_preprocess: bool = True,
adapter_name: Optional[str] = None,
+ skip_stop_tokens: bool = True,
+ rewind_stop_tokens: bool = False,
+ input_ids: Optional[List] = None,
**kwargs):
"""Generate responses.
@@ -493,10 +651,13 @@ async def generate(
do_preprocess (bool): whether pre-process the messages. Default to
True, which means chat_template will be applied.
"""
- if str(session_id) not in self.id2step:
- self.id2step[str(session_id)] = 0
+ if (messages is not None) ^ (input_ids is None):
+ raise ValueError(
+ 'You must specify exactly one of messages or input_ids')
+ if session_id not in self.id2step:
+ self.id2step[session_id] = 0
if step != 0:
- self.id2step[str(session_id)] = step
+ self.id2step[session_id] = step
if gen_config is None:
gen_config = GenerationConfig()
else:
@@ -523,119 +684,164 @@ async def generate(
logger.ERROR(f"n({gen_config.n}) > 1 hasn't been supported yet. "
f'Fallback to 1')
gen_config.n = 1
- prompt = messages
- self.request_logger.log_prompt(session_id=session_id, prompt=prompt)
- prompt_input = await self._get_prompt_input(prompt,
- do_preprocess,
- sequence_start,
- adapter_name,
- tools=tools)
- prompt = prompt_input['prompt']
- input_ids = prompt_input['input_ids']
- finish_reason = None
- self.request_logger.log_inputs(session_id=session_id,
- prompt=prompt,
- prompt_token_ids=input_ids,
- gen_config=gen_config,
- adapter_name=adapter_name)
- logger.info(f'session_id={session_id}, '
- f'history_tokens={self.id2step[str(session_id)]}, '
- f'input_tokens={len(input_ids)}, '
- f'max_new_tokens={gen_config.max_new_tokens}, '
- f'seq_start={sequence_start}, seq_end={sequence_end}, '
- f'step={step}, prep={do_preprocess}')
-
+ if messages:
+ prompt = messages
+ self.request_logger.log_prompt(session_id=session_id,
+ prompt=prompt)
+ prompt_input = await self._get_prompt_input(prompt,
+ do_preprocess,
+ sequence_start,
+ adapter_name,
+ tools=tools)
+ prompt = prompt_input['prompt']
+ input_ids = prompt_input['input_ids']
+ self.request_logger.log_inputs(session_id=session_id,
+ prompt=prompt,
+ prompt_token_ids=input_ids,
+ gen_config=gen_config,
+ adapter_name=adapter_name)
+ logger.info(f'session_id={session_id}, '
+ f'history_tokens={self.id2step[session_id]}, '
+ f'input_tokens={len(input_ids)}, '
+ f'max_new_tokens={gen_config.max_new_tokens}, '
+ f'seq_start={sequence_start}, seq_end={sequence_end}, '
+ f'step={step}, prep={do_preprocess}')
+ else:
+ # TODO(lvhan) VLM doesn't support input_ids as an argument.
+ # Figure out a graceful way to handle the invalid input
+ prompt_input = dict(input_ids=input_ids)
if gen_config.max_new_tokens is None:
# for interactive endpoint, will try maximum possible token num
gen_config.max_new_tokens = max(
- 128, self.session_len - self.id2step[str(session_id)] -
- len(input_ids))
- elif self.id2step[str(session_id)] + len(
+ 128,
+ self.session_len - self.id2step[session_id] - len(input_ids))
+ elif self.id2step[session_id] + len(
input_ids) + gen_config.max_new_tokens > self.session_len:
gen_config.max_new_tokens = max(
- self.session_len - self.id2step[str(session_id)] -
- len(input_ids), 128)
+ self.session_len - self.id2step[session_id] - len(input_ids),
+ 128)
logger.error(
f'Truncate max_new_tokens to {gen_config.max_new_tokens}')
- if self.id2step[str(session_id)] + len(
+ if self.id2step[session_id] + len(
input_ids) + gen_config.max_new_tokens > self.session_len:
logger.error(f'run out of tokens. session_id={session_id}.')
- yield GenOut('', self.id2step[str(session_id)], len(input_ids), 0,
+ yield GenOut('', self.id2step[session_id], len(input_ids), 0,
'length')
if sequence_end is True and sequence_start is False:
await self.end_session(session_id)
- else:
-
- def is_error(status):
- return status not in [
- ResponseType.SUCCESS, ResponseType.FINISH
- ]
-
- generator = await self.get_generator(False, session_id)
- async with self.safe_run(session_id):
- state = DetokenizeState(len(input_ids))
- start_ids_offset = state.ids_offset
- response = ''
- async for outputs in generator.async_stream_infer(
- session_id=session_id,
- **prompt_input,
- gen_config=gen_config,
- adapter_name=adapter_name,
- stream_output=stream_response,
- sequence_start=sequence_start,
- sequence_end=sequence_end,
- step=self.id2step[str(session_id)]):
+ return
+
+ def is_error(status):
+ return status not in [ResponseType.SUCCESS, ResponseType.FINISH]
+
+ # used to skip / rewind stop words in interactive mode
+ stop_ids = []
+ if skip_stop_tokens and not gen_config.ignore_eos:
+ stop_ids = gen_config.stop_token_ids or []
+ if self.tokenizer.eos_token_id not in stop_ids:
+ stop_ids.append(self.tokenizer.eos_token_id)
+
+ async with self.model_inst(session_id) as inst:
+ token_ids = input_ids.copy()
+ history_len = self.id2step[session_id]
+ input_len = len(input_ids)
+ output_len, gen_len = 0, 0
+ state = DetokenizeState(len(input_ids))
+ start_ids_offset = state.ids_offset
+ response = ''
+ finish_reason = None
+ async with self.safe_run(inst,
+ session_id=session_id,
+ **prompt_input,
+ gen_config=gen_config,
+ adapter_name=adapter_name,
+ stream_output=stream_response,
+ sequence_start=sequence_start,
+ sequence_end=sequence_end,
+ step=history_len) as gen:
+ prev_len = 0
+ hit_stop_token = 0
+ async for outputs in gen:
# decode res
if is_error(outputs.status):
- tokens = 0
break
- res, tokens = input_ids + outputs.token_ids, outputs.num_token # noqa
- if len(res) <= state.ids_offset:
+
+ output_len = outputs.num_token
+
+ if hit_stop_token or prev_len == output_len:
continue
+ # This assumes the engine will stop when stop token is hit
+ if output_len and outputs.token_ids[-1] in stop_ids:
+ hit_stop_token = 1
+ # one token and it's been skipped
+ if output_len == prev_len + 1:
+ continue
+
+ mask = slice(prev_len - output_len,
+ output_len - hit_stop_token)
+
+ token_ids += outputs.token_ids[mask]
+ gen_len = len(token_ids) - input_len
+
+ prev_len = output_len
+
ids_offset = state.ids_offset
response, state = self.tokenizer.detokenize_incrementally(
- res,
+ token_ids,
state,
- skip_special_tokens=gen_config.skip_special_tokens)
+ skip_special_tokens=gen_config.skip_special_tokens,
+ spaces_between_special_tokens=gen_config.
+ spaces_between_special_tokens)
+ res = token_ids[ids_offset:]
+
+ out = GenOut(response, history_len, input_len, gen_len,
+ finish_reason, res)
- res = res[ids_offset:]
- logprobs = None
- if outputs.logprobs:
+ if outputs.logprobs is not None:
log_offset = ids_offset - start_ids_offset
- logprobs = outputs.logprobs[log_offset:]
+ out.logprobs = outputs.logprobs[log_offset:]
+ if outputs.last_hidden_state is not None:
+ out.last_hidden_state = outputs.last_hidden_state
+ if hit_stop_token:
+ out.last_hidden_state = \
+ out.last_hidden_state[:-hit_stop_token]
+ if outputs.logits is not None:
+ out.logits = outputs.logits
+ if hit_stop_token:
+ out.logits = out.logits[:-hit_stop_token]
+
+ yield out
+ # end of generator loop
- # response, history token len,
- # input token len, gen token len
- yield GenOut(response, self.id2step[str(session_id)],
- len(input_ids), tokens, finish_reason, res,
- logprobs)
if not is_error(outputs.status):
finish_reason = 'length' \
- if tokens >= gen_config.max_new_tokens else 'stop'
+ if gen_len >= gen_config.max_new_tokens else 'stop'
# utf-8 char at the end means it's a potential unfinished
# byte sequence
if not response.endswith('�'):
- # avaid returning the last response twice
+ # avoid returning the last response twice
response = ''
- yield GenOut(response, self.id2step[str(session_id)],
- len(input_ids), tokens, finish_reason)
+ yield GenOut(response, self.id2step[session_id],
+ len(input_ids), gen_len, finish_reason)
else:
- yield GenOut(
- response='internal error happened',
- history_token_len=self.id2step[str(session_id)],
- input_token_len=len(input_ids),
- generate_token_len=0,
- finish_reason='error',
- token_ids=[])
- # update step
- self.id2step[str(session_id)] += len(input_ids) + tokens
- if sequence_end:
- self.id2step[str(session_id)] = 0
- # manually end pytorch session
- # TODO modify pytorch or turbomind api
- if self.backend == 'pytorch' and sequence_end:
- await self.end_session(session_id)
+ yield GenOut(response='internal error happened',
+ history_token_len=self.id2step[session_id],
+ input_token_len=len(input_ids),
+ generate_token_len=0,
+ finish_reason='error',
+ token_ids=[])
+ # update step
+ if sequence_end:
+ self.id2step[session_id] = 0
+ if self.backend == 'pytorch':
+ # manually end pytorch session
+ await inst.async_end(session_id)
+ else:
+ if rewind_stop_tokens:
+ # rewind the step to the token before the stop token
+ output_len = gen_len
+ self.id2step[session_id] += input_len + output_len
def parse_tool_response(self, text, tools, **kwargs):
"""Parse model response containing tool information.
@@ -684,12 +890,28 @@ def parse_tool_response(self, text, tools, **kwargs):
for call_info in call_info_list]
return text, call_info_list
+ def _run(self, fn=None, coro=None, loop=None):
+ assert (fn or coro) and not (fn and coro)
+ loop = loop or self.internal_thread.loop
+ if fn:
+
+ async def _coro():
+ return fn()
+
+ coro = _coro()
+ return asyncio.run_coroutine_threadsafe(coro, loop)
+
+ def session(self, gen_config: GenerationConfig = None):
+ return Session(self._run(fn=lambda: next(self._session_id)).result(),
+ engine=self,
+ gen_config=gen_config)
+
def chat(self,
prompt: str,
session=None,
gen_config: Optional[GenerationConfig] = None,
- do_preprocess: bool = True,
- **kwargs) -> Session:
+ stream_response=False,
+ **kwargs) -> Union[Session, Iterator]:
"""Chat.
Args:
@@ -702,8 +924,7 @@ def chat(self,
**kwargs (dict): ad hoc parametrization of `gen_config
"""
if session is None:
- session = Session()
- session._engine = self.engine
+ session = self.session()
# sync & init
session._prompt = prompt
@@ -711,25 +932,35 @@ def chat(self,
sequence_start = session._step == 0
- async def _work():
- resp = Response('', -1, -1, session._id)
- async for output in self.generate(prompt,
- session_id=session._id,
- gen_config=gen_config,
- stream_response=False,
- sequence_start=sequence_start,
- sequence_end=False,
- step=session._step,
- do_preprocess=do_preprocess,
- **kwargs):
- resp = session._merge_response(resp, output)
- return resp
-
- from lmdeploy.pytorch.engine.request import _run_until_complete
- resp = _run_until_complete(_work())
-
- session._response = resp
- session._step += resp.generate_token_len + resp.input_token_len
- session.history.append((session._prompt, resp.text))
+ generator = self.infer(prompt,
+ gen_config,
+ sequence_start=sequence_start,
+ sequence_end=False,
+ session_id=session._id,
+ stream_response=stream_response,
+ multiplex=True)
+
+ def _gen():
+ resp = None
+ try:
+ for out in generator:
+ resp = _append_response(resp, out)
+ yield out
+ except: # noqa
+ self._run(coro=self.stop_session(session._id)).result()
+ raise
+ else:
+ session._response = resp
+ session._step += resp.generate_token_len + resp.input_token_len
+ session.history.append((session._prompt, resp.text))
+
+ if stream_response:
+ session.generator = _gen()
+ else:
+ # run the generator until finish
+ with closing(_gen()) as gen:
+ for _ in gen:
+ pass
+ session.generator = None
return session
diff --git a/lmdeploy/serve/gradio/vl.py b/lmdeploy/serve/gradio/vl.py
index 103bcc5889..26f23613af 100644
--- a/lmdeploy/serve/gradio/vl.py
+++ b/lmdeploy/serve/gradio/vl.py
@@ -70,8 +70,6 @@ def run_local(model_path: str,
**kwargs):
from lmdeploy.serve.vl_async_engine import VLAsyncEngine
- if isinstance(backend_config, PytorchEngineConfig):
- backend_config.thread_safe = True
vision_config = VisionConfig(thread_safe=True)
engine = VLAsyncEngine(model_path=model_path,
model_name=model_name,
@@ -115,10 +113,13 @@ def chat(chatbot, session, max_new_tokens, top_p, top_k, temperature):
else:
prompt = history[-1][0][0]
images = history[-1][0][1:]
- prompt = (prompt, images)
-
- logger.info('prompt: ' + str(prompt))
- prompt = engine.vl_prompt_template.prompt_to_messages(prompt)
+ # convert prompt into GPT4V format
+ prompt = [
+ dict(role='user', content=[dict(type='text', text=prompt)])
+ ]
+ for image in images:
+ prompt[0]['content'].append(
+ dict(type='image_data', image_data=dict(data=image)))
t0 = time.perf_counter()
inputs = _run_until_complete(
engine._get_prompt_input(prompt, True, sequence_start, ''))
@@ -150,7 +151,9 @@ def chat(chatbot, session, max_new_tokens, top_p, top_k, temperature):
response, state = engine.tokenizer.detokenize_incrementally(
res,
state,
- skip_special_tokens=gen_config.skip_special_tokens)
+ skip_special_tokens=gen_config.skip_special_tokens,
+ spaces_between_special_tokens=gen_config.
+ spaces_between_special_tokens) # noqa
if chatbot[-1][1] is None:
chatbot[-1][1] = ''
history[-1][1] = ''
diff --git a/lmdeploy/serve/openai/api_server.py b/lmdeploy/serve/openai/api_server.py
index cce9567896..a284250f21 100644
--- a/lmdeploy/serve/openai/api_server.py
+++ b/lmdeploy/serve/openai/api_server.py
@@ -149,7 +149,8 @@ def _create_completion_logprobs(tokenizer: Tokenizer,
skip_special_tokens: bool = True,
offset: int = 0,
all_token_ids: List[int] = None,
- state: DetokenizeState = None):
+ state: DetokenizeState = None,
+ spaces_between_special_tokens: bool = True):
"""create openai LogProbs for completion.
Args:
@@ -162,6 +163,9 @@ def _create_completion_logprobs(tokenizer: Tokenizer,
offset (int): text offset.
all_token_ids (int): the history output token ids.
state (DetokenizeState): tokenizer decode state.
+ spaces_between_special_tokens (bool): Whether or not to add spaces
+ around special tokens. The behavior of Fast tokenizers is to have
+ this to False. This is setup to True in slow tokenizers.
"""
if logprobs is None or len(logprobs) == 0:
return None, None, None, None
@@ -183,7 +187,8 @@ def _create_completion_logprobs(tokenizer: Tokenizer,
response, _state = tokenizer.detokenize_incrementally(
all_token_ids + [top_id],
copy.deepcopy(state),
- skip_special_tokens=skip_special_tokens)
+ skip_special_tokens=skip_special_tokens,
+ spaces_between_special_tokens=spaces_between_special_tokens)
res[response] = prob
if top_id == token_id:
out_state = _state
@@ -323,6 +328,9 @@ async def chat_completions_v1(request: ChatCompletionRequest,
- ignore_eos (bool): indicator for ignoring eos
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
+ - spaces_between_special_tokens (bool): Whether or not to add spaces
+ around special tokens. The behavior of Fast tokenizers is to have
+ this to False. This is setup to True in slow tokenizers.
- min_new_tokens (int): To generate at least numbers of tokens.
- min_p (float): Minimum token probability, which will be scaled by the
probability of the most likely token. It must be a value between
@@ -340,8 +348,7 @@ async def chat_completions_v1(request: ChatCompletionRequest,
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret
- if VariableInterface.async_engine.id2step.get(str(request.session_id),
- 0) != 0:
+ if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0:
return create_error_response(
HTTPStatus.BAD_REQUEST,
f'The session_id `{request.session_id}` is occupied.')
@@ -394,7 +401,8 @@ async def chat_completions_v1(request: ChatCompletionRequest,
logits_processors=logits_processors,
min_new_tokens=request.min_new_tokens,
min_p=request.min_p,
- random_seed=random_seed)
+ random_seed=random_seed,
+ spaces_between_special_tokens=request.spaces_between_special_tokens)
tools = None
if request.tools and request.tool_choice != 'none':
@@ -509,7 +517,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
for call_info in call_info_list
]
except Exception as e:
- logger.error(f'Exception: {e}')
+ logger.error(f'Failed to parse {text}. Exception: {e}.')
return create_error_response(
HTTPStatus.BAD_REQUEST,
'Failed to parse fc related info to json format!')
@@ -582,6 +590,9 @@ async def completions_v1(request: CompletionRequest,
- ignore_eos (bool): indicator for ignoring eos
- skip_special_tokens (bool): Whether or not to remove special tokens
in the decoding. Default to be True.
+ - spaces_between_special_tokens (bool): Whether or not to add spaces
+ around special tokens. The behavior of Fast tokenizers is to have
+ this to False. This is setup to True in slow tokenizers.
- top_k (int): The number of the highest probability vocabulary
tokens to keep for top-k-filtering
@@ -596,8 +607,7 @@ async def completions_v1(request: CompletionRequest,
error_check_ret = await check_request(request)
if error_check_ret is not None:
return error_check_ret
- if VariableInterface.async_engine.id2step.get(str(request.session_id),
- 0) != 0:
+ if VariableInterface.async_engine.id2step.get(request.session_id, 0) != 0:
return create_error_response(
HTTPStatus.BAD_REQUEST,
f'The session_id `{request.session_id}` is occupied.')
@@ -625,7 +635,8 @@ async def completions_v1(request: CompletionRequest,
ignore_eos=request.ignore_eos,
stop_words=request.stop,
skip_special_tokens=request.skip_special_tokens,
- random_seed=random_seed)
+ random_seed=random_seed,
+ spaces_between_special_tokens=request.spaces_between_special_tokens)
generators = []
for i in range(len(request.prompt)):
result_generator = VariableInterface.async_engine.generate(
@@ -674,7 +685,7 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]:
VariableInterface.async_engine.tokenizer,
res.token_ids, res.logprobs,
gen_config.skip_special_tokens, offset, all_token_ids,
- state)
+ state, gen_config.spaces_between_special_tokens)
if request.stream_options and request.stream_options.include_usage: # noqa E501
final_res = res
total_tokens = sum([
@@ -726,8 +737,12 @@ async def _inner_call(i, generator):
logprobs = None
if request.logprobs and len(final_logprobs):
logprobs, _, _, _ = _create_completion_logprobs(
- VariableInterface.async_engine.tokenizer, final_token_ids,
- final_logprobs, gen_config.skip_special_tokens)
+ VariableInterface.async_engine.tokenizer,
+ final_token_ids,
+ final_logprobs,
+ gen_config.skip_special_tokens,
+ spaces_between_special_tokens=gen_config.
+ spaces_between_special_tokens)
assert final_res is not None
choice_data = CompletionResponseChoice(
@@ -865,11 +880,22 @@ async def chat_interactive_v1(request: GenerateRequest,
request.session_id = VariableInterface.session_id
async_engine = VariableInterface.async_engine
- sequence_start = async_engine.id2step.get(str(request.session_id), 0) == 0
+ sequence_start = async_engine.id2step.get(request.session_id, 0) == 0
sequence_end = not request.interactive_mode
if isinstance(request.stop, str):
request.stop = [request.stop]
+ end_session = sequence_end and not sequence_start \
+ and request.prompt == '' and request.request_output_len == 0
+ if end_session:
+ await async_engine.end_session(request.session_id)
+ return JSONResponse(
+ dict(text='',
+ tokens=0,
+ input_tokens=0,
+ history_tokens=0,
+ finish_reason=None))
+
random_seed = request.seed if request.seed else None
gen_config = GenerationConfig(
diff --git a/lmdeploy/serve/openai/protocol.py b/lmdeploy/serve/openai/protocol.py
index 2b9d39c7b7..a6f945ac13 100644
--- a/lmdeploy/serve/openai/protocol.py
+++ b/lmdeploy/serve/openai/protocol.py
@@ -135,6 +135,7 @@ class ChatCompletionRequest(BaseModel):
session_id: Optional[int] = -1
ignore_eos: Optional[bool] = False
skip_special_tokens: Optional[bool] = True
+ spaces_between_special_tokens: Optional[bool] = True
top_k: Optional[int] = 40
seed: Optional[int] = None
min_new_tokens: Optional[int] = Field(default=None, examples=[None])
@@ -251,6 +252,7 @@ class CompletionRequest(BaseModel):
session_id: Optional[int] = -1
ignore_eos: Optional[bool] = False
skip_special_tokens: Optional[bool] = True
+ spaces_between_special_tokens: Optional[bool] = True
top_k: Optional[int] = 40 # for opencompass
seed: Optional[int] = None
diff --git a/lmdeploy/serve/utils.py b/lmdeploy/serve/utils.py
index 3a16f0a65b..afcec3d4ab 100644
--- a/lmdeploy/serve/utils.py
+++ b/lmdeploy/serve/utils.py
@@ -4,8 +4,8 @@
import numpy as np
import torch
-from torch.nn.utils.rnn import pad_sequence
+from lmdeploy.messages import GenerationConfig
from lmdeploy.utils import get_logger
logger = get_logger('lmdeploy')
@@ -16,166 +16,44 @@
PromptType = Union[str, List[Dict]]
-def _get_event_loop():
- """get event loop."""
- try:
- event_loop = asyncio.get_event_loop()
- except Exception:
- logger.warning('Can not found event loop in current thread.'
- ' Create a new event loop.')
- event_loop = asyncio.new_event_loop()
- asyncio.set_event_loop(event_loop)
- return event_loop
-
-
class LogitsMixin:
- """Helper class to calculate logits and ppl."""
-
- def prepare_inputs(self, prompts: Union[PromptType, List[PromptType]]):
- if hasattr(self, '_convert_prompts'):
- prompts = self._convert_prompts(prompts)
- need_list_wrap = isinstance(prompts, str) or isinstance(
- prompts[0], Dict)
- prompts = [prompts] if need_list_wrap else prompts
-
- decorated = []
- input_ids = []
- input_embeddings = []
- input_embedding_ranges = []
- for prompt in prompts:
- out = _get_event_loop().run_until_complete(
- self._get_prompt_input(prompt,
- do_preprocess=True,
- sequence_start=True,
- adapter_name=None))
- decorated.append(out['prompt'])
- input_ids.append(out['input_ids'])
- input_embeddings.append(out.get('input_embeddings', None))
- input_embedding_ranges.append(
- out.get('input_embedding_ranges', None))
-
- outputs = dict(prompts=decorated, input_ids=input_ids)
- if not any(input_embeddings):
- input_embeddings = None
- input_embedding_ranges = None
- outputs['input_embeddings'] = input_embeddings
- outputs['input_embedding_ranges'] = input_embedding_ranges
-
- return outputs
-
- def get_logits(
- self,
- input_ids: Union[InputIdsType, List[InputIdsType]],
- input_embeddings: Union[InputEmbsType, List[InputEmbsType]] = None,
- input_embedding_ranges: Union[InputEmbRngsType,
- List[InputEmbRngsType]] = None):
- """Get logits given a list of input tokens.
-
- Args:
- input_ids (Union[List[int], List[List[int]]]): the batch of
- input token ids
- """
- assert len(input_ids) > 0
- if isinstance(input_ids[0], int):
- input_ids = [input_ids]
- for input_id in input_ids:
- assert len(input_id) > 0
-
- bs = len(input_ids)
- # TODO: a better way to determine `max_input_len`, at most allocate
- # 2G mem for logits with shape [bs, max_input_len, vocab_size]
- vocab_size = self.hf_tm_cfg.vocab_size
- max_input_len = 2 * 1024**3 // (bs * vocab_size * 4)
-
- n_max_iter = np.ceil(
- max([len(input_id)
- for input_id in input_ids]) / max_input_len).astype(int)
+ """Helper class to calculate ppl."""
+
+ async def _async_get_logits(
+ self,
+ input_ids,
+ steps: List[int] = None,
+ sequence_start: bool = True,
+ sequence_end: bool = True) -> List[torch.Tensor]:
+ assert input_ids and all(isinstance(_, List) for _ in input_ids)
+ assert steps is None or (len(steps) == len(input_ids))
+
+ logits = [None] * len(input_ids)
+
+ async def _proc(i):
+ async for out in self.generate(
+ messages=None,
+ input_ids=input_ids[i],
+ step=0 if steps is None else steps[i],
+ session_id=i,
+ # `max_new_tokens=0` means we don't need engine to
+ # generate tokens and `output_logits=all` requests engine
+ # to output logits of all input tokens
+ gen_config=GenerationConfig(max_new_tokens=0,
+ output_logits='all'),
+ stream_response=False,
+ sequence_start=sequence_start,
+ sequence_end=sequence_end):
+ # In the last iteration, the yielded `out` is an empty response
+ # indicating the finish_reason, which should be ignored here
+ if out.finish_reason is None:
+ # Try not to return in async for loop. Otherwise, there
+ # will be `GeneratorExit` exception
+ logits[i] = out.logits
+
+ tasks = [_proc(i) for i in range(len(input_ids))]
+ await asyncio.gather(*tasks)
- index_range_starts = []
- index_range_ends = []
- for input_id in input_ids:
- index_range_start = np.array(
- [i * max_input_len for i in range(n_max_iter)])
- index_range_end = index_range_start + max_input_len
- index_range_start[index_range_start >= len(input_id)] = len(
- input_id)
- index_range_end[index_range_end >= len(input_id)] = len(input_id)
- index_range_starts.append(index_range_start)
- index_range_ends.append(index_range_end)
-
- def _split_embeddings(input_ids, niter, iter_len, embeddings,
- embedding_ranges):
- embs = [None] * niter
- ranges = [None] * niter
-
- if embeddings is None:
- return embs, ranges
-
- for i in range(niter):
- iembs = []
- iranges = []
- for emb, (begin, end) in zip(embeddings, embedding_ranges):
- assert end <= len(input_ids)
- if begin >= (i + 1) * iter_len or end <= i * iter_len:
- continue
- if isinstance(emb, np.ndarray):
- emb = torch.from_numpy(emb)
- emb = emb.squeeze()
- offx = max(iter_len * i - begin, 0)
- offy = max(end - iter_len * (i + 1), 0)
- emb = emb[offx:emb.shape[0] - offy]
- off = max(begin - iter_len * i, 0)
- rng = [off, off + emb.shape[0]]
- iembs.append(emb)
- iranges.append(rng)
-
- iembs = iembs or None
- iranges = iranges or None
- embs[i] = iembs
- ranges[i] = iranges
-
- return embs, ranges
-
- if input_embeddings is not None:
- if not isinstance(input_embeddings[0], list):
- input_embeddings = [input_embeddings]
- input_embedding_ranges = [input_embedding_ranges]
- _input_embeddings = []
- _input_embedding_ranges = []
- for i in range(len(input_ids)):
- embeddings, ranges = _split_embeddings(
- input_ids[i], n_max_iter, max_input_len,
- input_embeddings[i], input_embedding_ranges[i])
- _input_embeddings.append(embeddings)
- _input_embedding_ranges.append(ranges)
- input_embeddings = _input_embeddings
- input_embedding_ranges = _input_embedding_ranges
-
- logits = []
- generator = self.engine.create_instance()
- for i in range(n_max_iter):
- steps = [start[i] for start in index_range_starts]
- _input_ids = [
- input_id[start[i]:end[i]] for input_id, start, end in zip(
- input_ids, index_range_starts, index_range_ends)
- ]
- embeddings = None
- ranges = None
- if input_embeddings is not None:
- embeddings = [x[i] for x in input_embeddings]
- ranges = [x[i] for x in input_embedding_ranges]
-
- _logits = generator.decode(_input_ids,
- steps=steps,
- input_embeddings=embeddings,
- input_embedding_ranges=ranges,
- sequence_start=(i == 0),
- sequence_end=(i == n_max_iter - 1))
- _logits = _logits.cpu()
- logits.append(_logits)
-
- # concat logits. Shape is [bsz, seq_len, vocab_size]
- logits = torch.cat(logits, dim=1)
return logits
def get_ppl(self, input_ids: Union[List[int],
@@ -188,21 +66,19 @@ def get_ppl(self, input_ids: Union[List[int],
input token ids
Returns:
- Union[float, List[float]]: A list of perplexity scores.
+ List[float]: A list of perplexity scores.
"""
assert isinstance(input_ids, List)
if isinstance(input_ids[0], int):
input_ids = [input_ids]
-
- generator = self.engine.create_instance()
+ assert all(len(_) > 1 for _ in input_ids)
# TODO: a better way to determine `max_input_len`, at most allocate
# 2G mem for logits with shape [bs, max_input_len, vocab_size]
vocab_size = self.hf_tm_cfg.vocab_size
max_input_len = 2 * 1024**3 // (vocab_size * 4)
sizes = [len(_) for _ in input_ids]
- losses = []
- target_counts = []
+ result = []
sorted_index_values = sorted(list(enumerate(sizes)),
key=lambda x: x[1],
reverse=True)
@@ -214,29 +90,20 @@ def get_ppl(self, input_ids: Union[List[int],
logger.info(f'start: {start}, end: {end}')
if start == end:
_input_ids = input_ids[indices[start]]
- loss, target_count = self._get_long_text_ppl(
- generator=generator,
- input_ids=_input_ids,
- max_input_len=max_input_len)
- losses.append(loss)
- target_counts.append(target_count)
+ res = self._get_long_text_ppl(input_ids=_input_ids,
+ max_input_len=max_input_len)
+ result.append(res)
else:
_input_ids = [input_ids[indices[i]] for i in range(start, end)]
- loss, target_count = self._get_ppl(
- generator=generator,
+ res = self._get_ppl(
input_ids=_input_ids,
max_input_len=max_input_len,
)
- losses.append(loss)
- target_counts.append(target_count)
- loss = torch.concatenate(losses)
- target_count = torch.concatenate(target_counts)
- loss_avg = loss / target_count
- loss_avg = loss_avg.numpy().tolist()
- result = list(range(len(loss_avg)))
+ result.extend(res)
+ output = list(range(len(result)))
for index, sorted_index in enumerate(indices):
- result[sorted_index] = loss_avg[index]
- return result
+ output[sorted_index] = result[index]
+ return output
def _batch_iterator(self, sizes, max_value):
"""Return an iterator that calculates intervals (start, end) of a
@@ -261,7 +128,7 @@ def _batch_iterator(self, sizes, max_value):
else:
i += 1
- def _get_long_text_ppl(self, generator, input_ids, max_input_len):
+ def _get_long_text_ppl(self, input_ids, max_input_len):
assert all(isinstance(_, int) for _ in input_ids)
seq_len = len(input_ids)
assert seq_len > max_input_len
@@ -276,31 +143,29 @@ def _get_long_text_ppl(self, generator, input_ids, max_input_len):
target_ids = input_ids[i + 1:i + 1 + max_input_len]
loss, target_count = self._get_ppl(
- generator=generator,
input_ids=[token_ids],
max_input_len=max_input_len,
target_ids=[target_ids],
steps=step,
sequence_start=(i == 0),
sequence_end=(i + max_input_len >= seq_len))
- losses.append(loss)
- target_counts.append(target_count)
- loss_sum = torch.concatenate(losses).sum().unsqueeze(0)
- target_count = torch.concatenate(target_counts).sum().unsqueeze(0)
- return loss_sum, target_count
+ losses.extend(loss)
+ target_counts.extend(target_count)
+ loss_sum = sum(losses)
+ target_count = sum(target_counts)
+ return loss_sum / target_count
def _get_ppl(self,
- generator,
input_ids,
max_input_len,
target_ids=None,
steps=None,
sequence_start: bool = True,
sequence_end: bool = True):
- assert isinstance(input_ids, List)
- assert all(isinstance(_, List) for _ in input_ids)
- if target_ids:
- assert all(isinstance(_, List) for _ in target_ids)
+ assert (isinstance(input_ids, List)
+ and all(isinstance(_, List) for _ in input_ids))
+ assert steps is None or len(steps) == len(input_ids)
+ assert target_ids is None or len(target_ids) == len(input_ids)
lens = [len(_) for _ in input_ids]
total_len = sum(lens)
@@ -309,41 +174,41 @@ def _get_ppl(self,
logger.info(f'get_ppl: bs: {len(input_ids)}, lens: {lens}, '
f'total_len: {total_len}')
torch.cuda.empty_cache()
- logits = generator.decode(input_ids=input_ids,
- steps=steps,
- sequence_start=sequence_start,
- sequence_end=sequence_end)
- bsz, seq_len, vocab_size = logits.shape
- logits = logits.float()
+
+ logits = self._run(
+ coro=self._async_get_logits(input_ids=input_ids,
+ steps=steps,
+ sequence_start=sequence_start,
+ sequence_end=sequence_end)).result()
padding_token_id = -100
if target_ids is None:
- # shift token_ids by 1 to the left
target_ids = [x[1:] + [padding_token_id] for x in input_ids]
else:
target_ids = [
target_ids[i] + [padding_token_id]
if len(target_ids[i]) < len(input_ids[i]) else target_ids[i]
- for i in range(bsz)
+ for i in range(len(input_ids))
]
target_ids = [
torch.Tensor(torch.LongTensor(_target_ids))
for _target_ids in target_ids
]
- target_ids = pad_sequence(target_ids,
- batch_first=True,
- padding_value=padding_token_id)
- target_ids = target_ids.to(logits.device)
- target_mask = target_ids != padding_token_id
- # compute cross entropy loss
- flat_logits = logits.contiguous().view(-1, vocab_size)
- flat_target_ids = target_ids.contiguous().view(-1)
- flat_loss_matrix = torch.nn.functional.cross_entropy(
- flat_logits,
- flat_target_ids,
- reduction='none',
- ignore_index=padding_token_id)
- flat_loss_matrix = flat_loss_matrix.view(bsz, seq_len)
- loss = flat_loss_matrix.sum(dim=-1).cpu()
- target_count = target_mask.sum(dim=-1).cpu()
- return loss, target_count
+ result = []
+ for _logits, _target_ids in zip(logits, target_ids):
+ _logits = _logits.float()
+ vocab_size = _logits.shape[-1]
+ _target_ids = _target_ids.to(_logits.device)
+ target_mask = _target_ids != padding_token_id
+ # compute cross entropy loss
+ flat_logits = _logits.contiguous().view(-1, vocab_size)
+ flat_target_ids = _target_ids.contiguous().view(-1)
+ flat_loss_matrix = torch.nn.functional.cross_entropy(
+ flat_logits,
+ flat_target_ids,
+ reduction='none',
+ ignore_index=padding_token_id)
+ loss = flat_loss_matrix.sum()
+ target_count = target_mask.sum()
+ result.append(loss.item() / target_count.item())
+ return result
diff --git a/lmdeploy/turbomind/chat.py b/lmdeploy/turbomind/chat.py
index e106beae17..6985e3dc27 100644
--- a/lmdeploy/turbomind/chat.py
+++ b/lmdeploy/turbomind/chat.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
+import asyncio
import os
import random
@@ -28,6 +29,28 @@ def input_prompt(model_name):
return '\n'.join(iter(input, sentinel))
+async def async_infer(generator, session_id, input_ids, gen_config,
+ sequence_start, step, stream_output, tokenizer, state):
+ token_ids = input_ids.copy()
+ prev_len = 0
+ async for output in generator.async_stream_infer(
+ session_id=session_id,
+ input_ids=input_ids,
+ gen_config=gen_config,
+ sequence_start=sequence_start,
+ sequence_end=False,
+ step=step,
+ stream_output=stream_output):
+ tokens = output.num_token
+ if tokens > prev_len:
+ token_ids += output.token_ids[prev_len - tokens:]
+ response, state = tokenizer.detokenize_incrementally(token_ids,
+ state=state)
+ prev_len = tokens
+ print(response, end='', flush=True)
+ return tokens
+
+
def main(model_path: str,
session_id: int = 1,
top_k: float = 40,
@@ -130,6 +153,9 @@ def main(model_path: str,
repetition_penalty=repetition_penalty,
stop_token_ids=stop_words)
+ loop = asyncio.new_event_loop()
+ asyncio.set_event_loop(loop)
+
nth_round = 1
step = 0
seed = random.getrandbits(64)
@@ -138,7 +164,7 @@ def main(model_path: str,
if prompt == 'exit':
exit(0)
elif prompt == 'end':
- generator.end(session_id)
+ loop.run_until_complete(generator.async_end(session_id))
nth_round = 1
step = 0
seed = random.getrandbits(64)
@@ -149,10 +175,8 @@ def main(model_path: str,
if model.capability == 'chat':
sequence_start = (nth_round == 1)
- sequence_end = False
else:
sequence_start = True
- sequence_end = True
step = 0
if step + len(
@@ -163,20 +187,11 @@ def main(model_path: str,
print(f'{prompt}', end='', flush=True)
state = DetokenizeState(len(input_ids))
- for outputs in generator.stream_infer(
- session_id=session_id,
- input_ids=[input_ids],
- gen_config=gen_config,
- sequence_start=sequence_start,
- sequence_end=sequence_end,
- step=step,
- stream_output=stream_output):
-
- res, tokens = input_ids + outputs.token_ids, outputs.num_token
- # decode res
- response, state = tokenizer.detokenize_incrementally(
- res, state=state)
- print(response, end='', flush=True)
+
+ coro = async_infer(generator, session_id, input_ids, gen_config,
+ sequence_start, step, stream_output, tokenizer,
+ state)
+ tokens = loop.run_until_complete(coro)
# update step
step += len(input_ids) + tokens
diff --git a/lmdeploy/turbomind/decode.py b/lmdeploy/turbomind/decode.py
deleted file mode 100644
index 5ba4675c59..0000000000
--- a/lmdeploy/turbomind/decode.py
+++ /dev/null
@@ -1,41 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import os
-import os.path as osp
-
-import torch
-
-from lmdeploy import turbomind as tm
-from lmdeploy.tokenizer import Tokenizer
-
-os.environ['TM_LOG_LEVEL'] = 'ERROR'
-
-
-def main(model_path, inputs):
- """An example to perform model inference through the command line
- interface.
-
- Args:
- model_path (str): the path of the deployed model
- inputs (str): the path of text file contatin input text lines
- """
- tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
- tokenizer = Tokenizer(tokenizer_model_path)
- tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id)
- generator = tm_model.create_instance()
-
- with open(inputs, 'r') as f:
- lines = f.readlines()
-
- input_ids = [tokenizer.encode(x) for x in lines]
-
- logits = generator.decode(input_ids)
-
- top_1 = torch.argmax(logits, -1)
-
- print(top_1)
-
-
-if __name__ == '__main__':
- import fire
-
- fire.Fire(main)
diff --git a/lmdeploy/turbomind/turbomind.py b/lmdeploy/turbomind/turbomind.py
index a1b2fff944..02c22314f4 100644
--- a/lmdeploy/turbomind/turbomind.py
+++ b/lmdeploy/turbomind/turbomind.py
@@ -4,11 +4,13 @@
import json
import os.path as osp
import sys
+from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import asdict
+from functools import partial
from itertools import repeat
-from queue import LifoQueue, Queue
-from typing import Dict, Iterable, List
+from queue import Queue
+from typing import Dict, List
import numpy as np
import torch
@@ -317,6 +319,92 @@ def create_instance(self, cuda_stream_id=0):
return TurboMindInstance(self, self.config, cuda_stream_id)
+def _get_logits(outputs, offset: int):
+ logits = outputs['logits']
+
+ def _func(out: EngineOutput, step: int):
+ out.logits = logits[:step - offset - 1, :]
+
+ return _func
+
+
+def _get_last_hidden_state(outputs, offset: int):
+ last_hidden_state = outputs['last_hidden_state']
+
+ def _func(out: EngineOutput, step: int):
+ out.last_hidden_state = last_hidden_state[:step - offset - 1, :]
+
+ return _func
+
+
+def _get_logprobs_impl(logprob_vals: torch.Tensor,
+ logprob_idxs: torch.Tensor,
+ logprob_nums: torch.Tensor,
+ output_ids: List[int],
+ logprobs: int,
+ out_logprobs: List[Dict[int, float]] = None):
+ length = len(output_ids)
+ offset = len(out_logprobs)
+ if length == offset:
+ return out_logprobs
+ for (pos, idx, val, n) in zip(range(offset,
+ length), logprob_idxs[offset:length],
+ logprob_vals[offset:length],
+ logprob_nums[offset:length]):
+ topn = min(n.item(), logprobs)
+ tok_res = {idx[i].item(): val[i].item() for i in range(topn)}
+ token_id = output_ids[pos]
+ if token_id not in tok_res:
+ print(token_id, tok_res)
+ valid_n = n.item()
+ tok_res[token_id] = \
+ val[:valid_n][idx[:valid_n] == token_id].item()
+ ids = list(tok_res.keys())
+ for k in ids:
+ if tok_res[k] == float('-inf'):
+ tok_res.pop(k)
+ out_logprobs.append(tok_res)
+ return out_logprobs
+
+
+def _get_logprobs(outputs, output_logprobs: int):
+ logprob_vals = outputs['logprob_vals']
+ logprob_idxs = outputs['logprob_indexes']
+ logprob_nums = outputs['logprob_nums']
+
+ logprobs = []
+
+ def _func(out: EngineOutput, step: int):
+ _get_logprobs_impl(logprob_vals, logprob_idxs, logprob_nums,
+ out.token_ids, output_logprobs, logprobs)
+ out.logprobs = logprobs
+
+ return _func
+
+
+class StreamingSemaphore:
+
+ def __init__(self):
+ self.loop = asyncio.get_running_loop()
+ self.fut = None
+ self.val = 0
+
+ async def acquire(self):
+ if self.val:
+ self.val = 0
+ return
+ self.fut = self.loop.create_future()
+ await self.fut
+ self.fut = None
+ self.val = 0
+
+ def release(self):
+ if not self.val:
+ self.val = 1
+ if self.fut:
+ self.fut.set_result(None)
+
+
class TurboMindInstance:
"""Instance of TurboMind.
@@ -343,116 +431,30 @@ def __init__(self,
# create model instances
self.model_inst = self._create_model_instance(0)
- self.que = Queue()
- self.executor: ThreadPoolExecutor = None
- self.future = None
self.config = config
+ self.lock = None
def _create_model_instance(self, device_id):
- rank = self.node_id * self.gpu_count + device_id
- model_inst = self.tm_model.model_comm.create_model_instance(
- device_id, rank, self.cuda_stream_id, self.nccl_params)
+ model_inst = self.tm_model.model_comm.create_model_instance(device_id)
return model_inst
- def _forward_callback(self, result, ctx):
- self.que.put((False, result))
-
- def _forward_thread(self, inputs):
-
- def _func():
- try:
- output = self.model_inst.forward(inputs)
- except Exception as e:
- logger.error(f'unhandled exception: {e}')
- self.que.put((-1, None))
- return
- self.que.put((True, output))
-
- self.executor = ThreadPoolExecutor(1)
- self.future = self.executor.submit(_func)
-
- def _async_forward_callback(self, result, ctx, que: LifoQueue):
- que.put((False, result))
-
- def _async_forward_thread(self, inputs, que: LifoQueue):
-
- def _func():
- try:
- output = self.model_inst.forward(inputs)
- except Exception as e:
- logger.error(f'unhandled exception: {e}')
- que.put((-1, None))
- return
- que.put((True, output))
-
- self.executor = ThreadPoolExecutor(1)
- self.future = self.executor.submit(_func)
-
- def _get_logprobs(self,
- logprob_vals: torch.Tensor,
- logprob_indexes: torch.Tensor,
- logprob_nums: torch.Tensor,
- output_ids: torch.Tensor,
- logprobs: int = None,
- length: int = None,
- out_logprobs: List[Dict[int, float]] = None,
- session_id: int = None):
- if logprobs is None:
- return None
- if out_logprobs is None:
- out_logprobs = []
- if len(output_ids) <= len(out_logprobs):
- return out_logprobs
- offset = len(out_logprobs)
- for (token_id, idx, val, n) in zip(output_ids[offset:length],
- logprob_indexes[offset:length],
- logprob_vals[offset:length],
- logprob_nums[offset:length]):
- topn = min(n.item(), logprobs)
- tok_res = {idx[i].item(): val[i].item() for i in range(topn)}
- if token_id.item() not in tok_res:
- valid_n = n.item()
- tok_res[token_id.item()] = \
- val[:valid_n][idx[:valid_n] == token_id].item()
- ids = list(tok_res.keys())
- for k in ids:
- if tok_res[k] == float('-inf'):
- tok_res.pop(k)
- out_logprobs.append(tok_res)
- return out_logprobs
+ def _get_extra_output_processors(self, outputs: Dict[str, torch.Tensor],
+ gen_config: GenerationConfig,
+ input_len: int):
+
+ def _get_offset(type):
+ return input_len - 1 if type == 'generation' else 0
- def end(self, session_id: int):
- """End the given session."""
- input_ids = [self.tm_model.tokenizer.eos_token_id]
- end_generator = self.tm_model.create_instance()
- for outputs in end_generator.stream_infer(
- session_id,
- input_ids,
- sequence_start=False,
- sequence_end=True,
- gen_config=GenerationConfig(max_new_tokens=0)):
- pass
-
- async def async_end(self, session_id: int):
- """End the given session."""
- self.end(session_id)
-
- def cancel(self, session_id: int):
- """Stop current streaming inference."""
- input_ids = [self.tm_model.tokenizer.eos_token_id]
- stop_generator = self.tm_model.create_instance()
- for outputs in stop_generator.stream_infer(
- session_id,
- input_ids,
- sequence_start=False,
- sequence_end=False,
- stop=True,
- gen_config=GenerationConfig(max_new_tokens=0)):
- pass
-
- async def async_cancel(self, session_id: int):
- """End the given session."""
- self.cancel(session_id)
+ fs = []
+ if gen_config.output_logits:
+ offset = _get_offset(gen_config.output_logits)
+ fs.append(_get_logits(outputs, offset))
+ if gen_config.output_last_hidden_state:
+ offset = _get_offset(gen_config.output_last_hidden_state)
+ fs.append(_get_last_hidden_state(outputs, offset))
+ if gen_config.logprobs:
+ fs.append(_get_logprobs(outputs, gen_config.logprobs))
+ return fs
def prepare_embeddings(self,
input_embeddings=None,
@@ -506,61 +508,17 @@ def prepare_embeddings(self,
return input_embeddings, input_embedding_ranges
def prepare_inputs(self,
- session_id,
input_ids,
gen_config: GenerationConfig,
input_embeddings=None,
- input_embedding_ranges=None,
- sequence_start: bool = True,
- sequence_end: bool = False,
- step=0,
- stop=False):
+ input_embedding_ranges=None):
"""Convert inputs format."""
- if len(input_ids) == 0:
- input_ids = [[]]
- if isinstance(input_ids[0], int):
- input_ids = [input_ids]
-
- batch_size = len(input_ids)
-
- def _broadcast_np(data, dtype, shape=(batch_size, )):
- if isinstance(data, Iterable):
- assert len(data) == batch_size
- return data
+ assert isinstance(input_ids, Sequence)
- return np.full(shape, data, dtype=dtype)
+ input_ids = torch.IntTensor(input_ids)
+ input_len = len(input_ids)
- input_ids = [torch.IntTensor(ids) for ids in input_ids]
- input_lengths = torch.IntTensor([len(ids) for ids in input_ids])
- input_ids = pad_sequence(input_ids,
- batch_first=True,
- padding_value=self.eos_id)
-
- if isinstance(session_id, int):
- session_id = [session_id]
- assert len(session_id) == batch_size
-
- step = _broadcast_np(step, np.int32)
-
- inputs = dict(
- input_ids=input_ids,
- input_lengths=input_lengths,
- request_output_len=np.full(input_lengths.shape,
- gen_config.max_new_tokens,
- dtype=np.uint32),
- runtime_top_k=_broadcast_np(gen_config.top_k, np.uint32),
- runtime_top_p=_broadcast_np(gen_config.top_p, np.float32),
- runtime_min_p=_broadcast_np(gen_config.min_p, np.float32),
- temperature=_broadcast_np(gen_config.temperature, np.float32),
- repetition_penalty=_broadcast_np(gen_config.repetition_penalty,
- np.float32),
- step=step,
-
- # session input
- START=_broadcast_np((1 if sequence_start else 0), np.int32),
- END=_broadcast_np((1 if sequence_end else 0), np.int32),
- CORRID=np.array(session_id, dtype=np.uint64),
- STOP=_broadcast_np((1 if stop else 0), np.int32))
+ inputs = dict(input_ids=input_ids, )
input_embeddings, input_embedding_ranges = self.prepare_embeddings(
input_embeddings, input_embedding_ranges)
@@ -568,17 +526,6 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
inputs['input_embeddings'] = input_embeddings
inputs['input_embedding_ranges'] = input_embedding_ranges
- if gen_config.min_new_tokens is not None:
- inputs['min_length'] = _broadcast_np(gen_config.min_new_tokens,
- np.int32)
-
- if gen_config.logprobs is not None and gen_config.logprobs > 0:
- if gen_config.logprobs > MAX_LOGPROBS:
- gen_config.logprobs = MAX_LOGPROBS
- logger.warning('logprobs shoudd be in range [1, 1024]'
- f'update logprobs={gen_config.logprobs}')
- inputs['logprobs'] = _broadcast_np(gen_config.logprobs, np.int32)
-
bad_words = []
if gen_config.bad_token_ids is not None:
bad_words.extend(gen_config.bad_token_ids)
@@ -597,10 +544,24 @@ def _broadcast_np(data, dtype, shape=(batch_size, )):
if bad_words is not None:
inputs['bad_words_list'] = bad_words
- if gen_config.random_seed is not None:
- inputs['random_seed'] = _broadcast_np(gen_config.random_seed,
- np.uint64)
- return inputs, input_lengths
+ return inputs, input_len
+
+ async def async_cancel(self, session_id: int = None):
+ self.model_inst.cancel()
+
+ def async_end_cb(self, fut: asyncio.Future, status: int):
+ """executing on engine's signaling thread."""
+ logger.info(f'[async_end_cb] session ended, status = {status}')
+ fut.get_loop().call_soon_threadsafe(fut.set_result, status)
+
+ async def async_end(self, session_id):
+ fut = asyncio.get_running_loop().create_future()
+ self.model_inst.end(partial(self.async_end_cb, fut), session_id)
+ await fut
+
+ def async_signal_cb(self, s: StreamingSemaphore):
+ """executing on engine's signaling thread."""
+ s.loop.call_soon_threadsafe(s.release)
async def async_stream_infer(self,
session_id,
@@ -610,7 +571,6 @@ async def async_stream_infer(self,
sequence_start: bool = True,
sequence_end: bool = False,
step=0,
- stop=False,
gen_config: GenerationConfig = None,
stream_output=False,
**kwargs):
@@ -630,295 +590,116 @@ async def async_stream_infer(self,
stream_output (bool): indicator for stream output
kwargs (dict): kwargs for backward compatibility
"""
- # start forward thread
- que = LifoQueue()
- from functools import partial
- _forward_callback = partial(self._async_forward_callback, que=que)
- _forward_thread = partial(self._async_forward_thread, que=que)
- if stream_output and not stop:
- logger.info(f'Register stream callback for {session_id}')
- self.model_inst.register_callback(_forward_callback)
-
- inputs, input_lengths = self.prepare_inputs(
- session_id=session_id,
- input_ids=input_ids,
- input_embeddings=input_embeddings,
- input_embedding_ranges=input_embedding_ranges,
- sequence_start=sequence_start,
- sequence_end=sequence_end,
- step=step,
- stop=stop,
- gen_config=gen_config)
-
- tm_inputs = _np_dict_to_tm_dict(inputs)
- _forward_thread(tm_inputs)
-
- seq_start = input_lengths + input_lengths.new_tensor(step)
-
- out_logprobs = None
- prev_len = 0
- # generator
- while True:
- while que.qsize() == 0: # let other requests in
- await asyncio.sleep(0.002)
-
- finish, tm_outputs = que.get()
- if finish < 0:
- yield EngineOutput(status=ResponseType.INTERNAL_ENGINE_ERROR,
- token_ids=[],
- num_token=0)
- self.executor.shutdown()
- break
-
- outputs = _tm_dict_to_torch_dict(tm_outputs)
-
- output_ids = outputs['output_ids'][:, 0, :]
- sequence_length = outputs['sequence_length'].long()[:, 0]
- output_ids = [
- output_id[s:l] for output_id, s, l in zip(
- output_ids, seq_start, sequence_length)
- ]
- sequence_length -= seq_start.to(sequence_length.device)
-
- if 'logprob_vals' in outputs:
- logprob_vals = outputs['logprob_vals'][0, 0]
- logprob_indexes = outputs['logprob_indexes'][0, 0]
- logprob_nums = outputs['logprob_nums'][0, 0]
- out_logprobs = self._get_logprobs(logprob_vals,
- logprob_indexes,
- logprob_nums, output_ids[0],
- gen_config.logprobs,
- sequence_length.cpu().item(),
- out_logprobs, session_id)
-
- outputs = []
- status = ResponseType.FINISH if finish else ResponseType.SUCCESS
- for output, len_ in zip(output_ids, sequence_length):
- output, len_ = output, len_.item()
- if len(output) > 0 and output[-1].item() == self.eos_id \
- and not gen_config.ignore_eos:
- outputs = EngineOutput(status, output[:-1].tolist(),
- len_ - 1)
- elif len(output) > 0 and \
- gen_config.stop_token_ids is not None and \
- output[-1].item() in gen_config.stop_token_ids:
- outputs = EngineOutput(status, output[:-1].tolist(), len_)
- else:
- outputs = EngineOutput(status, output.tolist(), len_)
- if outputs.num_token < prev_len and not finish:
- continue
- else:
- prev_len = outputs.num_token
-
- if out_logprobs:
- output_token_len = len(outputs.token_ids)
- outputs.logprobs = out_logprobs[:output_token_len]
-
- yield outputs
-
- if finish:
- self.future.result()
- self.executor.shutdown()
- break
-
- if stream_output and not stop:
- logger.info(f'UN-register stream callback for {session_id}')
- self.model_inst.unregister_callback()
-
- def stream_infer(self,
- session_id,
- input_ids,
- input_embeddings=None,
- input_embedding_ranges=None,
- sequence_start: bool = True,
- sequence_end: bool = False,
- step=0,
- stop=False,
- gen_config: GenerationConfig = None,
- stream_output=False,
- **kwargs):
- """Perform model inference.
+ logger.info(f'[async_stream_infer] session {session_id} start')
+ gen_cfg = self._get_generation_config(gen_config)
- Args:
- session_id (int): the id of a session
- input_ids (numpy.ndarray): the token ids of a prompt
- input_embeddings (List[numpy.ndarray]): embeddings features
- input_embedding_ranges (List[Tuple[int,int]]): the begin/end
- offsets of input_embeddings to input_ids
- sequence_start (bool): indicator for starting a sequence
- sequence_end (bool): indicator for ending a sequence
- step (int): the offset of the k/v cache
- stop (bool): indicator for cancelling the session
- gen_config (GenerationConfig): generation config
- stream_output (bool): indicator for stream output
- kwargs (dict): kwargs for backward compatibility
- """
- if stream_output and not stop:
- logger.info(f'Register stream callback for {session_id}')
- self.model_inst.register_callback(self._forward_callback)
-
- inputs, input_lengths = self.prepare_inputs(
- session_id=session_id,
+ inputs, input_len = self.prepare_inputs(
input_ids=input_ids,
input_embeddings=input_embeddings,
input_embedding_ranges=input_embedding_ranges,
- sequence_start=sequence_start,
- sequence_end=sequence_end,
- step=step,
- stop=stop,
gen_config=gen_config)
- tm_inputs = _np_dict_to_tm_dict(inputs)
- # start forward thread
- self.que = Queue()
- self._forward_thread(tm_inputs)
-
- seq_start = input_lengths + input_lengths.new_tensor(step)
- out_logprobs = None
-
- # generator
- while True:
- while self.que.qsize() > 1:
- self.que.get()
-
- finish, tm_outputs = self.que.get()
- if finish < 0:
- yield EngineOutput(status=ResponseType.INTERNAL_ENGINE_ERROR,
- token_ids=[],
- num_token=0)
- self.executor.shutdown()
- break
-
- outputs = _tm_dict_to_torch_dict(tm_outputs)
-
- output_ids = outputs['output_ids'][:, 0, :]
- sequence_length = outputs['sequence_length'].long()[:, 0]
- output_ids = [
- output_id[s:l] for output_id, s, l in zip(
- output_ids, seq_start, sequence_length)
- ]
- sequence_length -= seq_start.to(sequence_length.device)
-
- if 'logprob_vals' in outputs:
- logprob_vals = outputs['logprob_vals'][0, 0]
- logprob_indexes = outputs['logprob_indexes'][0, 0]
- logprob_nums = outputs['logprob_nums'][0, 0]
- out_logprobs = self._get_logprobs(logprob_vals,
- logprob_indexes,
- logprob_nums, output_ids[0],
- gen_config.logprobs,
- sequence_length.cpu().item(),
- out_logprobs, session_id)
-
- outputs = []
- status = ResponseType.FINISH if finish else ResponseType.SUCCESS
- for output, len_ in zip(output_ids, sequence_length):
- output, len_ = output, len_.item()
- if len(output) > 0 and output[-1].item() == self.eos_id \
- and not gen_config.ignore_eos:
- outputs = EngineOutput(status, output[:-1].tolist(),
- len_ - 1, out_logprobs)
- elif len(output) > 0 and \
- gen_config.stop_token_ids is not None and \
- output[-1].item() in gen_config.stop_token_ids:
- outputs = EngineOutput(status, output[:-1].tolist(), len_,
- out_logprobs)
- else:
- outputs = EngineOutput(status, output.tolist(), len_,
- out_logprobs)
-
- if out_logprobs:
- output_token_len = len(outputs.token_ids)
- outputs.logprobs = out_logprobs[:output_token_len]
-
- yield outputs
-
- if finish:
- self.future.result()
- self.executor.shutdown()
- while self.que.qsize() > 0:
- self.que.get()
- break
-
- if stream_output and not stop:
- logger.info(f'UN-register stream callback for {session_id}')
- self.model_inst.unregister_callback()
-
- def decode(self,
- input_ids,
- steps: List[int] = None,
- input_embeddings=None,
- input_embedding_ranges=None,
- sequence_start: bool = True,
- sequence_end: bool = True):
- """Perform context decode on input tokens.
-
- Args:
- input_ids (numpy.ndarray): the batch of input token ids
- steps (List[int]): the offset of the k/v cache
- input_embeddings (List[List[Union[torch.Tensor, np.ndarray]]]):
- embeddings features
- input_embedding_ranges: (List[List[Tuple[int, int]]]):
- the begin/end offsets of input_embeddings to input_ids
- sequence_start (bool): indicator for starting a sequence
- sequence_end (bool): indicator for ending a sequence
- """
-
- if len(input_ids) == 0:
- input_ids = [[]]
- if isinstance(input_ids[0], int):
- input_ids = [input_ids]
- if steps is None:
- steps = [0] * len(input_ids)
- assert isinstance(steps, List) and len(steps) == len(input_ids)
-
- # append an extra token since input_len-1 tokens will be
- # decoded by context decoder
- input_ids = [x[:] for x in input_ids]
- for inputs in input_ids:
- inputs.append(0)
-
- batch_size = len(input_ids)
-
- def _broadcast_np(data, dtype, shape=(batch_size, )):
- if isinstance(data, Iterable):
- assert len(data) == batch_size
- return data
-
- return np.full(shape, data, dtype=dtype)
-
- input_ids = [torch.IntTensor(ids) for ids in input_ids]
- input_lengths = torch.IntTensor([len(ids) for ids in input_ids])
- input_ids = pad_sequence(input_ids,
- batch_first=True,
- padding_value=self.eos_id)
- steps = torch.IntTensor([step for step in steps])
-
- inputs = dict(input_ids=input_ids,
- input_lengths=input_lengths,
- request_output_len=_broadcast_np(0, dtype=np.uint32),
- is_return_logits=_broadcast_np(1, np.uint32),
- START=_broadcast_np((1 if sequence_start else 0),
- np.int32),
- END=_broadcast_np((1 if sequence_end else 0), np.int32),
- step=steps)
-
- input_embeddings, input_embedding_ranges = self.prepare_embeddings(
- input_embeddings, input_embedding_ranges)
- if input_embeddings is not None:
- inputs['input_embeddings'] = input_embeddings
- inputs['input_embedding_ranges'] = input_embedding_ranges
-
- tm_inputs = _np_dict_to_tm_dict(inputs)
-
- # start forward thread
- self._forward_thread(tm_inputs)
-
- res, tm_outputs = self.que.get()
- if res < 0:
- return None
-
- outputs = _tm_dict_to_torch_dict(tm_outputs)
- logits = outputs['logits']
-
- return logits[:, :-1, :]
+ session = _tm.SessionParam(id=session_id,
+ step=step,
+ start=sequence_start,
+ end=sequence_end)
+
+ inputs = _np_dict_to_tm_dict(inputs)
+
+ sem = StreamingSemaphore()
+ signal_cb = partial(self.async_signal_cb, sem)
+
+ outputs, shared_state = self.model_inst.forward(
+ inputs, session, gen_cfg, stream_output, signal_cb)
+
+ outputs = _tm_dict_to_torch_dict(outputs)
+
+ extra_fs = self._get_extra_output_processors(outputs, gen_config,
+ input_len)
+
+ output_ids_buf = outputs['output_ids']
+
+ finish = False
+ state = None
+
+ output_ids = []
+ output_len = 0
+ prev_len = step + input_len
+ try:
+ while True:
+ await sem.acquire()
+ state = shared_state.consume()
+
+ status, seq_len = state.status, state.seq_len
+
+ if status in [7, 8]: # finish / canceled
+ finish, status = True, 0
+ elif status:
+ yield self._get_error_output()
+ break
+
+ if seq_len == prev_len and not finish:
+ continue
+
+ output_ids += output_ids_buf[prev_len:seq_len].tolist()
+ output_len += seq_len - prev_len
+ status = ResponseType.FINISH if finish else ResponseType.SUCCESS # noqa
+ output = EngineOutput(status, output_ids, output_len)
+
+ for f in extra_fs:
+ f(output, seq_len)
+
+ prev_len = seq_len
+
+ yield output
+
+ if finish:
+ break
+
+ except (GeneratorExit, asyncio.CancelledError) as e:
+ logger.info(f'[async_stream_infer] {type(e).__name__}')
+ self.model_inst.cancel()
+ except Exception as e:
+ logger.error(f'[async_stream_infer] {type(e).__name__} {e}')
+ self.model_inst.cancel()
+ yield self._get_error_output()
+ finally:
+ # Contract: `cb` won't be called again if status is non-zero
+ # wait for status to be set as `finish` or `error`
+ while not state or state.status == 0:
+ await sem.acquire()
+ state = shared_state.consume()
+ logger.info(f'[async_stream_infer] session {session_id} done')
+
+ def _get_error_output(self):
+ return EngineOutput(status=ResponseType.INTERNAL_ENGINE_ERROR,
+ token_ids=[],
+ num_token=0)
+
+ def _get_generation_config(self, cfg: GenerationConfig):
+ c = _tm.GenerationConfig()
+ c.max_new_tokens = cfg.max_new_tokens
+ c.top_k = cfg.top_k
+ c.top_p = cfg.top_p
+ c.min_p = cfg.min_p
+ c.temperature = cfg.temperature
+ c.repetition_penalty = cfg.repetition_penalty
+ if cfg.min_new_tokens:
+ c.min_new_tokens = cfg.min_new_tokens
+ output_type = dict(all=1, generation=2)
+ if cfg.output_last_hidden_state:
+ c.output_last_hidden_state = output_type[
+ cfg.output_last_hidden_state]
+ if cfg.output_logits:
+ c.output_logits = output_type[cfg.output_logits]
+ if cfg.logprobs:
+ if cfg.logprobs > MAX_LOGPROBS:
+ cfg.logprobs = MAX_LOGPROBS
+ logger.warning(
+ f'logprobs shoudd be in range [1, {MAX_LOGPROBS}]'
+ f'update logprobs={cfg.logprobs}')
+ c.output_logprobs = cfg.logprobs
+ if cfg.random_seed is not None:
+ c.random_seed = cfg.random_seed
+ # print (c)
+ return c
diff --git a/lmdeploy/utils.py b/lmdeploy/utils.py
index fbdd374f80..e9ef0ba2bb 100644
--- a/lmdeploy/utils.py
+++ b/lmdeploy/utils.py
@@ -332,7 +332,7 @@ def get_max_batch_size(device_type: str):
Args:
device_type (str): the type of device
"""
- assert device_type in ['cuda', 'ascend', 'maca']
+ assert device_type in ['cuda', 'ascend', 'maca', 'camb']
if device_type == 'cuda':
max_batch_size_map = {
'a100': 256,
@@ -352,6 +352,8 @@ def get_max_batch_size(device_type: str):
return 16
elif device_type == 'maca':
return 128
+ elif device_type == 'camb':
+ return 128
def is_bf16_supported(device_type: str = 'cuda'):
@@ -387,5 +389,7 @@ def is_bf16_supported(device_type: str = 'cuda'):
# return False
elif device_type == 'maca':
return True
+ elif device_type == 'camb':
+ return True
else:
return False
diff --git a/lmdeploy/version.py b/lmdeploy/version.py
index f705fcb332..0b4b8a5379 100644
--- a/lmdeploy/version.py
+++ b/lmdeploy/version.py
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
-__version__ = '0.6.4'
+__version__ = '0.6.5'
short_version = __version__
diff --git a/lmdeploy/vl/engine.py b/lmdeploy/vl/engine.py
index 7f786d5f90..7d490b2b77 100644
--- a/lmdeploy/vl/engine.py
+++ b/lmdeploy/vl/engine.py
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import asyncio
+from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Union
import torch
@@ -40,12 +41,13 @@ def __init__(
vision_config = VisionConfig()
self.vision_config = vision_config
self.max_batch_size = vision_config.max_batch_size
+ self.executor = ThreadPoolExecutor(max_workers=1)
torch.cuda.empty_cache()
async def preprocess(self, messages: List[Dict]) -> List[Dict]:
"""preprocess multimodal data in the messages."""
future = asyncio.get_event_loop().run_in_executor(
- None, self.model.preprocess, messages)
+ self.executor, self.model.preprocess, messages)
future.add_done_callback(_raise_exception_on_finish)
outputs = await future
return outputs
@@ -58,7 +60,7 @@ async def async_infer(self, messages: List[Dict]) -> List[Dict]:
of `preprocess()`
"""
future = asyncio.get_event_loop().run_in_executor(
- None, self.model.forward, messages, self.max_batch_size)
+ self.executor, self.model.forward, messages, self.max_batch_size)
future.add_done_callback(_raise_exception_on_finish)
outputs = await future
return outputs
diff --git a/lmdeploy/vl/model/xcomposer2.py b/lmdeploy/vl/model/xcomposer2.py
index 3c72d0c29f..312ef9132b 100644
--- a/lmdeploy/vl/model/xcomposer2.py
+++ b/lmdeploy/vl/model/xcomposer2.py
@@ -156,6 +156,9 @@ def build_model(self):
trust_remote_code=True)
model.vit.load_model()
model.vit.resize_pos()
+ if hasattr(self.hf_config, 'img_size'):
+ model.vit.vision_tower.vision_model.embeddings.image_size = \
+ self.hf_config.img_size
model.vit.vision_tower.vision_model.post_layernorm.to_empty(
device='cpu').half()
self.vl_model = model
diff --git a/requirements/runtime_camb.txt b/requirements/runtime_camb.txt
new file mode 100644
index 0000000000..e56d0cb494
--- /dev/null
+++ b/requirements/runtime_camb.txt
@@ -0,0 +1,21 @@
+accelerate==1.2.0
+einops
+fastapi
+fire
+mmengine-lite
+numpy<2.0.0
+openai
+outlines<0.1.0
+peft<=0.11.1
+pillow
+protobuf
+pydantic>2.0.0
+pynvml
+safetensors
+sentencepiece
+shortuuid
+tiktoken
+torch==2.4.0
+torchvision<=0.19.0,>=0.15.0
+transformers
+uvicorn
diff --git a/requirements_camb.txt b/requirements_camb.txt
new file mode 100644
index 0000000000..24b1f3e796
--- /dev/null
+++ b/requirements_camb.txt
@@ -0,0 +1,4 @@
+-r requirements/build.txt
+-r requirements/runtime_camb.txt
+-r requirements/lite.txt
+-r requirements/serve.txt
diff --git a/src/turbomind/CMakeLists.txt b/src/turbomind/CMakeLists.txt
index aec443a1aa..62adb94e5a 100644
--- a/src/turbomind/CMakeLists.txt
+++ b/src/turbomind/CMakeLists.txt
@@ -16,6 +16,7 @@ add_subdirectory(utils)
add_subdirectory(kernels)
add_subdirectory(layers)
add_subdirectory(models)
+add_subdirectory(engine)
if(BUILD_PYT)
add_subdirectory(th_op)
endif()
diff --git a/src/turbomind/engine/CMakeLists.txt b/src/turbomind/engine/CMakeLists.txt
new file mode 100644
index 0000000000..1d68116cf6
--- /dev/null
+++ b/src/turbomind/engine/CMakeLists.txt
@@ -0,0 +1,7 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+
+cmake_minimum_required(VERSION 3.8)
+
+add_library(engine STATIC gateway.cc request_queue.cc model_request.cc)
+set_property(TARGET engine PROPERTY POSITION_INDEPENDENT_CODE ON)
+set_property(TARGET engine PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
diff --git a/src/turbomind/engine/gateway.cc b/src/turbomind/engine/gateway.cc
new file mode 100644
index 0000000000..e949ec7cd3
--- /dev/null
+++ b/src/turbomind/engine/gateway.cc
@@ -0,0 +1,40 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#include
+
+#include "src/turbomind/engine/gateway.h"
+#include "src/turbomind/engine/request_queue.h"
+
+namespace turbomind {
+
+Gateway::Gateway(std::function()> ctx_factory): request_queue_{this}, ctx_factory_{ctx_factory}
+{
+ signal_thread_ = std::thread(&Gateway::signal_thread_entry, this);
+}
+
+void Gateway::shutdown()
+{
+ request_queue_.close();
+ signal_buffer_.close();
+
+ signal_thread_.join();
+}
+
+void Gateway::signal_thread_entry() noexcept
+{
+ while (true) {
+ bool abort{};
+ std::vector signals = signal_buffer_.take_all(abort);
+ if (abort) {
+ break;
+ }
+ else {
+ auto ctx = ctx_factory_();
+ for (const auto& s : signals) {
+ s();
+ }
+ }
+ }
+}
+
+} // namespace turbomind
diff --git a/src/turbomind/engine/gateway.h b/src/turbomind/engine/gateway.h
new file mode 100644
index 0000000000..d939c0bcc2
--- /dev/null
+++ b/src/turbomind/engine/gateway.h
@@ -0,0 +1,61 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "src/turbomind/engine/request_queue.h"
+#include "src/turbomind/engine/signal_buffer.h"
+
+namespace turbomind {
+
+class Gateway {
+public:
+ Gateway(std::function()> ctx_factory);
+
+ void shutdown();
+
+ void push(std::vector> reqs)
+ {
+ return request_queue_.push(std::move(reqs));
+ }
+
+ void pop(std::vector>& infer_reqs,
+ std::vector>& kill_reqs,
+ unsigned max_infer_num,
+ bool blocking,
+ bool& abort)
+ {
+ return request_queue_.pop(infer_reqs, kill_reqs, max_infer_num, blocking, abort);
+ }
+
+ void cancel(std::shared_ptr req)
+ {
+ return request_queue_.cancel(std::move(req));
+ }
+
+ void kill(std::shared_ptr req)
+ {
+ return request_queue_.kill(std::move(req));
+ }
+
+ void notify(std::vector signals)
+ {
+ return signal_buffer_.push(std::move(signals));
+ }
+
+private:
+ void signal_thread_entry() noexcept;
+
+private:
+ RequestQueue request_queue_;
+ SignalBuffer signal_buffer_;
+
+ std::function()> ctx_factory_;
+
+ std::thread signal_thread_;
+};
+
+} // namespace turbomind
diff --git a/src/turbomind/engine/model_request.cc b/src/turbomind/engine/model_request.cc
new file mode 100644
index 0000000000..6ba355e896
--- /dev/null
+++ b/src/turbomind/engine/model_request.cc
@@ -0,0 +1,174 @@
+
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "src/turbomind/engine/model_request.h"
+#include "src/turbomind/engine/request.h"
+#include "src/turbomind/utils/Tensor.h"
+#include "src/turbomind/utils/constant.h"
+#include "src/turbomind/utils/cuda_utils.h"
+
+namespace turbomind {
+
+static ManagedTensor create(DataType dtype, MemoryType where, const std::vector& size, int64_t& byte_size)
+{
+ byte_size = std::accumulate(size.begin(), size.end(), Tensor::getTypeSize(dtype), std::multiplies<>{});
+ void* data{};
+
+ if (where == MEMORY_GPU) {
+ check_cuda_error(cudaMallocAsync(&data, byte_size, nullptr));
+ }
+ else {
+ data = std::malloc(byte_size);
+ }
+
+ ManagedTensor ret;
+ ret.tensor = Tensor{where, dtype, std::vector(size.begin(), size.end()), data};
+ ret.data_holder.reset((void*)nullptr, [data, where](auto) {
+ // std::cerr << "turbomind tensor deallocate" << std::endl;
+ if (where == MEMORY_GPU) {
+ /// TODO: guard device id
+ check_cuda_error(cudaFreeAsync(data, nullptr));
+ }
+ else {
+ std::free(data);
+ }
+ });
+ return ret;
+}
+
+template
+static T get(const std::unordered_map& m, const std::string& key, T fallback = {})
+{
+ auto it = m.find(key);
+ if (it != m.end()) {
+ return it->second->getVal();
+ }
+ return fallback;
+}
+
+ModelRequest::ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim):
+ gateway_{gateway},
+ data_type_{data_type},
+ session_len_{session_len},
+ vocab_size_{vocab_size},
+ hidden_dim_{hidden_dim}
+{
+}
+
+void ModelRequest::Cancel()
+{
+ // request is finished if lock failed
+ if (auto r = request_.lock()) {
+ gateway_->cancel(std::move(r));
+ }
+}
+
+void ModelRequest::End(std::function cb, uint64_t session_id)
+{
+ auto r = std::make_shared();
+
+ r->id = r->session.id = session_id;
+ r->session.kill_flag = true;
+
+ r->end_cb = std::move(cb);
+
+ gateway_->kill(std::move(r));
+}
+
+auto ModelRequest::Forward(InputParam param, std::function cb) -> OutputParam
+{
+ inputs_ = std::make_shared();
+ outputs_ = std::make_shared();
+
+ auto add = [](auto& dest, auto key, auto dtype, auto where, auto shape, auto&&... dims) {
+ std::vector shape_;
+ if constexpr (std::is_integral_v) {
+ shape_ = {shape, dims...};
+ }
+ else {
+ shape_ = {shape.cbegin(), shape.cend()};
+ }
+ int64_t byte_size{};
+ auto it = dest->emplace(key, create(dtype, where, shape_, byte_size)).first;
+ return std::make_pair(it->second->data, byte_size);
+ };
+
+ auto& inputs = *param.tensors;
+
+ FT_CHECK(inputs.at("input_ids")->shape.size() == 1);
+
+ const int input_len = inputs.at("input_ids")->shape[0];
+ const int output_len = param.gen_cfg.max_new_tokens;
+
+ // Max possible length of a sequence, this depends on `history_len` which isn't available here, so `session_len`
+ // is used instead
+ const int max_seq_len = session_len_ + 1;
+ const int max_out_len = std::min(output_len, session_len_) + 1;
+ // This does not include histroy length in interactive mode
+ const int max_in_out_len = std::min(input_len + output_len, session_len_) + 1;
+
+ for (auto& [k, v] : *param.tensors) {
+ inputs_->emplace(k, v);
+ }
+
+ add(outputs_, "output_ids", TYPE_INT32, MEMORY_CPU, max_seq_len);
+ add(outputs_, "sequence_length", TYPE_INT32, MEMORY_CPU, 1);
+
+ if (param.gen_cfg.output_logits) {
+ const int len = param.gen_cfg.output_logits == GenerationConfig::kAll ? max_in_out_len : max_out_len;
+ add(outputs_, "logits", TYPE_FP32, MEMORY_CPU, len, vocab_size_);
+ }
+
+ if (param.gen_cfg.output_last_hidden_state) {
+ const int len = param.gen_cfg.output_last_hidden_state == GenerationConfig::kAll ? max_in_out_len : max_out_len;
+ add(outputs_, "last_hidden_state", data_type_, MEMORY_CPU, len, hidden_dim_);
+ }
+
+ if (param.gen_cfg.output_logprobs) {
+ add(outputs_, "logprob_vals", TYPE_FP32, MEMORY_CPU, max_out_len, kMaxLogProb);
+ add(outputs_, "logprob_indexes", TYPE_INT32, MEMORY_CPU, max_out_len, kMaxLogProb);
+ add(outputs_, "logprob_nums", TYPE_INT32, MEMORY_CPU, max_out_len);
+ }
+
+ auto r = std::make_shared();
+
+ for (const auto& [k, v] : *inputs_) {
+ r->inputs.insert(k, *v);
+ }
+ for (const auto& [k, v] : *outputs_) {
+ r->outputs.insert(k, *v);
+ }
+
+ auto state = std::make_shared();
+
+ if (param.session.start_flag) {
+ session_id_ = param.session.id;
+ }
+
+ r->id = param.session.id;
+ r->session = param.session;
+ r->gen_cfg = param.gen_cfg;
+ r->stream_output = param.stream_output;
+ r->forward_cb = std::move(cb);
+ r->state = state;
+
+ r->output_ids = *outputs_->at("output_ids");
+ r->sequence_length = *outputs_->at("sequence_length");
+
+ // Keep a weak reference for canceling the request
+ request_ = r;
+
+ gateway_->push({std::move(r)});
+
+ return OutputParam{outputs_, state};
+}
+
+} // namespace turbomind
diff --git a/src/turbomind/engine/model_request.h b/src/turbomind/engine/model_request.h
new file mode 100644
index 0000000000..aea889e856
--- /dev/null
+++ b/src/turbomind/engine/model_request.h
@@ -0,0 +1,59 @@
+
+
+#pragma once
+
+#include
+
+#include "src/turbomind/engine/gateway.h"
+#include "src/turbomind/utils/Tensor.h"
+
+namespace turbomind {
+
+class ModelRequest {
+public:
+ virtual ~ModelRequest() = default;
+
+ ModelRequest(Gateway* gateway, DataType data_type, int session_len, int vocab_size, int hidden_dim);
+
+ // Cancel running request
+ void Cancel();
+
+ // Reset the channel to uninitailized state, calls `notify` when done
+ void End(std::function cb, uint64_t session_id);
+
+ using TensorMap_ = std::unordered_map;
+
+ struct InputParam {
+ std::shared_ptr tensors;
+
+ SessionParam session;
+ GenerationConfig gen_cfg;
+
+ bool stream_output;
+ };
+
+ struct OutputParam {
+ std::shared_ptr tensors;
+ std::shared_ptr state;
+ };
+
+ OutputParam Forward(InputParam param, std::function cb);
+
+protected:
+ Gateway* const gateway_;
+
+ const DataType data_type_;
+
+ const int session_len_;
+ const int hidden_dim_;
+ const int vocab_size_;
+
+ uint64_t session_id_;
+
+ std::weak_ptr request_;
+
+ std::shared_ptr inputs_; // owned by caller
+ std::shared_ptr outputs_; // owned by `this`
+};
+
+} // namespace turbomind
diff --git a/src/turbomind/engine/request.h b/src/turbomind/engine/request.h
new file mode 100644
index 0000000000..6bf706c9b8
--- /dev/null
+++ b/src/turbomind/engine/request.h
@@ -0,0 +1,148 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include
+#include
+#include
+#include
+#include
+
+#include "src/turbomind/utils/Tensor.h"
+
+namespace turbomind {
+
+struct GenerationConfig {
+ int max_new_tokens = 0;
+ int min_new_tokens = 0;
+
+ int top_k = 1;
+ float top_p = 0.f;
+ float min_p = 0.f;
+ float temperature = 1.f;
+
+ float repetition_penalty = 1.f;
+
+ uint64_t random_seed = 0;
+
+ int output_logprobs = 0;
+
+ enum OutType
+ {
+ kNone = 0,
+ kAll = 1,
+ kGeneration = 2
+ };
+ int output_last_hidden_state = 0;
+ int output_logits = 0;
+};
+
+inline std::ostream& operator<<(std::ostream& os, const GenerationConfig& c)
+{
+ os << "GenerationConfig { ";
+ os << "max_new_tokens=" << c.max_new_tokens;
+ os << ", min_new_tokens=" << c.min_new_tokens;
+ os << ", top_p=" << c.top_p;
+ os << ", top_k=" << c.top_k;
+ os << ", min_p=" << c.min_p;
+ os << ", temperature=" << c.temperature;
+ os << ", repetition_penalty=" << c.repetition_penalty;
+ os << ", random_seed=" << c.random_seed;
+ os << ", output_logprobs=" << c.output_logprobs;
+ os << ", output_hidden_states=" << c.output_last_hidden_state;
+ os << ", output_logits=" << c.output_logits;
+ os << " }";
+ return os;
+}
+
+struct SessionParam {
+ uint64_t id;
+
+ int step;
+
+ bool start_flag;
+ bool end_flag;
+ bool kill_flag;
+};
+
+struct RequestState {
+ int status;
+ int seq_len;
+};
+
+struct AtomicRequestState {
+
+ std::atomic data_;
+
+ static_assert(std::atomic::is_always_lock_free);
+
+ ~AtomicRequestState()
+ {
+ auto data = exchange(nullptr);
+ }
+
+ std::unique_ptr exchange(RequestState* data)
+ {
+ return std::unique_ptr{data_.exchange(data, std::memory_order_acq_rel)};
+ }
+};
+
+struct Request {
+ uint64_t id; // sequence id
+ uint64_t unique_id; // monotonic increasing
+
+ SessionParam session;
+ GenerationConfig gen_cfg;
+
+ bool stream_output;
+
+ // reference to IO tensors
+ TensorMap inputs;
+ TensorMap outputs;
+ // fast path for accessing common output buffers
+ Tensor output_ids;
+ Tensor sequence_length;
+
+ std::function end_cb;
+
+ std::atomic cancel_flag;
+ bool is_canceled{};
+
+ std::function forward_cb;
+
+ std::shared_ptr state;
+
+ int ec; // set when disabling conflicting requests
+
+ enum
+ {
+ kOk = 0,
+ kInvalid = 1, // Sequence not exist or both `start` & `stop` (instead of `end`) is set
+ kConflict = 2, // Concurrent requests to the same sequence
+ kBusy = 3, // Sequence is already running
+ kInactive = 4, // Sequence to `stop` is not active
+ kFail = 5, // Can't find sequence for `stop` request or internal error during inference
+ kTooLong = 6, // history + prompt > session_len,
+ kFinish = 7,
+ kCancel = 8,
+ };
+};
+
+inline void UpdateState(Request& r, int status, int seq_len)
+{
+ try {
+ auto new_state = new RequestState{status, seq_len};
+ auto old_state = r.state->exchange(new_state);
+ if (!old_state && r.forward_cb) {
+ r.forward_cb();
+ }
+ }
+ catch (const std::exception& e) {
+ TM_LOG_ERROR("Error invoking callback for (%lu): %s", r.id, e.what());
+ }
+ catch (...) {
+ TM_LOG_ERROR("Unknown error invoking callback for (%lu)", r.id);
+ }
+}
+
+} // namespace turbomind
diff --git a/src/turbomind/engine/request_queue.cc b/src/turbomind/engine/request_queue.cc
new file mode 100644
index 0000000000..8c0b52b5bf
--- /dev/null
+++ b/src/turbomind/engine/request_queue.cc
@@ -0,0 +1,93 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#include "src/turbomind/engine/request_queue.h"
+#include "src/turbomind/engine/gateway.h"
+
+#include "src/turbomind/engine/request.h"
+
+namespace turbomind {
+
+void RequestQueue::push(std::vector> reqs)
+{
+ {
+ std::lock_guard lock(mutex_);
+ if (closed_) {
+ throw std::runtime_error("Queue is closed");
+ }
+ for (auto& r : reqs) {
+ queue_.push(std::move(r));
+ }
+ }
+ cv_.notify_one();
+}
+
+void RequestQueue::cancel(std::shared_ptr r)
+{
+ // -1 canceled
+ // 0 queued
+ // 1 active
+ if (r->cancel_flag.exchange(-1, std::memory_order_acq_rel) != 0) {
+ // request is picked up by engine
+ return;
+ }
+ else {
+ // not picked by engine yet, skip directly
+ gateway_->notify({[r = std::move(r)] { //
+ UpdateState(*r, Request::kCancel, 0);
+ }});
+ }
+}
+
+void RequestQueue::kill(std::shared_ptr r)
+{
+ {
+ std::lock_guard lock(mutex_);
+ if (closed_) {
+ throw std::runtime_error("Queue is closed");
+ }
+ kill_.push_back(std::move(r));
+ }
+ cv_.notify_one();
+}
+
+void RequestQueue::pop(std::vector>& infer_reqs,
+ std::vector>& kill_reqs,
+ unsigned max_infer_num,
+ bool blocking,
+ bool& abort)
+{
+ std::unique_lock lock(mutex_);
+
+ if (blocking) {
+ cv_.wait(lock, [this] { return !queue_.empty() || !kill_.empty() || closed_; });
+ if (closed_) {
+ abort = true;
+ return;
+ }
+ }
+
+ infer_reqs.clear();
+ while (!queue_.empty() && infer_reqs.size() <= max_infer_num) {
+ auto& r = queue_.front();
+ if (r->cancel_flag.exchange(1, std::memory_order_acq_rel) == 0) {
+ infer_reqs.push_back(std::move(r));
+ }
+ else {
+ // Canceled requests are simply ignored
+ }
+ queue_.pop();
+ }
+
+ kill_reqs = std::move(kill_);
+}
+
+void RequestQueue::close()
+{
+ {
+ std::lock_guard lock(mutex_);
+ closed_ = true;
+ }
+ cv_.notify_all();
+}
+
+} // namespace turbomind
diff --git a/src/turbomind/engine/request_queue.h b/src/turbomind/engine/request_queue.h
new file mode 100644
index 0000000000..c029f38f4b
--- /dev/null
+++ b/src/turbomind/engine/request_queue.h
@@ -0,0 +1,46 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include
+#include
+#include
+
+#include "src/turbomind/engine/request.h"
+
+namespace turbomind {
+
+class Gateway;
+
+class RequestQueue {
+public:
+ RequestQueue(Gateway* gateway): gateway_{gateway} {}
+
+ void push(std::vector> reqs);
+
+ void pop(std::vector>& infer_reqs,
+ std::vector>& kill_reqs,
+ unsigned max_infer_num,
+ bool blocking,
+ bool& abort);
+
+ void cancel(std::shared_ptr r);
+
+ void kill(std::shared_ptr r);
+
+ void close();
+
+private:
+ Gateway* gateway_;
+
+ std::queue> queue_;
+
+ std::vector> kill_;
+
+ std::mutex mutex_;
+ std::condition_variable cv_;
+
+ bool closed_{false};
+};
+
+} // namespace turbomind
diff --git a/src/turbomind/engine/signal_buffer.h b/src/turbomind/engine/signal_buffer.h
new file mode 100644
index 0000000000..cb09be7909
--- /dev/null
+++ b/src/turbomind/engine/signal_buffer.h
@@ -0,0 +1,61 @@
+// Copyright (c) OpenMMLab. All rights reserved.
+
+#pragma once
+
+#include
+#include
+#include
+
+namespace turbomind {
+
+using Signal = std::function;
+
+class SignalBuffer {
+public:
+ void push(std::vector signals)
+ {
+ if (signals.empty()) {
+ return;
+ }
+ {
+ std::lock_guard lock{mutex_};
+ signals_.insert(signals_.end(), std::move_iterator{signals.begin()}, std::move_iterator{signals.end()});
+ }
+ cv_.notify_one();
+ }
+
+ void close()
+ {
+ {
+ std::lock_guard lock{mutex_};
+ aborted_ = true;
+ }
+ cv_.notify_all();
+ }
+
+ std::vector take_all(bool& abort)
+ {
+ std::vector signals;
+ {
+ std::unique_lock lock{mutex_};
+ cv_.wait(lock, [&] { return !signals_.empty() || aborted_; });
+ if (aborted_) {
+ abort = true;
+ }
+ else {
+ signals.swap(signals_);
+ }
+ }
+ return signals;
+ }
+
+private:
+ std::vector signals_;
+
+ std::mutex mutex_;
+ std::condition_variable cv_;
+
+ bool aborted_{false};
+};
+
+} // namespace turbomind
diff --git a/src/turbomind/kernels/gemm/moe_utils_v2.cu b/src/turbomind/kernels/gemm/moe_utils_v2.cu
index a9e4f7da51..44fec67748 100644
--- a/src/turbomind/kernels/gemm/moe_utils_v2.cu
+++ b/src/turbomind/kernels/gemm/moe_utils_v2.cu
@@ -2,6 +2,7 @@
#include
#include
+#include
#include
#include
#include
diff --git a/src/turbomind/kernels/gemm/test/test_utils.cu b/src/turbomind/kernels/gemm/test/test_utils.cu
index 8f2b4007f6..8ee595ab9b 100644
--- a/src/turbomind/kernels/gemm/test/test_utils.cu
+++ b/src/turbomind/kernels/gemm/test/test_utils.cu
@@ -84,7 +84,7 @@ FastCompare(const T* src, const T* ref, int dims, int bsz, cudaStream_t stream,
thrust::cuda::par.on(stream),
zip_iter,
zip_iter + count,
- [=] __device__(auto tup) {
+ [=] __host__ __device__(thrust::tuple tup) -> Tuple {
float s = thrust::get<0>(tup);
float r = thrust::get<1>(tup);
float abs_diff = fabsf(s - r);
diff --git a/src/turbomind/kernels/gpt_kernels.cu b/src/turbomind/kernels/gpt_kernels.cu
index 4f47631fa5..d611cfab43 100644
--- a/src/turbomind/kernels/gpt_kernels.cu
+++ b/src/turbomind/kernels/gpt_kernels.cu
@@ -269,4 +269,61 @@ void invokeTransposeAxis01(
template void invokeTransposeAxis01(
int* out, int* in, const int* in_skipping_dim1, const int dim0, const int dim1, cudaStream_t stream);
+template
+__global__ void transpose_2d_kernel(T* __restrict__ dst, const T* __restrict__ src, int rows, int cols, bool swap_xy)
+{
+ __shared__ T smem[TILE_DIM][TILE_DIM + 1];
+
+ const int block_idx_x = swap_xy ? blockIdx.y : blockIdx.x;
+ const int block_idx_y = swap_xy ? blockIdx.x : blockIdx.y;
+
+ {
+ const int j = block_idx_x * TILE_DIM + threadIdx.x;
+ const int i = block_idx_y * TILE_DIM + threadIdx.y;
+
+#pragma unroll
+ for (int y = 0; y < TILE_DIM; y += BLOCK_ROWS) {
+ if (i + y < rows && j < cols) {
+ smem[threadIdx.y + y][threadIdx.x] = src[(i + y) * cols + j];
+ }
+ }
+ }
+
+ __syncthreads();
+
+ {
+ const int j = block_idx_y * TILE_DIM + threadIdx.x;
+ const int i = block_idx_x * TILE_DIM + threadIdx.y;
+
+#pragma unroll
+ for (int y = 0; y < TILE_DIM; y += BLOCK_ROWS) {
+ if (i + y < cols && j < rows) {
+ dst[(i + y) * rows + j] = smem[threadIdx.x][threadIdx.y + y];
+ }
+ }
+ }
+}
+
+template
+void invokeTranspose2D_(T* dst, const T* src, int rows, int cols, cudaStream_t st)
+{
+ constexpr int TILE_DIM = 32; // warp size
+ constexpr int BLOCK_ROWS = 8;
+
+ const dim3 block(TILE_DIM, BLOCK_ROWS);
+
+ dim3 grid((cols + TILE_DIM - 1) / TILE_DIM, //
+ (rows + TILE_DIM - 1) / TILE_DIM);
+ bool swap_xy = false;
+
+ if (grid.y > 65535) { // max dim for grid.y
+ std::swap(grid.x, grid.y);
+ swap_xy = true;
+ }
+
+ transpose_2d_kernel<<>>(dst, src, rows, cols, swap_xy);
+}
+
+template void invokeTranspose2D_(uint32_t*, const uint32_t*, int, int, cudaStream_t);
+
} // namespace turbomind
diff --git a/src/turbomind/kernels/gpt_kernels.h b/src/turbomind/kernels/gpt_kernels.h
index 4e1dc49be8..a351473332 100644
--- a/src/turbomind/kernels/gpt_kernels.h
+++ b/src/turbomind/kernels/gpt_kernels.h
@@ -238,4 +238,19 @@ void invokeSumLengthDimension(float* out_buf,
const size_t hidden_dim,
cudaStream_t stream = 0);
+template
+void invokeTranspose2D_(T* dst, const T* src, int rows, int cols, cudaStream_t st);
+
+template
+void invokeTranspose2D(T* dst, const T* src, int rows, int cols, cudaStream_t st)
+{
+ if constexpr (sizeof(T) == 4) {
+ // FT_CHECK(0);
+ invokeTranspose2D_((uint32_t*)dst, (const uint32_t*)src, rows, cols, st);
+ }
+ else {
+ FT_CHECK(0);
+ }
+}
+
} // namespace turbomind
diff --git a/src/turbomind/kernels/sampling_penalty_kernels.cu b/src/turbomind/kernels/sampling_penalty_kernels.cu
index 1d4cfe24b0..cf360580b9 100644
--- a/src/turbomind/kernels/sampling_penalty_kernels.cu
+++ b/src/turbomind/kernels/sampling_penalty_kernels.cu
@@ -17,6 +17,8 @@
#include
#include
+#include "src/turbomind/kernels/core/array_ops.h"
+#include "src/turbomind/kernels/core/common.h"
#include "src/turbomind/kernels/sampling_penalty_kernels.h"
namespace turbomind {
@@ -221,6 +223,81 @@ template void invokeBatchApplyTemperaturePenalty(half* logits,
const int vocab_size_padd,
cudaStream_t stream);
#endif
+
+template
+__global__ void batchApplyTemperaturePenalty_v2(float* logits,
+ const float* bias,
+ const float* temperatures,
+ const int batch_size,
+ const int vocab_size,
+ const int vocab_size_padded)
+{
+ const int vi = blockIdx.x * blockDim.x + threadIdx.x;
+ const int bi = blockIdx.y;
+
+ __shared__ float shared_scale;
+
+ if (threadIdx.x == 0) {
+ shared_scale = fdividef(1.f, temperatures[bi] + 1e-6f);
+ }
+
+ __syncthreads();
+
+ const float scale = shared_scale;
+
+ logits += (size_t)bi * vocab_size_padded;
+
+ const int step = gridDim.x * blockDim.x * vec_size;
+
+ for (int i = vi * vec_size; i < vocab_size_padded; i += step) {
+ Array vec;
+ Load(vec, logits + i);
+ PRAGMA_UNROLL
+ for (int c = 0; c < vec_size; ++c) {
+ if (i + c < vocab_size) {
+ vec[c] *= scale;
+ }
+ else {
+ vec[c] = -FLT_MAX;
+ }
+ }
+ Store(logits + i, vec);
+ }
+}
+
+void invokeBatchApplyTemperaturePenalty_v2(float* logits,
+ const float* bias,
+ const float* temperatures,
+ const int batch_size,
+ const int vocab_size,
+ const int vocab_size_padded,
+ cudaStream_t stream)
+{
+
+ auto invoke = [&](auto vec_size) {
+ constexpr int threads = 256;
+ const int blocks_per_tok = (vocab_size_padded + threads * vec_size - 1) / (threads * vec_size);
+ const dim3 blocks(blocks_per_tok, batch_size);
+ batchApplyTemperaturePenalty_v2<<>>( //
+ logits,
+ bias,
+ temperatures,
+ batch_size,
+ vocab_size,
+ vocab_size_padded);
+ };
+
+ if (vocab_size_padded % 4 == 0) {
+ invoke(std::integral_constant{});
+ }
+ else if (vocab_size_padded % 2 == 0) {
+ invoke(std::integral_constant{});
+ }
+ else {
+ invoke(std::integral_constant{});
+ }
+}
+
template
__global__ void applyRepetitionPenalty(T* logits,
const float penalty,
diff --git a/src/turbomind/kernels/sampling_penalty_kernels.h b/src/turbomind/kernels/sampling_penalty_kernels.h
index e12698cdf7..1f26b7d352 100644
--- a/src/turbomind/kernels/sampling_penalty_kernels.h
+++ b/src/turbomind/kernels/sampling_penalty_kernels.h
@@ -69,6 +69,14 @@ void invokeBatchApplyTemperaturePenalty(T* logits,
const int vocab_size_padd,
cudaStream_t stream);
+void invokeBatchApplyTemperaturePenalty_v2(float* logits,
+ const float* bias,
+ const float* temperatures,
+ const int batch_size,
+ const int vocab_size,
+ const int vocab_size_padd,
+ cudaStream_t stream);
+
template
void invokeMinLengthPenalty(T* logits,
const int* min_lengths,
diff --git a/src/turbomind/kernels/sampling_topp_kernels.cu b/src/turbomind/kernels/sampling_topp_kernels.cu
index 04ea0577d1..4d4cff464c 100644
--- a/src/turbomind/kernels/sampling_topp_kernels.cu
+++ b/src/turbomind/kernels/sampling_topp_kernels.cu
@@ -22,6 +22,7 @@
#include "3rdparty/cub/cub.cuh"
#endif
+#include "src/turbomind/kernels/core/math.h"
#include "src/turbomind/kernels/reduce_kernel_utils.cuh"
#include "src/turbomind/kernels/sampling_topp_kernels.h"
#include "src/turbomind/utils/constant.h"
@@ -216,9 +217,9 @@ void invokeTopPSort(TopPSortParams& params, cudaStream_t stream)
size_t topp_id_val_buf_size = sizeof(int) * params.batch_size * params.vocab_size_padded;
size_t begin_offset_buf_size = sizeof(int) * params.batch_size;
size_t end_offset_buf_size = sizeof(int) * params.batch_size;
- topp_id_val_buf_size = div_up(topp_id_val_buf_size, 256) * 256;
- begin_offset_buf_size = div_up(begin_offset_buf_size, 256) * 256;
- end_offset_buf_size = div_up(end_offset_buf_size, 256) * 256;
+ topp_id_val_buf_size = cdiv(topp_id_val_buf_size, 256) * 256;
+ begin_offset_buf_size = cdiv(begin_offset_buf_size, 256) * 256;
+ end_offset_buf_size = cdiv(end_offset_buf_size, 256) * 256;
if (params.workspace == nullptr) {
size_t cub_temp_storage_size;
@@ -236,7 +237,7 @@ void invokeTopPSort(TopPSortParams& params, cudaStream_t stream)
0, // begin_bit
sizeof(T) * 8, // end_bit = sizeof(KeyT) * 8
stream)); // cudaStream_t
- cub_temp_storage_size = div_up(cub_temp_storage_size, 256) * 256;
+ cub_temp_storage_size = cdiv(cub_temp_storage_size, 256) * 256;
params.workspace_size =
topp_id_val_buf_size + begin_offset_buf_size + end_offset_buf_size + cub_temp_storage_size;
return;
diff --git a/src/turbomind/layers/sampling_layers/LogitsProcessorLayer.cc b/src/turbomind/layers/sampling_layers/LogitsProcessorLayer.cc
index b588d8b6f5..c458998031 100644
--- a/src/turbomind/layers/sampling_layers/LogitsProcessorLayer.cc
+++ b/src/turbomind/layers/sampling_layers/LogitsProcessorLayer.cc
@@ -178,7 +178,7 @@ void LogitsProcessorLayer::forward(TensorMap* output_tensors, TensorMap* inpu
// temperature
{
if (!ALL_OF(temperature_.begin(), batch_size, float, 1.f)) {
- invokeBatchApplyTemperaturePenalty(
+ invokeBatchApplyTemperaturePenalty_v2(
logits, (T*)nullptr, temperature_buf_, batch_size, args_.vocab_size, args_.vocab_size_padded, stream_);
sync_check_cuda_error();
}
diff --git a/src/turbomind/models/llama/CMakeLists.txt b/src/turbomind/models/llama/CMakeLists.txt
index 3c714bd234..6c297e3d56 100644
--- a/src/turbomind/models/llama/CMakeLists.txt
+++ b/src/turbomind/models/llama/CMakeLists.txt
@@ -25,6 +25,7 @@ add_library(Llama STATIC
set_property(TARGET Llama PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET Llama PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
target_link_libraries(Llama PUBLIC CUDA::cudart
+ engine
gemm2
rms_norm
cublasMMWrapper
diff --git a/src/turbomind/models/llama/LlamaBatch.cc b/src/turbomind/models/llama/LlamaBatch.cc
index ea321d06a0..e37af1bb76 100644
--- a/src/turbomind/models/llama/LlamaBatch.cc
+++ b/src/turbomind/models/llama/LlamaBatch.cc
@@ -1,19 +1,43 @@
// Copyright (c) OpenMMLab. All rights reserved.
-#include "src/turbomind/models/llama/LlamaBatch.h"
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "src/turbomind/macro.h"
+
+#include "src/turbomind/engine/gateway.h"
+#include "src/turbomind/engine/request.h"
+
#include "src/turbomind/kernels/core/data_type.h"
#include "src/turbomind/kernels/decoding_kernels.h"
#include "src/turbomind/kernels/gemm/tuner/params.h"
#include "src/turbomind/kernels/sampling_topk_kernels.h"
-#include "src/turbomind/macro.h"
+
#include "src/turbomind/models/llama/BlockManager.h"
+#include "src/turbomind/models/llama/LlamaBatch.h"
#include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/LlamaV2.h"
-#include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/models/llama/SequenceManager.h"
#include "src/turbomind/models/llama/copy.h"
#include "src/turbomind/models/llama/llama_kernels.h"
#include "src/turbomind/models/llama/llama_utils.h"
+
#include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/anomaly_handler.h"
#include "src/turbomind/utils/constant.h"
@@ -21,20 +45,6 @@
#include "src/turbomind/utils/debug_utils.h"
#include "src/turbomind/utils/logger.h"
#include "src/turbomind/utils/nccl_utils.h"
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
namespace turbomind {
@@ -84,150 +94,92 @@ void DropEmbeddings(const Sequence& seq)
}
template
-void LlamaBatch::RejectInvalidRequests(Requests& stop_reqs, Requests& infer_reqs)
+void LlamaBatch::DisableConflictRequests(Requests& infer_reqs, Requests& kill_reqs)
{
- std::unordered_map occurrence;
-
- auto count_occurrence = [&occurrence](const Requests& rs) {
- for (const auto& r : rs) {
- ++occurrence[r->id];
- }
- };
-
- auto reject = [](const char* type, std::shared_ptr& req, int ec) {
- TM_LOG_WARNING(
- "[RejectInvalidRequests] Skipping invalid %s request for id %ld, code = %d", type, (long)req->id, ec);
- req->signal.set_value(ec);
- req.reset();
- };
-
- auto handle_conflict_or_invalid = [this, &occurrence, &reject](Requests& rs, const char* type) {
- for (auto& r : rs) {
- if (r) {
- int ec = 0;
+ NvtxScope _("disable conflict");
- const int input_length = r->inputs.getVal("input_lengths", 0);
- const auto get_offset = [&](int token_count) {
- return std::max(0, std::min(token_count, r->inputs.getVal("step", token_count)));
- };
+ std::pmr::monotonic_buffer_resource mbr;
+ std::pmr::unordered_map occur(&mbr);
- if (occurrence[r->id] != 1) {
- ec = Request::kConflict;
- }
- else if (r->start_flag && r->stop_flag) {
- ec = Request::kInvalid;
- }
- else if (input_length > session_len_) {
- ec = Request::kTooLong;
- }
- else if (!r->start_flag) {
- if (auto seq = sequence_manager_->Get(r->id); seq == nullptr) {
- ec = Request::kInvalid;
- }
- else if (get_offset(seq->tokens.size()) + input_length > session_len_) {
- ec = Request::kTooLong;
- }
- }
-
- if (ec) {
- reject(type, r, ec);
- }
- }
+ auto count = [&occur](const auto& reqs) {
+ for (const auto& r : reqs) {
+ ++occur[r->id];
}
};
- auto drop_invalid = [](Requests& rs) {
- int count = 0;
- for (int i = 0; i < rs.size(); ++i) {
- if (rs[i]) {
- rs[count++] = std::move(rs[i]);
+ auto validate = [&occur](auto& reqs, const char* type) {
+ for (const auto& r : reqs) {
+ if (occur[r->id] > 1) {
+ TM_LOG_ERROR("Skip conflicting %s request for ID %lu", type, r->id);
+ r->ec = Request::kConflict;
}
}
- rs.resize(count);
};
- count_occurrence(stop_reqs);
- count_occurrence(infer_reqs);
-
- if (!stop_reqs.empty()) {
- handle_conflict_or_invalid(stop_reqs, "stop");
-
- // invalidate stop-only requests for inactive sequences
- for (auto& r : stop_reqs) {
- if (r && r->end_flag == false) {
- int ec = Request::kInactive;
- for (int i = 0; i < state_->size; ++i) {
- if (state_->requests[i] && state_->requests[i]->id == r->id) {
- ec = 0;
- break;
- }
- }
- if (ec) {
- reject("stop", r, ec);
- }
- }
+ for (int i = 0; i < state_->size; ++i) {
+ if (state_->requests[i]) {
+ ++occur[state_->requests[i]->id];
}
-
- drop_invalid(stop_reqs);
}
- if (!infer_reqs.empty()) {
- handle_conflict_or_invalid(infer_reqs, "infer");
+ count(kill_reqs);
+ count(infer_reqs);
- // invalidate requests for busy sequences
- for (auto& r : infer_reqs) {
- if (r) {
- for (int i = 0; i < state_->size; ++i) {
- if (state_->requests[i] && state_->requests[i]->id == r->id) {
- reject("infer", r, Request::kBusy);
- break;
- }
- }
- }
- }
+ validate(kill_reqs, "kill");
+ validate(infer_reqs, "infer");
+}
- drop_invalid(infer_reqs);
+template
+void LlamaBatch::BroadcastCancelFlags()
+{
+ for (int i = 0; i < state_->size; ++i) {
+ const auto& r = state_->requests[i];
+ if (r && r->cancel_flag.load(std::memory_order_acquire) == -1) {
+ r->is_canceled = true;
+ }
}
}
-template
-auto LlamaBatch::ProcessStopRequests(const Requests& requests) -> std::vector
+template
+void LlamaBatch::ProcessCancelRequests(std::vector& signals)
{
- NvtxScope scope("stop_request");
- std::vector signals;
- int count = 0;
- for (const auto& r : requests) {
- int ec = Request::kFail;
- // find matching active sequence
- for (int i = 0; i < state_->size; ++i) {
- // stop & optionally erase active sequence
- if (state_->requests[i] && state_->requests[i]->id == r->id) {
- ec = 0;
- signals.push_back(Interrupt(i, true, r->end_flag));
- ++count;
- break;
- }
- }
- // mismatch, try erase inactive sequence, in this case there is no active request to interrupt
- if (ec && r->end_flag) {
- if (sequence_manager_->Erase(r->id)) {
- ec = 0;
- }
+ int count = 0;
+ for (int i = 0; i < state_->size; ++i) {
+ const auto& r = state_->requests[i];
+ if (r && r->is_canceled) {
+ ++count;
+ signals.push_back(Interrupt(i, true));
+ // Interrupt should reset r
+ FT_CHECK(!r);
}
- signals.push_back([=] {
- if (rank_ == 0) {
- r->signal.set_value(ec);
- }
- });
}
if (count) {
check_cuda_error(cudaStreamSynchronize(stream_));
}
- return signals;
+}
+
+template
+void LlamaBatch::ProcessKillRequests(const Requests& kill_reqs, std::vector& signals)
+{
+ for (auto& r : kill_reqs) {
+ if (r) {
+ int ec = r->ec;
+ if (!ec) {
+ if (!sequence_manager_->Erase(r->id)) {
+ ec = Request::kInvalid;
+ }
+ }
+ signals.push_back([=] {
+ if (r->end_cb) {
+ r->end_cb(ec);
+ }
+ });
+ }
+ }
}
template
-void LlamaBatch::ProcessInferRequests(const Requests& requests)
+void LlamaBatch::ProcessInferRequests(const Requests& reqs, std::vector& signals)
{
NvtxScope scope("infer_request");
auto& state = *incoming_;
@@ -238,58 +190,90 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests)
std::vector existing_idx;
int idx = 0;
- for (const auto& r : requests) {
- FT_CHECK(!state.requests[idx]);
+ for (const auto& r : reqs) {
if (rank_ == 0) {
TM_LOG_INFO("[ProcessInferRequests] Request for %ld received.", (long)r->id);
}
- state.requests[idx] = r;
+ if (r->ec) {
+ signals.push_back([r] { UpdateState(*r, r->ec, 0); });
+ continue;
+ }
- // get sequence for the request
- state.sequences[idx] = r->start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Get(r->id);
- FT_CHECK(state.sequences[idx]);
+ const int input_length = r->inputs.at("input_ids").shape[0];
- auto& seq = *state.sequences[idx];
+ if (input_length > session_len_) {
+ signals.push_back([r] { UpdateState(*r, Request::kTooLong, 0); });
+ continue;
+ }
- if (int step = r->inputs.getVal("step", -1); step >= 0) {
- if (step <= seq.tokens.size()) {
- seq.tokens.resize(step);
- seq.cache_len = std::min(seq.cache_len, step);
- DropEmbeddings(seq);
+ auto ptr = r->session.start_flag ? sequence_manager_->Create(r->id) : sequence_manager_->Get(r->id);
+ if (!ptr) {
+ signals.push_back([r] { UpdateState(*r, Request::kInvalid, 0); });
+ continue;
+ }
+
+ const int step = [&] {
+ int s = r->session.step;
+ if (s < 0) {
+ s = ptr->tokens.size();
}
- else if (rank_ == 0) {
- TM_LOG_WARNING(
- "[ProcessInferRequests] Skipping invalid step (%d) setting for ID %ld", step, (long)seq.id);
+ else if (s > ptr->tokens.size()) {
+ if (rank_ == 0) {
+ TM_LOG_WARNING("[ProcessInferRequests] Skipping invalid step (%d) setting for ID %lu", s, ptr->id);
+ }
+ s = ptr->tokens.size();
}
+ return s;
+ }();
+
+ if (step + input_length > session_len_) {
+ signals.push_back([r] { UpdateState(*r, Request::kTooLong, 0); });
+ continue;
}
- const int input_length = r->inputs.getVal("input_lengths");
- const int* input_ids = r->inputs.getPtr("input_ids");
+ FT_CHECK(!state.requests[idx]);
+
+ state.requests[idx] = r;
+ state.sequences[idx] = ptr;
+
+ auto& seq = *state.sequences[idx];
+
+ if (step < seq.tokens.size()) {
+ // resize sequence tokens to match step
+ seq.tokens.resize(step);
+ seq.cache_len = std::min(seq.cache_len, step);
+ DropEmbeddings(seq);
+ }
+
+ const int* input_ids = r->inputs.getPtr("input_ids");
{
// `output_ids` contains all token ids of the sequences
const auto output_ids_base = state.output_ids + session_len_ * idx;
- auto output_ids = output_ids_base;
+ auto d_output_ids = output_ids_base;
+ auto h_output_ids = r->output_ids.getPtr();
// copy history tokens
if (!seq.tokens.empty()) {
- output_ids = Copy(seq.tokens.data(), seq.tokens.size(), output_ids);
+ d_output_ids = Copy(seq.tokens.data(), seq.tokens.size(), d_output_ids);
+ h_output_ids = std::copy_n(seq.tokens.data(), seq.tokens.size(), h_output_ids);
}
// copy input tokens
if (input_length) {
- output_ids = Copy(input_ids, input_length, output_ids);
+ d_output_ids = Copy(input_ids, input_length, d_output_ids);
+ h_output_ids = std::copy_n(input_ids, input_length, h_output_ids);
}
// total context length (history + input)
- state.h_prompt_length[idx] = output_ids - output_ids_base;
- state.h_context_length[idx] = output_ids - output_ids_base;
+ state.h_prompt_length[idx] = d_output_ids - output_ids_base;
+ state.h_context_length[idx] = d_output_ids - output_ids_base;
state.h_finished[idx] = false;
}
// copy input tokens to prompt for prefix matching
- if (input_length && r->start_flag && !r->inputs.isExist("input_embedding_ranges")) {
+ if (input_length && r->session.start_flag && !r->inputs.isExist("input_embedding_ranges")) {
// TODO: truncate prompt to enable prefix caching for VLM
seq.prompt.resize(input_length);
std::copy_n(input_ids, input_length, seq.prompt.data());
@@ -348,8 +332,8 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests)
}
}
- const int request_output_len = state.requests[idx]->inputs.getVal("request_output_len");
- state.seq_len_limit[idx] = state.h_context_length[idx] + request_output_len;
+ const int max_new_tokens = state.requests[idx]->gen_cfg.max_new_tokens;
+ state.seq_len_limit[idx] = state.h_context_length[idx] + max_new_tokens;
// `length_criterion` sets finish flag when step >= seq_limit_len, however when step == seq_limit_len
// the actual sequence length is seq_limit_len + 1, hence seq_limit_len must truncated to session_len - 1
if (state.seq_len_limit[idx] >= session_len_) {
@@ -357,17 +341,17 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests)
if (rank_ == 0) {
const int trunc_output_len = state.seq_len_limit[idx] - state.h_context_length[idx];
TM_LOG_WARNING(
- "[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `request_output_len` is truncated to %d",
+ "[ProcessInferRequests] [%ld] total sequence length (%d + %d) exceeds `session_len` (%d), `max_new_tokens` is truncated to %d",
(long)seq.id,
state.h_context_length[idx],
- request_output_len,
+ max_new_tokens,
(int)session_len_,
trunc_output_len);
}
}
// compute rope scaling factor
- if (r->start_flag) {
+ if (r->session.start_flag) {
seq.rope_theta = model_->attn_param_.rotary_embedding_base;
if (model_->attn_param_.use_dynamic_ntk) {
auto scaling_factor = model_->attn_param_.rope_scaling_factor;
@@ -388,9 +372,9 @@ void LlamaBatch::ProcessInferRequests(const Requests& requests)
}
state.h_rope_theta[idx] = seq.rope_theta;
- if (r->start_flag) {
+ if (r->session.start_flag) {
// prepare to initialize random state for new sequence
- h_random_seed_[idx] = r->inputs.getVal("random_seed", 0);
+ h_random_seed_[idx] = r->gen_cfg.random_seed;
}
else {
// Recover device states if not a new sequence
@@ -799,12 +783,6 @@ void LlamaBatch::AllocatePersistantBuffer(size_t max_batch_size, int cache_bl
sampling_params_ = {
{"stop_words_list", (std::byte*)h_stop_words_, (std::byte*)d_stop_words_},
{"bad_words_list", (std::byte*)h_bad_words_, (std::byte*)d_bad_words_},
- {"min_length", (std::byte*)h_min_length_, nullptr},
- {"runtime_top_k", (std::byte*)h_runtime_top_k_, nullptr},
- {"runtime_top_p", (std::byte*)h_runtime_top_p_, nullptr},
- {"runtime_min_p", (std::byte*)h_runtime_min_p_, nullptr},
- {"temperature", (std::byte*)h_temperature_, nullptr},
- {"repetition_penalty", (std::byte*)h_repetition_penalty_, nullptr},
};
for (auto& s : states_) {
@@ -941,19 +919,9 @@ template
LlamaBatch::~LlamaBatch()
{
TM_LOG_DEBUG("~LlamaBatch()");
- shared_state_->request_queue.close();
internal_thread_.join();
- if (output_thread_.joinable()) {
- {
- std::lock_guard lock{output_mutex_};
- output_stop_token_ = true;
- }
- output_cv_.notify_one();
- output_thread_.join();
- }
-
// The dtor maybe called from unknown thread, set device id before CUDA calls
check_cuda_error(cudaSetDevice(device_id_));
check_cuda_error(cudaStreamSynchronize(stream_));
@@ -970,8 +938,10 @@ LlamaBatch::LlamaBatch(const EngineParam& param,
std::unique_ptr> model, // ! This is moved
std::unique_ptr> ctx, // ! This is moved
std::shared_ptr state,
+ std::shared_ptr gateway,
int device_id):
param_(param),
+ gateway_(gateway),
shared_state_(state),
max_batch_size_(param.max_batch_size),
max_forward_token_num_(param.max_prefill_token_num + param.max_batch_size),
@@ -1068,7 +1038,7 @@ void LlamaBatch::InitializeSampling(const GenerationState& g)
sync_check_cuda_error();
Clear(token_ids_buf_, batch_size * session_len_);
- invokeTransposeAxis01(token_ids_buf_, state_->output_ids, batch_size, session_len_, 1, stream_);
+ invokeTranspose2D(token_ids_buf_, state_->output_ids, batch_size, session_len_, stream_);
sync_check_cuda_error();
// token_ids_buf_[s, b]
@@ -1087,6 +1057,27 @@ void LlamaBatch::InitializeSampling(const GenerationState& g)
Copy(h_seq_limit_len_, batch_size, seq_limit_len_);
TensorMap inputs;
+
+ auto member_to_tensor = [&](auto getter, auto key, auto dest, auto init) {
+ int count = 0;
+ for (int i = 0; i < batch_size; ++i) {
+ // `std::invoke`
+ dest[i] = state_->requests[i]->gen_cfg.*getter;
+ count += dest[i] != init;
+ }
+ if (count) {
+ inputs.insert(key, {MEMORY_CPU, getTensorType(), {(size_t)batch_size}, dest});
+ }
+ };
+
+ using G = GenerationConfig;
+ member_to_tensor(&G::top_k, "runtime_top_k", h_runtime_top_k_, 0);
+ member_to_tensor(&G::top_p, "runtime_top_p", h_runtime_top_p_, 0);
+ member_to_tensor(&G::min_p, "runtime_min_p", h_runtime_min_p_, 0);
+ member_to_tensor(&G::temperature, "temperature", h_temperature_, 0.f);
+ member_to_tensor(&G::repetition_penalty, "repetition_penalty", h_repetition_penalty_, 1.f);
+ member_to_tensor(&G::min_new_tokens, "min_length", h_min_length_, 0);
+
for (const auto& [name, h_ptr, d_ptr] : sampling_params_) {
// find an exemplar that matches the param name
const Tensor* ptr{};
@@ -1173,7 +1164,7 @@ void LlamaBatch::InitializeSampling(const GenerationState& g)
TensorMap outputs;
for (int i = 0; i < batch_size; i++) {
- if (state_->requests[i]->inputs.isExist("logprobs")) {
+ if (state_->requests[i]->gen_cfg.output_logprobs) {
outputs.insert(
{"sampled_logprobs", {MEMORY_GPU, TYPE_FP32, {(size_t)batch_size, 1, kMaxLogProb}, sampled_logprobs_}});
outputs.insert(
@@ -1187,89 +1178,157 @@ void LlamaBatch::InitializeSampling(const GenerationState& g)
sync_check_cuda_error();
}
-template
-void LlamaBatch::OutputContextLogits(T* context_decoder_output,
- const std::vector& indices,
- const std::vector& lengths,
- const std::vector& sequences)
+template
+void LlamaBatch::ComputeAndOutputLogits(T* hidden_states, int first, int last)
{
- std::vector output_logits;
- int num_token = 0;
- {
- bool is_return_logits = false;
- for (int k = 0; k < indices.size(); ++k) {
- auto& request = state_->requests[indices[k]];
- auto logits = request->outputs.getPtr("logits", nullptr);
- if (logits && sequences[k]->cache_len + lengths[k] <= sequences[k]->tokens.size()) {
- logits = nullptr;
- }
- output_logits.push_back(logits);
- num_token += lengths[k];
- if (output_logits.back()) {
- is_return_logits = true;
+ int token_num = 0;
+ bool found = false;
+ for (int i = first; i < last; ++i) {
+ if (state_->requests[i]->gen_cfg.output_logits == GenerationConfig::kAll) {
+ const auto& s = *state_->sequences[i];
+ // Skip when the seq is filling missed cache only
+ if (s.cache_len + h_input_length_buf_[i] > s.tokens.size()) {
+ found = true;
}
}
- if (!is_return_logits) {
- return;
- }
+ token_num += h_input_length_buf_[i];
}
- {
- context_logits_buf_ = (float*)allocator_->reMalloc(
- context_logits_buf_, sizeof(float) * model_->vocab_size_padded_ * num_token, false);
- const auto tp = model_->tensor_para_.world_size_;
- if (tp > 1) {
- NcclGuard guard(model_->tensor_para_, stream_, true);
- FT_CHECK(model_->vocab_size_padded_ % tp == 0);
- const auto local_vocab_size = model_->vocab_size_padded_ / tp;
- local_context_logits_buf_ = (float*)peer_allocator_->reMalloc(
- local_context_logits_buf_, sizeof(float) * model_->vocab_size_padded_ * num_token, false);
- }
+ if (!found) {
+ return;
}
- model_->postDecodeEmbedding(context_logits_buf_, local_context_logits_buf_, context_decoder_output, num_token);
+ context_logits_buf_ = (float*)allocator_->reMalloc(
+ context_logits_buf_, sizeof(float) * model_->vocab_size_padded_ * token_num, false);
+ const auto tp = model_->tensor_para_.world_size_;
+
+ if (tp > 1) {
+ NcclGuard guard(model_->tensor_para_, stream_, true);
+ FT_CHECK(model_->vocab_size_padded_ % tp == 0);
+ const auto local_vocab_size = model_->vocab_size_padded_ / tp;
+ local_context_logits_buf_ = (float*)peer_allocator_->reMalloc(
+ local_context_logits_buf_, sizeof(float) * model_->vocab_size_padded_ * token_num, false);
+ }
- auto logits = context_logits_buf_;
+ model_->postDecodeEmbedding(context_logits_buf_, local_context_logits_buf_, hidden_states, token_num);
- // Only rank-0 writes to output
if (rank_ != 0) {
return;
}
- for (int k = 0; k < indices.size(); ++k) {
- if (output_logits[k]) {
- auto src_ptr = logits;
- auto dst_ptr = output_logits[k];
- int num_new_token = 0;
- if (sequences[k]->cache_len < sequences[k]->tokens.size()) {
- num_new_token = sequences[k]->cache_len + lengths[k] - sequences[k]->tokens.size();
- src_ptr += (lengths[k] - num_new_token) * model_->vocab_size_padded_;
- }
- else {
- num_new_token = lengths[k];
- dst_ptr += (sequences[k]->cache_len - sequences[k]->tokens.size()) * model_->vocab_size_;
+ OutputLogits(context_logits_buf_, first, last, GenerationConfig::kAll);
+}
+
+template
+void LlamaBatch::OutputLogits(const float* logits, int first, int last, GenerationConfig::OutType out_type)
+{
+ // when `is_all` is true, logits only contains last token of the sequences
+ const bool is_all = out_type == GenerationConfig::kAll;
+
+ for (int i = first; i < last; ++i) {
+
+ const int input_len = h_input_length_buf_[i]; // input lenght for this iter
+ const float* src_ptr = logits;
+
+ logits += (is_all ? input_len : 1) * model_->vocab_size_padded_;
+
+ if (state_->requests[i]->gen_cfg.output_logits == out_type) {
+
+ auto dst_ptr = state_->requests[i]->outputs.getPtr("logits");
+
+ const int cache_len = state_->sequences[i]->cache_len;
+ const int history_len = state_->sequences[i]->tokens.size();
+
+ // ----------H------I-------P-----------
+ // C C C C
+
+ // offset to the last token prompt
+ const int offset = is_all ? 0 : state_->requests[i]->inputs.at("input_ids").shape[0] - 1;
+
+ int diff = (history_len + offset) - cache_len;
+
+ const int valid_len = input_len - std::max(0, (history_len + offset) - cache_len);
+
+ // TM_LOG_ERROR("%d %d %d %d %d %d %d",
+ // history_len,
+ // offset,
+ // cache_len,
+ // input_len,
+ // valid_len,
+ // std::max(0, diff),
+ // std::max(0, -diff));
+
+ if (valid_len <= 0) {
+ continue;
}
- if (model_->vocab_size_padded_ == model_->vocab_size_) {
- Copy(src_ptr, model_->vocab_size_ * num_new_token, dst_ptr);
+
+ if (is_all) {
+ // Skip invalid tokens caused by cache miss
+ src_ptr += std::max(0, (history_len + offset) - cache_len) * model_->vocab_size_padded_;
}
- else {
- for (int tok = 0; tok < num_new_token; tok++) {
- Copy(src_ptr, model_->vocab_size_, dst_ptr);
- src_ptr += model_->vocab_size_padded_;
- dst_ptr += model_->vocab_size_;
- }
+ // Skip previous chunks
+ dst_ptr += std::max(0, cache_len - (history_len + offset)) * model_->vocab_size_;
+
+ check_cuda_error(cudaMemcpy2DAsync(dst_ptr,
+ sizeof(float) * model_->vocab_size_,
+ src_ptr,
+ sizeof(float) * model_->vocab_size_padded_,
+ sizeof(float) * model_->vocab_size_,
+ valid_len,
+ cudaMemcpyDefault,
+ stream_));
+ }
+ }
+}
+
+template
+void LlamaBatch::OutputLastHiddenState(const T* hidden_states, int first, int last)
+{
+ for (int i = first; i < last; ++i) {
+
+ const int input_len = h_input_length_buf_[i]; // input lenght for this iter
+ const T* src_ptr = hidden_states;
+
+ hidden_states += input_len * model_->hidden_units_;
+
+ if (auto out_type = state_->requests[i]->gen_cfg.output_last_hidden_state) {
+
+ const bool is_all = out_type == GenerationConfig::kAll;
+
+ T* dst_ptr = state_->requests[i]->outputs.getPtr("last_hidden_state");
+
+ const int cache_len = state_->sequences[i]->cache_len;
+ const int history_len = state_->sequences[i]->tokens.size();
+
+ // offset to the last prompt token
+ const int offset = is_all ? 0 : state_->requests[i]->inputs.at("input_ids").shape[0] - 1;
+
+ const int valid_len = input_len - std::max(0, (history_len + offset) - cache_len);
+
+ // TM_LOG_ERROR("%d %d %d %d %d", history_len, offset, cache_len, input_len, valid_len);
+
+ if (valid_len <= 0) {
+ continue;
}
+
+ // Skip invalid tokens caused by cache miss
+ src_ptr += std::max(0, (history_len + offset) - cache_len) * model_->hidden_units_;
+ // Skip previous chunks
+ dst_ptr += std::max(0, cache_len - (history_len + offset)) * model_->hidden_units_;
+
+ Copy(src_ptr, valid_len * model_->hidden_units_, dst_ptr);
}
- logits += model_->vocab_size_padded_ * lengths[k];
}
}
template
-auto LlamaBatch::Finish(GenerationState& g) -> std::vector
+void LlamaBatch::Finish(GenerationState& g, std::vector& signals)
{
NvtxScope scope("Finish");
const int batch_size = state_->active_size;
+ signals.reserve(batch_size);
+
if (batch_size - g.partial) {
FT_CHECK(g.step >= 0);
@@ -1285,13 +1344,22 @@ auto LlamaBatch::Finish(GenerationState& g) -> std::vector
sync_check_cuda_error();
}
- Copy(state_->output_ids, batch_size * session_len_, h_output_ids_);
+ Copy(token_ids_buf_ + (g.step - 1) * (batch_size - g.partial), batch_size - g.partial, h_output_ids_);
Copy(finished_buf_, batch_size, state_->h_finished);
Copy(sequence_lengths_, batch_size, state_->h_context_length);
- Copy(sampled_logprobs_, batch_size * kMaxLogProb, h_sampled_logprobs_);
- Copy(sampled_indexes_, batch_size * kMaxLogProb, h_sampled_indexes_);
- Copy(sampled_nums_, batch_size, h_sampled_nums_);
+ bool output_logprobs = false;
+ for (int i = 0; i < batch_size - g.partial; ++i) {
+ if (state_->requests[i]->gen_cfg.output_logprobs) {
+ output_logprobs = true;
+ break;
+ }
+ }
+ if (output_logprobs) {
+ Copy(sampled_logprobs_, batch_size * kMaxLogProb, h_sampled_logprobs_);
+ Copy(sampled_indexes_, batch_size * kMaxLogProb, h_sampled_indexes_);
+ Copy(sampled_nums_, batch_size, h_sampled_nums_);
+ }
check_cuda_error(cudaStreamSynchronize(stream_));
@@ -1302,13 +1370,14 @@ auto LlamaBatch::Finish(GenerationState& g) -> std::vector
}
// ! Only rank-0 writes to output
- if (rank_ == 0) {
+ if (rank_ == 0 && output_logprobs) {
+ NvtxScope scope("logprobs");
// output logprobs, should be set before sequence_length
float* sampled_logprobs_ptr = h_sampled_logprobs_;
uint32_t* sampled_indexes_ptr = h_sampled_indexes_;
uint32_t* sampled_nums_ptr = h_sampled_nums_;
for (int i = 0; i < batch_size - g.partial; ++i) {
- if (state_->requests[i] && state_->requests[i]->inputs.isExist("logprobs")) {
+ if (state_->requests[i] && state_->requests[i]->gen_cfg.output_logprobs) {
auto logprob_vals = state_->requests[i]->outputs.getPtr("logprob_vals");
auto logprob_indexes = state_->requests[i]->outputs.getPtr("logprob_indexes");
auto logprob_nums = state_->requests[i]->outputs.getPtr("logprob_nums");
@@ -1330,18 +1399,37 @@ auto LlamaBatch::Finish(GenerationState& g) -> std::vector
// ! Only rank-0 writes to output
if (rank_ == 0) {
- // set output tokens ids and sequence length
- int* output_ptr = h_output_ids_;
- for (int i = 0; i < batch_size - g.partial; ++i) {
- if (state_->requests[i] && (state_->requests[i]->stream_cb || state_->h_finished[i])) {
- auto output_ids = state_->requests[i]->outputs.getPtr("output_ids");
- auto output_len = state_->requests[i]->outputs.getPtr("sequence_length");
- const int count = state_->h_context_length[i];
- // TODO: sync history output tokens at when receiving the request and copy the last token here
- std::copy(output_ptr, output_ptr + count, output_ids);
- *output_len = count;
+ NvtxScope scope("output_ids");
+ if constexpr (0) {
+ // set output tokens ids and sequence length
+ int* output_ptr = h_output_ids_;
+ for (int i = 0; i < batch_size - g.partial; ++i) {
+ if (auto& r = state_->requests[i]) {
+ auto output_ids = static_cast(r->output_ids.data);
+ auto output_len = static_cast(r->sequence_length.data);
+ const int count = state_->h_context_length[i];
+ if (r->stream_output) {
+ output_ids[count - 1] = output_ptr[count - 1];
+ *output_len = count;
+ }
+ else if (state_->h_finished[i]) {
+ std::copy(output_ptr, output_ptr + count, output_ids);
+ *output_len = count;
+ }
+ }
+ output_ptr += session_len_;
+ }
+ }
+ else {
+ for (int i = 0; i < batch_size - g.partial; ++i) {
+ if (auto& r = state_->requests[i]) {
+ auto output_ids = static_cast(r->output_ids.data);
+ auto output_len = static_cast(r->sequence_length.data);
+ const int count = state_->h_context_length[i];
+ output_ids[count - 1] = h_output_ids_[i];
+ *output_len = count;
+ }
}
- output_ptr += session_len_;
}
}
@@ -1362,48 +1450,53 @@ auto LlamaBatch::Finish(GenerationState& g) -> std::vector
}
}
- std::vector signals;
{
- NvtxScope _("stream_and_completion_signal");
+ NvtxScope _("count and sync");
+ bool need_sync = false;
for (int i = 0; i < batch_size - g.partial; ++i) {
- if (state_->requests[i]) {
- if (state_->h_finished[i]) {
- // Interrupt finished sequences and move the request handle into the signal closure
- signals.push_back(Interrupt(i));
- ++g.finished_count;
- }
- else if (state_->requests[i]->stream_cb) {
- // Create signals by copying the request handles for non-finished streaming requests
- signals.push_back([this, r = state_->requests[i]] {
- if (rank_ == 0) {
- try {
- r->stream_cb(&r->outputs.get());
- }
- catch (const std::bad_function_call& e) {
- TM_LOG_ERROR("Null stream callback for (%s)", std::to_string(r->id).c_str());
- }
- catch (...) {
- TM_LOG_ERROR("Unknown exception invoking stream callback for (%s)",
- std::to_string(r->id).c_str());
- }
- }
- });
+ if (state_->h_finished[i]) {
+ ++g.finished_count;
+ if (!state_->requests[i]->session.end_flag) {
+ need_sync = true;
}
}
}
- if (g.finished_count) {
- // synchronize for interrupted sequences
- check_cuda_error(cudaStreamSynchronize(stream_));
+ if (need_sync) {
+ // Release updates on request output buffers to all ranks (`Interrupt` will use it)
+ shared_state_->barrier->wait();
}
}
+ {
+ NvtxScope _("stream_and_completion_signal");
+ for (int i = 0; i < batch_size - g.partial; ++i) {
+ auto& r = state_->requests[i];
+ if (state_->h_finished[i]) {
+ // Interrupt finished sequences and move the request handle into the signal closure
+ signals.push_back(Interrupt(i));
+ // Interrupt should reset r
+ FT_CHECK(!r);
+ }
+ else if (r->stream_output && rank_ == 0) {
+ const auto seq_len = r->sequence_length.getVal();
+ // Create signals by copying the request handles for non-finished streaming requests
+ signals.push_back([this, r, seq_len] { //
+ UpdateState(*r, Request::kOk, seq_len);
+ });
+ }
+ }
+ }
+
+ if (g.finished_count) {
+ // synchronize for interrupted sequences
+ check_cuda_error(cudaStreamSynchronize(stream_));
+ }
+
if (g.partial) {
const int i = batch_size - 1;
// recover full context length of partial
state_->h_context_length[i] = g.partial_context_legnth;
}
-
- return signals;
}
template
@@ -1424,7 +1517,7 @@ auto LlamaBatch::Interrupt(int index, bool force_stop, bool force_end) -> Sig
TM_LOG_INFO("[Interrupt] slot %d, tokens [%s]", index, ss.str().c_str());
}
- if (state_->requests[index]->end_flag || force_end) {
+ if (state_->requests[index]->session.end_flag || force_end) {
// Sequence is ending this round or a stop request is issued to end it
FT_CHECK(sequence_manager_->Erase(state_->requests[index]->id));
}
@@ -1434,17 +1527,10 @@ auto LlamaBatch::Interrupt(int index, bool force_stop, bool force_end) -> Sig
// Update token IDs
seq.tokens.resize(output_len);
- const auto output_ids_data = [&] {
- if (force_stop) {
- // `h_output_ids_` is UNDEFINED at `ProcessStopRequests`
- return state_->requests[index]->outputs.at("output_ids").getPtr();
- }
- else {
- // `h_output_ids_` just updated by `Finish`, but `outputs` is NOT synced atm
- return h_output_ids_ + index * (size_t)session_len_;
- }
- }();
- std::copy_n(output_ids_data, output_len, seq.tokens.data());
+
+ // output_ids is updated & synced in `Finish`
+ const auto output_ids = state_->requests[index]->output_ids.getPtr();
+ std::copy_n(output_ids, output_len, seq.tokens.data());
// Save random state in host memory
seq.random_state.resize(sizeof(curandState_t));
@@ -1457,13 +1543,12 @@ auto LlamaBatch::Interrupt(int index, bool force_stop, bool force_end) -> Sig
state_->sequences[index] = nullptr;
- auto ec = std::exchange(state_->errors[index], 0);
+ auto ec = std::exchange(state_->errors[index], Request::kOk);
+ const auto len = state_->requests[index]->sequence_length.getVal();
// move the request handle into the signal
- return [this, ec, r = std::move(state_->requests[index])] {
- if (rank_ == 0) {
- r->signal.set_value(ec);
- }
+ return [this, len, force_stop, r = std::move(state_->requests[index])] { //
+ UpdateState(*r, force_stop ? Request::kCancel : Request::kFinish, len);
};
}
@@ -1476,33 +1561,30 @@ void LlamaBatch::InternalThreadEntry()
// Initialize `AnomalyHandler`
AnomalyHandler::instance().Init(rank_, model_->vocab_size_padded_, model_->end_id_, max_batch_size_, stream_);
- auto& request_queue = shared_state_->request_queue;
- auto& infer_requests = shared_state_->infer_requests;
- auto& stop_requests = shared_state_->stop_requests;
+ // auto& request_queue = shared_state_->request_queue;
+ auto& infer_reqs = shared_state_->infer_reqs;
+ auto& kill_reqs = shared_state_->kill_reqs;
GenerationState g{};
- constexpr int request_interval = 1;
- long request_counter = 0;
-
while (1) {
+
if (rank_ == 0) {
- const int free_slot_count = max_batch_size_ - state_->size + g.finished_count;
- const bool is_empty = (free_slot_count == max_batch_size_);
- stop_requests.clear();
- infer_requests.clear();
- if (is_empty || request_counter % request_interval == 0) {
+ {
+ NvtxScope _("pop");
+ const int free_slot_count = max_batch_size_ - state_->size + g.finished_count;
+ const bool is_empty = (free_slot_count == max_batch_size_);
// Block if batch is empty
- request_queue.dequeue(stop_requests, infer_requests, free_slot_count, is_empty, shared_state_->abort);
- if (!shared_state_->abort) {
- RejectInvalidRequests(stop_requests, infer_requests);
- }
+ gateway_->pop(infer_reqs, kill_reqs, free_slot_count, is_empty, shared_state_->abort);
}
+ // Mark reqs to the same session_id as invalid (which are dangerous to the engine)
+ DisableConflictRequests(infer_reqs, kill_reqs);
}
NvtxScope scope("mainloop");
- // wait while rank-0 is dequeueing
+ // 1. Wait while rank-0 is dequeueing
+ // 2. Broadcast `ec` from rank-0
shared_state_->barrier->wait();
if (shared_state_->abort) {
@@ -1510,90 +1592,58 @@ void LlamaBatch::InternalThreadEntry()
return;
}
- auto signals = ProcessStopRequests(stop_requests);
+ std::vector signals;
+
+ ProcessKillRequests(kill_reqs, signals);
// Shared `priority` field will be assigned by rank-0
- ProcessInferRequests(infer_requests);
+ ProcessInferRequests(infer_reqs, signals);
+
+ // is_canceled <- cancel_flag.load()
+ if (rank_ == 0) {
+ BroadcastCancelFlags();
+ }
- // Wait while shared `requests` is being used
+ // 1. Wait while shared `requests` is being used
+ // 2. Broadcast modifcations from rank-0
shared_state_->barrier->wait();
- SendSignals(std::move(signals));
+ ProcessCancelRequests(signals);
+
+ if (rank_ == 0) {
+ gateway_->notify(std::move(signals));
+ }
Initialize(g);
if (state_->active_size) {
//
- (void)Forward(g);
- //
- if (auto signals = Finish(g); !signals.empty()) {
- if (g.finished_count) {
- // Finished requests and corresponding output tensors will be released when notified
- // wait for all ranks to ensure no rank (except for output thread) will access related
- // resources
- shared_state_->barrier->wait();
- }
- SendSignals(std::move(signals));
+ Forward(g);
+
+ Finish(g, signals);
+
+ if (g.finished_count) {
+ // Finished requests and corresponding output tensors will be released when notified
+ // wait for all ranks to ensure no rank (except for output thread) will access related
+ // resources
+ shared_state_->barrier->wait();
}
- }
- ++request_counter;
+ if (rank_ == 0) {
+ gateway_->notify(std::move(signals));
+ }
+ }
}
+ // Unreachable
FT_CHECK(0);
}
-template
-void LlamaBatch::SendSignals(std::vector signals)
-{
- if (rank_ != 0 || signals.empty()) {
- return;
- }
- {
- std::lock_guard lock{output_mutex_};
- output_signals_.insert(output_signals_.end(), //
- std::move_iterator{signals.begin()},
- std::move_iterator{signals.end()});
- }
- output_cv_.notify_one();
-}
-
template
void LlamaBatch::Start()
{
TM_LOG_INFO("LlamaBatch::Start()");
internal_thread_ = std::thread(&LlamaBatch::InternalThreadEntry, this);
- if (rank_ == 0) {
- output_thread_ = std::thread(&LlamaBatch::OutputThreadEntry, this);
- }
-}
-
-template
-void LlamaBatch::OutputThreadEntry()
-{
- while (true) {
- std::vector signals;
- {
- // Wait for signals to come
- std::unique_lock lock(output_mutex_);
- output_cv_.wait(lock, [&] { return !output_signals_.empty() || output_stop_token_; });
- if (output_stop_token_) {
- TM_LOG_INFO("[OutputThreadEntry] stop requested.");
- return;
- }
- signals = std::move(output_signals_);
- }
- if (rank_ == 0 && ffi_lock_) {
- ffi_lock_(1);
- }
- // invoke stream cbs & signals
- for (const auto& s : signals) {
- s();
- }
- if (rank_ == 0 && ffi_lock_) {
- ffi_lock_(0);
- }
- }
}
template